Blog/Tutorials

Getting Started with Medical Image Classification Using Deep Learning

A practical tutorial on building your first medical image classifier with modern deep learning tools.

Introduction

Medical image classification is one of the most accessible entry points into healthcare AI. The core task, given a medical image, assign it a diagnostic category, is structurally identical to standard image classification. The difference is that the stakes are higher and the failure modes are more subtle. A misclassified cat is a meme. A misclassified pathology slide is a missed diagnosis.

This tutorial walks through building a medical image classifier using PyTorch. It draws from my experience building Histia, a pathology AI system, and focuses on the practical decisions that separate a working medical classifier from a dangerously misleading one. If you know basic Python and have seen a PyTorch tutorial, you have enough to follow along.

Advanced medical imaging technology equipment in a clinical setting
Medical imaging technology generates massive volumes of visual data that deep learning models can analyze for diagnostic patterns invisible to the human eye.Photo on Unsplash

Choosing Your Dataset

The dataset determines everything downstream. For a first project, you want something small enough to iterate quickly, well-labeled enough to trust, and medically relevant enough to be interesting.

DatasetDomainSizeBest for
PathMNISTColon pathology107K images, 9 classesFirst project, fast iteration
ISIC 2020Dermatology33K images, 2 classesBinary classification, class imbalance
CheXpertChest X-ray224K images, 14 labelsMulti-label, uncertainty handling

Start with PathMNIST. Part of the MedMNIST collection, it provides standardized 224x224 images with consistent train/val/test splits, eliminating data engineering overhead so you can focus on modeling.

python
from medmnist import PathMNIST
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = PathMNIST(split="train", transform=train_transform, download=True, size=224)
val_data = PathMNIST(split="val", transform=val_transform, download=True, size=224)
test_data = PathMNIST(split="test", transform=val_transform, download=True, size=224)
Key takeaway

Notice what is NOT in the augmentation pipeline: random color jitter. In pathology, color carries diagnostic information. Stain intensity differences between eosinophilic and basophilic tissue are how pathologists distinguish cell types. Randomly shifting hue and saturation destroys clinically meaningful signal.

Preprocessing Pitfalls Specific to Medical Imaging

Medical image preprocessing has domain-specific pitfalls that general CV tutorials never cover.

1
Normalization must match pretraining. Use ImageNet statistics for ImageNet-pretrained models, pathology-specific statistics for pathology-pretrained models. Mismatched normalization silently degrades performance.
2
Augmentation must respect clinical semantics. Flips and rotation are safe for pathology because tissue orientation is arbitrary under a microscope. But elastic deformations can create biologically impossible morphology, and aggressive cropping can remove diagnostically critical regions.
3
Stain normalization is a separate step. If your dataset combines images from multiple scanners, stain variation will dominate your feature space. Macenko or Vahadane normalization standardizes appearance. Skip this and your model learns scanner identity, not pathology.
4
Patient-level splitting is mandatory. Multiple images from the same patient must all go into the same split. If patient A appears in both training and validation, the model memorizes patient-specific features and reports inflated metrics. This is data leakage, the most common methodological error in medical imaging papers.

Transfer Learning: The Two-Phase Approach

Do not train from scratch. Medical image datasets are small by deep learning standards. Transfer learning is not optional. It is required. ResNet-50 pretrained on ImageNet is the baseline: well-understood, fast to fine-tune, and competitive on most medical imaging benchmarks.

python
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torch.optim import AdamW

class MedicalClassifier(nn.Module):
    def __init__(self, num_classes=9, freeze_backbone=True):
        super().__init__()
        self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.3), nn.Linear(in_features, 512), nn.ReLU(),
            nn.Dropout(0.2), nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

    def unfreeze(self):
        for param in self.backbone.parameters():
            param.requires_grad = True

model = MedicalClassifier(num_classes=9, freeze_backbone=True).cuda()
criterion = nn.CrossEntropyLoss()

# Phase 1: Train classifier head (5-10 epochs, lr=1e-3)

opt = AdamW(model.backbone.fc.parameters(), lr=1e-3, weight_decay=1e-4)

