Designing a custom Monitored Training Session
tf.train module defines a helpful utility function
tf.train.MonitoredTrainingSession that handles a lot of the training boilerplate in an elegant way. Recently, I’ve found limitations in its design, so I decided to build my own version.
In particular, it restricts the client to either restore or initialize all of the trainable variables together, and only allows for restoration from a single checkpoint. For models which require pretrained weights to be restored from a checkpoint and model weights to be restored from a separate checkpoint or initialized on their own, these constraints were severely limiting. An example of such a model might be an LSTM that accepts the embeddings output from a pretrained CNN. Perhaps you want to load weights from a previous training session for the LSTM but also pretrained CNN.
My version offers the following interface:
with utils.MonitoredTrainingSession( config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement), hooks=session_hooks.values(), model_checkpoint_dir=FLAGS.model_checkpoint_dir, # restores vars from `checkpoint` file model_scaffold=None, # uses default pretrained_checkpoint_dirs=pretrained_checkpoint_dirs, pretrained_scaffolds=pretrained_scaffolds, ) as sess: while not sess.should_stop(): sess.run(...)
This custom function would behave similarly to the one in
tf.train except it would accept lists of checkpoints and scaffolds instead of just a single one. In addition, it would allow for selective initialization and restoration. The implementation ended up being pretty straightforward, with the most significant modification in
class SessionManager(tf.train.SessionManager): def __init__(self, checkpoint_dirs=None, graph=None, ready_op=None, scaffolds=None, target=""): if graph is None: graph = tf.get_default_graph() self._checkpoint_dirs = checkpoint_dirs self._graph = graph self._ready_op = ready_op self._scaffolds = scaffolds self._sess = None self._target = target def _get_session(self, config): if self._sess: return self._sess self._sess = tf.Session(self._target, graph=self._graph, config=config) return self._sess def _restore_checkpoint(self, saver=None, checkpoint_dir=None, config=None): sess = self._get_session(config) if not saver or not checkpoint_dir: return # Don't bother waiting for checkpoint ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if not ckpt or not ckpt.model_checkpoint_path: return # Loads the checkpoint. saver.restore(sess, ckpt.model_checkpoint_path) saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) def prepare_session(self, config=None): """ Prepares session. Inputs: - scaffolds: A list of scaffolds. - checkpoint_dirs: :A list of checkpoint_dirs. - checkpoint_filenames_with_path: A list of checkpoint_filenames_with_path. """ sess = self._get_session(config) zipped_items = zip(self._scaffolds, self._checkpoint_dirs) for scaffold, checkpoint_dir in zipped_items: if scaffold.init_op is not None: sess.run(scaffold.init_op, feed_dict=scaffold.init_feed_dict) if scaffold.init_fn: scaffold.init_fn(sess) # Restores a subset of the variables (given by the saver's var list) self._restore_checkpoint( saver=scaffold.saver, checkpoint_dir=checkpoint_dir, config=config ) is_ready, msg = self._model_ready(sess) if not is_ready: raise RuntimeError("is_ready is False: %s" % msg) return sess
In particular, notice that
SessionManager.prepare_session initializes all of the variables, then restores the a subset of them from the checkpoint file, which guarantees that all of the variables will either be initialized or restored, and allows the client to decide on a per-variable basis.
I haven’t had any issues with it yet!