Solving the Fashion MNIST with a simple neural network

Can computers recognize shirts from sandals?

Anna Shi
5 min readDec 25, 2020

This is part one of a two-part series about the Fashion MNIST dataset. I’ll talk about how we can use a basic neural network here, and talk about convolutional neural networks in the next one.

Remember the opening scene of Clueless? It’s an insider’s view into the way normal life of a teenage girl: Cher gets up, brushes her teeth, and picks out her school clothes. Well, technically, her closet picks them out.

Unfortunately for us fashionably-challenged, this magical device hasn’t quite hit the market yet. That doesn’t mean it’s not possible, though. With AI’s exponential growth trajectory, someone could probably put one together. In fact, I’ll even help get the ball rolling.

What’s the first step to building a good outfit? Knowing what clothes you have!

Fashion MNIST dataset

With that in mind, I’m going to turn our attention to a well-known machine learning project. AI enthusiasts probably know of the MNIST dataset, a collection of handwritten digits from 0 to 9. It’s the “Hello World” of machine learning, serving as a great beginner neural network.

Similarly, the Fashion-MNIST dataset is a collection of clothing items grouped into ten different categories. Modeled after the original MNIST, it contains 60 000 greyscale photos of 28 by 28 pixels. Each article of clothing belongs to one of the ten following groups: T-shirt/top, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag, or ankle boot.

The challenge is to train a computer to recognize different types of clothing from photos.

The first 20 photos and labels of the Fashion MNIST training set.

There are several ways to do this, but the two methods I’ll be using are:

  1. Dense neural networks
  2. Convolutional neural networks

In this article, I’ll talk about the first. Dense neural networks are standard neural networks where the output of each node depends on every single nodes in the previous layer. All the connections between the layers create an interconnected web. They’re great for classification problems such as the Fashion MNIST, where we need to predict how to categorize each item.

See how each node is connected to every node in the adjacent layers?

This neural network will look at an image in its whole and learn to recognize certain features in a specific area. If it were deciding if a picture is a shirt, for example, it might look for something like a collar at the top and sleeves on the sides.

Coding the neural network

We’ll start by importing all the libraries we’ll need.

%tensorflow_version 2.x
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

Then, we import the Fashion MNIST dataset. It’s already included in Keras’ datasets. Conveniently, the data’s already split into a training set and a testing set containing 60 000 and 10 000 photos and labels, respectively.

# load the fashion mnist dataset from keras
fashion_mnist = keras.datasets.fashion_mnist
(train_X, train_y),(test_X,test_y) = fashion_mnist.load_data()

Before we can start building the model, we need to do some data preprocessing. Each pixel in the 28 by 28 photos is a data point of the shade of the pixel, represented by a number between 0 and 255.

It's typically easier to work with numbers within a specified range so that certain features don't get distorted and weighed more than another. In this case, all the inputs will be numbers between 0 and 255, but we'll still normalize the data for consistency.

# data preprocessing, scaling values from 0-255 to 0-1
train_X = train_X/255.0
test_X = test_X/255.0

Now we can begin building the layers of the network. Since this is a standard neural network and not a CNN or RNN, we can just use a sequential model. The model will take information from one layer to the next, and on and on. We'll create one hidden layer with 128 nodes, and an output layer with 10 nodes, one for each type of clothing.

model = tf.keras.models.Sequential([ 
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])

The model initializes random weights between the layers, so we're going to get some wonky results. We'll implement an optimizer that can check exactly how wrong we are and update the model's weights accordingly. This is also known as backpropagation, which you can read more about here.

model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

From there, we can finally start training the model. We can choose how many times we want to run through the model with the number of epochs.

model.fit(train_X, train_y, epochs=12)

Although you may think the more epochs the better, we also have to be careful about overfitting. If you train your model on the same data too many times, then your model might become too tailored to the training data. This is why it's important to separate the data into training and testing, so that you can measure how well your model would perform with unfamiliar data.

loss, acc = model.evaluate(test_X, test_y, verbose = 1)
print('\\nTest accuracy: ', acc)

That's it! You can see the full code here. If you run through it, you can even test the accuracy yourself. Choose a random photo by typing in a number, and then see if the model classified it correctly.

When I run this, I usually get an accuracy in the mid-eighties. Using a convolutional neural network (CNN) could help us to achieve a higher level of accuracy (>90%), but I'll get into that in the next article.

Even without using CNNs, we achieved a respectable accuracy. That's a great start to building an assistant closet. Maybe one day, we can all be fashion icons like Cher Horowitz. With the help of AI, anything's possible.

This code is based off of some great tutorials from TensorFlow and FreeCodeCamp. Try them out if you're interested in a beginner's machine learning project. I'd be happy to answer any questions.

If you liked this article, follow me on Medium to see more in the future! Reach out if you have any questions through my LinkedIn or email.

--

--

Anna Shi

Learning how tomorrow's technologies will transform today's future. Especially interested in artificial intelligence and climate solutions.