Como o AWS Lambda SnapStart elimina inicializações a frio para inferência de aprendizado de máquina sem servidor

Desafios com inicialização a frio para inferência de ML
Um dos principais desafios da Serverless Machine Learning Inference sempre foi uma inicialização a frio. E no caso da inferência de ML, há várias coisas que contribuem para isso:
- inicialização do tempo de execução
- carregando bibliotecas e dependências
- carregando o próprio modelo (do S3 ou pacote)
- inicializando o modelo
Recurso SnapStart
Com o recém-anunciado recurso SnapStart para AWS Lambda, o cold start é substituído pelo SnapStart. O AWS Lambda criará um snapshot criptografado e imutável da memória e do estado do disco e o armazenará em cache para reutilização. Este instantâneo terá o modelo de ML carregado na memória e pronto para uso.
Coisas a ter em mente:
- [Java runtime] No momento, o SnapStart é compatível apenas com o Java runtime. Isso adiciona limitações, mas o ONNX funciona em Java e é possível executar o ONNX com o SnapStart.
- [Carga do modelo] A carga do modelo deve ocorrer na etapa de inicialização, não na etapa de execução, e o modelo deve ser reutilizado entre as execuções. Em java, é um bloco estático. O bom é que não estamos limitados pelo timeout da função para carregar o modelo e o tempo máximo de inicialização é de 15 minutos.
- [Snap-Resilient] O SnapStart tem limitações específicas - exclusividade, pois o SnapStart usa instantâneo. Isso significa, por exemplo, que se uma semente aleatória for definida durante a fase de inicialização, todas as invocações de lambda terão o mesmo gerador. Leia mais sobre como tornar o Lambda resiliente aqui .
Um exemplo com ONNX e SnapStart está disponível publicamente aqui e pode ser usado com Sam para implantar o endpoint ONNX Inception V3 e testá-lo.
Para destacar a arquitetura do SnapStart no caso de ONNX:
- onnxSession — tem um modelo pré-carregado e é reutilizado entre invocações.
- getOnnxSession — carrega o modelo se não foi carregado antes e o ignora se foi usado carregado antes.
- bloco estático — execute o código durante a criação do SnapStart. Esta é a parte importante — o código no manipulador não será executado durante a criação do instantâneo.
package onnxsnapstart;
/**
* Handler for Onnx predictions on Lambda function.
*/
public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
// Onnx session with preloaded model which will be reused between invocations and will be
// initialized as part of snapshot creation
private static OrtSession onnxSession;
// Returns Onnx session with preloaded model. Reuses existing session if exists.
private static OrtSession getOnnxSession() {
String modelPath = "inception_v3.onnx";
if (onnxSession==null) {
System.out.println("Start model load");
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
OrtSession.SessionOptions options = new SessionOptions()) {
try {
OrtSession session = env.createSession(modelPath, options);
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
System.out.println(inputInfoList);
System.out.println(outputInfoList);
onnxSession = session;
return onnxSession;
}
catch(OrtException exc) {
exc.printStackTrace();
}
}
}
return onnxSession;
}
// This code runs during snapshot initialization. In the normal lambda that would run in init phase.
static {
System.out.println("Start model init");
getOnnxSession();
System.out.println("Finished model init");
}
// Main handler for the Lambda
public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("X-Custom-Header", "application/json");
float[][][][] testData = new float[1][3][299][299];
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
OnnxTensor test = OnnxTensor.createTensor(env, testData);
OrtSession session = getOnnxSession();
String inputName = session.getInputNames().iterator().next();
Result output = session.run(Collections.singletonMap(inputName, test));
System.out.println(output);
}
catch(OrtException exc) {
exc.printStackTrace();
}
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
String output = String.format("{ \"message\": \"made prediction\" }");
return response
.withStatusCode(200)
.withBody(output);
}
}
Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
Start model init
Start model load
{x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
{924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
Finished model init
START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
ai.onnxruntime.OrtSession$Result@e580929
END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
RESTORE_REPORT Restore Duration: 571.67 ms
START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
ai.onnxruntime.OrtSession$Result@47f6473
END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
Ainda temos latência adicional devido à restauração do instantâneo, mas agora nossa cauda está significativamente reduzida e não temos solicitações que levariam mais de 2,5 segundos.
- Lambda Tradicional
Percentage of the requests served within a certain time (ms)
50% 352
66% 377
75% 467
80% 473
90% 488
95% 9719
98% 10329
99% 10419
100% 12825
50% 365
66% 445
75% 477
80% 487
90% 556
95% 1392
98% 2233
99% 2319
100% 2589 (longest request)