GradientTape menghitung arti-penting dalam fungsi kerugian
Saya mencoba membangun jaringan LSTM untuk mengklasifikasikan kalimat dan memberikan penjelasan untuk klasifikasi menggunakan saliency . Jaringan ini harus belajar dari kelas yang sebenarnya y_true
serta dari kata-kata mana yang tidak boleh dia perhatikan Z
(topeng biner).
Makalah ini menginspirasi kami untuk menemukan fungsi kerugian kami. Saya ingin fungsi kerugian saya terlihat seperti ini:
Coût de classification
menerjemahkan ke classification_loss
dan Coût d'explication (saillance)
ke saliency_loss
(yang sama dengan gradien keluaran wrt masukan) dalam kode di bawah ini . Saya mencoba menerapkan ini dengan Model kustom di Keras, dengan Tensorflow sebagai backend:
loss_tracker = metrics.Mean(name="loss")
classification_loss_tracker = metrics.Mean(name="classification_loss")
saliency_loss_tracker = metrics.Mean(name="saliency_loss")
accuracy_tracker = metrics.CategoricalAccuracy(name="accuracy")
class CustomSequentialModel(Sequential):
def _train_test_step(self, data, training):
# Unpack the data
X = data[0]["X"]
Z = data[0]["Z"] # binary mask (1 for important words)
y_true = data[1]
# gradient tape requires "float32" instead of "int32"
# X.shape = (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM)
X = tf.cast(X, tf.float32)
# Persitent=True because we call the `gradient` more than once
with GradientTape(persistent=True) as tape:
# The tape will record everything that happens to X
# for automatic differentiation later on (used to compute saliency)
tape.watch(X)
# Forward pass
y_pred = self(X, training=training)
# (1) Compute the classification_loss
classification_loss = K.mean(
categorical_crossentropy(y_true, y_pred)
)
# (2) Compute the saliency loss
# (2.1) Compute the gradient of output wrt the maximum probability
log_prediction_proba = K.log(K.max(y_pred))
# (2.2) Compute the gradient of the output wrt the input
# saliency.shape is (None, MAX_SEQUENCE_LENGTH, None)
# why isn't it (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM) ?!
saliency = tape.gradient(log_prediction_proba, X)
# (2.3) Sum along the embedding dimension
saliency = K.sum(saliency, axis=2)
# (2.4) Sum with the binary mask
saliency_loss = K.sum(K.square(saliency)*(1-Z))
# => ValueError: No gradients provided for any variable
loss = classification_loss + saliency_loss
trainable_vars = self.trainable_variables
# ValueError caused by the '+ saliency_loss'
gradients = tape.gradient(loss, trainable_vars)
del tape # garbage collection
if training:
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics
saliency_loss_tracker.update_state(saliency_loss)
classification_loss_tracker.update_state(classification_loss)
loss_tracker.update_state(loss)
accuracy_tracker.update_state(y_true, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
def train_step(self, data):
return self._train_test_step(data, True)
def test_step(self, data):
return self._train_test_step(data, False)
@property
def metrics(self):
return [
loss_tracker,
classification_loss_tracker,
saliency_loss_tracker,
accuracy_tracker
]
Saya berhasil menghitung classification_loss
serta saliency_loss
dan saya mendapatkan nilai skalar. Namun, ini berhasil: tape.gradient(classification_loss, trainable_vars)
tetapi ini tidaktape.gradient(classification_loss + saliency_loss, trainable_vars)
dan melempar ValueError: No gradients provided for any variable
.
Jawaban
Anda melakukan penghitungan di luar konteks rekaman (setelah gradient
panggilan pertama ) dan kemudian mencoba mengambil lebih banyak gradien setelahnya. Ini tidak berhasil; semua operasi untuk membedakan perlu terjadi di dalam manajer konteks. Saya akan menyarankan untuk merestrukturisasi kode Anda sebagai berikut, menggunakan dua kaset bersarang:
with GradientTape() as loss_tape:
with GradientTape() as saliency_tape:
# The tape will record everything that happens to X
# for automatic differentiation later on (used to compute saliency)
saliency_tape.watch(X)
# Forward pass
y_pred = self(X, training=training)
# (2) Compute the saliency loss
# (2.1) Compute the gradient of output wrt the maximum probability
log_prediction_proba = K.log(K.max(y_pred))
# (2.2) Compute the gradient of the output wrt the input
# saliency.shape is (None, MAX_SEQUENCE_LENGTH, None)
# why isn't it (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM) ?!
saliency = saliency_tape.gradient(log_prediction_proba, X)
# (2.3) Sum along the embedding dimension
saliency = K.sum(saliency, axis=2)
# (2.4) Sum with the binary mask
saliency_loss = K.sum(K.square(saliency)*(1-Z))
# (1) Compute the classification_loss
classification_loss = K.mean(
categorical_crossentropy(y_true, y_pred)
)
loss = classification_loss + saliency_loss
trainable_vars = self.trainable_variables
gradients = loss_tape.gradient(loss, trainable_vars)
Sekarang kita memiliki satu pita yang bertanggung jawab untuk menghitung gradien dengan masukan untuk arti-penting. Kami memiliki pita lain di sekitarnya yang melacak operasi tersebut dan kemudian dapat menghitung gradien dari gradien (yaitu gradien arti-penting). Rekaman ini juga menghitung gradien untuk kerugian klasifikasi. Saya memindahkan kerugian klasifikasi dalam konteks pita luar karena pita bagian dalam tidak membutuhkannya. Perhatikan juga bagaimana bahkan penambahan dua kerugian berada di dalam konteks pita luar - semuanya harus terjadi di sana, jika tidak grafik komputasi hilang / tidak lengkap dan gradien tidak dapat dihitung.
Cobalah untuk menghias train_step()
dengan@tf.function