def restore(self, sess, checkpoints):
        vgg_network.vgg_assign_from_values_fn()(sess)

        if checkpoints:
            if not isinstance(checkpoints, (list, tuple)):
                checkpoints = [checkpoints]
            if len(checkpoints) != self.hparams.num_views:
                raise ValueError(
                    'number of checkpoints should be equal to the number of views'
                )
            savers = []
            for i, checkpoint in enumerate(checkpoints):
                print("creating restore saver from checkpoint %s" % checkpoint)
                restore_scope = 'view%d' % i

                def restore_to_checkpoint_mapping(name):
                    name = name.split(':')[0]
                    assert name.split('/')[0] == restore_scope
                    name = '/'.join(name.split('/')[1:])
                    return name

                saver, _ = tf_utils.get_checkpoint_restore_saver(
                    checkpoint,
                    restore_to_checkpoint_mapping=restore_to_checkpoint_mapping,
                    restore_scope=restore_scope)
                savers.append(saver)
            restore_op = [saver.saver_def.restore_op_name for saver in savers]
            sess.run(restore_op)
示例#2
0
    def restore(self, sess, checkpoints):
        vgg_network.vgg_assign_from_values_fn()(sess)

        if checkpoints:
            # possibly restore from multiple checkpoints. useful if subset of weights
            # (e.g. generator or discriminator) are on different checkpoints.
            if not isinstance(checkpoints, (list, tuple)):
                checkpoints = [checkpoints]
            # automatically skip global_step if more than one checkpoint is provided
            skip_global_step = len(checkpoints) > 1
            savers = []
            for checkpoint in checkpoints:
                print("creating restore saver from checkpoint %s" % checkpoint)
                saver, _ = tf_utils.get_checkpoint_restore_saver(
                    checkpoint, skip_global_step=skip_global_step)
                savers.append(saver)
            restore_op = [saver.saver_def.restore_op_name for saver in savers]
            sess.run(restore_op)
示例#3
0
 def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
     if checkpoints:
         var_list = self.saveable_variables
         # possibly restore from multiple checkpoints. useful if subset of weights
         # (e.g. generator or discriminator) are on different checkpoints.
         if not isinstance(checkpoints, (list, tuple)):
             checkpoints = [checkpoints]
         # automatically skip global_step if more than one checkpoint is provided
         skip_global_step = len(checkpoints) > 1
         savers = []
         for checkpoint in checkpoints:
             print("creating restore saver from checkpoint %s" % checkpoint)
             saver, _ = tf_utils.get_checkpoint_restore_saver(
                 checkpoint, var_list, skip_global_step=skip_global_step,
                 restore_to_checkpoint_mapping=restore_to_checkpoint_mapping)
             savers.append(saver)
         restore_op = [saver.saver_def.restore_op_name for saver in savers]
         sess.run(restore_op)