Как AWS Lambda SnapStart устраняет холодный запуск для бессерверного машинного обучения

Nov 29 2022
Проблемы с холодным запуском для вывода машинного обучения Одной из основных проблем с бессерверным выводом машинного обучения всегда был холодный запуск. А в случае вывода машинного обучения этому способствуют несколько факторов: Функция SnapStart Недавно анонсированная функция SnapStart для холодного запуска AWS Lambda заменена на SnapStart.

Проблемы с холодным стартом для вывода ML

Одной из основных проблем с Serverless Machine Learning Inference всегда был холодный старт. И в случае вывода ML этому способствуют несколько вещей:

  • инициализация во время выполнения
  • загрузка библиотек и зависимостей
  • загрузка самой модели (из S3 или пакета)
  • инициализация модели

Функция мгновенного запуска

С недавно анонсированной функцией SnapStart для AWS холодный запуск Lambda заменен на SnapStart. AWS Lambda создаст неизменяемый зашифрованный снимок состояния памяти и диска и кэширует его для повторного использования. В этом снимке модель машинного обучения будет загружена в память и готова к использованию.

Что нужно иметь в виду:

  • [Среда выполнения Java] В настоящее время SnapStart поддерживается только для среды выполнения Java. Это добавляет ограничений, но ONNX работает на Java, и ONNX можно запускать с помощью SnapStart.
  • [Загрузка модели] Загрузка модели должна происходить на этапе инициализации, а не на этапе запуска, и модель должна повторно использоваться между запусками. В Java это статический блок. Хорошо, что мы не ограничены временем ожидания функции для загрузки модели, а максимальное время инициализации составляет 15 минут.
  • [Snap-Resilient] SnapStart имеет определенные ограничения — уникальность, так как SnapStart использует снимок. Это означает, например, что если на этапе инициализации определено случайное начальное число, то все вызовы лямбда-выражений будут иметь один и тот же генератор. Подробнее о том, как сделать Lambda устойчивым , читайте здесь .

Пример с ONNX и SnapStart общедоступен здесь, и его можно использовать с Сэмом для развертывания конечной точки ONNX Inception V3 и ее тестирования.

Чтобы выделить архитектуру SnapStart в случае ONNX:

  • onnxSession — имеет предварительно загруженную модель и повторно используется между вызовами.
  • getOnnxSession — загружает модель, если она не была загружена ранее, и пропускает ее, если она была загружена ранее.
  • статический блок — запускать код во время создания SnapStart. Это важная часть — код в обработчике не будет выполняться во время создания снимка.
  • package onnxsnapstart;
    
    /**
     * Handler for Onnx predictions on Lambda function.
     */
    public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
    
        // Onnx session with preloaded model which will be reused between invocations and will be
        // initialized as part of snapshot creation
        private static OrtSession onnxSession;
    
        // Returns Onnx session with preloaded model. Reuses existing session if exists.
        private static OrtSession getOnnxSession() {
            String modelPath = "inception_v3.onnx";
            if (onnxSession==null) {
              System.out.println("Start model load");
              try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
                OrtSession.SessionOptions options = new SessionOptions()) {
              try {
                OrtSession session = env.createSession(modelPath, options);
                Map<String, NodeInfo> inputInfoList = session.getInputInfo();
                Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
                System.out.println(inputInfoList);
                System.out.println(outputInfoList);
                onnxSession = session;
                return onnxSession;
              }
              catch(OrtException exc) {
                exc.printStackTrace();
              }
            }
            }
            return onnxSession;
        }
    
        // This code runs during snapshot initialization. In the normal lambda that would run in init phase.
        static {
            System.out.println("Start model init");
            getOnnxSession();
            System.out.println("Finished model init");
        }
    
        // Main handler for the Lambda
        public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
            Map<String, String> headers = new HashMap<>();
            headers.put("Content-Type", "application/json");
            headers.put("X-Custom-Header", "application/json");
    
    
            float[][][][] testData = new float[1][3][299][299];
    
            try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
                OnnxTensor test = OnnxTensor.createTensor(env, testData);
                OrtSession session = getOnnxSession();
                String inputName = session.getInputNames().iterator().next();
                Result output = session.run(Collections.singletonMap(inputName, test));
                System.out.println(output);
            }
            catch(OrtException exc) {
                exc.printStackTrace();
            }
    
    
            APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
            String output = String.format("{ \"message\": \"made prediction\" }");
    
            return response
                    .withStatusCode(200)
                    .withBody(output);
        }
    }
    

  • Традиционная лямбда
  • Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
    Start model init
    Start model load
    {x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
    {924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
    Finished model init
    START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
    ai.onnxruntime.OrtSession$Result@e580929
    END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
    REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
    

    RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
    RESTORE_REPORT Restore Duration: 571.67 ms
    START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
    ai.onnxruntime.OrtSession$Result@47f6473
    END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
    REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
    

У нас по-прежнему есть дополнительная задержка из-за восстановления снапшота, но теперь наш хвост значительно укорачивается, и у нас нет запросов, которые занимали бы более 2,5 секунд.

  • Традиционная лямбда
  • Percentage of the requests served within a certain time (ms)
      50%    352
      66%    377
      75%    467
      80%    473
      90%    488
      95%   9719
      98%  10329
      99%  10419
     100%  12825
    

    50%    365
      66%    445
      75%    477
      80%    487
      90%    556
      95%   1392
      98%   2233
      99%   2319
     100%   2589 (longest request)