# Phase 2: Unfreeze and fine-tune entire network (10-20 epochs, lr=1e-5)

model.unfreeze()
opt = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)

Phase one freezes the backbone and trains only the classification head at lr=1e-3, adapting the classifier to your label space without disrupting pretrained features. Phase two unfreezes the backbone and fine-tunes at lr=1e-5, letting the feature extractor adapt to medical image statistics while the classifier head stays stable.

Vision Transformers (ViT) often outperform CNNs on medical imaging because their attention mechanism handles spatial relationships between tissue structures naturally. But they are more data-hungry. Try ViT-B/16 after you have a ResNet baseline. If ViT does not improve AUROC by at least 2 points, the complexity is not worth it.

Evaluation: Why Accuracy Lies

Here is where most tutorials fail. In medical classification, accuracy is a misleading metric. A skin lesion classifier where 95% of images are benign achieves 95% accuracy by predicting "benign" for everything while missing every cancer case.

1
AUROC. Measures the model's ability to distinguish classes across all thresholds. This is the primary metric for medical classifiers and should be reported per-class in multi-class settings.
2
Sensitivity and Specificity. Sensitivity measures what fraction of true positives are caught. In screening, sensitivity matters more because missing a disease is worse than a false alarm.
3
Confusion Matrix. Shows exactly which classes the model confuses. Confusing adenocarcinoma with squamous cell carcinoma has different clinical implications than confusing either with normal tissue.
4
Calibration. A model outputting 0.8 probability should be correct 80% of the time at that confidence. Temperature scaling fixes poorly calibrated models post-hoc.
python
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np

def evaluate(model, loader, device):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images.to(device))
            all_probs.append(torch.softmax(outputs, dim=1).cpu().numpy())
            all_labels.append(labels.squeeze().numpy())
    probs, labels = np.concatenate(all_probs), np.concatenate(all_labels)
    auroc = roc_auc_score(labels, probs, multi_class="ovr", average=None)
    print(f"Mean AUROC: {auroc.mean():.4f}")
    print(classification_report(labels, probs.argmax(1), digits=4))

Common Pitfalls That Waste Months

1
Class imbalance. If your dataset has 50,000 normal images and 500 cancer images, the model predicts normal for everything and achieves 99% accuracy while being clinically useless. Fix this with weighted cross-entropy that scales inversely with class frequency, oversampling, or focal loss.
2
Data leakage from image-level splits. If your dataset contains multiple images per patient, splitting at the image level lets the model memorize patient-specific features and report inflated metrics. Always split at the patient level. PathMNIST handles this for you. Other datasets may not.
3
Center-specific overfitting. Models trained on single-hospital data learn scanner characteristics and staining protocols, not pathology. Performance drops dramatically at other institutions. This is distribution shift, and it is the reason most medical AI papers do not replicate.
Why this matters

Medical AI is not a Kaggle competition. These metrics map to patient outcomes. A model that optimizes accuracy while ignoring sensitivity on the disease class misses sick patients. Build proper evaluation into your training loop from day one.

Medical brain scan image showing detailed neural structures analyzed by deep learning
Deep learning applied to medical imaging can detect subtle patterns across thousands of scans, but proper evaluation metrics are essential to ensure clinical reliability.Photo on Unsplash

Conclusion

Building a medical image classifier is structurally similar to any image classification task. The difference is that every decision carries clinical weight. Random color jitter destroys diagnostic signal. Accuracy hides class imbalance. Image-level splits produce unreplicable results.

Start with PathMNIST and a pretrained ResNet-50. Get the two-phase transfer learning pipeline working. Evaluate with AUROC and per-class metrics. Split at the patient level. Then iterate. The goal is not to beat a benchmark. It is to build something a clinician would trust with their patients, and that standard is much higher and more interesting than a leaderboard position.

Get involved

If you are building medical imaging tools or want to discuss Histia and the computational pathology work behind it, reach out. This is a field where more builders are needed.

Stay in the loop

Follow along as I explore the intersection of medicine, AI, and engineering.

Just honest writing, straight from me. Unsubscribe anytime.