コード例 #1
0
ファイル: ris_old.py プロジェクト: ai3DVision/img-count
    def train_loop(step=0):
        # Train loop
        for x_bat, y_bat, s_bat in BatchIterator(num_ex_train,
                                                 batch_size=batch_size,
                                                 get_fn=get_batch_train,
                                                 cycle=True,
                                                 progress_bar=False):
            # Run validation
            if step % train_opt['steps_per_valid'] == 0:
                run_validation()

            if step % train_opt['steps_per_plot'] == 0:
                run_samples()

            # Train step
            train_step(step, x_bat, y_bat, s_bat)

            # Model ID reminder
            if step % (10 * train_opt['steps_per_log']) == 0:
                log.info('model id {}'.format(model_id))

            # Save model
            if args.save_ckpt and step % train_opt['steps_per_ckpt'] == 0:
                saver.save(sess, global_step=step)

            step += 1

            # Termination
            if step > train_opt['num_steps']:
                break
        pass
コード例 #2
0
 def __init__(self,
              sess,
              model,
              dataset,
              num_batch,
              train_opt,
              model_opt,
              outputs,
              step=StepCounter(0),
              loggers=None,
              phase_train=True,
              increment_step=False):
     self.dataset = dataset
     self.loggers = loggers
     self.log = logger.get()
     self.model_opt = model_opt
     self.train_opt = train_opt
     self.input_variables = self.get_input_variables()
     num_ex = dataset.get_dataset_size()
     batch_iter = BatchIterator(num_ex,
                                batch_size=train_opt['batch_size'],
                                get_fn=self.get_batch,
                                cycle=True,
                                shuffle=True,
                                log_epoch=-1)
     super(Runner, self).__init__(sess,
                                  model,
                                  batch_iter,
                                  outputs,
                                  num_batch=num_batch,
                                  step=step,
                                  phase_train=phase_train,
                                  increment_step=increment_step)
コード例 #3
0
def run_eval(sess, m, dataset, batch_size=10, fname=None, cvppp_test=False):
    """Run evaluation

    Args:
        sess: tensorflow session
        m: model
        dataset: dataset object
        batch_size: mini-batch to run
        fname: output report filename
        cvppp_test: whether in test mode of CVPPP dataset
    """
    analyzers = []
    if not cvppp_test:
        analyzers = [
            StageAnalyzer('IOU', f_ins_iou, fname=fname),
            StageAnalyzer('SBD', f_symmetric_best_dice, fname=fname),
            StageAnalyzer('WT COV', f_wt_coverage, fname=fname),
            StageAnalyzer('UNWT COV', f_unwt_coverage, fname=fname),
            StageAnalyzer('FG DICE', f_fg_dice, fname=fname),
            StageAnalyzer('FG IOU', f_fg_iou, fname=fname),
            StageAnalyzer('COUNT ACC', f_count_acc, fname=fname),
            StageAnalyzer('DIC', f_dic, fname=fname),
            StageAnalyzer('|DIC|', f_dic_abs, fname=fname)
        ]
    else:
        analyzers = [
            StageAnalyzer('FG DICE', f_fg_dice, fname=fname),
            StageAnalyzer('FG IOU', f_fg_iou, fname=fname)
        ]

    data = dataset.get_dataset()
    num_ex = data['input'].shape[0]
    batch_size = 10
    batch_iter = BatchIterator(num_ex,
                               batch_size=batch_size,
                               get_fn=get_batch_fn(data),
                               cycle=False,
                               progress_bar=True)
    _run_eval(sess, m, dataset, batch_iter, analyzers)
    pass
コード例 #4
0
                                                    'train_ce.csv'),
                                       'train_ce',
                                       buffer_size=25)
    valid_ce_logger = TimeSeriesLogger(os.path.join(exp_logs_folder,
                                                    'valid_ce.csv'),
                                       'valid_ce',
                                       buffer_size=2)
    log.info('Curves can be viewed at: http://{}/visualizer?id={}'.format(
        args.localhost, model_id))

    step = 0
    while step < loop_config['num_steps']:
        # Validation
        valid_ce = 0
        for st, nd in BatchIterator(num_ex_val,
                                    batch_size=64,
                                    progress_bar=False):
            inp_batch = inp_all_val[st:nd]
            lab_seg_batch = lab_seg_all_val[st:nd]
            lab_obj_batch = lab_obj_all_val[st:nd]
            inp_batch, lab_seg_batch, lab_obj_batch = preprocess(
                inp_batch, lab_seg_batch, lab_obj_batch)
            vce = sess.run(m['total_err'],
                           feed_dict={
                               m['inp']: inp_batch,
                               m['segm_gt']: lab_seg_batch,
                               m['obj_gt']: lab_obj_batch
                           })
            valid_ce += vce * inp_batch.shape[0] / float(num_ex_val)

        log.info('{:d} valid ce: {:.4f}'.format(step, valid_ce))
