Por que às vezes os modelos da CNN prevêem apenas uma classe entre todas as outras?

Jan 18 2021

Eu sou relativamente novo no cenário do aprendizado profundo, então, por favor, não seja tão mau quanto o Reddit! Parece uma questão geral, então não irei fornecer meu código aqui, pois não parece necessário (se for, aqui está o link para colab )

Um pouco sobre os dados: Você pode encontrar os dados originais aqui . É uma versão reduzida do conjunto de dados original de 82 GB.

Depois de treinar minha CNN sobre isso, ela prediz 'Sem Retinopatia Diabética' (Sem RD) todas as vezes, levando a uma precisão de 73%. A razão para isso é apenas a grande quantidade de imagens sem DR ou outra coisa? Eu não faço ideia! As 5 aulas que tenho para previsão são ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"].

Provavelmente é apenas um código ruim, espero que vocês possam ajudar

Respostas

1 Ivan Jan 18 2021 at 00:21

Eu ia comentar:

Uma abordagem mais rigorosa seria começar a medir o equilíbrio do conjunto de dados: quantas imagens de cada classe você tem? Isso provavelmente dará uma resposta à sua pergunta.

Mas não pude evitar olhar para o link que você forneceu. O Kaggle já oferece uma visão geral do conjunto de dados:

Cálculo rápido: 25,812 / 35,126 * 100 = 73%. É interessante, você disse que tinha uma precisão de 74%. Seu modelo está aprendendo em um conjunto de dados desbalanceado, com a primeira classe sendo representada demais, 25k/35ké enorme. Minha hipótese é que seu modelo continua prevendo a primeira classe, o que significa que, em média, você terá uma precisão de 74%.

O que você deve fazer é equilibrar seu conjunto de dados. Por exemplo, permitindo que apenas 35,126 - 25,810 = 9,316exemplos da primeira classe apareçam durante uma época. Melhor ainda, equilibre seu conjunto de dados em todas as classes de forma que cada classe apareça apenas n vezes cada, por época.

2 Shai Jan 18 2021 at 04:30

Como Ivan já observou, você tem um problema de desequilíbrio de classe. Isso pode ser resolvido através de:

  1. Mineração negativa pesada online: a cada iteração após calcular a perda, você pode classificar todos os elementos no lote pertencentes à classe "sem DR" e manter apenas o pior k. Então você estima o gradiente usando apenas esses k piores e descarta todo o resto.
    ver, por exemplo:
    Abhinav Shrivastava, Abhinav Gupta e Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)

  2. Perda focal: uma modificação para a perda de entropia cruzada "vanilla" pode ser usada para lidar com o desequilíbrio de classes.


Postagens relacionadas isto e aquilo .