ग्रैडिएंटटेप लॉस फ़ंक्शन में सामर्थ्य की गणना करता है

Dec 14 2020

मैं वाक्यों को वर्गीकृत करने के लिए एक LSTM नेटवर्क बनाने की कोशिश कर रहा हूं और सामर्थ्य का उपयोग करके वर्गीकरण के लिए स्पष्टीकरण प्रदान कर रहा हूं । इस नेटवर्क को सही वर्ग के y_trueसाथ-साथ उन शब्दों से भी सीखना चाहिए जिन पर उसे ध्यान नहीं देना चाहिए Z(बाइनरी मास्क)।

इस पत्र ने हमें अपने नुकसान के कार्य के साथ आने के लिए प्रेरित किया। यहाँ मैं अपने नुकसान की तरह दिखना चाहता हूँ:

Coût de classificationकरने के लिए अनुवाद classification_lossऔर Coût d'explication (saillance)करने के लिए saliency_lossनीचे दिए गए कोड में (जो इनपुट wrt उत्पादन की ढाल के रूप में एक ही है) । मैंने इसे कर्स में एक कस्टम मॉडल के साथ लागू करने की कोशिश की, बैकेंड के रूप में टेन्सरफ्लो के साथ:

loss_tracker = metrics.Mean(name="loss")
classification_loss_tracker = metrics.Mean(name="classification_loss")
saliency_loss_tracker = metrics.Mean(name="saliency_loss")
accuracy_tracker = metrics.CategoricalAccuracy(name="accuracy")

class CustomSequentialModel(Sequential):
        
    def _train_test_step(self, data, training):
        # Unpack the data
        X = data[0]["X"]
        Z = data[0]["Z"] # binary mask (1 for important words)
        y_true = data[1]
        
        # gradient tape requires "float32" instead of "int32"
        # X.shape = (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM)
        X = tf.cast(X, tf.float32)

        # Persitent=True because we call the `gradient` more than once
        with GradientTape(persistent=True) as tape:
            # The tape will record everything that happens to X
            # for automatic differentiation later on (used to compute saliency)
            tape.watch(X)
            # Forward pass
            y_pred = self(X, training=training) 
            
            # (1) Compute the classification_loss
            classification_loss = K.mean(
                categorical_crossentropy(y_true, y_pred)
            )
 
            # (2) Compute the saliency loss
            # (2.1) Compute the gradient of output wrt the maximum probability
            log_prediction_proba = K.log(K.max(y_pred))
            
        # (2.2) Compute the gradient of the output wrt the input
        # saliency.shape is (None, MAX_SEQUENCE_LENGTH, None)
        # why isn't it (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM) ?!
        saliency = tape.gradient(log_prediction_proba, X)
        # (2.3) Sum along the embedding dimension
        saliency = K.sum(saliency, axis=2)
        # (2.4) Sum with the binary mask
        saliency_loss = K.sum(K.square(saliency)*(1-Z))
        # =>  ValueError: No gradients provided for any variable
        loss = classification_loss + saliency_loss 
        
        trainable_vars = self.trainable_variables
        # ValueError caused by the '+ saliency_loss'
        gradients = tape.gradient(loss, trainable_vars) 
        del tape # garbage collection
        
        if training:
            # Update weights
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Update metrics
        saliency_loss_tracker.update_state(saliency_loss)
        classification_loss_tracker.update_state(classification_loss)
        loss_tracker.update_state(loss)
        accuracy_tracker.update_state(y_true, y_pred)
        
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    
    def train_step(self, data):
        return self._train_test_step(data, True)
    
    def test_step(self, data):
        return self._train_test_step(data, False)
    
    @property
    def metrics(self):
        return [
            loss_tracker,
            classification_loss_tracker,
            saliency_loss_tracker,
            accuracy_tracker
        ]

मैं और classification_lossसाथ ही गणना करने का प्रबंधन करता saliency_lossहूं और मुझे एक स्केलर मूल्य मिलता है। हालाँकि, यह काम करता है: tape.gradient(classification_loss, trainable_vars)लेकिन यहtape.gradient(classification_loss + saliency_loss, trainable_vars) और नहीं फेंकता है ValueError: No gradients provided for any variable

जवाब

1 xdurch0 Dec 14 2020 at 07:21

आप टेप संदर्भ (पहली gradientकॉल के बाद ) के बाहर गणना कर रहे हैं और फिर बाद में अधिक ग्रेडिएंट लेने की कोशिश कर रहे हैं। यह काम नहीं करता है; अलग-अलग करने के लिए सभी ऑपरेशन संदर्भ प्रबंधक के अंदर होने की जरूरत है। मैं दो नीडिंत टेपों का उपयोग करते हुए आपके कोड को इस प्रकार पुनर्गठन करना चाहूंगा:

with GradientTape() as loss_tape:
    with GradientTape() as saliency_tape:
        # The tape will record everything that happens to X
        # for automatic differentiation later on (used to compute saliency)
        saliency_tape.watch(X)
        # Forward pass
        y_pred = self(X, training=training) 
        
        # (2) Compute the saliency loss
        # (2.1) Compute the gradient of output wrt the maximum probability
        log_prediction_proba = K.log(K.max(y_pred))
        
    # (2.2) Compute the gradient of the output wrt the input
    # saliency.shape is (None, MAX_SEQUENCE_LENGTH, None)
    # why isn't it (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM) ?!
    saliency = saliency_tape.gradient(log_prediction_proba, X)
    # (2.3) Sum along the embedding dimension
    saliency = K.sum(saliency, axis=2)
    # (2.4) Sum with the binary mask
    saliency_loss = K.sum(K.square(saliency)*(1-Z))

    # (1) Compute the classification_loss
    classification_loss = K.mean(
        categorical_crossentropy(y_true, y_pred)
    )

    loss = classification_loss + saliency_loss 
    
trainable_vars = self.trainable_variables
gradients = loss_tape.gradient(loss, trainable_vars)

अब हमारे पास एक टेप है जो ग्रैडिएंट्स की गणना करने के लिए जिम्मेदार है, जो सामर्थ्य के लिए इनपुट लिखता है। हमारे पास इसके चारों ओर एक और टेप है, जो उन ऑपरेशनों को ट्रैक करता है और बाद में ग्रेडिएंट के ढाल (यानी ढाल का ढाल) की गणना कर सकता है। यह टेप वर्गीकरण नुकसान के लिए ग्रेडिएंट्स की गणना भी करता है। मैंने बाहरी टेप के संदर्भ में वर्गीकरण को नुकसान पहुंचाया क्योंकि आंतरिक टेप को इसकी आवश्यकता नहीं है। यह भी ध्यान दें कि बाहरी टेप के संदर्भ के अंदर दो नुकसानों के अलावा भी कैसे - सब कुछ वहां होना है, अन्यथा गणना ग्राफ खो गया है / अपूर्ण है और ग्रेडिएंट की गणना नहीं की जा सकती है।

Andrey Dec 14 2020 at 00:31

सजाने के लिए प्रयास करें train_step()साथ@tf.function