TensorFlow - การถดถอยเชิงเส้น

ในบทนี้เราจะเน้นไปที่ตัวอย่างพื้นฐานของการนำการถดถอยเชิงเส้นโดยใช้ TensorFlow การถดถอยโลจิสติกส์หรือการถดถอยเชิงเส้นเป็นวิธีการเรียนรู้ของเครื่องที่ได้รับการดูแลสำหรับการจัดหมวดหมู่หมวดหมู่ที่ไม่ต่อเนื่อง เป้าหมายของเราในบทนี้คือการสร้างแบบจำลองที่ผู้ใช้สามารถทำนายความสัมพันธ์ระหว่างตัวแปรทำนายกับตัวแปรอิสระอย่างน้อยหนึ่งตัว

ความสัมพันธ์ระหว่างตัวแปรทั้งสองนี้ถือว่าเป็นเชิงเส้น ถ้า y เป็นตัวแปรตามและ x ถือเป็นตัวแปรอิสระความสัมพันธ์การถดถอยเชิงเส้นของสองตัวแปรจะมีลักษณะเหมือนสมการต่อไปนี้ -

Y = Ax+b

เราจะออกแบบอัลกอริทึมสำหรับการถดถอยเชิงเส้น สิ่งนี้จะช่วยให้เราเข้าใจแนวคิดสำคัญสองประการต่อไปนี้ -

  • ฟังก์ชันต้นทุน
  • อัลกอริทึมการไล่ระดับสี

การแสดงแผนผังของการถดถอยเชิงเส้นแสดงไว้ด้านล่าง -

มุมมองกราฟิกของสมการการถดถอยเชิงเส้นมีการกล่าวถึงด้านล่าง -

ขั้นตอนในการออกแบบอัลกอริทึมสำหรับการถดถอยเชิงเส้น

ตอนนี้เราจะเรียนรู้เกี่ยวกับขั้นตอนที่ช่วยในการออกแบบอัลกอริทึมสำหรับการถดถอยเชิงเส้น

ขั้นตอนที่ 1

สิ่งสำคัญคือต้องนำเข้าโมดูลที่จำเป็นสำหรับการพล็อตโมดูลการถดถอยเชิงเส้น เราเริ่มนำเข้าไลบรารี Python NumPy และ Matplotlib

import numpy as np 
import matplotlib.pyplot as plt

ขั้นตอนที่ 2

กำหนดจำนวนค่าสัมประสิทธิ์ที่จำเป็นสำหรับการถดถอยโลจิสติก

number_of_points = 500 
x_point = [] 
y_point = [] 
a = 0.22 
b = 0.78

ขั้นตอนที่ 3

ทำซ้ำตัวแปรเพื่อสร้างจุดสุ่ม 300 จุดรอบสมการถดถอย -

Y = 0.22x + 0.78

for i in range(number_of_points): 
   x = np.random.normal(0.0,0.5) 
   y = a*x + b +np.random.normal(0.0,0.1) x_point.append([x]) 
   y_point.append([y])

ขั้นตอนที่ 4

ดูจุดที่สร้างขึ้นโดยใช้ Matplotlib

fplt.plot(x_point,y_point, 'o', label = 'Input Data') plt.legend() plt.show()

รหัสที่สมบูรณ์สำหรับการถดถอยโลจิสติกมีดังนี้ -

import numpy as np 
import matplotlib.pyplot as plt 

number_of_points = 500 
x_point = [] 
y_point = [] 
a = 0.22 
b = 0.78 

for i in range(number_of_points): 
   x = np.random.normal(0.0,0.5) 
   y = a*x + b +np.random.normal(0.0,0.1) x_point.append([x]) 
   y_point.append([y]) 
   
plt.plot(x_point,y_point, 'o', label = 'Input Data') plt.legend() 
plt.show()

จำนวนจุดที่นำมาเป็นอินพุตถือเป็นข้อมูลอินพุต