Ejemplo n.º 1
0
    def load_weights(self, log_dir=None, type='latest', **kwargs):

        if not log_dir:
            log_dir = self.log_dir

        if type == 'latest':
            init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
        elif type == 'best_dice':
            init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(
                log_dir, 'model_best_dice.ckpt')
        elif type == 'best_loss':
            init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(
                log_dir, 'model_best_loss.ckpt')
        elif type == 'best_ged':
            init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(
                log_dir, 'model_best_ged.ckpt')
        elif type == 'iter':
            assert 'iteration' in kwargs, "argument 'iteration' must be provided for type='iter'"
            iteration = kwargs['iteration']
            init_checkpoint_path = os.path.join(log_dir,
                                                'model.ckpt-%d' % iteration)
        else:
            raise ValueError(
                'Argument type=%s is unknown. type can be latest/iter.' % type)

        self.saver.restore(self.sess, init_checkpoint_path)
Ejemplo n.º 2
0
    def load_weights(self, log_dir=None, type='latest', **kwargs):
        """
        Load weights into the model
        :param log_dir: experiment directory into which all the checkpoints have been written
        :param type: can be 'latest', 'best_wasserstein' (highest validation Wasserstein distance), or 'iter' (specific
                     iteration, requires passing the iteration argument with a valid step number from the checkpoint 
                     files)
        """

        if not log_dir:
            log_dir = self.log_dir

        if type == 'latest':
            init_checkpoint_path = tf_utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
        elif type == 'best_wasserstein':
            init_checkpoint_path = tf_utils.get_latest_model_checkpoint_path(
                log_dir, 'model_best_wasserstein.ckpt')
        elif type == 'iter':
            assert 'iteration' in kwargs, "argument 'iteration' must be provided for type='iter'"
            iteration = kwargs['iteration']
            init_checkpoint_path = os.path.join(log_dir,
                                                'model.ckpt-%d' % iteration)
        else:
            raise ValueError(
                'Argument type=%s is unknown. type can be latest/best_wasserstein/iter.'
                % type)

        self.saver.restore(self.sess, init_checkpoint_path)
Ejemplo n.º 3
0
    def _setup_log_dir_and_continue_mode(self):

        # Default values
        self.log_dir = os.path.join(sys_config.log_root, 'classifier', self.exp_config.experiment_name)
        self.init_checkpoint_path = None
        self.continue_run = False
        self.init_step = 0

        # If a checkpoint file already exists enable continue mode
        if tf.gfile.Exists(self.log_dir):
            init_checkpoint_path = tf_utils.get_latest_model_checkpoint_path(self.log_dir, 'model.ckpt')
            if init_checkpoint_path is not False:

                self.init_checkpoint_path = init_checkpoint_path
                self.continue_run = True
                self.init_step = int(self.init_checkpoint_path.split('/')[-1].split('-')[-1])
                self.log_dir += '_cont'

                logging.info('--------------------------- Continuing previous run --------------------------------')
                logging.info('Checkpoint path: %s' % self.init_checkpoint_path)
                logging.info('Latest step was: %d' % self.init_step)
                logging.info('------------------------------------------------------------------------------------')

        tf.gfile.MakeDirs(self.log_dir)
        self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)

        # Copy experiment config file to log_dir for future reference
        shutil.copy(self.exp_config.__file__, self.log_dir)