los gradientes de tensorflow v2 no se muestran en los histogramas de tensorboard

Aug 21 2020

Tengo una red neuronal simple para la que estoy tratando de trazar los gradientes usando tensorboard usando una devolución de llamada como se muestra a continuación:

class GradientCallback(tf.keras.callbacks.Callback):
    console = False
    count = 0
    run_count = 0

    def on_epoch_end(self, epoch, logs=None):
        weights = [w for w in self.model.trainable_weights if 'dense' in w.name and 'bias' in w.name]
        self.run_count += 1
        run_dir = logdir+"/gradients/run-" + str(self.run_count)
        with tf.summary.create_file_writer(run_dir).as_default(),tf.GradientTape() as g:
          # use test data to calculate the gradients
          _x_batch = test_images_scaled_reshaped[:100]
          _y_batch = test_labels_enc[:100]
          g.watch(_x_batch)
          _y_pred = self.model(_x_batch)  # forward-propagation
          per_sample_losses = tf.keras.losses.categorical_crossentropy(_y_batch, _y_pred) 
          average_loss = tf.reduce_mean(per_sample_losses) # Compute the loss value
          gradients = g.gradient(average_loss, self.model.weights) # Compute the gradient

        for t in gradients:
          tf.summary.histogram(str(self.count), data=t)
          self.count+=1
          if self.console:
                print('Tensor: {}'.format(t.name))
                print('{}\n'.format(K.get_value(t)[:10]))

# Set up logging
!rm -rf ./logs/ # clear old logs
from datetime import datetime
import os
root_logdir = "logs"
run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join(root_logdir, run_id)


# register callbacks, this will be used for tensor board latter
callbacks = [
    tf.keras.callbacks.TensorBoard( log_dir=logdir, histogram_freq=1, 
                                   write_images=True, write_grads = True ),
    GradientCallback()
]

Luego, uso las devoluciones de llamada durante el ajuste como:

network.fit(train_pipe, epochs = epochs,batch_size = batch_size, validation_data = val_pipe, callbacks=callbacks)

Ahora, cuando reviso el tensorboard, puedo ver gradientes en el filtro del lado izquierdo, pero no aparece nada en la pestaña Histograma:

¿Que me estoy perdiendo aqui? ¿Estoy registrando los gradientes correctamente?

Respuestas

J.G Feb 09 2021 at 18:16

Parece que el problema es que escribe sus histogramas fuera del contexto del escritor de resumen tf. Cambié tu código en consecuencia. Pero no lo probé.

class GradientCallback(tf.keras.callbacks.Callback):
    console = False
    count = 0
    run_count = 0

    def on_epoch_end(self, epoch, logs=None):
        weights = [w for w in self.model.trainable_weights if 'dense' in w.name and 'bias' in w.name]
        self.run_count += 1
        run_dir = logdir+"/gradients/run-" + str(self.run_count)
        with tf.summary.create_file_writer(run_dir).as_default()
          with tf.GradientTape() as g:
            # use test data to calculate the gradients
            _x_batch = test_images_scaled_reshaped[:100]
            _y_batch = test_labels_enc[:100]
            g.watch(_x_batch)
            _y_pred = self.model(_x_batch)  # forward-propagation
            per_sample_losses = tf.keras.losses.categorical_crossentropy(_y_batch, _y_pred) 
            average_loss = tf.reduce_mean(per_sample_losses) # Compute the loss value
            gradients = g.gradient(average_loss, self.model.weights) # Compute the gradient

          for nr, grad in enumerate(gradients):
            tf.summary.histogram(str(nr), data=grad)
            if self.console:
                  print('Tensor: {}'.format(grad.name))
                  print('{}\n'.format(K.get_value(grad)[:10]))