Mengapa terkadang model CNN memprediksi hanya satu kelas dari yang lainnya?
Saya relatif baru mengenal lanskap pembelajaran mendalam, jadi tolong jangan kejam seperti Reddit! Sepertinya pertanyaan umum jadi saya tidak akan memberikan kode saya di sini karena sepertinya tidak perlu (jika ya, ini tautan ke colab )
Sedikit tentang data: Anda dapat menemukan data asli di sini . Ini adalah versi kecil dari kumpulan data asli 82 GB.
Setelah saya melatih CNN saya tentang hal ini, CNN memprediksi 'No Diabetic Retinopathy' (No DR) setiap saat, menghasilkan akurasi 73%. Apakah alasan untuk ini hanya karena banyaknya gambar Tanpa DR atau yang lainnya? Saya tidak punya ide! 5 kelas yang saya miliki untuk prediksi adalah ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"]
.
Itu mungkin hanya kode yang buruk, berharap kalian bisa membantu
Jawaban
Saya akan berkomentar:
Pendekatan yang lebih ketat adalah mulai mengukur keseimbangan set data Anda: berapa banyak gambar dari setiap kelas yang Anda miliki? Ini kemungkinan besar akan memberikan jawaban atas pertanyaan Anda.
Tapi tidak bisa menahan diri untuk melihat tautan yang Anda berikan. Kaggle sudah memberi Anda gambaran umum tentang kumpulan data:
Perhitungan cepat: 25,812 / 35,126 * 100 = 73%
. Itu menarik, Anda bilang Anda punya akurasi 74%
. Model Anda belajar pada kumpulan data yang tidak seimbang, dengan kelas pertama diwakili secara berlebihan, 25k/35k
sangatlah besar. Hipotesis saya adalah bahwa model Anda terus memprediksi kelas pertama yang berarti bahwa rata-rata Anda akan mendapatkan keakuratan sebesar 74%
.
Apa yang harus Anda lakukan adalah menyeimbangkan kumpulan data Anda. Misalnya dengan hanya mengizinkan 35,126 - 25,810 = 9,316
contoh dari kelas pertama muncul selama suatu epoch. Lebih baik lagi, seimbangkan kumpulan data Anda di semua kelas sehingga setiap kelas hanya akan muncul masing-masing n kali, per epoch.
Seperti yang sudah Ivan catat, Anda memiliki masalah ketidakseimbangan kelas. Ini dapat diselesaikan melalui:
Penambangan negatif keras online: pada setiap iterasi setelah menghitung kerugian, Anda dapat mengurutkan semua elemen dalam kelompok yang termasuk dalam kelas "tanpa DR" dan hanya menyimpan yang terburuk
k
. Kemudian Anda memperkirakan gradien hanya menggunakan k yang lebih buruk ini dan membuang sisanya.
lihat, misalnya:
Abhinav Shrivastava, Abhinav Gupta dan Ross Girshick Training Region-based Object Detectors dengan Online Hard Example Mining (CVPR 2016)Focal loss: modifikasi untuk kehilangan entropi silang "vanilla" dapat digunakan untuk mengatasi ketidakseimbangan kelas.
Posting terkait ini dan ini .