Wie AWS Lambda SnapStart Kaltstarts für Serverless Machine Learning Inference eliminiert
Herausforderungen mit Kaltstart für ML-Inferenz
Eine der größten Herausforderungen bei Serverless Machine Learning Inference war immer ein Kaltstart. Und im Fall der ML-Inferenz tragen mehrere Dinge dazu bei:
- Laufzeitinitialisierung
- Laden von Bibliotheken und Abhängigkeiten
- Laden des Modells selbst (aus S3 oder Paket)
- Initialisieren des Modells
SnapStart-Funktion
Mit der neu angekündigten SnapStart-Funktion für AWS Lambda wird der Kaltstart durch SnapStart ersetzt. AWS Lambda erstellt einen unveränderlichen, verschlüsselten Snapshot des Arbeitsspeichers und des Festplattenstatus und speichert ihn zur Wiederverwendung im Cache. Bei diesem Snapshot ist das ML-Modell in den Arbeitsspeicher geladen und einsatzbereit.
Dinge zu beachten:
- [Java-Laufzeit] SnapStart wird derzeit nur für die Java-Laufzeit unterstützt. Das fügt Einschränkungen hinzu, aber ONNX funktioniert auf Java und es ist möglich, ONNX mit SnapStart auszuführen.
- [Modell laden] Das Laden des Modells muss innerhalb des Initialisierungsschritts erfolgen, nicht während des Ausführungsschritts, und das Modell sollte zwischen den Ausführungen wiederverwendet werden. In Java ist es ein statischer Block. Das Gute ist, dass wir beim Laden des Modells nicht durch das Funktions-Timeout eingeschränkt sind und die maximale Initialisierungsdauer 15 Minuten beträgt.
- [Snap-Resilient] SnapStart hat bestimmte Einschränkungen – Einzigartigkeit, da SnapStart Snapshots verwendet. Dies bedeutet beispielsweise, dass, wenn während der Init-Phase ein zufälliger Seed definiert wird, alle Lambda-Aufrufe denselben Generator haben. Lesen Sie hier mehr darüber, wie Sie Lambda widerstandsfähig machen können .
Ein Beispiel mit ONNX und SnapStart ist hier öffentlich verfügbar und kann mit Sam verwendet werden, um den ONNX Inception V3-Endpunkt bereitzustellen und zu testen.
Um die Architektur für SnapStart im Fall von ONNX hervorzuheben:
- onnxSession – hat ein vorgeladenes Modell und wird zwischen Aufrufen wiederverwendet.
- getOnnxSession — lädt das Modell, wenn es vorher nicht geladen wurde, und überspringt es, wenn es vorher geladen wurde.
- statischer Block – führt den Code während der SnapStart-Erstellung aus. Dies ist der wichtige Teil – der Code im Handler wird während der Erstellung des Snapshots nicht ausgeführt.
package onnxsnapstart;
/**
* Handler for Onnx predictions on Lambda function.
*/
public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
// Onnx session with preloaded model which will be reused between invocations and will be
// initialized as part of snapshot creation
private static OrtSession onnxSession;
// Returns Onnx session with preloaded model. Reuses existing session if exists.
private static OrtSession getOnnxSession() {
String modelPath = "inception_v3.onnx";
if (onnxSession==null) {
System.out.println("Start model load");
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
OrtSession.SessionOptions options = new SessionOptions()) {
try {
OrtSession session = env.createSession(modelPath, options);
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
System.out.println(inputInfoList);
System.out.println(outputInfoList);
onnxSession = session;
return onnxSession;
}
catch(OrtException exc) {
exc.printStackTrace();
}
}
}
return onnxSession;
}
// This code runs during snapshot initialization. In the normal lambda that would run in init phase.
static {
System.out.println("Start model init");
getOnnxSession();
System.out.println("Finished model init");
}
// Main handler for the Lambda
public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("X-Custom-Header", "application/json");
float[][][][] testData = new float[1][3][299][299];
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
OnnxTensor test = OnnxTensor.createTensor(env, testData);
OrtSession session = getOnnxSession();
String inputName = session.getInputNames().iterator().next();
Result output = session.run(Collections.singletonMap(inputName, test));
System.out.println(output);
}
catch(OrtException exc) {
exc.printStackTrace();
}
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
String output = String.format("{ \"message\": \"made prediction\" }");
return response
.withStatusCode(200)
.withBody(output);
}
}
Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
Start model init
Start model load
{x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
{924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
Finished model init
START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
ai.onnxruntime.OrtSession$Result@e580929
END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
RESTORE_REPORT Restore Duration: 571.67 ms
START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
ai.onnxruntime.OrtSession$Result@47f6473
END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
Wir haben immer noch zusätzliche Latenz aufgrund der Wiederherstellung des Snapshots, aber jetzt ist unser Tail erheblich kurzgeschlossen und wir haben keine Anfragen, die länger als 2,5 Sekunden dauern würden.
- Traditionelles Lambda
Percentage of the requests served within a certain time (ms)
50% 352
66% 377
75% 467
80% 473
90% 488
95% 9719
98% 10329
99% 10419
100% 12825
50% 365
66% 445
75% 477
80% 487
90% 556
95% 1392
98% 2233
99% 2319
100% 2589 (longest request)

![Was ist überhaupt eine verknüpfte Liste? [Teil 1]](https://post.nghiatu.com/assets/images/m/max/724/1*Xokk6XOjWyIGCBujkJsCzQ.jpeg)



