コード例 #5
0
ファイル: ris_box.py プロジェクト: ai3DVision/img-count
    def train_loop(step=0):
        """Train loop"""
        if train_opt['has_valid']:
            batch_iter_valid = BatchIterator(num_ex_valid,
                                             batch_size=batch_size,
                                             get_fn=get_batch_valid,
                                             cycle=True,
                                             progress_bar=False)
            outputs_valid = get_outputs_valid()
        num_batch_valid = trainer.get_num_batch_valid(args.dataset)
        batch_iter_trainval = BatchIterator(num_ex_train,
                                            batch_size=batch_size,
                                            get_fn=get_batch_train,
                                            cycle=True,
                                            progress_bar=False)
        outputs_trainval = get_outputs_trainval()

        for _x, _y, _s in BatchIterator(num_ex_train,
                                        batch_size=batch_size,
                                        get_fn=get_batch_train,
                                        cycle=True,
                                        progress_bar=False):
            # Run validation stats
            if train_opt['has_valid']:
                if step % train_opt['steps_per_valid'] == 0:
                    log.info('Running validation')
                    trainer.run_stats(step, sess, m, num_batch_valid,
                                      batch_iter_valid, outputs_valid,
                                      write_log_valid(loggers), False)
                    pass

            # Train stats
            if step % train_opt['steps_per_trainval'] == 0:
                log.info('Running train validation')
                trainer.run_stats(step, sess, m, num_batch_valid,
                                  batch_iter_trainval, outputs_trainval,
                                  write_log_trainval(loggers), True)
                pass

            # Plot samples
            if step % train_opt['steps_per_plot'] == 0 and step > 0:
                run_samples()
                pass

            # Train step
            train_step(step, _x, _y, _s)

            # Model ID reminder
            if step % (10 * train_opt['steps_per_log']) == 0:
                log.info('model id {}'.format(model_id))
                pass

            # Save model
            if args.save_ckpt and step % train_opt['steps_per_ckpt'] == 0:
                saver.save(sess, global_step=step)
                pass

            step += 1

            # Termination
            if step > train_opt['num_steps']:
                break

        pass
コード例 #6
0
ファイル: ris_old.py プロジェクト: ai3DVision/img-count
    def run_validation():
        # Validation
        loss = 0.0
        iou_hard = 0.0
        iou_soft = 0.0
        count_acc = 0.0
        segm_loss = 0.0
        conf_loss = 0.0
        num_cnn = len(model_opt['cnn_filter_size'])
        num_dcnn = len(model_opt['dcnn_filter_size'])
        cnn_bm = [0.0] * num_cnn
        cnn_bv = [0.0] * num_cnn
        cnn_em = [0.0] * num_cnn
        cnn_ev = [0.0] * num_cnn
        dcnn_bm = [0.0] * num_dcnn
        dcnn_bv = [0.0] * num_dcnn
        dcnn_em = [0.0] * num_dcnn
        dcnn_ev = [0.0] * num_dcnn
        log.info('Running validation')
        for _x, _y, _s in BatchIterator(num_ex_valid,
                                        batch_size=batch_size,
                                        get_fn=get_batch_valid,
                                        progress_bar=False):
            results_list = [
                m['loss'], m['segm_loss'], m['conf_loss'], m['iou_soft'],
                m['iou_hard'], m['count_acc']
            ]
            offset = len(results_list)
            for ii in xrange(num_cnn):
                results_list.append(m['cnn_{}_bm'.format(ii)])
                results_list.append(m['cnn_{}_bv'.format(ii)])
                results_list.append(m['cnn_{}_em'.format(ii)])
                results_list.append(m['cnn_{}_ev'.format(ii)])

            for ii in xrange(num_dcnn):
                results_list.append(m['dcnn_{}_bm'.format(ii)])
                results_list.append(m['dcnn_{}_bv'.format(ii)])
                results_list.append(m['dcnn_{}_em'.format(ii)])
                results_list.append(m['dcnn_{}_ev'.format(ii)])

            results = sess.run(results_list,
                               feed_dict={
                                   m['x']: _x,
                                   m['phase_train']: False,
                                   m['y_gt']: _y,
                                   m['s_gt']: _s
                               })
            _loss = results[0]
            _segm_loss = results[1]
            _conf_loss = results[2]
            _iou_soft = results[3]
            _iou_hard = results[4]
            _count_acc = results[5]

            for ii in xrange(num_cnn):
                _cnn_bm = results[offset]
                _cnn_bv = results[offset + 1]
                _cnn_em = results[offset + 2]
                _cnn_ev = results[offset + 3]
                offset += 4

            for ii in xrange(num_dcnn):
                _dcnn_bm = results[offset]
                _dcnn_bv = results[offset + 1]
                _dcnn_em = results[offset + 2]
                _dcnn_ev = results[offset + 3]
                offset += 4

            num_ex_batch = _x.shape[0]
            loss += _loss * num_ex_batch / num_ex_valid
            segm_loss += _segm_loss * num_ex_batch / num_ex_valid
            conf_loss += _conf_loss * num_ex_batch / num_ex_valid
            iou_soft += _iou_soft * num_ex_batch / num_ex_valid
            iou_hard += _iou_hard * num_ex_batch / num_ex_valid
            count_acc += _count_acc * num_ex_batch / num_ex_valid
            for ii in xrange(num_cnn):
                cnn_bm[ii] += _cnn_bm * num_ex_batch / num_ex_valid
                cnn_bv[ii] += _cnn_bv * num_ex_batch / num_ex_valid
                cnn_em[ii] += _cnn_em * num_ex_batch / num_ex_valid
                cnn_ev[ii] += _cnn_ev * num_ex_batch / num_ex_valid
            for ii in xrange(num_dcnn):
                dcnn_bm[ii] += _dcnn_bm * num_ex_batch / num_ex_valid
                dcnn_bv[ii] += _dcnn_bv * num_ex_batch / num_ex_valid
                dcnn_em[ii] += _dcnn_em * num_ex_batch / num_ex_valid
                dcnn_ev[ii] += _dcnn_ev * num_ex_batch / num_ex_valid

        log.info(('{:d} valid loss {:.4f} segm_loss {:.4f} conf_loss {:.4f} '
                  'iou soft {:.4f} iou hard {:.4f} count acc {:.4f}').format(
                      step, loss, segm_loss, conf_loss, iou_soft, iou_hard,
                      count_acc))

        if args.logs:
            loss_logger.add(step, ['', loss])
            iou_logger.add(step, ['', iou_soft, '', iou_hard])
            count_acc_logger.add(step, ['', count_acc])
            for ii in xrange(num_cnn):
                cnn_bn_loggers[ii].add(
                    step, ['', cnn_bm[ii], '', cnn_bv[ii], '', ''])
            for ii in xrange(num_dcnn):
                dcnn_bn_loggers[ii].add(
                    step, ['', dcnn_bm[ii], '', dcnn_bv[ii], '', ''])

        pass
コード例 #7
0
  def reset(self):
    self.info("Resetting concurrent batch iter")
    self.info("Stopping all workers")
    for f in self.fetchers:
      f.stop()
    self.info("Cleaning queue")
    cleaner = BatchConsumer(self.q)
    cleaner.start()
    for f in self.fetchers:
      f.join()
    self.q.join()
    cleaner.stop()
    self.info("Resetting index")
    self.batch_iter.reset()
    self.info("Restarting workers")
    self.fetchers = []
    self.init_fetchers()
    self.relaunch = True
    pass

  pass


if __name__ == "__main__":
  from batch_iter import BatchIterator
  b = BatchIterator(100, batch_size=6, get_fn=None)
  cb = ConcurrentBatchIterator(b, max_queue_size=5, num_threads=3)
  for _batch in cb:
    log = logger.get()
    log.info(("Final out", _batch))