Имеет ли значение, какой dtype массив numpy предназначен для ввода в нейронную сеть tensorflow / keras?
Если я возьму модель tensorflow.keras и вызову model.fit(x, y)
(где x
и y
- массивы numpy), имеет ли значение, что dtype
такое массив numpy? Лучше ли мне просто сделать dtype
как можно меньше (например, int8
для двоичных данных) или это дает тензорному потоку / керасу дополнительную работу, чтобы преобразовать его в поплавок?
Ответы
Вы должны np.float32
указать свой ввод , это dtype по умолчанию для Keras. Поищи это:
import tensorflow as tf
tf.keras.backend.floatx()
'float32'
Если вы введете Keras np.float64
, он будет жаловаться:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)
X = iris[:, :3]
y = iris[:, 3]
ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.d0 = Dense(16, activation='relu')
self.d1 = Dense(32, activation='relu')
self.d2 = Dense(1, activation='linear')
def call(self, x):
x = self.d0(x)
x = self.d1(x)
x = self.d2(x)
return x
model = MyModel()
_ = model(X)
ПРЕДУПРЕЖДЕНИЕ: tensorflow: Layer my_model преобразует входной тензор из dtype float64 в dtype слоя float32, что является новым поведением в TensorFlow 2. Слой имеет dtype float32, потому что его dtype по умолчанию равен floatx. Если вы намеревались запустить этот слой в float32, можете игнорировать это предупреждение. Если вы сомневаетесь, это предупреждение, скорее всего, является проблемой, только если вы переносите модель TensorFlow 1.X на TensorFlow 2. Чтобы изменить все слои на dtype float64 по умолчанию, вызовите
tf.keras.backend.set_floatx('float64')
. Чтобы изменить только этот слой, передайте dtype = 'float64' конструктору слоя. Если вы являетесь автором этого слоя, вы можете отключить автоматическое преобразование, передав autocast = False в конструктор базового слоя.
Можно использовать Tensorflow для обучения с 8- битным вводом , что называется квантованием. Но в большинстве случаев это сложно и ненужно (то есть, если вам не нужно развертывать свои модели на периферийных устройствах).
tl; dr сохраните свой вклад np.float32
. См. Также этот пост .