क्या यह मायने रखता है कि dtype एक numpy सरणी इनपुट के लिए एक टेंसरफ़्लो / केरस न्यूरल नेटवर्क में क्या है?
अगर मैं एक टेंसरफ़्लो.केरस मॉडल लेता हूं और कॉल करता हूं model.fit(x, y)
(जहां x
और y
सुन्न सरणियां हैं) क्या इससे कोई फर्क नहीं पड़ता है कि dtype
सुन्न सरणी क्या है? क्या मैं सबसे अच्छा है dtype
कि जितना संभव हो उतना छोटा बना int8
दूं (जैसे कि बाइनरी डेटा के लिए) या क्या यह एक फ्लोट को कास्ट करने के लिए टेंसरफ़्लो / केरस अतिरिक्त काम देता है?
जवाब
आपको अपना इनपुट कास्ट करना चाहिए np.float32
, यही केरस के लिए डिफ़ॉल्ट dtype है। इसे देखो:
import tensorflow as tf
tf.keras.backend.floatx()
'float32'
यदि आप केरस इनपुट देते हैं np.float64
, तो यह शिकायत करेगा:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)
X = iris[:, :3]
y = iris[:, 3]
ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.d0 = Dense(16, activation='relu')
self.d1 = Dense(32, activation='relu')
self.d2 = Dense(1, activation='linear')
def call(self, x):
x = self.d0(x)
x = self.d1(x)
x = self.d2(x)
return x
model = MyModel()
_ = model(X)
चेतावनी: टेंसरफ़्लो: परत my_model dtype float64 से लेयर के फ़्लोट 32 तक लेयर के dtype में एक इनपुट टेन्सर को कास्टिंग कर रहा है, जो TensorFlow 2 में नया व्यवहार है। लेयर में dotype फ़्लोट 32 है, क्योंकि यह फ्लोटेक्स के लिए dtype डिफॉल्ट है। यदि आप फ्लोट 32 में इस परत को चलाने का इरादा रखते हैं, तो आप इस चेतावनी को सुरक्षित रूप से अनदेखा कर सकते हैं। यदि संदेह है, तो यह चेतावनी केवल एक समस्या है यदि आप TensorFlow 1.X मॉडल को TensorFlow 2 में पोर्ट कर रहे हैं
tf.keras.backend.set_floatx('float64')
। डिफ़ॉल्ट रूप से कॉल करने के लिए सभी परतों को dtype float64 में बदलने के लिए । इस परत को बदलने के लिए, dtype = 'float64' को लेयर कंस्ट्रक्टर के पास भेजें। यदि आप इस लेयर के लेखक हैं, तो आप ऑटोकैस्ट = आधार लेयर कंस्ट्रक्टर को गलत तरीके से पास करके ऑटोकास्टिंग को अक्षम कर सकते हैं।
8 बिट इनपुट के साथ प्रशिक्षण के लिए टेंसरफ्लो का उपयोग करना संभव है , जिसे क्वांटिज़ेशन कहा जाता है। लेकिन यह ज्यादातर मामलों में चुनौतीपूर्ण और अनावश्यक है (यानी, जब तक आपको अपने उपकरणों को किनारे वाले उपकरणों पर तैनात करने की आवश्यकता नहीं है)।
tl; अपने इनपुट को अंदर रखें np.float32
। इस पोस्ट को भी देखें ।