Overfitting ใน Linear Regression

Aug 27 2020

ฉันเพิ่งเริ่มต้นกับแมชชีนเลิร์นนิงและฉันมีปัญหาในการทำความเข้าใจว่าการฟิตติ้งมากเกินไปสามารถเกิดขึ้นได้อย่างไรในโมเดลการถดถอยเชิงเส้น

เมื่อพิจารณาว่าเราใช้ตัวแปรคุณลักษณะเพียง 2 ตัวในการฝึกโมเดลระนาบแบนจะติดตั้งชุดจุดข้อมูลมากเกินไปได้อย่างไร

ฉันถือว่าการถดถอยเชิงเส้นใช้เพียงเส้นเดียวเพื่ออธิบายความสัมพันธ์เชิงเส้นระหว่าง 2 ตัวแปรและระนาบแบนเพื่ออธิบายความสัมพันธ์ระหว่าง 3 ตัวแปรฉันมีปัญหาในการทำความเข้าใจ (หรือค่อนข้างจะจินตนาการ) ว่าการใส่เกินในเส้นหรือระนาบจะเกิดขึ้นได้อย่างไร

คำตอบ

20 RobertLong Aug 27 2020 at 17:18

การถดถอยเชิงเส้นเกิดขึ้นเมื่อโมเดล "ซับซ้อนเกินไป" สิ่งนี้มักเกิดขึ้นเมื่อมีพารามิเตอร์จำนวนมากเมื่อเทียบกับจำนวนการสังเกต แบบจำลองดังกล่าวจะไม่สามารถสรุปได้ดีกับข้อมูลใหม่ นั่นคือมันจะทำงานได้ดีกับข้อมูลการฝึกอบรม แต่มีข้อมูลการทดสอบไม่ดี

การจำลองอย่างง่ายสามารถแสดงสิ่งนี้ได้ ที่นี่ฉันใช้ R:

> set.seed(2)
> N <- 4
> X <- 1:N
> Y <- X + rnorm(N, 0, 1)
> 
> (m0 <- lm(Y ~ X)) %>% summary()

Coefficients:
            Estimate Std. Error t value Pr(>|t|)
(Intercept)  -0.2393     1.8568  -0.129    0.909
X             1.0703     0.6780   1.579    0.255

Residual standard error: 1.516 on 2 degrees of freedom
Multiple R-squared:  0.5548,    Adjusted R-squared:  0.3321 
F-statistic: 2.492 on 1 and 2 DF,  p-value: 0.2552

โปรดทราบว่าเราได้ค่าประมาณที่แท้จริงสำหรับค่าสัมประสิทธิ์ของ X โปรดสังเกต R-squared ที่ปรับปรุงแล้วเป็น 0.3321 ซึ่งเป็นตัวบ่งชี้ความพอดีของโมเดล

ตอนนี้เราพอดีกับแบบจำลองกำลังสอง:

> (m1 <- lm(Y ~ X + I(X^2) )) %>% summary()


Coefficients:
            Estimate Std. Error t value Pr(>|t|)
(Intercept)  -4.9893     2.7654  -1.804    0.322
X             5.8202     2.5228   2.307    0.260
I(X^2)       -0.9500     0.4967  -1.913    0.307

Residual standard error: 0.9934 on 1 degrees of freedom
Multiple R-squared:  0.9044,    Adjusted R-squared:  0.7133 
F-statistic: 4.731 on 2 and 1 DF,  p-value: 0.3092

ตอนนี้เรามี Adjusted R-squared ที่สูงขึ้นมาก: 0.7133 ซึ่งอาจทำให้เราคิดว่าโมเดลนั้นดีกว่ามาก แน่นอนว่าถ้าเราวางแผนข้อมูลและค่าที่คาดการณ์ไว้จากทั้งสองแบบเราจะได้รับ:

