GradientTape menghitung arti-penting dalam fungsi kerugian

Dec 14 2020

Saya mencoba membangun jaringan LSTM untuk mengklasifikasikan kalimat dan memberikan penjelasan untuk klasifikasi menggunakan saliency . Jaringan ini harus belajar dari kelas yang sebenarnya y_trueserta dari kata-kata mana yang tidak boleh dia perhatikan Z(topeng biner).

Makalah ini menginspirasi kami untuk menemukan fungsi kerugian kami. Saya ingin fungsi kerugian saya terlihat seperti ini:

Coût de classificationmenerjemahkan ke classification_lossdan Coût d'explication (saillance)ke saliency_loss(yang sama dengan gradien keluaran wrt masukan) dalam kode di bawah ini . Saya mencoba menerapkan ini dengan Model kustom di Keras, dengan Tensorflow sebagai backend:

loss_tracker = metrics.Mean(name="loss")
classification_loss_tracker = metrics.Mean(name="classification_loss")
saliency_loss_tracker = metrics.Mean(name="saliency_loss")
accuracy_tracker = metrics.CategoricalAccuracy(name="accuracy")

class CustomSequentialModel(Sequential):
        
    def _train_test_step(self, data, training):
        # Unpack the data
        X = data[0]["X"]
        Z = data[0]["Z"] # binary mask (1 for important words)
        y_true = data[1]
        
        # gradient tape requires "float32" instead of "int32"
        # X.shape = (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM)
        X = tf.cast(X, tf.float32)

        # Persitent=True because we call the `gradient` more than once
        with GradientTape(persistent=True) as tape:
            # The tape will record everything that happens to X
            # for automatic differentiation later on (used to compute saliency)
            tape.watch(X)
            # Forward pass
            y_pred = self(X, training=training) 
            
            # (1) Compute the classification_loss
            classification_loss = K.mean(
                categorical_crossentropy(y_true, y_pred)
            )
 
            # (2) Compute the saliency loss
            # (2.1) Compute the gradient of output wrt the maximum probability
            log_prediction_proba = K.log(K.max(y_pred))
            
        # (2.2) Compute the gradient of the output wrt the input
        # saliency.shape is (None, MAX_SEQUENCE_LENGTH, None)
        # why isn't it (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM) ?!
        saliency = tape.gradient(log_prediction_proba, X)
        # (2.3) Sum along the embedding dimension
        saliency = K.sum(saliency, axis=2)
        # (2.4) Sum with the binary mask
        saliency_loss = K.sum(K.square(saliency)*(1-Z))
        # =>  ValueError: No gradients provided for any variable
        loss = classification_loss + saliency_loss 
        
        trainable_vars = self.trainable_variables
        # ValueError caused by the '+ saliency_loss'
        gradients = tape.gradient(loss, trainable_vars) 
        del tape # garbage collection
        
        if training:
            # Update weights
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Update metrics
        saliency_loss_tracker.update_state(saliency_loss)
        classification_loss_tracker.update_state(classification_loss)
        loss_tracker.update_state(loss)
        accuracy_tracker.update_state(y_true, y_pred)
        
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    
    def train_step(self, data):
        return self._train_test_step(data, True)
    
    def test_step(self, data):
        return self._train_test_step(data, False)
    
    @property
    def metrics(self):
        return [
            loss_tracker,
            classification_loss_tracker,
            saliency_loss_tracker,
            accuracy_tracker
        ]

Saya berhasil menghitung classification_lossserta saliency_lossdan saya mendapatkan nilai skalar. Namun, ini berhasil: tape.gradient(classification_loss, trainable_vars)tetapi ini tidaktape.gradient(classification_loss + saliency_loss, trainable_vars) dan melempar ValueError: No gradients provided for any variable.

Jawaban

1 xdurch0 Dec 14 2020 at 07:21

Anda melakukan penghitungan di luar konteks rekaman (setelah gradientpanggilan pertama ) dan kemudian mencoba mengambil lebih banyak gradien setelahnya. Ini tidak berhasil; semua operasi untuk membedakan perlu terjadi di dalam manajer konteks. Saya akan menyarankan untuk merestrukturisasi kode Anda sebagai berikut, menggunakan dua kaset bersarang:

with GradientTape() as loss_tape:
    with GradientTape() as saliency_tape:
        # The tape will record everything that happens to X
        # for automatic differentiation later on (used to compute saliency)
        saliency_tape.watch(X)
        # Forward pass
        y_pred = self(X, training=training) 
        
        # (2) Compute the saliency loss
        # (2.1) Compute the gradient of output wrt the maximum probability
        log_prediction_proba = K.log(K.max(y_pred))
        
    # (2.2) Compute the gradient of the output wrt the input
    # saliency.shape is (None, MAX_SEQUENCE_LENGTH, None)
    # why isn't it (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM) ?!
    saliency = saliency_tape.gradient(log_prediction_proba, X)
    # (2.3) Sum along the embedding dimension
    saliency = K.sum(saliency, axis=2)
    # (2.4) Sum with the binary mask
    saliency_loss = K.sum(K.square(saliency)*(1-Z))

    # (1) Compute the classification_loss
    classification_loss = K.mean(
        categorical_crossentropy(y_true, y_pred)
    )

    loss = classification_loss + saliency_loss 
    
trainable_vars = self.trainable_variables
gradients = loss_tape.gradient(loss, trainable_vars)

Sekarang kita memiliki satu pita yang bertanggung jawab untuk menghitung gradien dengan masukan untuk arti-penting. Kami memiliki pita lain di sekitarnya yang melacak operasi tersebut dan kemudian dapat menghitung gradien dari gradien (yaitu gradien arti-penting). Rekaman ini juga menghitung gradien untuk kerugian klasifikasi. Saya memindahkan kerugian klasifikasi dalam konteks pita luar karena pita bagian dalam tidak membutuhkannya. Perhatikan juga bagaimana bahkan penambahan dua kerugian berada di dalam konteks pita luar - semuanya harus terjadi di sana, jika tidak grafik komputasi hilang / tidak lengkap dan gradien tidak dapat dihitung.

Andrey Dec 14 2020 at 00:31

Cobalah untuk menghias train_step()dengan@tf.function