Faz diferença que tipo de um array numpy é para entrada em uma rede neural tensorflow / keras?

Aug 15 2020

Se eu pegar um modelo tensorflow.keras e chamar model.fit(x, y)(where xand yare numpy arrays), importa qual é dtypeo array numpy? É melhor apenas fazer o dtypemenor possível (por exemplo, int8para dados binários) ou isso dá a tensorflow / keras trabalho extra para convertê-lo em um float?

Respostas

1 NicolasGervais Aug 15 2020 at 17:20

Você deve lançar sua entrada para np.float32, esse é o dtype padrão para Keras. Procure:

import tensorflow as tf
tf.keras.backend.floatx()
'float32'

Se você fornecer informações a Keras np.float64, ele reclamará:

import tensorflow as tf
from tensorflow.keras.layers import Dense 
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)

X = iris[:, :3]
y = iris[:, 3]

ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.d0 = Dense(16, activation='relu')
    self.d1 = Dense(32, activation='relu')
    self.d2 = Dense(1, activation='linear')

  def call(self, x):
    x = self.d0(x)
    x = self.d1(x)
    x = self.d2(x)
    return x

model = MyModel()

_ = model(X)

AVISO: tensorflow: a camada my_model está lançando um tensor de entrada de dtype float64 para o dtype de float32 da camada, que é um novo comportamento no TensorFlow 2. A camada tem dtype float32 porque seu padrão é floatx. Se você pretendia executar essa camada em float32, pode ignorar esse aviso com segurança. Em caso de dúvida, este aviso provavelmente é um problema apenas se você estiver portando um modelo TensorFlow 1.X para TensorFlow 2. Para alterar todas as camadas para ter dtype float64 por padrão, chame tf.keras.backend.set_floatx('float64'). Para alterar apenas esta camada, passe dtype = 'float64' para o construtor da camada. Se você for o autor desta camada, você pode desativar o autocast passando autocast = False para o construtor de camada base.

É possível usar o Tensorflow para treinamento com entrada de 8 bits , o que é chamado de quantização. Mas é desafiador e desnecessário na maioria dos casos (ou seja, a menos que você precise implantar seus modelos em dispositivos de borda).

tl; dr mantenha sua entrada np.float32. Veja também este post .