SNRAdam: การปรับปรุง Adam Optimizer

Apr 25 2023
เครื่องมือเพิ่มประสิทธิภาพเป็นเครื่องมือที่จำเป็นในชุดการสร้างแบบจำลอง หนึ่งในเครื่องมือเพิ่มประสิทธิภาพที่ใช้กันอย่างแพร่หลายคือเครื่องมือเพิ่มประสิทธิภาพ Adam ที่แนะนำโดย Kingma และ Ba [เอกสาร]

เครื่องมือเพิ่มประสิทธิภาพเป็นเครื่องมือที่จำเป็นในชุดการสร้างแบบจำลอง หนึ่งในเครื่องมือเพิ่มประสิทธิภาพที่ใช้กันอย่างแพร่หลายคือเครื่องมือเพิ่มประสิทธิภาพ Adam ที่แนะนำโดย Kingma และ Ba [ paper ] เครื่องมือเพิ่มประสิทธิภาพนี้ติดตามการทำงานของค่าเฉลี่ยของการไล่ระดับสี (หรือที่เรียกว่าโมเมนตัม) และโมเมนตัมที่สอง (หรือที่เรียกว่าพลังงาน) ของการไล่ระดับสีโดยใช้ตัวกรองค่าเฉลี่ยเคลื่อนที่แบบเอกซ์โพเนนเชียล (EMA) และใช้รากที่สองของเทอมพลังงานเพื่อทำให้เทอมโมเมนตัมเป็นปกติ ก่อนที่จะก้าว

การนำ Adam ไปใช้จากเอกสาร PyTorch

ดูเหมือนจะเป็นความคิดที่ดีใช่ไหม? สำหรับหนึ่ง ดูเหมือนการประมาณในแนวทแยงของการเพิ่มประสิทธิภาพอันดับสอง และสอง เมื่อการไล่ระดับสีมีเสียงดัง ตัวส่วน (รากที่สองของเทอมพลังงาน) จะมีขนาดใหญ่เมื่อเทียบกับตัวเศษ (เทอมโมเมนตัม) และขั้นตอนมีขนาดเล็ก ในทางกลับกัน เมื่อการไล่ระดับสีสอดคล้องกัน ตัวส่วนจะเท่ากับตัวเศษโดยประมาณ และเราจะใช้ขนาดขั้นคงที่เท่ากับอัตราการเรียนรู้ γ นั่นเป็นเหตุผลที่เครื่องมือเพิ่มประสิทธิภาพนี้เป็นตัวเลือกโดยพฤตินัยในหมู่นักวิจัยและผู้ปฏิบัติงานด้าน ML

มีการเสนอและศึกษาเครื่องมือเพิ่มประสิทธิภาพ Adam หลายรูปแบบ (ดู AdamW จากเอกสาร PyTorchและเครื่องมือเพิ่มประสิทธิภาพเช่น QHAdam จากแพ็คเกจ torch_optimizer ) ที่นี่ เราเสนอรูปแบบที่ไม่เบี่ยงเบนจาก Adam อย่างมีนัยสำคัญสำหรับการไล่ระดับสี "ล่าสุด" ที่มีเสียงดัง แต่ขยายขนาดขั้นตอนจริงอย่างมากสำหรับพารามิเตอร์ที่มีประวัติการไล่ระดับสี "ล่าสุด" ที่สอดคล้องกัน (ล่าสุดขึ้นอยู่กับพารามิเตอร์ EMA β1 และ β2)

การปรับเปลี่ยนของเรานั้นเรียบง่ายแต่ได้ผล: เราแทนที่ตัวกรอง EMA ของคำศัพท์พลังงานไล่ระดับสีด้วยตัวกรอง EMA ของคำศัพท์ความแปรปรวนของการไล่ระดับสี นี่หมายถึงสมการการอัปเดตขั้นสุดท้าย θ(t) = θ(t-1) - γ * sqrt(SNR) * เครื่องหมาย(โมเมนตัม) โดยที่ SNR หมายถึงอัตราส่วนสัญญาณต่อสัญญาณรบกวนของประวัติการไล่ระดับสี "ล่าสุด" (SNR คือ คำที่ยืมมาจากวรรณคดีการประมวลผลสัญญาณและอ้างถึงอัตราส่วนของค่าเฉลี่ยของพลังงานในสัญญาณที่ไม่มีเสียงรบกวนต่อความแปรปรวนของสัญญาณรบกวนในสัญญาณ) ดังนั้น พารามิเตอร์ที่มี SNR สูง เช่น ที่มีประวัติการไล่ระดับสี "ล่าสุด" ที่สอดคล้องกัน จะเห็นขนาดสเต็ปที่ใหญ่กว่ามาก (ใหญ่เท่ากับอินฟินิตี้หากการไล่ระดับสีคงที่) กว่าพารามิเตอร์ที่มีประวัติการไล่ระดับสี "ล่าสุด" ที่มีเสียงดัง (หรือ SNR การไล่ระดับสีต่ำ) . การใช้งานเครื่องมือเพิ่มประสิทธิภาพนี้ทำได้ง่ายและแสดงไว้ด้านล่างเพื่อความสมบูรณ์

