Cómo AWS Lambda SnapStart elimina los arranques en frío para la inferencia de aprendizaje automático sin servidor

Desafíos con el arranque en frío para la inferencia de ML
Uno de los principales desafíos con la inferencia de aprendizaje automático sin servidor siempre fue un comienzo en frío. Y en el caso de la inferencia de ML, hay varias cosas que contribuyen a ello:
- inicialización del tiempo de ejecución
- cargando bibliotecas y dependencias
- cargando el propio modelo (desde S3 o paquete)
- inicializando el modelo
Función de inicio rápido
Con la función SnapStart recientemente anunciada para AWS Lambda, el arranque en frío se reemplaza por SnapStart. AWS Lambda creará una instantánea cifrada e inmutable de la memoria y el estado del disco, y la almacenará en caché para su reutilización. Esta instantánea tendrá el modelo ML cargado en la memoria y listo para usar.
Cosas a tener en cuenta:
- [Tiempo de ejecución de Java] Actualmente, SnapStart solo es compatible con el tiempo de ejecución de Java. Eso agrega limitaciones, pero ONNX funciona en Java y es posible ejecutar ONNX con SnapStart.
- [Carga del modelo] La carga del modelo debe ocurrir en el paso de inicialización, no en el paso de ejecución, y el modelo debe reutilizarse entre ejecuciones. En Java, es un bloque estático. Lo bueno es que no estamos limitados por el tiempo de espera de la función para cargar el modelo y la cantidad máxima de inicialización es de 15 minutos.
- [Snap-Resilient] SnapStart tiene limitaciones específicas: singularidad ya que SnapStart usa instantáneas. Significa, por ejemplo, que si se define una semilla aleatoria durante la fase de inicio, todas las invocaciones lambda tendrán el mismo generador. Obtenga más información sobre cómo hacer que Lambda sea resistente aquí .
Un ejemplo con ONNX y SnapStart está disponible públicamente aquí y se puede usar con Sam para implementar el punto final de ONNX Inception V3 y probarlo.
Para resaltar la arquitectura de SnapStart en el caso de ONNX:
- onnxSession: tiene un modelo precargado y se reutiliza entre invocaciones.
- getOnnxSession: carga el modelo si no se cargó antes y lo omite si se usó cargado antes.
- bloque estático: ejecute el código durante la creación de SnapStart. Esta es la parte importante: el código del controlador no se ejecutará durante la creación de la instantánea.
package onnxsnapstart;
/**
* Handler for Onnx predictions on Lambda function.
*/
public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
// Onnx session with preloaded model which will be reused between invocations and will be
// initialized as part of snapshot creation
private static OrtSession onnxSession;
// Returns Onnx session with preloaded model. Reuses existing session if exists.
private static OrtSession getOnnxSession() {
String modelPath = "inception_v3.onnx";
if (onnxSession==null) {
System.out.println("Start model load");
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
OrtSession.SessionOptions options = new SessionOptions()) {
try {
OrtSession session = env.createSession(modelPath, options);
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
System.out.println(inputInfoList);
System.out.println(outputInfoList);
onnxSession = session;
return onnxSession;
}
catch(OrtException exc) {
exc.printStackTrace();
}
}
}
return onnxSession;
}
// This code runs during snapshot initialization. In the normal lambda that would run in init phase.
static {
System.out.println("Start model init");
getOnnxSession();
System.out.println("Finished model init");
}
// Main handler for the Lambda
public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("X-Custom-Header", "application/json");
float[][][][] testData = new float[1][3][299][299];
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
OnnxTensor test = OnnxTensor.createTensor(env, testData);
OrtSession session = getOnnxSession();
String inputName = session.getInputNames().iterator().next();
Result output = session.run(Collections.singletonMap(inputName, test));
System.out.println(output);
}
catch(OrtException exc) {
exc.printStackTrace();
}
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
String output = String.format("{ \"message\": \"made prediction\" }");
return response
.withStatusCode(200)
.withBody(output);
}
}
Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
Start model init
Start model load
{x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
{924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
Finished model init
START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
ai.onnxruntime.OrtSession$Result@e580929
END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
RESTORE_REPORT Restore Duration: 571.67 ms
START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
ai.onnxruntime.OrtSession$Result@47f6473
END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
Todavía tenemos una latencia adicional debido a la restauración de la instantánea, pero ahora nuestra cola se ha reducido significativamente y no tenemos solicitudes que tarden más de 2,5 segundos.
- Lambda tradicional
Percentage of the requests served within a certain time (ms)
50% 352
66% 377
75% 467
80% 473
90% 488
95% 9719
98% 10329
99% 10419
100% 12825
50% 365
66% 445
75% 477
80% 487
90% 556
95% 1392
98% 2233
99% 2319
100% 2589 (longest request)