Pourquoi les modèles CNN prédisent-ils parfois une seule classe parmi toutes les autres?

Jan 17 2021

Je suis relativement nouveau dans le paysage de l'apprentissage profond, alors ne soyez pas aussi méchant que Reddit! Cela semble être une question générale, donc je ne donnerai pas mon code ici car cela ne semble pas nécessaire (si c'est le cas, voici le lien vers colab )

Un peu sur les données: vous pouvez trouver les données originales ici . Il s'agit d'une version réduite de l'ensemble de données d'origine de 82 Go.

Une fois que j'ai formé mon CNN à ce sujet, il prédit `` Pas de rétinopathie diabétique '' (pas de DR) à chaque fois, conduisant à une précision de 73%. La raison en est-elle simplement la grande quantité d'images No DR ou autre chose? Je n'ai aucune idée! Les 5 classes que j'ai pour la prédiction sont ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"].

C'est probablement juste un mauvais code, j'espérais que vous pourriez aider

Réponses

1 Ivan Jan 18 2021 at 00:21

J'étais sur le point de commenter:

Une approche plus rigoureuse serait de commencer à mesurer l'équilibre de votre jeu de données: combien d'images de chaque classe avez-vous? Cela donnera probablement une réponse à votre question.

Mais je n'ai pas pu m'empêcher de regarder le lien que vous avez donné. Kaggle vous donne déjà un aperçu de l'ensemble de données:

Calcul rapide: 25,812 / 35,126 * 100 = 73%. C'est intéressant, vous avez dit que vous aviez une précision de 74%. Votre modèle apprend sur un jeu de données déséquilibré, la première classe étant surreprésentée, 25k/35kc'est énorme. Mon hypothèse est que votre modèle continue de prédire la première classe, ce qui signifie qu'en moyenne vous vous retrouverez avec une précision de 74%.

Ce que vous devez faire est de l' équilibre de votre ensemble de données. Par exemple en permettant uniquement aux 35,126 - 25,810 = 9,316exemples de la première classe d'apparaître à une époque. Mieux encore, équilibrez votre ensemble de données sur toutes les classes de sorte que chaque classe n'apparaisse que n fois chacune, par époque.

2 Shai Jan 18 2021 at 04:30

Comme Ivan l'a déjà noté, vous avez un problème de déséquilibre de classe. Cela peut être résolu via:

  1. Minage négatif en ligne: à chaque itération après le calcul de la perte, vous pouvez trier tous les éléments du lot appartenant à la classe "no DR" et ne conserver que le pire k. Ensuite, vous estimez le gradient en utilisant uniquement ces k pires et vous rejetez tout le reste.
    voir, par exemple:
    Abhinav Shrivastava, Abhinav Gupta et Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)

  2. Perte focale: une modification de la perte d'entropie croisée «vanille» peut être utilisée pour lutter contre le déséquilibre de classe.


Articles connexes ceci et cela .