from typing import Tuple

import torch
from torch.optim.optimizer import Optimizer


class SNRAdam(Optimizer):
    r"""Implements the SNRAdam optimization algorithm, which uses std deviation for the denominator rather than
    sqrt(energy) term used in conventional Adam. Why is this a good idea? If gradient stddev for a param is small, we
    should take larger steps as it means the gradient is consistent over time.

    Arguments:
        params: iterable of parameters to optimize or dicts defining
            parameter groups
        lr: learning rate (default: 1e-3)
        betas: coefficients used for computing
            running averages of gradient and its variance (default: (0.9, 0.999))
        eps: term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay: weight decay (L2 penalty) (default: 0)
    """

    def __init__(
            self,
            params,
            lr: float = 1e-3,
            betas: Tuple[float, float] = (0.9, 0.999),
            weight_decay: float = 0.0,
            eps: float = 1e-8,
    ):
        if lr <= 0.0:
            raise ValueError('Invalid learning rate: {}'.format(lr))
        if eps < 0.0:
            raise ValueError('Invalid epsilon value: {}'.format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                'Invalid beta parameter at index 0: {}'.format(betas[0])
            )
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                'Invalid beta parameter at index 1: {}'.format(betas[1])
            )
        if weight_decay < 0:
            raise ValueError(
                'Invalid weight_decay value: {}'.format(weight_decay)
            )

        defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure: A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            beta1, beta2 = group['betas']
            weight_decay = group['weight_decay']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                    continue

                d_p = p.grad.data
                if d_p.is_sparse:
                    raise RuntimeError(
                        'SNRAdam does not support sparse gradients, '
                        'please consider SparseAdam instead'
                    )

                state = self.state[p]

                if weight_decay != 0:
                    p.data.mul_(1 - lr * weight_decay)

                if len(state) == 0:
                    state['iter_'] = 1
                    state['exp_avg'] = torch.zeros_like(
                        p.data, memory_format=torch.preserve_format
                    )
                    state['exp_avg_sq'] = torch.zeros_like(
                        p.data, memory_format=torch.preserve_format
                    )
                iter_ = state['iter_']
                exp_avg = state['exp_avg']
                if iter_ == 1:
                    d_sub_p_sq = d_p - exp_avg
                else:
                    d_sub_p_sq = d_p - exp_avg.mul(1.0 / (1 - beta1 ** (iter_ - 1)))
                d_sub_p_sq.mul_(d_sub_p_sq)

                exp_avg_sq = state['exp_avg_sq']

                exp_avg.mul_(beta1).add_(d_p, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).add_(d_sub_p_sq, alpha=1.0 - beta2)

                p.data.addcdiv_(exp_avg.mul(1.0 / (1 - beta1 ** iter_)),
                                exp_avg_sq.mul(1.0 / (1 - beta2 ** iter_)).sqrt() + eps, value=-lr)
                state['iter_'] += 1

        return loss

SNRAdam สูญเสียการฝึกอย่างรวดเร็วเมื่อเทียบกับ Adam
การสูญเสียการตรวจสอบมาบรรจบกันอย่างรวดเร็วสำหรับ SNRAdam เมื่อเทียบกับ Adam

เพื่อแยกแหล่งที่มาของกำไร เราสร้างขนาดแบตช์ = ∞ และเปรียบเทียบอัลกอริทึมทั้งสอง สิ่งนี้แสดงให้เห็นว่ากำไรมาจาก (i) การแก้ไขสำหรับสุ่มในการไล่ระดับสีแบบ "สุ่ม" เช่น ขนาดแบทช์ที่เล็กกว่าขนาดชุดข้อมูลเต็ม หรือ (ii) สัญญาณรบกวนในการไล่ระดับสีที่มาจากวิถีการปรับให้เหมาะสม (ส่วนการไล่ระดับสีไล่ระดับสี ). เราเห็นว่ากำไรมาจากการชดเชยสำหรับ (ii):

การเปรียบเทียบอัลกอริธึมทั้งสองสำหรับขนาดแบตช์ = ขนาดชุดข้อมูล (รถไฟ)
การเปรียบเทียบอัลกอริธึมทั้งสองสำหรับขนาดแบตช์ = ขนาดชุดข้อมูล (การตรวจสอบความถูกต้อง)