Skip to main content
Machine Learning & Deep Learning

⏱ About 20 min20 XP

Training a Deep Network

Every concept in this module — the forward pass, the loss function, gradient descent, backpropagation, regularization — converges in the training loop. This lesson takes a practitioner's perspective: we trace through a realistic training run on a real-world-style problem, interpret the signals the model sends back, and make the kinds of decisions that separate a model that works from one that does not. You will learn not just what to do but why, and what to do when things go wrong.

The Training Setup

Problem: classify images of handwritten digits 0-9 (a simplified version of the famous MNIST dataset). Dataset: 50,000 training examples, 10,000 validation examples, 10,000 test examples. Each example is a 28x28 grayscale image (784 pixel values). Architecture: fully connected network. - Input: 784 features (flattened pixels) - Hidden layer 1: 256 neurons, ReLU - Hidden layer 2: 128 neurons, ReLU - Output: 10 neurons, softmax (probability for each digit class) Parameter count: W^(1): 784*256 = 200,704 weights + 256 biases W^(2): 256*128 = 32,768 weights + 128 biases W^(3): 128*10 = 1,280 weights + 10 biases Total: approximately 234,146 parameters. Loss function: categorical cross-entropy. Optimizer: mini-batch SGD, batch size 128, learning rate eta = 0.01. Regularization: weight decay lambda = 1e-4, dropout rate 0.3 on each hidden layer. Training: 30 epochs (one epoch = one pass through all 50,000 training examples = 391 mini-batches).

Weight Initialization

Weights cannot start at zero — if they did, all neurons in a layer would receive the same gradient and remain identical forever (the symmetry problem). Instead, weights are initialized randomly, typically from a distribution scaled to the layer size. A common choice is He initialization (for ReLU networks): sample from a normal distribution with mean 0 and standard deviation sqrt(2 / n_in), where n_in is the number of inputs to the layer. Biases are initialized to zero.

The training run, epoch by epoch (selected): Epoch 1: train_loss=2.31, val_loss=2.29, val_accuracy=15% (Near random — cross-entropy of a uniform 10-class classifier is -log(0.1) ≈ 2.30. Expected.) Epoch 3: train_loss=1.45, val_loss=1.48, val_accuracy=58% (Loss dropping fast; gradient signal is strong and learning rate is well-chosen.) Epoch 10: train_loss=0.38, val_loss=0.41, val_accuracy=88% (Train and val loss track closely — no overfitting yet. Regularization is working.) Epoch 20: train_loss=0.24, val_loss=0.28, val_accuracy=92% (Continued improvement. Small gap between train and val loss — healthy.) Epoch 25: train_loss=0.19, val_loss=0.27, val_accuracy=93% (Val loss slowing; train loss still dropping. Gap growing slightly.) Epoch 30: train_loss=0.16, val_loss=0.30, val_accuracy=93% (Val loss has slightly increased from epoch 25. Early stopping would trigger if we set patience=5.) Final test accuracy: 92.8% — respectable for a fully connected network on MNIST without convolution.

Diagnosing and Improving

What would we do next to improve this model? Observation 1: val_accuracy plateaued near 93%. For MNIST, state-of-the-art is above 99% — but that requires convolutional networks that exploit spatial structure in images. A fully connected network on flattened pixels cannot see spatial relationships. This is a structural limit, not a hyperparameter issue. Observation 2: The train-val loss gap grew slightly in epochs 25-30. To reduce this, we could: (a) increase dropout rate from 0.3 to 0.4-0.5, (b) increase weight decay lambda, (c) use early stopping at epoch 25. Observation 3: What if training had gone wrong? Common failure modes and their signatures: Loss not decreasing at all: check learning rate (too low?), check weight initialization, check for bugs in the loss function or data loading. Loss decreasing then suddenly spiking: learning rate too high causing an unstable update. Reduce eta or use a schedule. Validation accuracy stuck near chance level: possible label bug (labels shuffled incorrectly?), data normalization issue (pixel values in wrong range?), or model capacity too low. Training loss near zero, val accuracy near chance: extreme overfitting. The model memorized training examples by ID, not by content. Check for data leakage, increase regularization dramatically.

Flashcards — click each card to reveal the answer

Baseline First, Improve Second

Always establish a baseline before trying improvements. For classification, a trivial baseline is the majority-class classifier (predict the most common class every time). Compute its accuracy. If your network barely beats it, something is wrong with the setup, not the architecture. A meaningful result must beat the baseline by a meaningful margin before any other claims are credible.

In epoch 1, the training loss is 2.31, close to the expected random-classifier loss of 2.30. What does this confirm about the training setup?

Validation accuracy is 93% but a research paper reports 99% on the same dataset. Why would simply adding more fully connected layers likely not close this gap?

Training Run Critique

  1. Step 1: You are given the following training log for a binary classifier (output: spam/not-spam email):
  2. Epoch 5: train_loss=0.68, val_loss=0.69, train_acc=57%, val_acc=56%
  3. Epoch 10: train_loss=0.61, val_loss=0.62, train_acc=65%, val_acc=64%
  4. Epoch 15: train_loss=0.55, val_loss=0.58, train_acc=71%, val_acc=68%
  5. Epoch 20: train_loss=0.42, val_loss=0.61, train_acc=80%, val_acc=70%
  6. Epoch 25: train_loss=0.28, val_loss=0.74, train_acc=89%, val_acc=67%
  7. Step 2: Identify the epoch range where overfitting clearly begins. Cite specific numbers from the log.
  8. Step 3: If early stopping with patience=5 was used (patience = 5 epochs between checks), at which epoch checkpoint would you save the model?
  9. Step 4: List two hyperparameter changes you would try in the next run to improve val_acc without sacrificing test generalization. For each change, state what problem it addresses.
  10. Step 5: The baseline accuracy for this spam dataset (always predict 'not spam') is 60%. Is epoch 15's val_acc of 68% a meaningful improvement? By how many points?