# Designing a custom Monitored Training Session

The tf.train module defines a helpful utility function tf.train.MonitoredTrainingSession that handles a lot of the training boilerplate in an elegant way. Recently, I’ve found limitations in its design, so I decided to build my own version.

In particular, it restricts the client to either restore or initialize all of the trainable variables together, and only allows for restoration from a single checkpoint. For models which require pretrained weights to be restored from a checkpoint and model weights to be restored from a separate checkpoint or initialized on their own, these constraints were severely limiting. An example of such a model might be an LSTM that accepts the embeddings output from a pretrained CNN. Perhaps you want to load weights from a previous training session for the LSTM but also pretrained CNN.

My version offers the following interface:

with utils.MonitoredTrainingSession(
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement),
hooks=session_hooks.values(),
model_checkpoint_dir=FLAGS.model_checkpoint_dir,  # restores vars from checkpoint file
model_scaffold=None,  # uses default
pretrained_checkpoint_dirs=pretrained_checkpoint_dirs,
pretrained_scaffolds=pretrained_scaffolds,
) as sess:
while not sess.should_stop():
sess.run(...)


This custom function would behave similarly to the one in tf.train except it would accept lists of checkpoints and scaffolds instead of just a single one. In addition, it would allow for selective initialization and restoration. The implementation ended up being pretty straightforward, with the most significant modification in SessionManager:

class SessionManager(tf.train.SessionManager):
def __init__(self,
checkpoint_dirs=None,
graph=None,
scaffolds=None,
target=""):
if graph is None:
graph = tf.get_default_graph()

self._checkpoint_dirs = checkpoint_dirs
self._graph = graph
self._scaffolds = scaffolds
self._sess = None
self._target = target

def _get_session(self, config):
if self._sess:
return self._sess

self._sess = tf.Session(self._target, graph=self._graph, config=config)
return self._sess

def _restore_checkpoint(self,
saver=None,
checkpoint_dir=None,
config=None):
sess = self._get_session(config)
if not saver or not checkpoint_dir:
return

# Don't bother waiting for checkpoint

ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if not ckpt or not ckpt.model_checkpoint_path:
return

saver.restore(sess, ckpt.model_checkpoint_path)
saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)

def prepare_session(self, config=None):
"""
Prepares session.

Inputs:
- scaffolds: A list of scaffolds.
- checkpoint_dirs: :A list of checkpoint_dirs.
- checkpoint_filenames_with_path: A list of checkpoint_filenames_with_path.
"""
sess = self._get_session(config)
zipped_items = zip(self._scaffolds, self._checkpoint_dirs)

for scaffold, checkpoint_dir in zipped_items:
if scaffold.init_op is not None:
sess.run(scaffold.init_op, feed_dict=scaffold.init_feed_dict)
if scaffold.init_fn:
scaffold.init_fn(sess)

# Restores a subset of the variables (given by the saver's var list)
self._restore_checkpoint(
saver=scaffold.saver,
checkpoint_dir=checkpoint_dir,
config=config
)


In particular, notice that SessionManager.prepare_session initializes all of the variables, then restores the a subset of them from the checkpoint file, which guarantees that all of the variables will either be initialized or restored, and allows the client to decide on a per-variable basis.