> fun.linear <- function(x) { coef(m0)[1] + coef(m0)[2] * x  }
> fun.quadratic <- function(x) { coef(m1)[1] + coef(m1)[2] * x  + coef(m1)[3] * x^2}
> 
> ggplot(data.frame(X,Y), aes(y = Y, x = X)) + geom_point()  + stat_function(fun = fun.linear) + stat_function(fun = fun.quadratic)

ดังนั้นบนใบหน้าของมันโมเดลกำลังสองจึงดูดีกว่ามาก

ตอนนี้ถ้าเราจำลองข้อมูลใหม่ แต่ใช้โมเดลเดียวกันในการพล็อตการคาดการณ์เราจะได้

> set.seed(6)
> N <- 4
> X <- 1:N
> Y <- X + rnorm(N, 0, 1)
> ggplot(data.frame(X,Y), aes(y = Y, x = X)) + geom_point()  + stat_function(fun = fun.linear) + stat_function(fun = fun.quadratic)

เห็นได้ชัดว่าแบบจำลองกำลังสองทำได้ไม่ดีในขณะที่แบบจำลองเชิงเส้นยังคงสมเหตุสมผล อย่างไรก็ตามหากเราจำลองข้อมูลเพิ่มเติมด้วยช่วงขยายโดยใช้เมล็ดพันธุ์ดั้งเดิมเพื่อให้จุดข้อมูลเริ่มต้นเหมือนกับในการจำลองครั้งแรกที่เราพบ:

> set.seed(2)
> N <- 10
> X <- 1:N
> Y <- X + rnorm(N, 0, 1)
> ggplot(data.frame(X,Y), aes(y = Y, x = X)) + geom_point()  + stat_function(fun = fun.linear) + stat_function(fun = fun.quadratic)

เห็นได้ชัดว่าแบบจำลองเชิงเส้นยังคงทำงานได้ดี แต่แบบจำลองกำลังสองนั้นสิ้นหวังอยู่นอกช่วงดั้งเดิม เนื่องจากเมื่อเราติดตั้งโมเดลเรามีพารามิเตอร์มากเกินไป (3) เมื่อเทียบกับจำนวนการสังเกต (4)


แก้ไข: เพื่อจัดการกับคำถามในความคิดเห็นของคำตอบนี้เกี่ยวกับโมเดลที่ไม่มีคำสั่งซื้อที่สูงกว่า

สถานการณ์เหมือนกัน: หากจำนวนพารามิเตอร์เข้าใกล้จำนวนการสังเกตโมเดลจะติดตั้งมากเกินไป หากไม่มีเงื่อนไขลำดับที่สูงกว่านี้จะเกิดขึ้นเมื่อจำนวนตัวแปร / คุณลักษณะในแบบจำลองเข้าใกล้จำนวนข้อสังเกต

เราสามารถสาธิตสิ่งนี้ได้อย่างง่ายดายอีกครั้งด้วยการจำลอง:

ที่นี่เราจำลองข้อมูลข้อมูลแบบสุ่มจากการแจกแจงปกติซึ่งเรามีข้อสังเกต 7 ประการและตัวแปร / คุณลักษณะ 5 ประการ:

> set.seed(1)
> n.var <- 5
> n.obs <- 7
> 
> dt <- as.data.frame(matrix(rnorm(n.var * n.obs), ncol = n.var))
> dt$Y <- rnorm(nrow(dt))
> 
> lm(Y ~ . , dt) %>% summary()

Coefficients:
            Estimate Std. Error t value Pr(>|t|)
(Intercept)  -0.6607     0.2337  -2.827    0.216
V1            0.6999     0.1562   4.481    0.140
V2           -0.4751     0.3068  -1.549    0.365
V3            1.2683     0.3423   3.705    0.168
V4            0.3070     0.2823   1.087    0.473
V5            1.2154     0.3687   3.297    0.187

Residual standard error: 0.2227 on 1 degrees of freedom
Multiple R-squared:  0.9771,    Adjusted R-squared:  0.8627 

