Ejemplo n.º 1
0
    def __read_meta_graph(self):

        logger.info("Reading meta-graph...")

        # read the original meta-graph
        tf.train.import_meta_graph(self.meta_graph_filename + '.meta',
                                   clear_devices=True)

        # add custom gunpowder variables
        with tf.variable_scope('gunpowder'):
            self.iteration = tf.get_variable('iteration',
                                             shape=1,
                                             initializer=tf.zeros_initializer,
                                             trainable=False)
            self.iteration_increment = tf.assign(self.iteration,
                                                 self.iteration + 1)

        # Until now, only variables have been added to the graph that are part
        # of every checkpoint. We create a 'basic_saver' for only those
        # variables.
        self.basic_saver = tf.train.Saver(max_to_keep=None)

        # Add custom optimizer and loss, if requested. This potentially adds
        # more variables, not covered by the basic_saver.
        if self.optimizer_func is not None:
            loss, optimizer = self.optimizer_func(self.graph)
            self.loss = loss
            self.optimizer = optimizer

        # We create a 'full_saver' including those variables.
        self.full_saver = tf.train.Saver(max_to_keep=None)

        # find most recent checkpoint
        checkpoint_dir = os.path.dirname(self.meta_graph_filename)
        checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

        if checkpoint:

            try:
                # Try to restore the graph, including the custom optimizer
                # state (if a custom optimizer was used).
                self.__restore_graph(checkpoint, restore_full=True)

            except tf.errors.NotFoundError:

                # If that failed, we just transitioned from an earlier training
                # without the custom optimizer. In this case, restore only the
                # variables of the original meta-graph and 'gunpowder'
                # variables. Custom optimizer variables will be default
                # initialized.
                logger.info("Checkpoint did not contain custom optimizer "
                            "variables")
                self.__restore_graph(checkpoint, restore_full=False)
        else:

            logger.info("No checkpoint found")

            # initialize all variables
            self.session.run(tf.global_variables_initializer())
Ejemplo n.º 2
0
    def __restore_graph(self, checkpoint, restore_full):

        logger.info("Restoring model from %s", checkpoint)

        if restore_full:

            logger.info("...using a saver for all variables")
            self.full_saver.restore(self.session, checkpoint)

        else:

            # initialize all variables, such that non-basic variables are
            # initialized
            self.session.run(tf.global_variables_initializer())

            logger.info("...using a saver for basic variables only")
            self.basic_saver.restore(self.session, checkpoint)