Designing a custom Monitored Training Session

The tf.train module defines a helpful utility function tf.train.MonitoredTrainingSession that handles a lot of the training boilerplate in an elegant way. Recently, I’ve found limitations in its design, so I decided to build my own version.

In particular, it restricts the client to either restore or initialize all of the trainable variables together, and only allows for restoration from a single checkpoint. For models which require pretrained weights to be restored from a checkpoint and model weights to be restored from a separate checkpoint or initialized on their own, these constraints were severely limiting. An example of such a model might be an LSTM that accepts the embeddings output from a pretrained CNN. Perhaps you want to load weights from a previous training session for the LSTM but also pretrained CNN.

My version offers the following interface:

with utils.MonitoredTrainingSession(
  config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement),
  hooks=session_hooks.values(),
  model_checkpoint_dir=FLAGS.model_checkpoint_dir,  # restores vars from `checkpoint` file
  model_scaffold=None,  # uses default
  pretrained_checkpoint_dirs=pretrained_checkpoint_dirs,
  pretrained_scaffolds=pretrained_scaffolds,
) as sess:
  while not sess.should_stop():
    sess.run(...)

This custom function would behave similarly to the one in tf.train except it would accept lists of checkpoints and scaffolds instead of just a single one. In addition, it would allow for selective initialization and restoration. The implementation ended up being pretty straightforward, with the most significant modification in SessionManager:

class SessionManager(tf.train.SessionManager):
  def __init__(self,
               checkpoint_dirs=None,
               graph=None,
               ready_op=None,
               scaffolds=None,
               target=""):
    if graph is None:
      graph = tf.get_default_graph()

    self._checkpoint_dirs = checkpoint_dirs
    self._graph = graph
    self._ready_op = ready_op
    self._scaffolds = scaffolds
    self._sess = None
    self._target = target

  def _get_session(self, config):
    if self._sess:
      return self._sess

    self._sess = tf.Session(self._target, graph=self._graph, config=config)
    return self._sess

  def _restore_checkpoint(self,
                          saver=None,
                          checkpoint_dir=None,
                          config=None):
    sess = self._get_session(config)
    if not saver or not checkpoint_dir:
      return

    # Don't bother waiting for checkpoint

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if not ckpt or not ckpt.model_checkpoint_path:
      return

    # Loads the checkpoint.
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)


  def prepare_session(self, config=None):
    """
    Prepares session.

    Inputs:
    - scaffolds: A list of scaffolds.
    - checkpoint_dirs: :A list of checkpoint_dirs.
    - checkpoint_filenames_with_path: A list of checkpoint_filenames_with_path.
    """
    sess = self._get_session(config)
    zipped_items = zip(self._scaffolds, self._checkpoint_dirs)

    for scaffold, checkpoint_dir in zipped_items:
      if scaffold.init_op is not None:
        sess.run(scaffold.init_op, feed_dict=scaffold.init_feed_dict)
      if scaffold.init_fn:
        scaffold.init_fn(sess)

      # Restores a subset of the variables (given by the saver's var list)
      self._restore_checkpoint(
        saver=scaffold.saver,
        checkpoint_dir=checkpoint_dir,
        config=config
      )

    is_ready, msg = self._model_ready(sess)
    if not is_ready:
      raise RuntimeError("is_ready is False: %s" % msg)
    return sess

In particular, notice that SessionManager.prepare_session initializes all of the variables, then restores the a subset of them from the checkpoint file, which guarantees that all of the variables will either be initialized or restored, and allows the client to decide on a per-variable basis.

I haven’t had any issues with it yet!