Wie AWS Lambda SnapStart Kaltstarts für Serverless Machine Learning Inference eliminiert

Nov 29 2022
Herausforderungen beim Kaltstart für ML-Inferenz Eine der größten Herausforderungen bei Serverless Machine Learning Inference war immer ein Kaltstart. Und im Fall der ML-Inferenz tragen mehrere Dinge dazu bei: SnapStart-Funktion Mit der neu angekündigten SnapStart-Funktion für AWS Lambda wird der Kaltstart durch SnapStart ersetzt.

Herausforderungen mit Kaltstart für ML-Inferenz

Eine der größten Herausforderungen bei Serverless Machine Learning Inference war immer ein Kaltstart. Und im Fall der ML-Inferenz tragen mehrere Dinge dazu bei:

  • Laufzeitinitialisierung
  • Laden von Bibliotheken und Abhängigkeiten
  • Laden des Modells selbst (aus S3 oder Paket)
  • Initialisieren des Modells

SnapStart-Funktion

Mit der neu angekündigten SnapStart-Funktion für AWS Lambda wird der Kaltstart durch SnapStart ersetzt. AWS Lambda erstellt einen unveränderlichen, verschlüsselten Snapshot des Arbeitsspeichers und des Festplattenstatus und speichert ihn zur Wiederverwendung im Cache. Bei diesem Snapshot ist das ML-Modell in den Arbeitsspeicher geladen und einsatzbereit.

Dinge zu beachten:

  • [Java-Laufzeit] SnapStart wird derzeit nur für die Java-Laufzeit unterstützt. Das fügt Einschränkungen hinzu, aber ONNX funktioniert auf Java und es ist möglich, ONNX mit SnapStart auszuführen.
  • [Modell laden] Das Laden des Modells muss innerhalb des Initialisierungsschritts erfolgen, nicht während des Ausführungsschritts, und das Modell sollte zwischen den Ausführungen wiederverwendet werden. In Java ist es ein statischer Block. Das Gute ist, dass wir beim Laden des Modells nicht durch das Funktions-Timeout eingeschränkt sind und die maximale Initialisierungsdauer 15 Minuten beträgt.
  • [Snap-Resilient] SnapStart hat bestimmte Einschränkungen – Einzigartigkeit, da SnapStart Snapshots verwendet. Dies bedeutet beispielsweise, dass, wenn während der Init-Phase ein zufälliger Seed definiert wird, alle Lambda-Aufrufe denselben Generator haben. Lesen Sie hier mehr darüber, wie Sie Lambda widerstandsfähig machen können .

Ein Beispiel mit ONNX und SnapStart ist hier öffentlich verfügbar und kann mit Sam verwendet werden, um den ONNX Inception V3-Endpunkt bereitzustellen und zu testen.

Um die Architektur für SnapStart im Fall von ONNX hervorzuheben:

  • onnxSession – hat ein vorgeladenes Modell und wird zwischen Aufrufen wiederverwendet.
  • getOnnxSession — lädt das Modell, wenn es vorher nicht geladen wurde, und überspringt es, wenn es vorher geladen wurde.
  • statischer Block – führt den Code während der SnapStart-Erstellung aus. Dies ist der wichtige Teil – der Code im Handler wird während der Erstellung des Snapshots nicht ausgeführt.
  • package onnxsnapstart;
    
    /**
     * Handler for Onnx predictions on Lambda function.
     */
    public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
    
        // Onnx session with preloaded model which will be reused between invocations and will be
        // initialized as part of snapshot creation
        private static OrtSession onnxSession;
    
        // Returns Onnx session with preloaded model. Reuses existing session if exists.
        private static OrtSession getOnnxSession() {
            String modelPath = "inception_v3.onnx";
            if (onnxSession==null) {
              System.out.println("Start model load");
              try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
                OrtSession.SessionOptions options = new SessionOptions()) {
              try {
                OrtSession session = env.createSession(modelPath, options);
                Map<String, NodeInfo> inputInfoList = session.getInputInfo();
                Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
                System.out.println(inputInfoList);
                System.out.println(outputInfoList);
                onnxSession = session;
                return onnxSession;
              }
              catch(OrtException exc) {
                exc.printStackTrace();
              }
            }
            }
            return onnxSession;
        }
    
        // This code runs during snapshot initialization. In the normal lambda that would run in init phase.
        static {
            System.out.println("Start model init");
            getOnnxSession();
            System.out.println("Finished model init");
        }
    
        // Main handler for the Lambda
        public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
            Map<String, String> headers = new HashMap<>();
            headers.put("Content-Type", "application/json");
            headers.put("X-Custom-Header", "application/json");
    
    
            float[][][][] testData = new float[1][3][299][299];
    
            try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
                OnnxTensor test = OnnxTensor.createTensor(env, testData);
                OrtSession session = getOnnxSession();
                String inputName = session.getInputNames().iterator().next();
                Result output = session.run(Collections.singletonMap(inputName, test));
                System.out.println(output);
            }
            catch(OrtException exc) {
                exc.printStackTrace();
            }
    
    
            APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
            String output = String.format("{ \"message\": \"made prediction\" }");
    
            return response
                    .withStatusCode(200)
                    .withBody(output);
        }
    }
    

  • Traditionelles Lambda
  • Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
    Start model init
    Start model load
    {x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
    {924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
    Finished model init
    START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
    ai.onnxruntime.OrtSession$Result@e580929
    END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
    REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
    

    RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
    RESTORE_REPORT Restore Duration: 571.67 ms
    START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
    ai.onnxruntime.OrtSession$Result@47f6473
    END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
    REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
    

Wir haben immer noch zusätzliche Latenz aufgrund der Wiederherstellung des Snapshots, aber jetzt ist unser Tail erheblich kurzgeschlossen und wir haben keine Anfragen, die länger als 2,5 Sekunden dauern würden.

  • Traditionelles Lambda
  • Percentage of the requests served within a certain time (ms)
      50%    352
      66%    377
      75%    467
      80%    473
      90%    488
      95%   9719
      98%  10329
      99%  10419
     100%  12825
    

    50%    365
      66%    445
      75%    477
      80%    487
      90%    556
      95%   1392
      98%   2233
      99%   2319
     100%   2589 (longest request)