def __read_meta_graph(self): logger.info("Reading meta-graph...") # read the original meta-graph tf.train.import_meta_graph(self.meta_graph_filename + '.meta', clear_devices=True) # add custom gunpowder variables with tf.variable_scope('gunpowder'): self.iteration = tf.get_variable('iteration', shape=1, initializer=tf.zeros_initializer, trainable=False) self.iteration_increment = tf.assign(self.iteration, self.iteration + 1) # Until now, only variables have been added to the graph that are part # of every checkpoint. We create a 'basic_saver' for only those # variables. self.basic_saver = tf.train.Saver(max_to_keep=None) # Add custom optimizer and loss, if requested. This potentially adds # more variables, not covered by the basic_saver. if self.optimizer_func is not None: loss, optimizer = self.optimizer_func(self.graph) self.loss = loss self.optimizer = optimizer # We create a 'full_saver' including those variables. self.full_saver = tf.train.Saver(max_to_keep=None) # find most recent checkpoint checkpoint_dir = os.path.dirname(self.meta_graph_filename) checkpoint = tf.train.latest_checkpoint(checkpoint_dir) if checkpoint: try: # Try to restore the graph, including the custom optimizer # state (if a custom optimizer was used). self.__restore_graph(checkpoint, restore_full=True) except tf.errors.NotFoundError: # If that failed, we just transitioned from an earlier training # without the custom optimizer. In this case, restore only the # variables of the original meta-graph and 'gunpowder' # variables. Custom optimizer variables will be default # initialized. logger.info("Checkpoint did not contain custom optimizer " "variables") self.__restore_graph(checkpoint, restore_full=False) else: logger.info("No checkpoint found") # initialize all variables self.session.run(tf.global_variables_initializer())
def __restore_graph(self, checkpoint, restore_full): logger.info("Restoring model from %s", checkpoint) if restore_full: logger.info("...using a saver for all variables") self.full_saver.restore(self.session, checkpoint) else: # initialize all variables, such that non-basic variables are # initialized self.session.run(tf.global_variables_initializer()) logger.info("...using a saver for basic variables only") self.basic_saver.restore(self.session, checkpoint)