def train_loop(step=0): # Train loop for x_bat, y_bat, s_bat in BatchIterator(num_ex_train, batch_size=batch_size, get_fn=get_batch_train, cycle=True, progress_bar=False): # Run validation if step % train_opt['steps_per_valid'] == 0: run_validation() if step % train_opt['steps_per_plot'] == 0: run_samples() # Train step train_step(step, x_bat, y_bat, s_bat) # Model ID reminder if step % (10 * train_opt['steps_per_log']) == 0: log.info('model id {}'.format(model_id)) # Save model if args.save_ckpt and step % train_opt['steps_per_ckpt'] == 0: saver.save(sess, global_step=step) step += 1 # Termination if step > train_opt['num_steps']: break pass
def __init__(self, sess, model, dataset, num_batch, train_opt, model_opt, outputs, step=StepCounter(0), loggers=None, phase_train=True, increment_step=False): self.dataset = dataset self.loggers = loggers self.log = logger.get() self.model_opt = model_opt self.train_opt = train_opt self.input_variables = self.get_input_variables() num_ex = dataset.get_dataset_size() batch_iter = BatchIterator(num_ex, batch_size=train_opt['batch_size'], get_fn=self.get_batch, cycle=True, shuffle=True, log_epoch=-1) super(Runner, self).__init__(sess, model, batch_iter, outputs, num_batch=num_batch, step=step, phase_train=phase_train, increment_step=increment_step)
def run_eval(sess, m, dataset, batch_size=10, fname=None, cvppp_test=False): """Run evaluation Args: sess: tensorflow session m: model dataset: dataset object batch_size: mini-batch to run fname: output report filename cvppp_test: whether in test mode of CVPPP dataset """ analyzers = [] if not cvppp_test: analyzers = [ StageAnalyzer('IOU', f_ins_iou, fname=fname), StageAnalyzer('SBD', f_symmetric_best_dice, fname=fname), StageAnalyzer('WT COV', f_wt_coverage, fname=fname), StageAnalyzer('UNWT COV', f_unwt_coverage, fname=fname), StageAnalyzer('FG DICE', f_fg_dice, fname=fname), StageAnalyzer('FG IOU', f_fg_iou, fname=fname), StageAnalyzer('COUNT ACC', f_count_acc, fname=fname), StageAnalyzer('DIC', f_dic, fname=fname), StageAnalyzer('|DIC|', f_dic_abs, fname=fname) ] else: analyzers = [ StageAnalyzer('FG DICE', f_fg_dice, fname=fname), StageAnalyzer('FG IOU', f_fg_iou, fname=fname) ] data = dataset.get_dataset() num_ex = data['input'].shape[0] batch_size = 10 batch_iter = BatchIterator(num_ex, batch_size=batch_size, get_fn=get_batch_fn(data), cycle=False, progress_bar=True) _run_eval(sess, m, dataset, batch_iter, analyzers) pass
'train_ce.csv'), 'train_ce', buffer_size=25) valid_ce_logger = TimeSeriesLogger(os.path.join(exp_logs_folder, 'valid_ce.csv'), 'valid_ce', buffer_size=2) log.info('Curves can be viewed at: http://{}/visualizer?id={}'.format( args.localhost, model_id)) step = 0 while step < loop_config['num_steps']: # Validation valid_ce = 0 for st, nd in BatchIterator(num_ex_val, batch_size=64, progress_bar=False): inp_batch = inp_all_val[st:nd] lab_seg_batch = lab_seg_all_val[st:nd] lab_obj_batch = lab_obj_all_val[st:nd] inp_batch, lab_seg_batch, lab_obj_batch = preprocess( inp_batch, lab_seg_batch, lab_obj_batch) vce = sess.run(m['total_err'], feed_dict={ m['inp']: inp_batch, m['segm_gt']: lab_seg_batch, m['obj_gt']: lab_obj_batch }) valid_ce += vce * inp_batch.shape[0] / float(num_ex_val) log.info('{:d} valid ce: {:.4f}'.format(step, valid_ce))
def train_loop(step=0): """Train loop""" if train_opt['has_valid']: batch_iter_valid = BatchIterator(num_ex_valid, batch_size=batch_size, get_fn=get_batch_valid, cycle=True, progress_bar=False) outputs_valid = get_outputs_valid() num_batch_valid = trainer.get_num_batch_valid(args.dataset) batch_iter_trainval = BatchIterator(num_ex_train, batch_size=batch_size, get_fn=get_batch_train, cycle=True, progress_bar=False) outputs_trainval = get_outputs_trainval() for _x, _y, _s in BatchIterator(num_ex_train, batch_size=batch_size, get_fn=get_batch_train, cycle=True, progress_bar=False): # Run validation stats if train_opt['has_valid']: if step % train_opt['steps_per_valid'] == 0: log.info('Running validation') trainer.run_stats(step, sess, m, num_batch_valid, batch_iter_valid, outputs_valid, write_log_valid(loggers), False) pass # Train stats if step % train_opt['steps_per_trainval'] == 0: log.info('Running train validation') trainer.run_stats(step, sess, m, num_batch_valid, batch_iter_trainval, outputs_trainval, write_log_trainval(loggers), True) pass # Plot samples if step % train_opt['steps_per_plot'] == 0 and step > 0: run_samples() pass # Train step train_step(step, _x, _y, _s) # Model ID reminder if step % (10 * train_opt['steps_per_log']) == 0: log.info('model id {}'.format(model_id)) pass # Save model if args.save_ckpt and step % train_opt['steps_per_ckpt'] == 0: saver.save(sess, global_step=step) pass step += 1 # Termination if step > train_opt['num_steps']: break pass
def run_validation(): # Validation loss = 0.0 iou_hard = 0.0 iou_soft = 0.0 count_acc = 0.0 segm_loss = 0.0 conf_loss = 0.0 num_cnn = len(model_opt['cnn_filter_size']) num_dcnn = len(model_opt['dcnn_filter_size']) cnn_bm = [0.0] * num_cnn cnn_bv = [0.0] * num_cnn cnn_em = [0.0] * num_cnn cnn_ev = [0.0] * num_cnn dcnn_bm = [0.0] * num_dcnn dcnn_bv = [0.0] * num_dcnn dcnn_em = [0.0] * num_dcnn dcnn_ev = [0.0] * num_dcnn log.info('Running validation') for _x, _y, _s in BatchIterator(num_ex_valid, batch_size=batch_size, get_fn=get_batch_valid, progress_bar=False): results_list = [ m['loss'], m['segm_loss'], m['conf_loss'], m['iou_soft'], m['iou_hard'], m['count_acc'] ] offset = len(results_list) for ii in xrange(num_cnn): results_list.append(m['cnn_{}_bm'.format(ii)]) results_list.append(m['cnn_{}_bv'.format(ii)]) results_list.append(m['cnn_{}_em'.format(ii)]) results_list.append(m['cnn_{}_ev'.format(ii)]) for ii in xrange(num_dcnn): results_list.append(m['dcnn_{}_bm'.format(ii)]) results_list.append(m['dcnn_{}_bv'.format(ii)]) results_list.append(m['dcnn_{}_em'.format(ii)]) results_list.append(m['dcnn_{}_ev'.format(ii)]) results = sess.run(results_list, feed_dict={ m['x']: _x, m['phase_train']: False, m['y_gt']: _y, m['s_gt']: _s }) _loss = results[0] _segm_loss = results[1] _conf_loss = results[2] _iou_soft = results[3] _iou_hard = results[4] _count_acc = results[5] for ii in xrange(num_cnn): _cnn_bm = results[offset] _cnn_bv = results[offset + 1] _cnn_em = results[offset + 2] _cnn_ev = results[offset + 3] offset += 4 for ii in xrange(num_dcnn): _dcnn_bm = results[offset] _dcnn_bv = results[offset + 1] _dcnn_em = results[offset + 2] _dcnn_ev = results[offset + 3] offset += 4 num_ex_batch = _x.shape[0] loss += _loss * num_ex_batch / num_ex_valid segm_loss += _segm_loss * num_ex_batch / num_ex_valid conf_loss += _conf_loss * num_ex_batch / num_ex_valid iou_soft += _iou_soft * num_ex_batch / num_ex_valid iou_hard += _iou_hard * num_ex_batch / num_ex_valid count_acc += _count_acc * num_ex_batch / num_ex_valid for ii in xrange(num_cnn): cnn_bm[ii] += _cnn_bm * num_ex_batch / num_ex_valid cnn_bv[ii] += _cnn_bv * num_ex_batch / num_ex_valid cnn_em[ii] += _cnn_em * num_ex_batch / num_ex_valid cnn_ev[ii] += _cnn_ev * num_ex_batch / num_ex_valid for ii in xrange(num_dcnn): dcnn_bm[ii] += _dcnn_bm * num_ex_batch / num_ex_valid dcnn_bv[ii] += _dcnn_bv * num_ex_batch / num_ex_valid dcnn_em[ii] += _dcnn_em * num_ex_batch / num_ex_valid dcnn_ev[ii] += _dcnn_ev * num_ex_batch / num_ex_valid log.info(('{:d} valid loss {:.4f} segm_loss {:.4f} conf_loss {:.4f} ' 'iou soft {:.4f} iou hard {:.4f} count acc {:.4f}').format( step, loss, segm_loss, conf_loss, iou_soft, iou_hard, count_acc)) if args.logs: loss_logger.add(step, ['', loss]) iou_logger.add(step, ['', iou_soft, '', iou_hard]) count_acc_logger.add(step, ['', count_acc]) for ii in xrange(num_cnn): cnn_bn_loggers[ii].add( step, ['', cnn_bm[ii], '', cnn_bv[ii], '', '']) for ii in xrange(num_dcnn): dcnn_bn_loggers[ii].add( step, ['', dcnn_bm[ii], '', dcnn_bv[ii], '', '']) pass
def reset(self): self.info("Resetting concurrent batch iter") self.info("Stopping all workers") for f in self.fetchers: f.stop() self.info("Cleaning queue") cleaner = BatchConsumer(self.q) cleaner.start() for f in self.fetchers: f.join() self.q.join() cleaner.stop() self.info("Resetting index") self.batch_iter.reset() self.info("Restarting workers") self.fetchers = [] self.init_fetchers() self.relaunch = True pass pass if __name__ == "__main__": from batch_iter import BatchIterator b = BatchIterator(100, batch_size=6, get_fn=None) cb = ConcurrentBatchIterator(b, max_queue_size=5, num_threads=3) for _batch in cb: log = logger.get() log.info(("Final out", _batch))