Introduction
Medical image classification is one of the most accessible entry points into healthcare AI. The core task, given a medical image, assign it a diagnostic category, is structurally identical to standard image classification. The difference is that the stakes are higher and the failure modes are more subtle. A misclassified cat is a meme. A misclassified pathology slide is a missed diagnosis.
This tutorial walks through building a medical image classifier using PyTorch. It draws from my experience building Histia, a pathology AI system, and focuses on the practical decisions that separate a working medical classifier from a dangerously misleading one. If you know basic Python and have seen a PyTorch tutorial, you have enough to follow along.

Choosing Your Dataset
The dataset determines everything downstream. For a first project, you want something small enough to iterate quickly, well-labeled enough to trust, and medically relevant enough to be interesting.
| Dataset | Domain | Size | Best for |
|---|---|---|---|
| PathMNIST | Colon pathology | 107K images, 9 classes | First project, fast iteration |
| ISIC 2020 | Dermatology | 33K images, 2 classes | Binary classification, class imbalance |
| CheXpert | Chest X-ray | 224K images, 14 labels | Multi-label, uncertainty handling |
Start with PathMNIST. Part of the MedMNIST collection, it provides standardized 224x224 images with consistent train/val/test splits, eliminating data engineering overhead so you can focus on modeling.
from medmnist import PathMNIST
from torchvision import transforms
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_data = PathMNIST(split="train", transform=train_transform, download=True, size=224)
val_data = PathMNIST(split="val", transform=val_transform, download=True, size=224)
test_data = PathMNIST(split="test", transform=val_transform, download=True, size=224)Notice what is NOT in the augmentation pipeline: random color jitter. In pathology, color carries diagnostic information. Stain intensity differences between eosinophilic and basophilic tissue are how pathologists distinguish cell types. Randomly shifting hue and saturation destroys clinically meaningful signal.
Preprocessing Pitfalls Specific to Medical Imaging
Medical image preprocessing has domain-specific pitfalls that general CV tutorials never cover.
Transfer Learning: The Two-Phase Approach
Do not train from scratch. Medical image datasets are small by deep learning standards. Transfer learning is not optional. It is required. ResNet-50 pretrained on ImageNet is the baseline: well-understood, fast to fine-tune, and competitive on most medical imaging benchmarks.
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torch.optim import AdamW
class MedicalClassifier(nn.Module):
def __init__(self, num_classes=9, freeze_backbone=True):
super().__init__()
self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Sequential(
nn.Dropout(0.3), nn.Linear(in_features, 512), nn.ReLU(),
nn.Dropout(0.2), nn.Linear(512, num_classes)
)
def forward(self, x):
return self.backbone(x)
def unfreeze(self):
for param in self.backbone.parameters():
param.requires_grad = True
model = MedicalClassifier(num_classes=9, freeze_backbone=True).cuda()
criterion = nn.CrossEntropyLoss()
# Phase 1: Train classifier head (5-10 epochs, lr=1e-3)
opt = AdamW(model.backbone.fc.parameters(), lr=1e-3, weight_decay=1e-4)
# Phase 2: Unfreeze and fine-tune entire network (10-20 epochs, lr=1e-5)
model.unfreeze()
opt = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)Phase one freezes the backbone and trains only the classification head at lr=1e-3, adapting the classifier to your label space without disrupting pretrained features. Phase two unfreezes the backbone and fine-tunes at lr=1e-5, letting the feature extractor adapt to medical image statistics while the classifier head stays stable.
Vision Transformers (ViT) often outperform CNNs on medical imaging because their attention mechanism handles spatial relationships between tissue structures naturally. But they are more data-hungry. Try ViT-B/16 after you have a ResNet baseline. If ViT does not improve AUROC by at least 2 points, the complexity is not worth it.
Evaluation: Why Accuracy Lies
Here is where most tutorials fail. In medical classification, accuracy is a misleading metric. A skin lesion classifier where 95% of images are benign achieves 95% accuracy by predicting "benign" for everything while missing every cancer case.
adenocarcinoma with squamous cell carcinoma has different clinical implications than confusing either with normal tissue.from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
def evaluate(model, loader, device):
model.eval()
all_probs, all_labels = [], []
with torch.no_grad():
for images, labels in loader:
outputs = model(images.to(device))
all_probs.append(torch.softmax(outputs, dim=1).cpu().numpy())
all_labels.append(labels.squeeze().numpy())
probs, labels = np.concatenate(all_probs), np.concatenate(all_labels)
auroc = roc_auc_score(labels, probs, multi_class="ovr", average=None)
print(f"Mean AUROC: {auroc.mean():.4f}")
print(classification_report(labels, probs.argmax(1), digits=4))Common Pitfalls That Waste Months
Medical AI is not a Kaggle competition. These metrics map to patient outcomes. A model that optimizes accuracy while ignoring sensitivity on the disease class misses sick patients. Build proper evaluation into your training loop from day one.

Conclusion
Building a medical image classifier is structurally similar to any image classification task. The difference is that every decision carries clinical weight. Random color jitter destroys diagnostic signal. Accuracy hides class imbalance. Image-level splits produce unreplicable results.
Start with PathMNIST and a pretrained ResNet-50. Get the two-phase transfer learning pipeline working. Evaluate with AUROC and per-class metrics. Split at the patient level. Then iterate. The goal is not to beat a benchmark. It is to build something a clinician would trust with their patients, and that standard is much higher and more interesting than a leaderboard position.
If you are building medical imaging tools or want to discuss Histia and the computational pathology work behind it, reach out. This is a field where more builders are needed.