Jak AWS Lambda SnapStart eliminuje zimne uruchamianie dla wnioskowania o uczeniu maszynowym bez serwera

Wyzwania z zimnym startem dla wnioskowania ML
Jednym z głównych wyzwań związanych z bezserwerowym wnioskowaniem o uczeniu maszynowym był zawsze zimny start. A w przypadku wnioskowania ML składa się na to wiele rzeczy:
- inicjalizacja środowiska wykonawczego
- ładowanie bibliotek i zależności
- ładowanie samego modelu (z S3 lub paczki)
- inicjowanie modelu
Funkcja SnapStart
Dzięki nowo ogłoszonej funkcji SnapStart dla AWS Lambda zimny start został zastąpiony SnapStart. AWS Lambda utworzy niezmienną, zaszyfrowaną migawkę stanu pamięci i dysku oraz zapisze ją w pamięci podręcznej do ponownego użycia. Ta migawka będzie miała model ML załadowany do pamięci i gotowy do użycia.
O czym należy pamiętać:
- [Środowisko uruchomieniowe Java] SnapStart jest obecnie obsługiwane tylko w środowisku wykonawczym Java. To dodaje ograniczenia, ale ONNX działa na Javie i możliwe jest uruchomienie ONNX z SnapStart.
- [Ładowanie modelu] Ładowanie modelu musi nastąpić w ramach kroku inicjalizacji, a nie kroku uruchamiania, a model powinien być ponownie używany między uruchomieniami. W Javie jest to blok statyczny. Dobrą rzeczą jest to, że nie jesteśmy ograniczeni limitem czasu funkcji na załadowanie modelu, a maksymalny czas inicjalizacji wynosi 15 minut.
- [Snap-Resilient] SnapStart ma określone ograniczenia — wyjątkowość, ponieważ SnapStart używa migawki. Oznacza to na przykład, że jeśli losowe ziarno zostanie zdefiniowane podczas fazy init, to wszystkie wywołania lambda będą miały ten sam generator. Przeczytaj więcej o tym, jak uczynić Lambdę odporną tutaj .
Przykład z ONNX i SnapStart jest publicznie dostępny tutaj i może być użyty z Samem do wdrożenia punktu końcowego ONNX Inception V3 i przetestowania go.
Aby podkreślić architekturę SnapStart w przypadku ONNX:
- onnxSession — ma wstępnie załadowany model i jest ponownie używany między wywołaniami.
- getOnnxSession — ładuje model, jeśli nie był wcześniej ładowany i pomija go, jeśli był wcześniej ładowany.
- static block — uruchom kod podczas tworzenia SnapStart. To jest ważna część — kod w module obsługi nie zostanie uruchomiony podczas tworzenia migawki.
package onnxsnapstart;
/**
* Handler for Onnx predictions on Lambda function.
*/
public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
// Onnx session with preloaded model which will be reused between invocations and will be
// initialized as part of snapshot creation
private static OrtSession onnxSession;
// Returns Onnx session with preloaded model. Reuses existing session if exists.
private static OrtSession getOnnxSession() {
String modelPath = "inception_v3.onnx";
if (onnxSession==null) {
System.out.println("Start model load");
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
OrtSession.SessionOptions options = new SessionOptions()) {
try {
OrtSession session = env.createSession(modelPath, options);
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
System.out.println(inputInfoList);
System.out.println(outputInfoList);
onnxSession = session;
return onnxSession;
}
catch(OrtException exc) {
exc.printStackTrace();
}
}
}
return onnxSession;
}
// This code runs during snapshot initialization. In the normal lambda that would run in init phase.
static {
System.out.println("Start model init");
getOnnxSession();
System.out.println("Finished model init");
}
// Main handler for the Lambda
public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("X-Custom-Header", "application/json");
float[][][][] testData = new float[1][3][299][299];
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
OnnxTensor test = OnnxTensor.createTensor(env, testData);
OrtSession session = getOnnxSession();
String inputName = session.getInputNames().iterator().next();
Result output = session.run(Collections.singletonMap(inputName, test));
System.out.println(output);
}
catch(OrtException exc) {
exc.printStackTrace();
}
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
String output = String.format("{ \"message\": \"made prediction\" }");
return response
.withStatusCode(200)
.withBody(output);
}
}
Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
Start model init
Start model load
{x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
{924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
Finished model init
START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
ai.onnxruntime.OrtSession$Result@e580929
END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
RESTORE_REPORT Restore Duration: 571.67 ms
START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
ai.onnxruntime.OrtSession$Result@47f6473
END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
Nadal mamy dodatkowe opóźnienie z powodu przywracania migawki, ale teraz nasz ogon jest znacznie skrócony i nie mamy żądań, które zajęłyby więcej niż 2,5 sekundy.
- Tradycyjna Lambda
Percentage of the requests served within a certain time (ms)
50% 352
66% 377
75% 467
80% 473
90% 488
95% 9719
98% 10329
99% 10419
100% 12825
50% 365
66% 445
75% 477
80% 487
90% 556
95% 1392
98% 2233
99% 2319
100% 2589 (longest request)