เราได้รับ R-squared ที่ปรับแล้วเป็น 0.86 ซึ่งบ่งบอกถึงความพอดีของโมเดลที่ดีเยี่ยม บนข้อมูลสุ่มล้วนๆ โมเดลมีการติดตั้งมากเกินไป โดยการเปรียบเทียบถ้าเราเพิ่มจำนวนสิ่งกีดขวางเป็นสองเท่าเป็น 14:

> set.seed(1)
> n.var <- 5
> n.obs <- 14
> dt <- as.data.frame(matrix(rnorm(n.var * n.obs), ncol = n.var))
> dt$Y <- rnorm(nrow(dt))
> lm(Y ~ . , dt) %>% summary()

Coefficients:
            Estimate Std. Error t value Pr(>|t|)  
(Intercept) -0.10391    0.23512  -0.442   0.6702  
V1          -0.62357    0.32421  -1.923   0.0906 .
V2           0.39835    0.27693   1.438   0.1883  
V3          -0.02789    0.31347  -0.089   0.9313  
V4          -0.30869    0.30628  -1.008   0.3430  
V5          -0.38959    0.20767  -1.876   0.0975 .
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 0.7376 on 8 degrees of freedom
Multiple R-squared:  0.4074,    Adjusted R-squared:  0.03707 
F-statistic:   1.1 on 5 and 8 DF,  p-value: 0.4296

.. ปรับ R กำลังสองลดลงเหลือเพียง 0.037

4 Dhanushkumar Aug 28 2020 at 00:10

การใส่อุปกรณ์มากเกินไปจะเกิดขึ้นเมื่อโมเดลทำงานได้ดีกับข้อมูลรถไฟ แต่ทำงานได้ไม่ดีกับข้อมูลทดสอบ นี่เป็นเพราะเส้นที่พอดีที่สุดโดยแบบจำลองการถดถอยเชิงเส้นของคุณไม่ใช่แบบทั่วไป อาจเป็นเพราะปัจจัยต่างๆ บางส่วนของปัจจัยที่พบบ่อยคือ

  • ค่าผิดปกติในข้อมูลรถไฟ
  • ข้อมูลการฝึกและการทดสอบมาจากการแจกแจงที่แตกต่างกัน

ดังนั้นก่อนสร้างแบบจำลองให้แน่ใจว่าคุณได้ตรวจสอบปัจจัยเหล่านี้เพื่อให้ได้โมเดลทั่วไป

2 Peteris Aug 29 2020 at 00:22

พารามิเตอร์จำนวนมากเมื่อเทียบกับจุดข้อมูล

โดยทั่วไปแง่มุมหนึ่งของการใส่อุปกรณ์มากเกินไปคือการพยายาม "ประดิษฐ์ข้อมูลจากความรู้" เมื่อคุณต้องการกำหนดพารามิเตอร์จำนวนมากโดยเปรียบเทียบจากจุดข้อมูลหลักฐานที่แท้จริงจำนวน จำกัด

สำหรับการถดถอยเชิงเส้นอย่างง่ายy = ax + bมีพารามิเตอร์สองตัวดังนั้นสำหรับชุดข้อมูลส่วนใหญ่จะอยู่ภายใต้พารามิเตอร์ไม่ใช่พารามิเตอร์มากเกินไป อย่างไรก็ตามลองดูกรณี (เสื่อม) ของจุดข้อมูลเพียงสองจุด ในสถานการณ์นั้นคุณสามารถหาวิธีแก้ปัญหาการถดถอยเชิงเส้นที่สมบูรณ์แบบได้เสมอ - อย่างไรก็ตามคำตอบนั้นจำเป็นต้องมีความหมายหรือไม่? อาจจะไม่ หากคุณถือว่าการถดถอยเชิงเส้นของจุดข้อมูลสองจุดเป็นวิธีแก้ปัญหาที่เพียงพอนั่นจะเป็นตัวอย่างที่สำคัญของการใส่มากเกินไป

นี่คือตัวอย่างที่ดีของการใส่มากเกินไปด้วยการถดถอยเชิงเส้นโดยRandall Munroe แห่งชื่อเสียง xkcdที่แสดงให้เห็นถึงปัญหานี้: