コード例 #1
0
 def __init__(self,
              sess,
              model,
              dataset,
              num_batch,
              train_opt,
              model_opt,
              outputs,
              step=StepCounter(0),
              loggers=None,
              phase_train=True,
              increment_step=False):
     self.dataset = dataset
     self.loggers = loggers
     self.log = logger.get()
     self.model_opt = model_opt
     self.train_opt = train_opt
     self.input_variables = self.get_input_variables()
     num_ex = dataset.get_dataset_size()
     batch_iter = BatchIterator(num_ex,
                                batch_size=train_opt['batch_size'],
                                get_fn=self.get_batch,
                                cycle=True,
                                shuffle=True,
                                log_epoch=-1)
     super(Runner, self).__init__(sess,
                                  model,
                                  batch_iter,
                                  outputs,
                                  num_batch=num_batch,
                                  step=step,
                                  phase_train=phase_train,
                                  increment_step=increment_step)
コード例 #2
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              step=StepCounter(0),
              loggers=None,
              steps_per_log=10):
   outputs = ['loss', 'train_step']
   num_batch = steps_per_log
   self.log = logger.get()
   if model_opt['finetune']:
     self.log.warning('Finetuning')
     sess.run(tf.assign(model['global_step'], 0))
   super(Trainer, self).__init__(
       sess,
       model,
       dataset,
       num_batch,
       train_opt,
       model_opt,
       outputs,
       step=step,
       loggers=loggers,
       phase_train=True,
       increment_step=True)
コード例 #3
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              logs_folder,
              step=StepCounter(0),
              split='train',
              phase_train=False):
     outputs = [
         'x_trans', 'y_gt_trans', 'attn_top_left', 'attn_bot_right',
         'attn_top_left_gt', 'attn_bot_right_gt', 'match_box', 's_out',
         'ctrl_rnn_glimpse_map'
     ]
     num_batch = 1
     self.split = split
     self.logs_folder = logs_folder
     self.model_opt = model_opt
     loggers = self.get_loggers()
     super(Plotter, self).__init__(sess,
                                   model,
                                   dataset,
                                   num_batch,
                                   train_opt,
                                   model_opt,
                                   outputs,
                                   step=step,
                                   loggers=loggers,
                                   phase_train=phase_train,
                                   increment_step=False)
コード例 #4
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              step=StepCounter(0),
              num_batch=10,
              loggers=None,
              phase_train=True):
   outputs = [
       'loss', 'conf_loss', 'segm_loss', 'count_acc', 'dic', 'dic_abs',
       'learn_rate', 'box_loss', 'gt_knob_prob_box', 'gt_knob_prob_segm'
   ]
   super(Evaluator, self).__init__(
       sess,
       model,
       dataset,
       num_batch,
       train_opt,
       model_opt,
       outputs,
       step=step,
       loggers=loggers,
       phase_train=phase_train,
       increment_step=False)
コード例 #5
0
    def __init__(self, name, opt, data_opt=None, model_opt=None, seed=1234):
        # Restore previously saved checkpoints.
        self.opt = opt
        self.name = name
        self.new_model_opt = None
        if self.opt['restore']:
            self.restore_options(opt, data_opt)
            if model_opt is not None:
                if 'finetune' in model_opt and model_opt['finetune']:
                    self.model_opt['finetune'] = model_opt['finetune']
                    self.new_model_opt = model_opt
                    self.step.reset()
                    self.model_id = self.get_model_id()
                    self.exp_folder = os.path.join(self.opt['results'],
                                                   self.model_id)
                    self.saver = Saver(self.exp_folder,
                                       model_opt=self.model_opt,
                                       data_opt=self.data_opt)
            self.exp_folder = opt['restore']
        else:
            if self.opt['model_id']:
                self.model_id = self.opt['model_id']
            else:
                self.model_id = self.get_model_id()
            if model_opt is None or data_opt is None:
                raise Exception(
                    'You need to specify model options and data options')
            self.model_opt = model_opt
            self.data_opt = data_opt
            self.step = StepCounter()
            self.exp_folder = os.path.join(self.opt['results'], self.model_id)
            self.saver = Saver(self.exp_folder,
                               model_opt=self.model_opt,
                               data_opt=self.data_opt)

        self.init_cmd_logger()

        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))

        # Log arguments
        self.log.log_args()

        # Train loop options
        self.log.info('Building model')
        self.model = self.get_model()

        # Load dataset
        self.log.info('Loading dataset')
        self.dataset_name = self.data_opt['dataset']
        self.dataset = self.get_dataset()

        self.init_model()
        self.init_logs()
コード例 #6
0
 def restore_options(self, opt, data_opt):
   self.saver = Saver(opt['restore'])
   self.ckpt_info = self.saver.get_ckpt_info()
   self.model_opt = self.ckpt_info['model_opt']
   if data_opt is None:
     self.data_opt = self.ckpt_info['data_opt']
   else:
     self.data_opt = data_opt
   self.ckpt_fname = self.ckpt_info['ckpt_fname']
   self.step = StepCounter(self.ckpt_info['step'])
   self.model_id = self.ckpt_info['model_id']
   pass
コード例 #7
0
ファイル: runner.py プロジェクト: Khoa-NT/rec-attend-public
 def __init__(self,
              sess,
              model,
              batch_iter,
              outputs,
              num_batch=1,
              step=StepCounter(0),
              phase_train=True,
              increment_step=False):
     self._sess = sess
     self._model = model
     self._batch_iter = batch_iter
     self._num_batch = num_batch
     self._phase_train = phase_train
     self._step = step
     self._outputs = outputs
     self._current_batch = {}
     self._log = logger.get()
     self._increment_step = increment_step
     pass
コード例 #8
0
 def __init__(self,
              sess,
              model,
              dataset,
              num_batch,
              train_opt,
              model_opt,
              outputs,
              step=StepCounter(0),
              loggers=None,
              phase_train=True,
              increment_step=False):
     self.dataset = dataset
     self.log = logger.get()
     self.loggers = loggers
     self.add_orientation = model_opt['add_orientation']
     self.num_orientation_classes = model_opt['num_orientation_classes']
     self.input_variables = self.get_input_variables()
     num_ex = dataset.get_dataset_size()
     batch_iter = BatchIterator(num_ex,
                                batch_size=train_opt['batch_size'],
                                get_fn=self.get_batch,
                                cycle=True,
                                progress_bar=False,
                                shuffle=True,
                                log_epoch=-1)
     if train_opt['prefetch']:
         batch_iter = ConcurrentBatchIterator(
             batch_iter,
             max_queue_size=train_opt['queue_size'],
             num_threads=train_opt['num_worker'],
             log_queue=-1)
     super(Runner, self).__init__(sess,
                                  model,
                                  batch_iter,
                                  outputs,
                                  num_batch=num_batch,
                                  step=step,
                                  phase_train=phase_train,
                                  increment_step=increment_step)
コード例 #9
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              step=StepCounter(0),
              loggers=None,
              steps_per_log=10):
     outputs = ['loss', 'train_step']
     num_batch = steps_per_log
     super(Trainer, self).__init__(sess,
                                   model,
                                   dataset,
                                   num_batch,
                                   train_opt,
                                   model_opt,
                                   outputs,
                                   step=step,
                                   loggers=loggers,
                                   phase_train=True,
                                   increment_step=True)
コード例 #10
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              step=StepCounter(0),
              num_batch=10,
              loggers=None,
              phase_train=True):
     outputs = ['loss', 'box_loss', 'conf_loss']
     super(Evaluator, self).__init__(sess,
                                     model,
                                     dataset,
                                     num_batch,
                                     train_opt,
                                     model_opt,
                                     outputs,
                                     step=step,
                                     loggers=loggers,
                                     phase_train=phase_train,
                                     increment_step=False)
コード例 #11
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              logs_folder,
              step=StepCounter(0),
              split='train',
              phase_train=False):
   outputs = [
       'x_trans', 'y_gt_trans', 'y_out', 's_out', 'match', 'attn_top_left',
       'attn_bot_right', 'match_box', 's_out', 'x_patch',
       'ctrl_rnn_glimpse_map'
   ]
   num_batch = 1
   phase_train = phase_train
   self.split = split
   self.logs_folder = logs_folder
   self.model_opt = model_opt
   loggers = self.get_loggers()
   self.color_wheel = np.array(
       [[255, 17, 0], [255, 137, 0], [230, 255, 0], [34, 255, 0],
        [0, 255, 213], [0, 154, 255], [9, 0, 255], [255, 0, 255]],
       dtype='uint8')
   super(Plotter, self).__init__(
       sess,
       model,
       dataset,
       num_batch,
       train_opt,
       model_opt,
       outputs,
       step=step,
       loggers=loggers,
       phase_train=phase_train,
       increment_step=False)
コード例 #12
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              logs_folder,
              step=StepCounter(0),
              split='train',
              phase_train=False):
     outputs = ['x_trans', 'y_gt_trans', 'y_out']
     if model_opt['add_orientation']:
         outputs.extend(['d_out', 'd_gt_trans'])
     num_batch = 1
     self.split = split
     self.logs_folder = logs_folder
     self.ori_color_wheel = np.array(
         [[255, 17, 0], [255, 137, 0], [230, 255, 0], [34, 255, 0],
          [0, 255, 213], [0, 154, 255], [9, 0, 255], [255, 0, 255]],
         dtype='uint8')
     self.sem_color_wheel = np.array(
         [[0, 0, 0], [255, 17, 0], [255, 137, 0], [230, 255, 0], [
             34, 255, 0
         ], [0, 255, 213], [0, 154, 255], [9, 0, 255], [255, 0, 255]],
         dtype='uint8')
     loggers = self.get_loggers(model_opt['add_orientation'], split)
     super(Plotter, self).__init__(sess,
                                   model,
                                   dataset,
                                   num_batch,
                                   train_opt,
                                   model_opt,
                                   outputs,
                                   step=step,
                                   loggers=loggers,
                                   phase_train=phase_train,
                                   increment_step=False)
コード例 #13
0
 def __init__(self,
              sess,
              model,
              dataset,
              train_opt,
              model_opt,
              step=StepCounter(0),
              num_batch=10,
              loggers=None,
              phase_train=True):
     outputs = ['iou_soft', 'iou_hard', 'foreground_loss', 'loss']
     if model_opt['add_orientation']:
         outputs.extend(['orientation_ce', 'orientation_acc'])
     super(Evaluator, self).__init__(sess,
                                     model,
                                     dataset,
                                     num_batch,
                                     train_opt,
                                     model_opt,
                                     outputs,
                                     step=step,
                                     loggers=loggers,
                                     phase_train=phase_train,
                                     increment_step=False)