Czy ma znaczenie, jaki dtype tablica numpy ma być wprowadzana do sieci neuronowej tensorflow / keras?
Jeśli wezmę model tensorflow.keras i wywołam model.fit(x, y)
(gdzie x
i y
są tablice numpy), czy ma znaczenie, jaka dtype
jest tablica numpy? Czy najlepiej jest zrobić dtype
tak małe, jak to tylko możliwe (np. int8
Dla danych binarnych), czy też daje to tensorflow / keras dodatkowej pracy, aby rzucić je na float?
Odpowiedzi
Powinieneś rzucić swoje wejście np.float32
na domyślny typ dla Keras. Sprawdź to:
import tensorflow as tf
tf.keras.backend.floatx()
'float32'
Jeśli dasz Keras wejście np.float64
, będzie narzekał:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)
X = iris[:, :3]
y = iris[:, 3]
ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.d0 = Dense(16, activation='relu')
self.d1 = Dense(32, activation='relu')
self.d2 = Dense(1, activation='linear')
def call(self, x):
x = self.d0(x)
x = self.d1(x)
x = self.d2(x)
return x
model = MyModel()
_ = model(X)
OSTRZEŻENIE: tensorflow: Warstwa my_model rzutuje wejściowy tensor z dtype float64 na dtype warstwy float32, co jest nowym zachowaniem w TensorFlow 2. Warstwa ma dtype float32, ponieważ jej domyślnym typem dtype jest floatx. Jeśli zamierzałeś uruchomić tę warstwę w float32, możesz bezpiecznie zignorować to ostrzeżenie. W razie wątpliwości to ostrzeżenie jest prawdopodobnie problemem tylko w przypadku przenoszenia modelu TensorFlow 1.X do TensorFlow 2. Aby zmienić wszystkie warstwy tak, aby miały domyślnie dtype float64, wywołaj
tf.keras.backend.set_floatx('float64')
. Aby zmienić tylko tę warstwę, przekaż dtype = 'float64' do konstruktora warstwy. Jeśli jesteś autorem tej warstwy, możesz wyłączyć automatyczne przesyłanie, przekazując autocast = False do konstruktora warstwy podstawowej.
Możliwe jest użycie Tensorflow do uczenia z 8- bitowym wejściem , co nazywa się kwantyzacją. Jednak w większości przypadków jest to trudne i niepotrzebne (tj. Chyba, że musisz wdrożyć modele na urządzeniach brzegowych).
tl; dr zachowaj wprowadzone dane np.float32
. Zobacz także ten post .