¿Por qué a veces los modelos de CNN predicen solo una clase de todas las demás?
Soy relativamente nuevo en el panorama del aprendizaje profundo, ¡así que no seas tan malo como Reddit! Parece una pregunta general, por lo que no daré mi código aquí, ya que no parece necesario (si lo es, aquí está el enlace a colab )
Un poco sobre los datos: puede encontrar los datos originales aquí . Es una versión reducida del conjunto de datos original de 82 GB.
Una vez que entrené a mi CNN en esto, predice 'Sin retinopatía diabética' (Sin DR) cada vez, lo que lleva a una precisión del 73%. ¿La razón de esto es solo la gran cantidad de imágenes sin DR o algo más? ¡No tengo idea! Las 5 clases que tengo para la predicción son ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"]
.
Probablemente sea solo un código incorrecto, esperaba que ustedes pudieran ayudar
Respuestas
Estuve a punto de comentar:
Un enfoque más riguroso sería comenzar a medir el equilibrio de su conjunto de datos: ¿cuántas imágenes de cada clase tiene? Esto probablemente dará una respuesta a su pregunta.
Pero no pude evitar mirar el enlace que me diste. Kaggle ya le brinda una descripción general del conjunto de datos:
Cálculo rápido: 25,812 / 35,126 * 100 = 73%
. Eso es interesante, dijiste que tenías una precisión de 74%
. Su modelo está aprendiendo en un conjunto de datos desequilibrado, con la primera clase sobrerrepresentada, 25k/35k
es enorme. Mi hipótesis es que su modelo sigue prediciendo la primera clase, lo que significa que, en promedio, terminará con una precisión de 74%
.
Lo que debe hacer es equilibrar su conjunto de datos. Por ejemplo, solo permitiendo 35,126 - 25,810 = 9,316
que aparezcan ejemplos de la primera clase durante una época. Aún mejor, equilibre su conjunto de datos en todas las clases de modo que cada clase solo aparezca n veces cada una, por época.
Como Iván ya señaló, tienes un problema de desequilibrio de clases. Esto se puede resolver mediante:
Minería negativa dura en línea: en cada iteración después de calcular la pérdida, puede ordenar todos los elementos del lote que pertenecen a la clase "sin DR" y conservar solo los peores
k
. Luego, estima el gradiente solo usando estos k peores y descarta el resto.
ver, por ejemplo:
Abhinav Shrivastava, Abhinav Gupta y Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)Pérdida focal: se puede utilizar una modificación de la pérdida de entropía cruzada "vainilla" para abordar el desequilibrio de clases.
Publicaciones relacionadas esto y esto .