Perché a volte i modelli della CNN prevedono solo una classe su tutte le altre?

Jan 17 2021

Sono relativamente nuovo nel panorama del deep learning, quindi per favore non essere cattivo come Reddit! Sembra una domanda generale, quindi non fornirò il mio codice qui perché non sembra necessario (se lo è, ecco il link a colab )

Un po 'sui dati: puoi trovare i dati originali qui . È una versione ridotta del set di dati originale di 82 GB.

Una volta che ho addestrato la mia CNN su questo, predice "Nessuna retinopatia diabetica" (No DR) ogni volta, portando a una precisione del 73%. La ragione è solo la grande quantità di immagini No DR o qualcos'altro? Non ne ho idea! Le 5 classi che ho per la previsione sono ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"].

Probabilmente è solo un codice errato, speravo che poteste aiutarvi

Risposte

1 Ivan Jan 18 2021 at 00:21

Stavo per commentare:

Un approccio più rigoroso consiste nell'iniziare a misurare il bilanciamento del tuo set di dati: quante immagini hai di ogni classe? Questo probabilmente darà una risposta alla tua domanda.

Ma non ho potuto fare a meno di guardare il collegamento che hai fornito. Kaggle ti offre già una panoramica del set di dati:

Rapido calcolo: 25,812 / 35,126 * 100 = 73%. Interessante, hai detto di avere una precisione di 74%. Il tuo modello sta imparando su un set di dati sbilanciato, con la prima classe sovrarappresentata, 25k/35kè enorme. La mia ipotesi è che il tuo modello continui a prevedere la prima classe, il che significa che in media ti ritroverai con una precisione di 74%.

Quello che dovresti fare è bilanciare il tuo set di dati. Ad esempio, consentendo solo agli 35,126 - 25,810 = 9,316esempi della prima classe di apparire durante un'epoca. Ancora meglio, bilancia il tuo set di dati su tutte le classi in modo che ogni classe appaia solo n volte ciascuna, per epoca.

2 Shai Jan 18 2021 at 04:30

Come ha già notato Ivan, hai un problema di squilibrio di classe. Questo può essere risolto tramite:

  1. Estrazione negativa in linea: ad ogni iterazione dopo aver calcolato la perdita, è possibile ordinare tutti gli elementi del batch appartenenti alla classe "no DR" e conservare solo il peggio k. Quindi stimate il gradiente usando solo queste k peggiori e scarti tutto il resto.
    vedere, ad esempio:
    Abhinav Shrivastava, Abhinav Gupta e Ross Girshick Formazione per rilevatori di oggetti basati su regioni con estrazione di esempi difficili in linea (CVPR 2016)

  2. Perdita focale: una modifica per la perdita di entropia incrociata "vaniglia" può essere utilizzata per affrontare lo squilibrio di classe.


Post correlati questo e questo .