AWS Lambda SnapStart กำจัด Cold Start สำหรับการอนุมานของ Machine Learning แบบไร้เซิร์ฟเวอร์ได้อย่างไร
ความท้าทายในการเริ่มต้นเย็นสำหรับการอนุมาน ML
หนึ่งในความท้าทายหลักของการอนุมานการเรียนรู้ของเครื่องแบบไร้เซิร์ฟเวอร์คือการเริ่มต้นที่เย็นชาเสมอ และในกรณีของการอนุมาน ML มีหลายสิ่งที่สนับสนุน:
- การเริ่มต้นรันไทม์
- กำลังโหลดไลบรารีและการขึ้นต่อกัน
- กำลังโหลดโมเดลเอง (จาก S3 หรือแพ็คเกจ)
- กำลังเริ่มต้นโมเดล
คุณสมบัติ SnapStart
ด้วย คุณสมบัติ SnapStartที่ประกาศใหม่สำหรับ AWS Lambda Cold Start จะถูกแทนที่ด้วย SnapStart AWS Lambda จะสร้างสแน็ปช็อตของหน่วยความจำและสถานะดิสก์ที่เข้ารหัสซึ่งเปลี่ยนรูปไม่ได้ และจะแคชเพื่อใช้ซ้ำ สแน็ปช็อตนี้จะมีการโหลดโมเดล ML ไว้ในหน่วยความจำและพร้อมใช้งาน
สิ่งที่ควรทราบ:
- [Java runtime] ขณะนี้ SnapStart รองรับเฉพาะ Java runtime เท่านั้น นั่นเป็นการเพิ่มข้อจำกัด แต่ ONNX ทำงานบน Java และเป็นไปได้ที่จะเรียกใช้ ONNX ด้วย SnapStart
- [การโหลดแบบจำลอง] การโหลดแบบจำลองจะต้องเกิดขึ้นภายในขั้นตอนการเริ่มต้น ไม่ใช่ขั้นตอนการเรียกใช้ และควรใช้แบบจำลองซ้ำระหว่างการรัน ใน java มันเป็นบล็อกแบบคงที่ สิ่งที่ดีคือเราไม่ถูกจำกัดด้วยฟังก์ชันไทม์เอาต์ในการโหลดโมเดล และจำนวนการเริ่มต้นสูงสุดคือ 15 นาที
- [Snap-Resilient] SnapStart มีข้อจำกัดเฉพาะ — ความเป็นเอกลักษณ์เนื่องจาก SnapStart ใช้สแน็ปช็อต หมายความว่าหากมีการกำหนดเมล็ดแบบสุ่มในช่วงเริ่มต้น การเรียกใช้แลมบ์ดาทั้งหมดจะมีตัวสร้างเดียวกัน อ่านเพิ่มเติมเกี่ยวกับวิธีการทำให้แลมบ์ดามีความยืดหยุ่นได้ที่นี่
ตัวอย่างของ ONNX และ SnapStart มีให้บริการแบบสาธารณะที่นี่และสามารถใช้กับ Sam เพื่อปรับใช้ตำแหน่งข้อมูล ONNX Inception V3 และทดสอบได้
หากต้องการเน้นสถาปัตยกรรมสำหรับ SnapStart ในกรณีของ ONNX:
- onnxSession — มีโมเดลที่โหลดไว้ล่วงหน้าและถูกนำมาใช้ซ้ำระหว่างการเรียกใช้
- getOnnxSession — โหลดโมเดลหากไม่เคยโหลดมาก่อน และข้ามไปหากเคยโหลดมาก่อน
- บล็อกแบบคงที่ — เรียกใช้โค้ดระหว่างการสร้าง SnapStart นี่เป็นส่วนที่สำคัญ — โค้ดในตัวจัดการจะไม่ถูกเรียกใช้ในระหว่างการสร้างสแน็ปช็อต
package onnxsnapstart;
/**
* Handler for Onnx predictions on Lambda function.
*/
public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
// Onnx session with preloaded model which will be reused between invocations and will be
// initialized as part of snapshot creation
private static OrtSession onnxSession;
// Returns Onnx session with preloaded model. Reuses existing session if exists.
private static OrtSession getOnnxSession() {
String modelPath = "inception_v3.onnx";
if (onnxSession==null) {
System.out.println("Start model load");
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
OrtSession.SessionOptions options = new SessionOptions()) {
try {
OrtSession session = env.createSession(modelPath, options);
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
System.out.println(inputInfoList);
System.out.println(outputInfoList);
onnxSession = session;
return onnxSession;
}
catch(OrtException exc) {
exc.printStackTrace();
}
}
}
return onnxSession;
}
// This code runs during snapshot initialization. In the normal lambda that would run in init phase.
static {
System.out.println("Start model init");
getOnnxSession();
System.out.println("Finished model init");
}
// Main handler for the Lambda
public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("X-Custom-Header", "application/json");
float[][][][] testData = new float[1][3][299][299];
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
OnnxTensor test = OnnxTensor.createTensor(env, testData);
OrtSession session = getOnnxSession();
String inputName = session.getInputNames().iterator().next();
Result output = session.run(Collections.singletonMap(inputName, test));
System.out.println(output);
}
catch(OrtException exc) {
exc.printStackTrace();
}
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
String output = String.format("{ \"message\": \"made prediction\" }");
return response
.withStatusCode(200)
.withBody(output);
}
}
Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
Start model init
Start model load
{x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
{924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
Finished model init
START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
ai.onnxruntime.OrtSession$Result@e580929
END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
RESTORE_REPORT Restore Duration: 571.67 ms
START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
ai.onnxruntime.OrtSession$Result@47f6473
END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
เรายังมีเวลาแฝงเพิ่มเติมเนื่องจากการกู้คืนสแนปชอต แต่ตอนนี้ส่วนท้ายของเราสั้นลงมาก และเราไม่มีคำขอที่จะใช้เวลาเกิน 2.5 วินาที
- แลมบ์ดาแบบดั้งเดิม
Percentage of the requests served within a certain time (ms)
50% 352
66% 377
75% 467
80% 473
90% 488
95% 9719
98% 10329
99% 10419
100% 12825
50% 365
66% 445
75% 477
80% 487
90% 556
95% 1392
98% 2233
99% 2319
100% 2589 (longest request)