Project Goals
3 min- Load MNIST and understand its shape (28×28 grayscale images).
- Preprocess: normalise pixels, flatten or keep 2-D.
- Train a dense net (and optionally a small CNN) to classify digits.
- Inspect the confusion matrix and the misclassified digits.
Warm-Up · See the Data
5 minfrom tensorflow.keras.datasets import mnist import matplotlib.pyplot as plt (Xtr, ytr), (Xte, yte) = mnist.load_data() print(Xtr.shape, Xte.shape) # (60000, 28, 28) (10000, 28, 28) print("pixel range:", Xtr.min(), Xtr.max()) # 0 .. 255 plt.imshow(Xtr[0], cmap="gray") plt.title(f"label: {ytr[0]}") plt.savefig("digit.png", dpi=120)
Each image is a 28×28 array of 0-255 pixels (Lesson 5!). The net's job: map 784 pixels → one of 10 digits. Same fit/compile/evaluate recipe as before — only the input is bigger and the data is images.
Plan · Prep & Architecture
14 minPreprocess
# scale pixels to 0-1 (nets like small inputs) Xtr = Xtr.astype("float32") / 255.0 Xte = Xte.astype("float32") / 255.0
Option A — a dense net (flatten the image)
from tensorflow import keras from tensorflow.keras import layers dense = keras.Sequential([ layers.Input(shape=(28, 28)), layers.Flatten(), # 28×28 → 784 layers.Dense(128, activation="relu"), layers.Dropout(0.3), layers.Dense(10, activation="softmax"), # 10 digits ])
Option B — a small CNN (keeps 2-D structure)
cnn = keras.Sequential([ layers.Input(shape=(28, 28, 1)), layers.Conv2D(32, 3, activation="relu"), layers.MaxPooling2D(), layers.Conv2D(64, 3, activation="relu"), layers.MaxPooling2D(), layers.Flatten(), layers.Dense(64, activation="relu"), layers.Dropout(0.3), layers.Dense(10, activation="softmax"), ])
A Conv2D slides small filters across the image to detect local patterns (edges, curves) — much better suited to images than flattening. CNNs beat dense nets on images; we'll lean on pretrained CNNs in Lesson 29.
Compile (10-class softmax)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
Build · mnist.py
12 min# mnist.py — train & inspect a digit recogniser import numpy as np import matplotlib.pyplot as plt from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.datasets import mnist from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay (Xtr, ytr), (Xte, yte) = mnist.load_data() Xtr = Xtr.astype("float32") / 255.0 Xte = Xte.astype("float32") / 255.0 model = keras.Sequential([ layers.Input(shape=(28, 28)), layers.Flatten(), layers.Dense(128, activation="relu"), layers.Dropout(0.3), layers.Dense(10, activation="softmax"), ]) model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) model.fit(Xtr, ytr, validation_split=0.1, epochs=10, batch_size=128, verbose=2) loss, acc = model.evaluate(Xte, yte, verbose=0) print(f"\ntest accuracy: {acc:.2%}") # confusion matrix preds = model.predict(Xte).argmax(axis=1) ConfusionMatrixDisplay(confusion_matrix(yte, preds)).plot() plt.savefig("mnist_confusion.png", dpi=130) # show some mistakes wrong = np.where(preds != yte)[0][:8] fig, ax = plt.subplots(2, 4, figsize=(9, 5)) for a, i in zip(ax.flat, wrong): a.imshow(Xte[i], cmap="gray") a.set_title(f"true {yte[i]}, pred {preds[i]}") a.axis("off") fig.tight_layout(); fig.savefig("mnist_mistakes.png", dpi=130)
Sample output
test accuracy: 97.8%
Read the diff
A plain dense net hits ~98% in seconds. The confusion matrix shows which digits get mixed up (4↔9, 3↔5, 7↔1 are classic). Looking at the actual misclassified images is humbling — many are genuinely ambiguous scrawls a human would also pause on. A CNN pushes this to ~99%.
Extensions
13 minTrain both the dense net and the CNN. Compare test accuracy and training time.
From the confusion matrix, find the pair of digits most often mistaken for each other. Show examples.
Draw a digit in any paint app, save as 28×28 grayscale, invert if needed (MNIST is white-on-black), and predict it. Does your model read your handwriting?
Hint
from PIL import Image img = Image.open("my_digit.png").convert("L").resize((28, 28)) arr = 255 - np.array(img) # invert if your bg is white arr = arr.astype("float32") / 255.0 print("prediction:", model.predict(arr[None, ...]).argmax())
Stretch · Push Past 99%
8 minWith the CNN, add data augmentation (small rotations/shifts), a second dropout, and train longer with early stopping. Aim for 99%+ test accuracy. Document what moved the needle.
Recap
3 minMNIST images are 28×28 pixel arrays (0-255 → normalise to 0-1). A dense net hits ~98%; a CNN, which respects 2-D structure, hits ~99%. The confusion matrix + viewing mistakes is the real evaluation — accuracy alone hides which digits confuse the model. Next: classify clothing with the same toolkit.
Homework
4 minTrain an MNIST classifier (dense or CNN), report test accuracy, save the confusion matrix and a grid of 8 mistakes. Write a paragraph: which digits confuse your model and why those make visual sense.
The deliverable is mnist.py + the two figures + a paragraph. 4↔9 and 3↔5 confusions are expected — note how the misclassified images are often genuinely ambiguous.