A bug hidden in plain sight

Last week, I started searching for the origin of an odd bug. It turned out to be one of the more rewarding bugs to solve—unlike a variable misassignment or scoping issue, which might be straightforward (and somewhat mechanical) to resolve after finding the particular place that breaks the system, this bug deepened my conceptual understanding of the system that I had implemented and its underlying components. At least in retrospect, I’m glad I made the initial mistake because now I am woke.

Tldr;

def train(learning_rate):
  optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):  # for batch norm
    train_op = optimizer.minimize(
      total_loss_op,
      global_step=self._global_step_op,
      name="train"
    )
    return train_op

def main():
  ...

  with tf.Session() as sess:
    sess.run(train(0.01))

Refresher on batch norm

Our architecture

My research collaborator Jake and I are designing models to predict the relationship between image pairs, where each image in the pair gets drawn from one of two distributions, A or B. A couple of weeks ago, Jake brought odd behavior of one of these models to my attention. He designed a CNN in TensorFlow that read a TFRecord containing three image pairs (think of each pair as consisting of a “left” image and a “right” image). The first image pair contained two images from A, the second one contained two images from B and the third one contained a left image from A and a right one from B. He constructed three types of batches, one for each of the types of image pairs, and passed each of the images in the image pairs in these batches through ResNet-18, reusing the variables from ResNet across all of the batches. I’ve distilled the implementation in pseudocode:

def model():
  img_pair_1_batch_op, labels_1_batch_op,\
  img_pair_2_batch_op, labels_2_batch_op,\
  img_pair_3_batch_op, labels_3_batch_op = train_inputs()

  # Extracts embeddings with shape (128,)
  emb_1_left_op = resnet(img_pair_1_batch_op.left)
  emb_1_right_op = resnet(img_pair_1_batch_op.right, reuse=True)
  emb_2_left_op = resnet(img_pair_2_batch_op.left, reuse=True)  
  emb_2_right_op = resnet(img_pair_2_batch_op.right, reuse=True)
  emb_3_left_op = resnet(img_pair_3_batch_op.left, reuse=True)
  emb_3_right_op = resnet(img_pair_3_batch_op.right, reuse=True)

  # Concatenates embeddings into shape (256,) and passes through
  # a fully connected layer
  prediction_1_op = fully_connected(
    tf.concat[emb_1_left_op, emb_1_right_op], axis=1)
  prediction_2_op = fully_connected(
    tf.concat[emb_2_left_op, emb_2_right_op], axis=1)
  prediction_3_op = fully_connected(
    tf.concat[emb_3_left_op, emb_3_right_op], axis=1)

  # Computes losses
  total_loss = mean(l2(prediction_1_op - labels_1_batch_op, axis=1))\
               + mean(hinge(prediction_2_op - labels_2_batch_op, axis=1))\
               + mean(hinge(prediction_3_op - labels_3_batch_op, axis=1))

The loss was the sum of three individual losses: an L2 loss between the prediction and label for the first image pair between two images from A, and hinge losses between the predictions and labels for the second and third image pairs. He trained our loss on this task to 0.15, a promising result, given that a model that “guesses” all 0’s would produce a loss of 1.6. However, evaluating this model produced a loss closer to 0.80, even when evaluated on the training set. I realized that the silent failure caused by the bug contributed an additional challenge to finding its cause; however, the astonishing difference in model performance on the exact same dataset implied to me the existence of a blatant implementation error that would require one of those mechanical debugging sessions. I was somewhat right—in some sense, our mistake was hidden in plain sight.

Diagnosing the symptoms

As I diagnosed our system, I channeled my inner-Sherlock and dusted for evidence of the culprit. I started with the one signal that I had—that the model performed differently during training and evaluation—and composed a list of places where the logic for training and evaluation branched apart.

First, I confirmed that the variables had been restored to the same values during training and evaluation. Since I use a custom implementation of tf.train.MonitoredTrainingSession, I initially grew concerned that there existed some edge case that caused weights to load differently during training and evaluation. I randomly sampled about 10 of the weights during training and evaluation runs, and all of them matched.

Then, I noticed that during training, I grew my queue of examples to a minimum size much greater than the batch size and called tf.train.shuffle_batch to create the batch operator; however during evaluation, I grew my queue of examples to a minimum size of the batch size and called tf.train.batch to create the batch operator. Using tf.train.shuffle_batch with a queue containing a greater fraction of the dataset during evaluation didn’t change the results that I observed; the evaluation loss stayed at ~0.80.

It took me about a half day of staring at the rest of my code to realize the last place that the model differed during training and evaluation: batch norm.

