MNIST Addition

Back to all notebooks

The task considered in this notebook is very reminiscent of the classical learning task on the MNIST data. However, instead of providing labels for single digits, we train on pairs of images labeled with the sum of the individual digits. It was first introduced in Manhaeve 2018.

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
torch.manual_seed(1234)

We begin by defining our model, taken from the Pytorch MNIST tutorial

class MNIST_Net(nn.Module):
    def __init__(self, N=10):
        super(MNIST_Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1,  6, 5),
            nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
            nn.ReLU(True),
            nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
            nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
            nn.ReLU(True)
        )
        self.classifier =  nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, N)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x

We load the usual MNIST image data

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train_data = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True,transform=transform)
mnist_test_data = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True,transform=transform)

test_kwargs = {'batch_size': 256}
test_loader = torch.utils.data.DataLoader(mnist_test_data, **test_kwargs)

We load the MNIST addition dataset, generated by pairing random MNIST digits and labeling them with their summation i.e. each datum is of the form (idx1, idx2, summation) where idx1 corresponds to the index of the first image, idx2 corresponds to the index of the second image, and summation corresponds to the sum of their groundtruth labels

# ---------- train_data ----------
with open('train_data.txt') as f:
    train_data = f.readlines()
    
# Strip new lines
train_data = [d.strip() for d in train_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
train_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in train_data]

# ---------- test data ----------
with open('test_data.txt') as f:
    test_data = f.readlines()
    
# Strip new lines
test_data = [d.strip() for d in test_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
test_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in test_data]

# Tensorize
train_data = torch.tensor(train_data)[:9000]
test_data = torch.tensor(test_data)

Create our model as well as our optimizer

model = MNIST_Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Even though we train on pairs of images and their summation, we test on the classic setting i.e. predicting the label of a single digit

def test():
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()


    print('Test set: Accuracy: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In lieu of the traditional cross entropy loss, we require that the sum of predicted labels match the groundtruth by enforcing it as a constraint at training time. This requires that we import the constraint module. Line 12 declares enforce_sum_constraint as a constraint to be enforced at training time. We note that our constraint function, enforce_sum is a vanilla python function, and does not make use of any foreign syntax.

# ---------- Set up the constraints ----------
import sys
sys.path.append("../")

from pytorch_constraints.constraint import constraint

def enforce_sum(img1, img2, kwargs):
    return img1 + img2 == kwargs['summation']


enfore_sum_constraint = constraint(enforce_sum)

Finally, we proceed to our normal training loop, where we minimize our constraint loss during training, as can be seen on line 22

from tqdm import tqdm

NUM_EPOCHS = 1

for epoch in range(NUM_EPOCHS):
      
    # train
    for i, batch in enumerate(tqdm(train_data)):
        model.train()
        optimizer.zero_grad()
        idx1, idx2, summation = batch
        X1 = mnist_train_data[idx1][0].unsqueeze(0)
        X2 = mnist_train_data[idx2][0].unsqueeze(0)
        
        output1 = model(X1)
        output2 = model(X2)
        
        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        
        closs = enfore_sum_constraint(output1, output2, summation=summation)


        closs.backward()
        optimizer.step()
        
        if i % 1000 == 0 and i != 0:
            test()
        
    test()
 11%|█         | 1003/9000 [00:54<21:14,  6.27it/s]

Test set: Accuracy: 6282/10000 (63%)



 22%|██▏       | 2004/9000 [01:50<25:17,  4.61it/s]

Test set: Accuracy: 9333/10000 (93%)



 33%|███▎      | 3003/9000 [02:46<20:46,  4.81it/s]

Test set: Accuracy: 9581/10000 (96%)



 44%|████▍     | 4003/9000 [03:46<15:12,  5.47it/s]

Test set: Accuracy: 9630/10000 (96%)



 56%|█████▌    | 5002/9000 [04:46<16:40,  4.00it/s]

Test set: Accuracy: 9492/10000 (95%)



 67%|██████▋   | 6002/9000 [05:47<12:38,  3.95it/s]

Test set: Accuracy: 9679/10000 (97%)



 78%|███████▊  | 7004/9000 [06:49<06:48,  4.88it/s]

Test set: Accuracy: 9685/10000 (97%)



 89%|████████▉ | 8004/9000 [07:51<03:23,  4.88it/s]

Test set: Accuracy: 9716/10000 (97%)



100%|██████████| 9000/9000 [08:52<00:00, 16.91it/s]


Test set: Accuracy: 9685/10000 (97%)