予測を真の値と比較するTensorflowカスタム正則化項
こんにちは私は(バイナリクロスエントロピー)損失関数に追加するカスタム正則化項が必要です。誰かがこれを実装するためのTensorflow構文を手伝ってくれますか?私はすべてを可能な限り単純化したので、私を助けやすくなりました。
モデルは、18 x 18のバイナリ構成のデータセット10000を入力として受け取り、16x16の構成を出力として設定します。ニューラルネットワークは、2つのConvlutional層のみで構成されています。
私のモデルは次のようになります。
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
EPOCHS = 10
model = models.Sequential()
model.add(layers.Conv2D(1,2,activation='relu',input_shape=[18,18,1]))
model.add(layers.Conv2D(1,2,activation='sigmoid',input_shape=[17,17,1]))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),loss=tf.keras.losses.BinaryCrossentropy())
model.fit(initial.reshape(10000,18,18,1),target.reshape(10000,16,16,1),batch_size = 1000, epochs=EPOCHS, verbose=1)
output = model(initial).numpy().reshape(10000,16,16)
ここで、正則化項として持つ追加の正則化項として使用したい関数を作成しました。この関数は、真と予測を取ります。基本的に、それは両方のすべての点をその「正しい」隣接点で乗算します。次に、違いが取られます。真の予測項は16x16(10000x16x16ではない)であると想定しました。これは正しいです?
def regularization_term(prediction, true):
order = list(range(1,4))
order.append(0)
deviation = (true*true[:,order]) - (prediction*prediction[:,order])
deviation = abs(deviation)**2
return 0.2 * deviation
ニューラルネットワークがこの「右隣」の相互作用をよりよく訓練するのを助けるための私の損失に正則化用語としてこの関数のようなものを追加することで本当に感謝します。カスタマイズ可能なTensorflow機能をたくさん使用するのに本当に苦労しています。どうもありがとうございました。
回答
とても簡単です。追加の正則化項を定義するカスタム損失を指定する必要があります。このようなもの:
# to minimize!
def regularization_term(true, prediction):
order = list(range(1,4))
order.append(0)
deviation = (true*true[:,order]) - (prediction*prediction[:,order])
deviation = abs(deviation)**2
return 0.2 * deviation
def my_custom_loss(y_true, y_pred):
return tf.keras.losses.BinaryCrossentropy()(y_true, y_pred) + regularization_term(y_true, y_pred)
model.compile(optimizer='Adam', loss=my_custom_loss)
kerasが述べているように:
損失の配列(入力バッチのサンプルの1つ)を返すシグネチャloss_fn(y_true、y_pred)を使用して呼び出すことができるものはすべて、損失としてcompile()に渡すことができます。サンプルの重み付けは、このような損失に対して自動的にサポートされることに注意してください。
したがって、必ず損失の配列を返すようにしてください(編集:これでわかるように、単純なスカラーも返すことができます。たとえば、reduce関数を使用するかどうかは関係ありません)。基本的に、y_trueとy_predictedは、最初の次元としてバッチサイズを持ちます。
ここに詳細: https://keras.io/api/losses/