มันสำคัญหรือไม่ว่า dtype อาร์เรย์ numpy สำหรับอินพุตในโครงข่ายประสาทเทียม tensorflow / keras คืออะไร?

Aug 15 2020

ถ้าฉันใช้แบบจำลอง tensorflow.keras และเรียกmodel.fit(x, y)( อาร์เรย์ numpy อยู่ที่ไหนxและอยู่ที่ไหนy) มันสำคัญว่าdtypeอาร์เรย์ numpy คืออะไร? ฉันควรทำให้dtypeเล็กที่สุดเท่าที่จะเป็นไปได้ (เช่นint8สำหรับข้อมูลไบนารี) หรือไม่หรือให้เทนเซอร์โฟลว์ / เคราส์ทำงานพิเศษเพื่อโยนมันให้ลอย?

คำตอบ

1 NicolasGervais Aug 15 2020 at 17:20

คุณควรส่งข้อมูลที่คุณป้อนnp.float32ซึ่งเป็นค่าเริ่มต้นสำหรับ Keras ค้นดูสิ:

import tensorflow as tf
tf.keras.backend.floatx()
'float32'

หากคุณป้อนข้อมูล Keras เข้าnp.float64มันจะบ่น:

import tensorflow as tf
from tensorflow.keras.layers import Dense 
from tensorflow.keras import Model
from sklearn.datasets import load_iris
iris, target = load_iris(return_X_y=True)

X = iris[:, :3]
y = iris[:, 3]

ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.d0 = Dense(16, activation='relu')
    self.d1 = Dense(32, activation='relu')
    self.d2 = Dense(1, activation='linear')

  def call(self, x):
    x = self.d0(x)
    x = self.d1(x)
    x = self.d2(x)
    return x

model = MyModel()

_ = model(X)

คำเตือน: tensorflow: Layer my_model กำลังส่งอินพุตเทนเซอร์จาก dtype float64 ไปยัง dtype ของชั้นของ float32 ซึ่งเป็นลักษณะการทำงานใหม่ใน TensorFlow 2 เลเยอร์มี dtype float32 เนื่องจากค่าเริ่มต้นของ dtype เป็น floatx หากคุณตั้งใจจะเรียกใช้เลเยอร์นี้ใน float32 คุณสามารถเพิกเฉยต่อคำเตือนนี้ได้อย่างปลอดภัย หากมีข้อสงสัยคำเตือนนี้มีโอกาสเพียงปัญหาถ้าคุณกำลังย้ายรูปแบบ TensorFlow 1.X เพื่อ TensorFlow 2. การเปลี่ยนทุกชั้นจะมี float64 dtype tf.keras.backend.set_floatx('float64')โดยค่าเริ่มต้นการโทร หากต้องการเปลี่ยนเพียงเลเยอร์นี้ให้ส่ง dtype = 'float64' ไปยังตัวสร้างเลเยอร์ หากคุณเป็นผู้สร้างเลเยอร์นี้คุณสามารถปิดใช้งานการแคสต์อัตโนมัติได้โดยส่ง autocast = False ไปยังตัวสร้างเลเยอร์พื้นฐาน

เป็นไปได้ที่จะใช้ Tensorflow สำหรับการฝึกอบรมด้วยอินพุต 8 บิตซึ่งเรียกว่า quantization แต่ในกรณีส่วนใหญ่เป็นเรื่องที่ท้าทายและไม่จำเป็น (กล่าวคือเว้นแต่คุณจะต้องปรับใช้โมเดลของคุณบนอุปกรณ์ขอบ)

tl; drเก็บข้อมูลของคุณไว้ในnp.float32. ดูโพสต์นี้ด้วย