क्या यह मायने रखता है कि dtype एक numpy सरणी इनपुट के लिए एक टेंसरफ़्लो / केरस न्यूरल नेटवर्क में क्या है?

Aug 15 2020

अगर मैं एक टेंसरफ़्लो.केरस मॉडल लेता हूं और कॉल करता हूं model.fit(x, y)(जहां xऔर yसुन्न सरणियां हैं) क्या इससे कोई फर्क नहीं पड़ता है कि dtypeसुन्न सरणी क्या है? क्या मैं सबसे अच्छा है dtypeकि जितना संभव हो उतना छोटा बना int8दूं (जैसे कि बाइनरी डेटा के लिए) या क्या यह एक फ्लोट को कास्ट करने के लिए टेंसरफ़्लो / केरस अतिरिक्त काम देता है?

जवाब

1 NicolasGervais Aug 15 2020 at 17:20

आपको अपना इनपुट कास्ट करना चाहिए np.float32, यही केरस के लिए डिफ़ॉल्ट dtype है। इसे देखो:

import tensorflow as tf
tf.keras.backend.floatx()
'float32'

यदि आप केरस इनपुट देते हैं np.float64, तो यह शिकायत करेगा:

import tensorflow as tf
from tensorflow.keras.layers import Dense 
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)

X = iris[:, :3]
y = iris[:, 3]

ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.d0 = Dense(16, activation='relu')
    self.d1 = Dense(32, activation='relu')
    self.d2 = Dense(1, activation='linear')

  def call(self, x):
    x = self.d0(x)
    x = self.d1(x)
    x = self.d2(x)
    return x

model = MyModel()

_ = model(X)

चेतावनी: टेंसरफ़्लो: परत my_model dtype float64 से लेयर के फ़्लोट 32 तक लेयर के dtype में एक इनपुट टेन्सर को कास्टिंग कर रहा है, जो TensorFlow 2 में नया व्यवहार है। लेयर में dotype फ़्लोट 32 है, क्योंकि यह फ्लोटेक्स के लिए dtype डिफॉल्ट है। यदि आप फ्लोट 32 में इस परत को चलाने का इरादा रखते हैं, तो आप इस चेतावनी को सुरक्षित रूप से अनदेखा कर सकते हैं। यदि संदेह है, तो यह चेतावनी केवल एक समस्या है यदि आप TensorFlow 1.X मॉडल को TensorFlow 2 में पोर्ट कर रहे हैं tf.keras.backend.set_floatx('float64')। डिफ़ॉल्ट रूप से कॉल करने के लिए सभी परतों को dtype float64 में बदलने के लिए । इस परत को बदलने के लिए, dtype = 'float64' को लेयर कंस्ट्रक्टर के पास भेजें। यदि आप इस लेयर के लेखक हैं, तो आप ऑटोकैस्ट = आधार लेयर कंस्ट्रक्टर को गलत तरीके से पास करके ऑटोकास्टिंग को अक्षम कर सकते हैं।

8 बिट इनपुट के साथ प्रशिक्षण के लिए टेंसरफ्लो का उपयोग करना संभव है , जिसे क्वांटिज़ेशन कहा जाता है। लेकिन यह ज्यादातर मामलों में चुनौतीपूर्ण और अनावश्यक है (यानी, जब तक आपको अपने उपकरणों को किनारे वाले उपकरणों पर तैनात करने की आवश्यकता नहीं है)।

tl; अपने इनपुट को अंदर रखें np.float32। इस पोस्ट को भी देखें ।