AWS Lambda SnapStart กำจัด Cold Start สำหรับการอนุมานของ Machine Learning แบบไร้เซิร์ฟเวอร์ได้อย่างไร

Nov 29 2022
ความท้าทายในการเริ่มต้นแบบเย็นสำหรับการอนุมาน ML หนึ่งในความท้าทายหลักกับการอนุมานการเรียนรู้ของเครื่องแบบไร้เซิร์ฟเวอร์คือการเริ่มต้นแบบเย็นเสมอ และในกรณีของการอนุมาน ML มีหลายสิ่งที่สนับสนุน: คุณสมบัติ SnapStart ด้วยคุณสมบัติ SnapStart ที่เพิ่งประกาศใหม่สำหรับ AWS Lambda cold start จะถูกแทนที่ด้วย SnapStart

ความท้าทายในการเริ่มต้นเย็นสำหรับการอนุมาน ML

หนึ่งในความท้าทายหลักของการอนุมานการเรียนรู้ของเครื่องแบบไร้เซิร์ฟเวอร์คือการเริ่มต้นที่เย็นชาเสมอ และในกรณีของการอนุมาน ML มีหลายสิ่งที่สนับสนุน:

  • การเริ่มต้นรันไทม์
  • กำลังโหลดไลบรารีและการขึ้นต่อกัน
  • กำลังโหลดโมเดลเอง (จาก S3 หรือแพ็คเกจ)
  • กำลังเริ่มต้นโมเดล

คุณสมบัติ SnapStart

ด้วย คุณสมบัติ SnapStartที่ประกาศใหม่สำหรับ AWS Lambda Cold Start จะถูกแทนที่ด้วย SnapStart AWS Lambda จะสร้างสแน็ปช็อตของหน่วยความจำและสถานะดิสก์ที่เข้ารหัสซึ่งเปลี่ยนรูปไม่ได้ และจะแคชเพื่อใช้ซ้ำ สแน็ปช็อตนี้จะมีการโหลดโมเดล ML ไว้ในหน่วยความจำและพร้อมใช้งาน

สิ่งที่ควรทราบ:

  • [Java runtime] ขณะนี้ SnapStart รองรับเฉพาะ Java runtime เท่านั้น นั่นเป็นการเพิ่มข้อจำกัด แต่ ONNX ทำงานบน Java และเป็นไปได้ที่จะเรียกใช้ ONNX ด้วย SnapStart
  • [การโหลดแบบจำลอง] การโหลดแบบจำลองจะต้องเกิดขึ้นภายในขั้นตอนการเริ่มต้น ไม่ใช่ขั้นตอนการเรียกใช้ และควรใช้แบบจำลองซ้ำระหว่างการรัน ใน java มันเป็นบล็อกแบบคงที่ สิ่งที่ดีคือเราไม่ถูกจำกัดด้วยฟังก์ชันไทม์เอาต์ในการโหลดโมเดล และจำนวนการเริ่มต้นสูงสุดคือ 15 นาที
  • [Snap-Resilient] SnapStart มีข้อจำกัดเฉพาะ — ความเป็นเอกลักษณ์เนื่องจาก SnapStart ใช้สแน็ปช็อต หมายความว่าหากมีการกำหนดเมล็ดแบบสุ่มในช่วงเริ่มต้น การเรียกใช้แลมบ์ดาทั้งหมดจะมีตัวสร้างเดียวกัน อ่านเพิ่มเติมเกี่ยวกับวิธีการทำให้แลมบ์ดามีความยืดหยุ่นได้ที่นี่

ตัวอย่างของ ONNX และ SnapStart มีให้บริการแบบสาธารณะที่นี่และสามารถใช้กับ Sam เพื่อปรับใช้ตำแหน่งข้อมูล ONNX Inception V3 และทดสอบได้

หากต้องการเน้นสถาปัตยกรรมสำหรับ SnapStart ในกรณีของ ONNX:

  • onnxSession — มีโมเดลที่โหลดไว้ล่วงหน้าและถูกนำมาใช้ซ้ำระหว่างการเรียกใช้
  • getOnnxSession — โหลดโมเดลหากไม่เคยโหลดมาก่อน และข้ามไปหากเคยโหลดมาก่อน
  • บล็อกแบบคงที่ — เรียกใช้โค้ดระหว่างการสร้าง SnapStart นี่เป็นส่วนที่สำคัญ — โค้ดในตัวจัดการจะไม่ถูกเรียกใช้ในระหว่างการสร้างสแน็ปช็อต
  • package onnxsnapstart;
    
    /**
     * Handler for Onnx predictions on Lambda function.
     */
    public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
    
        // Onnx session with preloaded model which will be reused between invocations and will be
        // initialized as part of snapshot creation
        private static OrtSession onnxSession;
    
        // Returns Onnx session with preloaded model. Reuses existing session if exists.
        private static OrtSession getOnnxSession() {
            String modelPath = "inception_v3.onnx";
            if (onnxSession==null) {
              System.out.println("Start model load");
              try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
                OrtSession.SessionOptions options = new SessionOptions()) {
              try {
                OrtSession session = env.createSession(modelPath, options);
                Map<String, NodeInfo> inputInfoList = session.getInputInfo();
                Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
                System.out.println(inputInfoList);
                System.out.println(outputInfoList);
                onnxSession = session;
                return onnxSession;
              }
              catch(OrtException exc) {
                exc.printStackTrace();
              }
            }
            }
            return onnxSession;
        }
    
        // This code runs during snapshot initialization. In the normal lambda that would run in init phase.
        static {
            System.out.println("Start model init");
            getOnnxSession();
            System.out.println("Finished model init");
        }
    
        // Main handler for the Lambda
        public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
            Map<String, String> headers = new HashMap<>();
            headers.put("Content-Type", "application/json");
            headers.put("X-Custom-Header", "application/json");
    
    
            float[][][][] testData = new float[1][3][299][299];
    
            try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
                OnnxTensor test = OnnxTensor.createTensor(env, testData);
                OrtSession session = getOnnxSession();
                String inputName = session.getInputNames().iterator().next();
                Result output = session.run(Collections.singletonMap(inputName, test));
                System.out.println(output);
            }
            catch(OrtException exc) {
                exc.printStackTrace();
            }
    
    
            APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
            String output = String.format("{ \"message\": \"made prediction\" }");
    
            return response
                    .withStatusCode(200)
                    .withBody(output);
        }
    }
    

  • แลมบ์ดาแบบดั้งเดิม
  • Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
    Start model init
    Start model load
    {x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
    {924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
    Finished model init
    START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
    ai.onnxruntime.OrtSession$Result@e580929
    END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
    REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
    

    RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
    RESTORE_REPORT Restore Duration: 571.67 ms
    START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
    ai.onnxruntime.OrtSession$Result@47f6473
    END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
    REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
    

เรายังมีเวลาแฝงเพิ่มเติมเนื่องจากการกู้คืนสแนปชอต แต่ตอนนี้ส่วนท้ายของเราสั้นลงมาก และเราไม่มีคำขอที่จะใช้เวลาเกิน 2.5 วินาที

  • แลมบ์ดาแบบดั้งเดิม
  • Percentage of the requests served within a certain time (ms)
      50%    352
      66%    377
      75%    467
      80%    473
      90%    488
      95%   9719
      98%  10329
      99%  10419
     100%  12825
    

    50%    365
      66%    445
      75%    477
      80%    487
      90%    556
      95%   1392
      98%   2233
      99%   2319
     100%   2589 (longest request)