때때로 CNN 모델이 다른 모든 클래스 중에서 하나의 클래스 만 예측하는 이유는 무엇입니까?
나는 딥 러닝 환경에 비교적 익숙하지 않으므로 Reddit만큼 비열하지 마십시오! 일반적인 질문처럼 보이므로 필요하지 않은 것처럼 여기에 내 코드를 제공하지 않을 것입니다 (필요한 경우 여기에 colab 링크가 있습니다 )
데이터에 대한 약간의 정보 : 여기 에서 원본 데이터를 찾을 수 있습니다 . 82GB의 원래 데이터 세트를 축소 한 버전입니다.
이에 대해 CNN을 훈련하면 매번 'No Diabetic Retinopathy'(DR 없음)를 예측하여 73 %의 정확도로 이어집니다. 그 이유는 방대한 양의 No DR 이미지 때문입니까? 모르겠어요! 내가 예측을 위해 가지고있는 5 가지 클래스는 ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"]
.
아마도 나쁜 코드 일 것입니다.
답변
나는 논평하려고했다 :
보다 엄격한 접근 방식은 데이터 세트 균형 측정을 시작하는 것입니다. 각 클래스의 이미지는 몇 개입니까? 이것은 귀하의 질문에 대한 답을 줄 것입니다.
그러나 당신이 준 링크를 스스로 볼 수는 없었습니다. Kaggle은 이미 데이터 세트에 대한 개요를 제공합니다.
빠른 계산 : 25,812 / 35,126 * 100 = 73%
. 흥미 롭군요 74%
. 정확도가 . 모델이 불균형 데이터 세트에 대해 학습하고 있으며, 첫 번째 클래스가 과도하게 표현되면 25k/35k
엄청납니다. 내 가설은 모델이 첫 번째 클래스를 계속 예측하므로 평균적으로 정확도가 74%
.
당신이 해야 할 것은 데이터 세트의 균형이다. 예를 들어 35,126 - 25,810 = 9,316
첫 번째 클래스의 예제 만 에포크 동안 표시 되도록 허용 합니다. 더 좋은 방법은 각 클래스가 에포크 당 n 번만 나타나도록 모든 클래스에 대해 데이터 세트의 균형을 맞추는 것 입니다.
Ivan이 이미 언급했듯이 클래스 불균형 문제가 있습니다. 이는 다음을 통해 해결할 수 있습니다.
온라인 하드 네거티브 마이닝 : 손실을 계산 한 후 각 반복에서 "DR 없음"클래스에 속하는 배치의 모든 요소를 정렬하고 최악의 항목 만 유지할 수
k
있습니다. 그런 다음 이러한 더 나쁜 k 만 사용하여 기울기를 추정 하고 나머지는 모두 버립니다 .
참조 :
Abhinav Shrivastava, Abhinav Gupta 및 Ross Girshick 교육 지역 기반 물체 탐지기 (온라인 하드 예제 마이닝 포함) (CVPR 2016)초점 손실 : "바닐라"교차 엔트로피 손실에 대한 수정을 사용하여 클래스 불균형을 해결할 수 있습니다.
관련 게시물 this 및 this .