Cómo AWS Lambda SnapStart elimina los arranques en frío para la inferencia de aprendizaje automático sin servidor

Nov 29 2022
Desafíos con el inicio en frío para la inferencia de ML Uno de los principales desafíos con la inferencia de aprendizaje automático sin servidor siempre fue un inicio en frío. Y en el caso de la inferencia de ML, hay varias cosas que contribuyen a ello: Función SnapStart Con la función SnapStart recientemente anunciada para AWS Lambda, el arranque en frío se reemplaza por SnapStart.

Desafíos con el arranque en frío para la inferencia de ML

Uno de los principales desafíos con la inferencia de aprendizaje automático sin servidor siempre fue un comienzo en frío. Y en el caso de la inferencia de ML, hay varias cosas que contribuyen a ello:

  • inicialización del tiempo de ejecución
  • cargando bibliotecas y dependencias
  • cargando el propio modelo (desde S3 o paquete)
  • inicializando el modelo

Función de inicio rápido

Con la función SnapStart recientemente anunciada para AWS Lambda, el arranque en frío se reemplaza por SnapStart. AWS Lambda creará una instantánea cifrada e inmutable de la memoria y el estado del disco, y la almacenará en caché para su reutilización. Esta instantánea tendrá el modelo ML cargado en la memoria y listo para usar.

Cosas a tener en cuenta:

  • [Tiempo de ejecución de Java] Actualmente, SnapStart solo es compatible con el tiempo de ejecución de Java. Eso agrega limitaciones, pero ONNX funciona en Java y es posible ejecutar ONNX con SnapStart.
  • [Carga del modelo] La carga del modelo debe ocurrir en el paso de inicialización, no en el paso de ejecución, y el modelo debe reutilizarse entre ejecuciones. En Java, es un bloque estático. Lo bueno es que no estamos limitados por el tiempo de espera de la función para cargar el modelo y la cantidad máxima de inicialización es de 15 minutos.
  • [Snap-Resilient] SnapStart tiene limitaciones específicas: singularidad ya que SnapStart usa instantáneas. Significa, por ejemplo, que si se define una semilla aleatoria durante la fase de inicio, todas las invocaciones lambda tendrán el mismo generador. Obtenga más información sobre cómo hacer que Lambda sea resistente aquí .

Un ejemplo con ONNX y SnapStart está disponible públicamente aquí y se puede usar con Sam para implementar el punto final de ONNX Inception V3 y probarlo.

Para resaltar la arquitectura de SnapStart en el caso de ONNX:

  • onnxSession: tiene un modelo precargado y se reutiliza entre invocaciones.
  • getOnnxSession: carga el modelo si no se cargó antes y lo omite si se usó cargado antes.
  • bloque estático: ejecute el código durante la creación de SnapStart. Esta es la parte importante: el código del controlador no se ejecutará durante la creación de la instantánea.
  • package onnxsnapstart;
    
    /**
     * Handler for Onnx predictions on Lambda function.
     */
    public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
    
        // Onnx session with preloaded model which will be reused between invocations and will be
        // initialized as part of snapshot creation
        private static OrtSession onnxSession;
    
        // Returns Onnx session with preloaded model. Reuses existing session if exists.
        private static OrtSession getOnnxSession() {
            String modelPath = "inception_v3.onnx";
            if (onnxSession==null) {
              System.out.println("Start model load");
              try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
                OrtSession.SessionOptions options = new SessionOptions()) {
              try {
                OrtSession session = env.createSession(modelPath, options);
                Map<String, NodeInfo> inputInfoList = session.getInputInfo();
                Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
                System.out.println(inputInfoList);
                System.out.println(outputInfoList);
                onnxSession = session;
                return onnxSession;
              }
              catch(OrtException exc) {
                exc.printStackTrace();
              }
            }
            }
            return onnxSession;
        }
    
        // This code runs during snapshot initialization. In the normal lambda that would run in init phase.
        static {
            System.out.println("Start model init");
            getOnnxSession();
            System.out.println("Finished model init");
        }
    
        // Main handler for the Lambda
        public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
            Map<String, String> headers = new HashMap<>();
            headers.put("Content-Type", "application/json");
            headers.put("X-Custom-Header", "application/json");
    
    
            float[][][][] testData = new float[1][3][299][299];
    
            try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
                OnnxTensor test = OnnxTensor.createTensor(env, testData);
                OrtSession session = getOnnxSession();
                String inputName = session.getInputNames().iterator().next();
                Result output = session.run(Collections.singletonMap(inputName, test));
                System.out.println(output);
            }
            catch(OrtException exc) {
                exc.printStackTrace();
            }
    
    
            APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
            String output = String.format("{ \"message\": \"made prediction\" }");
    
            return response
                    .withStatusCode(200)
                    .withBody(output);
        }
    }
    

  • Lambda tradicional
  • Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
    Start model init
    Start model load
    {x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
    {924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
    Finished model init
    START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
    ai.onnxruntime.OrtSession$Result@e580929
    END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
    REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
    

    RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
    RESTORE_REPORT Restore Duration: 571.67 ms
    START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
    ai.onnxruntime.OrtSession$Result@47f6473
    END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
    REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
    

Todavía tenemos una latencia adicional debido a la restauración de la instantánea, pero ahora nuestra cola se ha reducido significativamente y no tenemos solicitudes que tarden más de 2,5 segundos.

  • Lambda tradicional
  • Percentage of the requests served within a certain time (ms)
      50%    352
      66%    377
      75%    467
      80%    473
      90%    488
      95%   9719
      98%  10329
      99%  10419
     100%  12825
    

    50%    365
      66%    445
      75%    477
      80%    487
      90%    556
      95%   1392
      98%   2233
      99%   2319
     100%   2589 (longest request)