Retinal Disease Classifier
Overview
The Retinal Disease Classifier is a computer vision model designed to assist in the detection and classification of retinal diseases using fundus images. It performs hierarchical classification—first detecting the presence of disease, then identifying the specific condition. The currently deployed model is a fine-tuned EfficientNet-B3 trained on ~1,600 retinal images.
Try the final product here: Retinal Disease Classifier Demo (may be sleeping due to inactivity)
Model Metrics:
- Disease Identification Recall: 85%
- Disease Classification Accuracy: 82%
- Inference Speed: ~0.5 seconds per image
Classification Types:
- Normal Retina
- Diabetic Retinopathy: Damages the retina’s blood vessels and may lead to vision loss.
- Age-Related Macular Degeneration (ARMD): Blurs central vision due to damage to the macula, common in older adults.
- Media Haze: Clouding of the eye’s optical media, reducing vision clarity.
- Optic Disc Cupping: Structural change in the optic nerve head, often related to glaucoma.
Features
- Hierarchical classification pipeline:
- Step 1: Disease presence detection
- Step 2: Disease type classification
- Custom vision models: EfficientNet-B3 and Vision Transformer (ViT) implemented from scratch
- Lightweight model optimized for low-latency inference
- Training monitored via TensorBoard
- Deployed on Hugging Face Spaces using Gradio for live demo
Data Preparation
The dataset was downloaded from IEEE DataPort.
The dataset was filtered to include only the five relevant conditions used in this model. The data was manually organized and split into training, validation, and test sets to ensure balanced class distribution.
Data Summary
Acronym | Full Name | Training | Validation | Test |
---|---|---|---|---|
NORMAL | Normal Retina | 516 | 134 | 134 |
DR | Diabetic Retinopathy | 375 | 132 | 124 |
ARMD | Age-Related Macular Degeneration | 100 | 38 | 31 |
MH | Media Haze | 316 | 92 | 100 |
ODC | Optic Disc Cupping | 281 | 72 | 91 |
Image Preprocessing
The following default image transforms were applied to match the expectations of the pretrained models:
- Resize: 300x300 (EffNet) / 224x224 (ViT)
- Normalization: Mean = [0.485, 0.456, 0.406], Std = [0.229, 0.224, 0.225]
- ToTensor: Converts PIL images to PyTorch tensors
- Optional Augmentations: RandomHorizontalFlip, ColorJitter (during training only)
Model Training
The primary goal of this project was to strike a balance between model accuracy and inference efficiency, particularly for potential deployment in resource-constrained environments.
Architectures Explored
- EfficientNet-B3: A compact yet powerful CNN-based architecture known for speed and accuracy.
- Vision Transformer (ViT): A transformer-based image model with strong feature extraction capabilities.
Both models were implemented from scratch to experiment with pretraining on the dataset. However, due to the limited size of the dataset (~1,600 images), pretraining from scratch resulted in underwhelming performance. As a result, both models were fine-tuned from their pretrained versions.
While the ViT architecture slightly outperformed EfficientNet-B3 in terms of raw accuracy, EfficientNet was selected for deployment due to its significantly lower inference time and resource requirements.
Model Comparison
Below is a comparison of key metrics between different architectures explored for the retinal disease classification task:
Model | Accuracy | Inference Time | Model Size |
---|---|---|---|
EfficientNet-B3 | 82% | ~0.5 sec | ~12M parameters |
EfficientNet-B4 | 84% | ~0.85 sec | ~19M parameters |
Vision Transformer (ViT) | 85% | ~0.9 sec | ~22M parameters |
Weight Balancing
During initial training, the model exhibited a strong bias toward classifying most disease cases as Diabetic Retinopathy (DR), due to class imbalance in the dataset. This resulted in poor performance on underrepresented classes like ARMD and Optic Disc Cupping.
To address the class imbalance present in the dataset—such as fewer samples of ARMD and Optic Disc Cupping compared to Normal or DR— class weighting was applied during training. This ensured that minority classes contributed proportionally to the loss function, helping the model avoid overfitting to majority classes.
Strategy:
- Calculated weights inversely proportional to class frequencies in the training set.
- Applied these weights to the cross-entropy loss function.
- Improved recall and F1-score for underrepresented diseases like ARMD and ODC.
Formula Used:Weighti = Total Training Samples / (Number of Classes × Samples in Class i)
Example Weights:
Class | Training Samples | Weight (normalized) |
---|---|---|
NORMAL | 516 | 0.37 |
DR | 375 | 0.51 |
ARMD | 100 | 1.92 |
MH | 316 | 0.60 |
ODC | 281 | 0.67 |
This approach significantly enhanced the model’s ability to identify less frequent retinal diseases, particularly within the hierarchical classification system.
Training Configuration
- Optimizer: AdamW
- Learning Rate: 1e-4
- Epochs: 10
- Loss Function: Cross Entropy Loss
- Hardware: Trained on NVIDIA RTX 4080
Evaluation Method
The model’s performance was assessed using two primary evaluation metrics to ensure both sensitivity and specificity in a medical context:
- Recall (Disease Detection): This metric evaluates the model’s ability to detect the presence of any retinal disease, regardless of type. High recall is crucial in healthcare applications to minimize false negatives—cases where a disease is present but the model fails to identify it.
Result: 85% Recall on the validation set. - Accuracy (Disease Classification): Measures the correctness of disease type predictions, given that a disease has been detected. It reflects the model’s precision in distinguishing among different retinal disease categories.
Result: 82% Classification Accuracy across all disease classes.
These metrics were computed on the validation set after training completion.
Class-wise Performance Metrics
To better understand the model’s classification behavior, the following is the calculated precision, recall, and F1-score for each disease class using the validation set.
Class | Precision | Recall | F1-Score |
---|---|---|---|
Normal Retina | 0.91 | 0.89 | 0.90 |
Diabetic Retinopathy (DR) | 0.76 | 0.88 | 0.81 |
Age-Related Macular Degeneration (ARMD) | 0.74 | 0.68 | 0.71 |
Media Haze (MH) | 0.79 | 0.75 | 0.77 |
Optic Disc Cupping (ODC) | 0.81 | 0.73 | 0.77 |