Batch norm uses the mini-batch stats during training but the moving averages found in training during evaluation. As a result, passing training=True to tf.layers.batch_normalization during evaluation (forcing it to use mini-batch stats) and using tf.train.shuffle_batch with a queue containing a greater fraction of the dataset and the same batch size as training during evaluation resulted in an evaluation loss on the training dataset that matched the training loss of ~0.20. This made sense since randomly sampling a large batch from a larger fraction of the dataset produced a more well-mixed batch than simply the next few examples in the dataset.,

Why did the published batch norm implementation of training and evaluation produce such poor results though? I googled around for tensorflow batch norm eval training accuracy and similar queries. A Git thread suggested that I decrease the training from 0.997 to 0.9. Changing this parameter moved the evaluation loss, but only by +/-0.05. But more on this later.

Anyways, upon realizing this odd behavior resulted from a published feature of batch norm, I passed through the 7 stages of grief: (1) SHOCK AND DENIAL. What?! Could batch norm be broken?! There’s no way?! It’s used in everything?! How could it be broken?! (2) PAIN AND GUILT. Aghh this is totes my fault for using a component without understanding it. (3) ANGER AND BARGAINING. HOW AM I THE ONLY PERSON IN THE WORLD WITH THIS BUG?? (4) DEPRESSION, REFLECTION, LONELINESS. -Crying alone in lab in darkness at 1 a.m. wondering why it’s not working, asking myself how I got here-. (5) THE UPWARD TURN. Well at least I’ve narrowed the bug down to be batch norm. (6) RECONSTRUCTION AND WORKING THROUGH. Let me chug a bunch of coffee and figure out why this is happening. (7) ACCEPTANCE AND HOPE. Well I never needed to accept that the bug exist because I FIXED IT, MOTHAtrucker!!!

The path to enlightenment started with a simple question: why does batch norm use mini-batch stats during training as opposed to just the moving averages it tracks? I reasoned that using the moving averages instead of mini-batch stats during training offered a couple desirable properties: (1) it would guarantee consistency between training loss and evaluation loss and (2) batch norm could better learn to normalize distributions that it would encounter during evaluation. The second property seemed like it would improve performance of this model in particular since the errors compounded—to see what I mean by this, consider the following:

In the standard implementation of batch norm, a network learns its weights as a function of mini-batch stats. If the mini-batch stats of the input activations to a specific batch norm layer t differ from the moving averages tracked for that batch norm, then the normalization applied during evaluation by is completely different from the normalization performed during training. As a result, learned parameters and are meaningless. To make matters worse, batch norm layer has been trained to normalize a specific distribution of inputs according to their batch stats. Now, not only must it apply a different normalization according to the moving averages instead of the mini-batch stats, but it must also apply it to an entirely different distribution. This effect compounds per batch norm layer, producing the astonishing chasm between training and evaluation loss.

I asked myself why there was such a huge difference between the mini-batch stats and the moving averages. After all, the moving averages should be an average of a bunch of mini-batch stats. This question produced the breakthrough insight: the moving averages were average activations across all of the images from distributions A and B. However, during training, each batch only consisted of a single type of image, from either A or B. So using the mini-batch stats represented a HUGE advantage. To see why, consider a batch norm layer that encounters batches of size 8 with activations of value 1 or 2 and tries to normalize the batch to 0 mean. If the batches:

(1 1 1 2 2 1 2 2) and (1 1 2 1 2 2 2 1) and moving_mean=1.5 => using mini-batch stats or moving averages produces an effective normalization.

(1 1 1 1 1 1 1 1) and (2 2 2 2 2 2 2 2) and moving_mean=1.5 => using mini-batch stats produces an effective normalization, but using the moving averages does not.

Our architecture constructs mini-batches of uniform distribution, either all from A or all from B. Using mini-batch stats for well-mixed data provides an insignificant advantage to normalize the examples compared to the moving averages for well-mixed batches because the presence of one example in the batch is independent from the presence of another example. However, our model creates batches with exactly complete dependence—if one of the images in a batch gets sampled from distribution A, then all of the images in that batch get sampled from distribution A. In some senses, our design represents the worst-case input for batch norm, when the moving averages are the furthest as possible for the largest fraction of the dataset.

To fix this issue, we need to create a mini-batch consisting of an equal number of each of the three types of image pairs, or normalize all of the images to a common distribution C.

The aftermath

This bug wasn’t very fun to find, but I’m happy that it existed. In search for its cause, I deepened my understanding of my topic of study. It caused me to read the batch norm paper from abstract to citations, and then reread it twice over. Even then, I didn’t understand it (at least well enough to solve my bug) until I started to ask questions about the effects of their design decisions on my model. When I finally pinpointed the source of the bug, I experienced a super intense nerd high—which is silly, considering it wasn’t even a great scientific discovery. I guess I just enjoy learning for the sake of learning…