Comment AWS Lambda SnapStart élimine les démarrages à froid pour l'inférence d'apprentissage automatique sans serveur

Nov 29 2022
Défis liés au démarrage à froid pour l'inférence ML L'un des principaux défis de l'inférence d'apprentissage automatique sans serveur a toujours été un démarrage à froid. Et dans le cas de l'inférence ML, plusieurs éléments y contribuent : Fonctionnalité SnapStart Avec la fonctionnalité SnapStart récemment annoncée pour AWS Lambda, le démarrage à froid est remplacé par SnapStart.

Défis liés au démarrage à froid pour l'inférence ML

L'un des principaux défis de l'inférence d'apprentissage automatique sans serveur était toujours un démarrage à froid. Et dans le cas de l'inférence ML, plusieurs éléments y contribuent :

  • initialisation d'exécution
  • chargement des bibliothèques et des dépendances
  • chargement du modèle lui-même (depuis S3 ou package)
  • initialisation du modèle

Fonction SnapStart

Avec la fonctionnalité SnapStart récemment annoncée pour AWS Lambda, le démarrage à froid est remplacé par SnapStart. AWS Lambda créera un instantané immuable et chiffré de l'état de la mémoire et du disque, et le mettra en cache pour réutilisation. Cet instantané aura un modèle ML chargé en mémoire et prêt à être utilisé.

Choses à garder à l'esprit :

  • [Exécution Java] SnapStart n'est actuellement pris en charge que pour l'exécution Java. Cela ajoute des limitations, mais ONNX fonctionne sur Java et il est possible d'exécuter ONNX avec SnapStart.
  • [Chargement du modèle] Le chargement du modèle doit se produire dans l'étape d'initialisation, pas dans l'étape d'exécution, et le modèle doit être réutilisé entre les exécutions. En Java, c'est un bloc statique. La bonne chose est que nous ne sommes pas limités par le délai d'attente de la fonction pour charger le modèle et que la durée maximale d'initialisation est de 15 minutes.
  • [Snap-Resilient] SnapStart a des limitations spécifiques — l'unicité puisque SnapStart utilise un instantané. Cela signifie par exemple que si une graine aléatoire est définie pendant la phase d'initialisation, toutes les invocations lambda auront le même générateur. En savoir plus sur la façon de rendre Lambda résilient ici .

Un exemple avec ONNX et SnapStart est disponible publiquement ici et peut être utilisé avec Sam pour déployer le point de terminaison ONNX Inception V3 et le tester.

Pour mettre en évidence l'architecture du SnapStart en cas d'ONNX :

  • onnxSession — a un modèle préchargé et est réutilisé entre les appels.
  • getOnnxSession - charge le modèle s'il n'a pas été chargé auparavant et l'ignore s'il a été utilisé chargé auparavant.
  • bloc statique — exécute le code lors de la création de SnapStart. C'est la partie importante - le code dans le gestionnaire ne sera pas exécuté lors de la création de l'instantané.
  • package onnxsnapstart;
    
    /**
     * Handler for Onnx predictions on Lambda function.
     */
    public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
    
        // Onnx session with preloaded model which will be reused between invocations and will be
        // initialized as part of snapshot creation
        private static OrtSession onnxSession;
    
        // Returns Onnx session with preloaded model. Reuses existing session if exists.
        private static OrtSession getOnnxSession() {
            String modelPath = "inception_v3.onnx";
            if (onnxSession==null) {
              System.out.println("Start model load");
              try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
                OrtSession.SessionOptions options = new SessionOptions()) {
              try {
                OrtSession session = env.createSession(modelPath, options);
                Map<String, NodeInfo> inputInfoList = session.getInputInfo();
                Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
                System.out.println(inputInfoList);
                System.out.println(outputInfoList);
                onnxSession = session;
                return onnxSession;
              }
              catch(OrtException exc) {
                exc.printStackTrace();
              }
            }
            }
            return onnxSession;
        }
    
        // This code runs during snapshot initialization. In the normal lambda that would run in init phase.
        static {
            System.out.println("Start model init");
            getOnnxSession();
            System.out.println("Finished model init");
        }
    
        // Main handler for the Lambda
        public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
            Map<String, String> headers = new HashMap<>();
            headers.put("Content-Type", "application/json");
            headers.put("X-Custom-Header", "application/json");
    
    
            float[][][][] testData = new float[1][3][299][299];
    
            try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
                OnnxTensor test = OnnxTensor.createTensor(env, testData);
                OrtSession session = getOnnxSession();
                String inputName = session.getInputNames().iterator().next();
                Result output = session.run(Collections.singletonMap(inputName, test));
                System.out.println(output);
            }
            catch(OrtException exc) {
                exc.printStackTrace();
            }
    
    
            APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
            String output = String.format("{ \"message\": \"made prediction\" }");
    
            return response
                    .withStatusCode(200)
                    .withBody(output);
        }
    }
    

  • Lambda traditionnel
  • Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
    Start model init
    Start model load
    {x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
    {924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
    Finished model init
    START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
    ai.onnxruntime.OrtSession$Result@e580929
    END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
    REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
    

    RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
    RESTORE_REPORT Restore Duration: 571.67 ms
    START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
    ai.onnxruntime.OrtSession$Result@47f6473
    END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
    REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
    

Nous avons encore une latence supplémentaire due à la restauration de l'instantané, mais maintenant notre queue est considérablement raccourcie et nous n'avons pas de requêtes qui prendraient plus de 2,5 secondes.

  • Lambda traditionnel
  • Percentage of the requests served within a certain time (ms)
      50%    352
      66%    377
      75%    467
      80%    473
      90%    488
      95%   9719
      98%  10329
      99%  10419
     100%  12825
    

    50%    365
      66%    445
      75%    477
      80%    487
      90%    556
      95%   1392
      98%   2233
      99%   2319
     100%   2589 (longest request)