Bagaimana AWS Lambda SnapStart menghilangkan cold start untuk Inferensi Pembelajaran Mesin Tanpa Server

Nov 29 2022
Tantangan dengan awal yang dingin untuk inferensi ML Salah satu tantangan utama dengan Inferensi Pembelajaran Mesin Tanpa Server selalu merupakan awal yang dingin. Dan dalam kasus inferensi ML, ada beberapa hal yang berkontribusi terhadapnya: Fitur SnapStart Dengan fitur SnapStart yang baru diumumkan untuk cold start AWS Lambda diganti dengan SnapStart.

Tantangan dengan awal yang dingin untuk inferensi ML

Salah satu tantangan utama dengan Inferensi Pembelajaran Mesin Tanpa Server selalu merupakan awal yang dingin. Dan dalam kasus inferensi ML, ada beberapa hal yang berkontribusi terhadapnya:

  • inisialisasi runtime
  • memuat pustaka dan dependensi
  • memuat model itu sendiri (dari S3 atau paket)
  • menginisialisasi model

Fitur SnapStart

Dengan fitur SnapStart yang baru diumumkan untuk cold start AWS Lambda diganti dengan SnapStart. AWS Lambda akan membuat snapshot memori dan status disk yang tidak dapat diubah dan dienkripsi, dan akan meng-cache-nya untuk digunakan kembali. Snapshot ini akan membuat model ML dimuat di memori dan siap digunakan.

Hal-hal yang perlu diingat:

  • [Java runtime] SnapStart saat ini hanya didukung untuk Java runtime. Itu menambah batasan, tetapi ONNX berfungsi di Java dan dimungkinkan untuk menjalankan ONNX dengan SnapStart.
  • [Pemuatan model] Pemuatan model harus terjadi dalam langkah inisialisasi, bukan langkah menjalankan dan model harus digunakan kembali di antara proses. Di java, ini adalah blok statis. Hal baiknya adalah kita tidak dibatasi oleh batas waktu fungsi untuk memuat model dan jumlah maksimum inisialisasi adalah 15 menit.
  • [Snap-Resilient] SnapStart memiliki batasan khusus — keunikan karena SnapStart menggunakan snapshot. Ini berarti misalnya jika seed acak ditentukan selama fase init maka semua pemanggilan lambda akan memiliki generator yang sama. Baca selengkapnya tentang cara membuat Lambda tangguh di sini .

Contoh dengan ONNX dan SnapStart tersedia untuk umum di sini dan dapat digunakan dengan Sam untuk menerapkan titik akhir ONNX Inception V3 dan mengujinya.

Untuk menyorot arsitektur SnapStart dalam kasus ONNX:

  • onnxSession — memiliki model pramuat dan digunakan kembali di antara pemanggilan.
  • getOnnxSession — muat model jika tidak dimuat sebelumnya dan lewati jika digunakan dimuat sebelumnya.
  • blok statis — jalankan kode selama pembuatan SnapStart. Ini adalah bagian yang penting — kode dalam handler tidak akan dijalankan selama pembuatan snapshot.
  • package onnxsnapstart;
    
    /**
     * Handler for Onnx predictions on Lambda function.
     */
    public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {
    
        // Onnx session with preloaded model which will be reused between invocations and will be
        // initialized as part of snapshot creation
        private static OrtSession onnxSession;
    
        // Returns Onnx session with preloaded model. Reuses existing session if exists.
        private static OrtSession getOnnxSession() {
            String modelPath = "inception_v3.onnx";
            if (onnxSession==null) {
              System.out.println("Start model load");
              try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
                OrtSession.SessionOptions options = new SessionOptions()) {
              try {
                OrtSession session = env.createSession(modelPath, options);
                Map<String, NodeInfo> inputInfoList = session.getInputInfo();
                Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
                System.out.println(inputInfoList);
                System.out.println(outputInfoList);
                onnxSession = session;
                return onnxSession;
              }
              catch(OrtException exc) {
                exc.printStackTrace();
              }
            }
            }
            return onnxSession;
        }
    
        // This code runs during snapshot initialization. In the normal lambda that would run in init phase.
        static {
            System.out.println("Start model init");
            getOnnxSession();
            System.out.println("Finished model init");
        }
    
        // Main handler for the Lambda
        public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
            Map<String, String> headers = new HashMap<>();
            headers.put("Content-Type", "application/json");
            headers.put("X-Custom-Header", "application/json");
    
    
            float[][][][] testData = new float[1][3][299][299];
    
            try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
                OnnxTensor test = OnnxTensor.createTensor(env, testData);
                OrtSession session = getOnnxSession();
                String inputName = session.getInputNames().iterator().next();
                Result output = session.run(Collections.singletonMap(inputName, test));
                System.out.println(output);
            }
            catch(OrtException exc) {
                exc.printStackTrace();
            }
    
    
            APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
            String output = String.format("{ \"message\": \"made prediction\" }");
    
            return response
                    .withStatusCode(200)
                    .withBody(output);
        }
    }
    

  • Lambda tradisional
  • Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
    Start model init
    Start model load
    {x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
    {924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
    Finished model init
    START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
    ai.onnxruntime.OrtSession$Result@e580929
    END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
    REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
    

    RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
    RESTORE_REPORT Restore Duration: 571.67 ms
    START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
    ai.onnxruntime.OrtSession$Result@47f6473
    END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
    REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms
    

Kami masih memiliki latensi tambahan karena pemulihan snapshot, tetapi sekarang ekor kami dipersingkat secara signifikan dan kami tidak memiliki permintaan yang membutuhkan waktu lebih dari 2,5 detik.

  • Lambda tradisional
  • Percentage of the requests served within a certain time (ms)
      50%    352
      66%    377
      75%    467
      80%    473
      90%    488
      95%   9719
      98%  10329
      99%  10419
     100%  12825
    

    50%    365
      66%    445
      75%    477
      80%    487
      90%    556
      95%   1392
      98%   2233
      99%   2319
     100%   2589 (longest request)