def train_flow(ops, names=None, gen_feed_dict_fn=None, deal_results_fn=melt.print_results, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, optimizer=None, learning_rate=0.1, num_steps_per_epoch=None, model_dir=None, metric_eval_fn=None, debug=False, summary_excls=None, init_fn=None, sess=None): if sess is None: sess = melt.get_session() if debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess) logging.info('learning_rate:{}'.format(FLAGS.learning_rate)) #batch size right now not define here, but in app code like input_app.py melt.set_global('batch_size', FLAGS.batch_size) melt.set_global('num_gpus', max(FLAGS.num_gpus, 1)) #NOTICE since melt.__init__.py with from melt.flow import * then you can not #use melt.flow.train.train_flow but you can always use #from melt.flow.train.train_flow import train_flow if optimizer is None: optimizer = FLAGS.optimizer # Set up the training ops. #notice '' only works in tf >= 0.11, for 0.10 will always add OptimeizeLoss scope #the diff is 0.10 use variable_op_scope and 0.11 use variable_scope optimize_scope = None if FLAGS.optimize_has_scope else '' #or judge by FLAGS.num_gpus if not isinstance(ops[0], (list, tuple)): learning_rate, learning_rate_decay_fn = gen_learning_rate() train_op = tf.contrib.layers.optimize_loss( loss=ops[0], global_step=None, learning_rate=learning_rate, optimizer=melt.util.get_optimizer(optimizer), clip_gradients=FLAGS.clip_gradients, learning_rate_decay_fn=learning_rate_decay_fn, name=optimize_scope) else: #---as in cifa10 example, put all but tower loss on cpu, wiki say, that will be faster, #but here I find without setting to cpu will be faster.. #https://github.com/tensorflow/tensorflow/issues/4881 #I've noticed same thing on cirrascale GPU machines - putting parameters on gpu:0 and using gpu->gpu transfer was a bit faster. I suppose this depends on particular details of hardware -- if you don't have p2p connectivity between your video cards then keeping parameters on CPU:0 gives faster training. #err but for my pc no p2p, with PHB connection nvidia-smi topo -m, still hurt by set cpu.. may be should not put cpu here #with tf.device('/cpu:0'): learning_rate, learning_rate_decay_fn = gen_learning_rate() train_op = melt.layers.optimize_loss( losses=ops[0], num_gpus=FLAGS.num_gpus, global_step=None, learning_rate=learning_rate, optimizer=melt.util.get_optimizer(optimizer), clip_gradients=FLAGS.clip_gradients, learning_rate_decay_fn=learning_rate_decay_fn, name=optimize_scope) #set the last tower loss as loss in ops ops[0] = ops[0][-1] ops.insert(0, train_op) #-----------post deal save_interval_seconds = FLAGS.save_interval_seconds if FLAGS.save_interval_seconds > 0 \ else FLAGS.save_interval_hours * 3600 interval_steps = FLAGS.interval_steps eval_interval_steps = FLAGS.eval_interval_steps metric_eval_interval_steps = FLAGS.metric_eval_interval_steps save_model = FLAGS.save_model save_interval_steps = FLAGS.save_interval_steps if not save_interval_steps: save_interval_steps = 1000000000000 if FLAGS.work_mode == 'train': eval_ops = None metric_eval_fn = None logging.info('running train only mode') elif FLAGS.work_mode == 'train_metric': eval_ops = None assert metric_eval_fn is not None, 'set metric_eval to 1' logging.info('running train+metric mode') elif FLAGS.work_mode == 'train_valid': metric_eval_fn = None logging.info('running train+valid mode') elif FLAGS.work_mode == 'test': ops = None logging.info('running test only mode') interval_steps = 0 eval_interval_steps = 1 metric_eval_interval_steps /= FLAGS.eval_interval_steps save_model = False return melt.flow.train_flow( ops, names=names, gen_feed_dict_fn=gen_feed_dict_fn, deal_results_fn=deal_results_fn, eval_ops=eval_ops, eval_names=eval_names, gen_eval_feed_dict_fn=gen_eval_feed_dict_fn, deal_eval_results_fn=deal_eval_results_fn, interval_steps=interval_steps, eval_interval_steps=eval_interval_steps, num_epochs=FLAGS.num_epochs, num_steps=FLAGS.num_steps, save_interval_seconds=save_interval_seconds, save_interval_steps=save_interval_steps, save_model=save_model, save_interval_epochs=FLAGS.save_interval_epochs, #optimizer=optimizer, optimizer= None, #must set None since here we have done choosing optimizer learning_rate=learning_rate, num_steps_per_epoch=num_steps_per_epoch, max_models_keep=FLAGS.max_models_keep, model_dir=model_dir, restore_from_latest=FLAGS.restore_from_latest, metric_eval_fn=metric_eval_fn, metric_eval_interval_steps=metric_eval_interval_steps, no_log=FLAGS.no_log, summary_excls=summary_excls, init_fn=init_fn, sess=sess)
def train(Dataset, model, loss_fn, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, optimizer=None, param_groups=None, init_fn=None, dataset=None, valid_dataset=None, test_dataset=None, sep=','): if Dataset is None: assert dataset if FLAGS.torch: # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(device) input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs #batch_size = FLAGS.batch_size batch_size = melt.batch_size() num_gpus = melt.num_gpus() #batch_size = max(batch_size, 1) #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = batch_size if FLAGS.fold is not None: inputs = [ x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold) ] # if FLAGS.valid_input: # inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', len(inputs), inputs[:100]) num_folds = FLAGS.num_folds or len(inputs) + 1 train_dataset_ = dataset or Dataset('train') train_dataset = train_dataset_.make_batch(batch_size, inputs) num_examples = train_dataset_.num_examples_per_epoch('train') num_all_examples = num_examples valid_inputs = None if FLAGS.valid_input: valid_inputs = gezi.list_files(FLAGS.valid_input) else: if FLAGS.fold is not None: #valid_inputs = [x for x in all_inputs if x not in inputs] if not FLAGS.test_aug: valid_inputs = [ x for x in all_inputs if not 'aug' in x and x not in inputs ] else: valid_inputs = [ x for x in all_inputs if 'aug' in x and x not in inputs ] logging.info('valid_inputs', valid_inputs) if valid_inputs: valid_dataset_ = valid_dataset or Dataset('valid') valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs) valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True) else: valid_datsset = None valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (num_folds - 1) / num_folds) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None logging.info('num_train_examples:', num_examples) num_valid_examples = None if FLAGS.valid_input: num_valid_examples = valid_dataset_.num_examples_per_epoch('valid') num_valid_steps_per_epoch = -( -num_valid_examples // batch_size_) if num_valid_examples else None else: if FLAGS.fold is not None: if num_examples: num_valid_examples = int(num_all_examples * (1 / num_folds)) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) else: num_valid_steps_per_epoch = None logging.info('num_valid_examples:', num_valid_examples) if FLAGS.test_input: test_inputs = gezi.list_files(FLAGS.test_input) #test_inputs = [x for x in test_inputs if not 'aug' in x] logging.info('test_inputs', test_inputs) else: test_inputs = None num_test_examples = None if test_inputs: test_dataset_ = test_dataset or Dataset('test') test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) num_test_examples = test_dataset_.num_examples_per_epoch('test') num_test_steps_per_epoch = -( -num_test_examples // batch_size_) if num_test_examples else None else: test_dataset = None logging.info('num_test_examples:', num_test_examples) summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid') writer = summary.create_file_writer(FLAGS.log_dir) writer_train = summary.create_file_writer(FLAGS.log_dir) writer_valid = summary.create_file_writer(FLAGS.log_dir) global_step = tf.train.get_or_create_global_step() learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None # ckpt dir save models one per epoch ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt') os.system('mkdir -p %s' % ckpt_dir) # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2') os.system('mkdir -p %s' % ckpt_dir2) #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) if latest_checkpoint: logging.info('Latest checkpoint:', latest_checkpoint) else: latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2) logging.info('Latest checkpoint:', latest_checkpoint) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode: #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir latest_checkpoint = FLAGS.model_dir #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt') if not FLAGS.torch: try: optimizer = optimizer or melt.get_optimizer( FLAGS.optimizer)(learning_rate) except Exception: logging.warning( f'Fail to using {FLAGS.optimizer} use adam instead') optimizer = melt.get_optimizer('adam')(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) start_epoch = int( latest_checkpoint.split('-') [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0 else: # TODO torch with learning rate adjust if optimizer is None: import lele is_dynamic_opt = True if FLAGS.optimizer == 'noam': optimizer = lele.training.optimizers.NoamOpt( 128, 2, 4000, torch.optim.Adamax(model.parameters(), lr=0)) elif FLAGS.optimizer == 'bert': num_train_steps = int( num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs)) num_warmup_steps = FLAGS.warmup_steps or int( num_train_steps * FLAGS.warmup_proportion) logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion) optimizer = lele.training.optimizers.BertOpt( FLAGS.learning_rate, FLAGS.min_learning_rate, num_train_steps, num_warmup_steps, torch.optim.Adamax(model.parameters(), lr=0)) else: is_dynamic_opt = False optimizer = torch.optim.Adamax( param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate) start_epoch = 0 latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join( FLAGS.model_dir, 'latest.pyt') if not os.path.exists(latest_path): latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt') if os.path.exists(latest_path): logging.info('loading torch model from', latest_path) checkpoint = torch.load(latest_path) if not FLAGS.torch_finetune: start_epoch = checkpoint['epoch'] step = checkpoint['step'] global_step.assign(step + 1) load_torch_model(model, latest_path) if FLAGS.torch_load_optimizer: optimizer.load_state_dict(checkpoint['optimizer']) # TODO by this way restart can not change learning rate.. if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) try: checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) except Exception: pass if FLAGS.torch and is_dynamic_opt: optimizer._step = global_step.numpy() #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') logging.info('optimizer:', optimizer) if FLAGS.torch_lr: learning_rate.assign(optimizer.rate(1)) if FLAGS.torch: learning_rate.assign(optimizer.param_groups[0]['lr']) logging.info('learning rate got from pytorch latest.py as', learning_rate) learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor) if learning_rate_weights is not None: learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor) # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024 will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ if start_epoch == 0 and not 'EVFIRST' in os.environ and will_valid: will_valid = False if start_epoch > 0 and will_valid: will_valid = True if will_valid: logging.info('----------valid') if FLAGS.torch: model.eval() names = None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else latest_checkpoint names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir) vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, suffix=valid_suffix, sep=sep) if names: logging.info2( 'epoch:%d/%d' % (start_epoch, num_epochs), ['%s:%.5f' % (name, val) for name, val in zip(names, vals)]) if FLAGS.work_mode == 'valid': exit(0) if 'test' in FLAGS.work_mode: logging.info('--------test/inference') if test_dataset: if FLAGS.torch: model.eval() if inference_fn is None: # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint # logging.info('model_path', model_path) assert latest_checkpoint inference(model, test_dataset, latest_checkpoint, infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, suffix=infer_suffix) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) exit(0) if 'SHOW' in os.environ: num_epochs = start_epoch + 1 class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean timer = gezi.Timer() num_insts = 0 if FLAGS.learning_rate_decay_factor > 0: #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?' #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay #since global step / decay_steps will not be correct epoch as num_steps per epoch changed #so if if you change batch set you have to reset global step as fixed step assert FLAGS.num_steps_per_decay or ( FLAGS.num_epochs_per_decay and num_steps_per_epoch ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch' decay_steps = FLAGS.num_steps_per_decay or int( num_steps_per_epoch * FLAGS.num_epochs_per_decay) decay_start_step = FLAGS.decay_start_step or int( num_steps_per_epoch * FLAGS.decay_start_epoch) # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) logging.info( 'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}' .format(FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step)) for epoch in range(start_epoch, num_epochs): melt.set_global('epoch', '%.4f' % (epoch)) if FLAGS.torch: model.train() epoch_loss_avg = Mean() epoch_valid_loss_avg = Mean() #for i, (x, y) in tqdm(enumerate(train_dataset), total=num_steps_per_epoch, ascii=True): for i, (x, y) in enumerate(train_dataset): if FLAGS.torch: x, y = to_torch(x, y) if is_dynamic_opt: learning_rate.assign(optimizer.rate()) #print(x, y) if not FLAGS.torch: loss, grads = melt.eager.grad(model, x, y, loss_fn) grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients) optimizer.apply_gradients(zip(grads, model.variables)) else: optimizer.zero_grad() if 'training' in inspect.getargspec(loss_fn).args: loss = loss_fn(model, x, y, training=True) else: loss = loss_fn(model, x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), FLAGS.clip_gradients) optimizer.step() global_step.assign_add(1) epoch_loss_avg(loss) # add current batch loss if FLAGS.torch: del loss batch_size_ = list( x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type( {}) else x.shape[FLAGS.batch_size_dim] num_insts += int(batch_size_) if global_step.numpy() % FLAGS.interval_steps == 0: #checkpoint.save(checkpoint_prefix) elapsed = timer.elapsed() steps_per_second = FLAGS.interval_steps / elapsed instances_per_second = num_insts / elapsed num_insts = 0 if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) if valid_dataset2: try: x, y = next(iter(valid_dataset2)) except Exception: # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset x, y = next(iter(valid_dataset2)) if FLAGS.torch: x, y = to_torch(x, y) model.eval() valid_loss = loss_fn(model, x, y) epoch_valid_loss_avg(valid_loss) if FLAGS.torch: model.train() logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(), 'valid_loss:[%.4f]' % epoch_valid_loss_avg.result().numpy()) if global_step.numpy() % FLAGS.eval_interval_steps == 0: with writer_valid.as_default( ), summary.always_record_summaries(): #summary.scalar('step/loss', epoch_valid_loss_avg.result().numpy()) summary.scalar( 'loss/eval', epoch_valid_loss_avg.result().numpy()) writer_valid.flush() else: logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy()) if global_step.numpy() % FLAGS.eval_interval_steps == 0: with writer_train.as_default( ), summary.always_record_summaries(): #summary.scalar('step/loss', epoch_loss_avg.result().numpy()) summary.scalar('loss/train_avg', epoch_loss_avg.result().numpy()) summary.scalar('learning_rate', learning_rate.numpy()) summary.scalar('batch_size', batch_size_) summary.scalar('epoch', melt.epoch()) summary.scalar('steps_per_second', steps_per_second) summary.scalar('instances_per_second', instances_per_second) writer_train.flush() if FLAGS.log_dir != FLAGS.model_dir: assert FLAGS.log_dir command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir) print(command, file=sys.stderr) os.system(command) if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy( ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0: if FLAGS.torch: model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch) elif eval_fn: names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, None, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, sep=sep) if vals and names: with writer_valid.as_default( ), summary.always_record_summaries(): for name, val in zip(names, vals): summary.scalar(f'step/valid/{name}', val) writer_valid.flush() if FLAGS.torch: if not FLAGS.torch_lr: # control learning rate by tensorflow learning rate for param_group in optimizer.param_groups: # important learning rate decay param_group['lr'] = learning_rate.numpy() model.train() if names and vals: logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'valid_step:%d' % global_step.numpy(), 'valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) # if i == 5: # print(i, '---------------------save') # print(len(model.trainable_variables)) ## TODO FIXME seems save weighs value not ok... not the same as checkpoint save # model.save_weights(os.path.join(ckpt_dir, 'weights')) # checkpoint.save(checkpoint_prefix) # exit(0) if global_step.numpy() % FLAGS.save_interval_steps == 0: if FLAGS.torch: state = { 'epoch': epoch, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt')) # TODO fixme why if both checpoint2 and chekpoint used... not ok.. if FLAGS.save_interval_epochs and FLAGS.save_interval_epochs < 1 and global_step.numpy( ) % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: #if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: checkpoint2.save(checkpoint_prefix2) if FLAGS.torch: state = { 'epoch': epoch, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt') if FLAGS.learning_rate_decay_factor > 0: if global_step.numpy( ) >= decay_start_step and global_step.numpy( ) % decay_steps == 0: lr = max( learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate) if lr < learning_rate.numpy(): learning_rate.assign(lr) if FLAGS.torch: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate.numpy() if epoch == start_epoch and i == 0: try: if not FLAGS.torch: logging.info(model.summary()) except Exception: traceback.print_exc() logging.info( 'Fail to do model.summary() may be you have layer define in init but not used in call' ) if 'SHOW' in os.environ: exit(0) logging.info2( 'epoch:%d/%d' % (epoch + 1, num_epochs), 'step:%d' % global_step.numpy(), 'batch_size:[%d]' % batch_size, 'gpus:[%d]' % num_gpus, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(), 'valid_loss::[%.4f]' % epoch_valid_loss_avg.result().numpy()) timer = gezi.Timer( f'save model to {checkpoint_prefix}-{checkpoint.save_counter.numpy() + 1}', False) checkpoint.save(checkpoint_prefix) if FLAGS.torch and FLAGS.save_interval_epochs == 1: state = { 'epoch': epoch + 1, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir) + '.pyt') timer.print_elapsed() if valid_dataset and (epoch + 1) % FLAGS.valid_interval_epochs == 0: if FLAGS.torch: model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else tf.train.latest_checkpoint( ckpt_dir) names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, suffix=valid_suffix, sep=sep) if vals and names: logging.info2('epoch:%d/%d' % (epoch + 1, num_epochs), 'step:%d' % global_step.numpy(), 'epoch_valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) with writer.as_default(), summary.always_record_summaries(): temp = global_step.value() global_step.assign(epoch + 1) summary.scalar('epoch/train/loss', epoch_loss_avg.result().numpy()) if valid_dataset: if FLAGS.torch: model.eval() if vals and names: for name, val in zip(names, vals): summary.scalar(f'epoch/valid/{name}', val) writer.flush() global_step.assign(temp) if test_dataset and (epoch + 1) % FLAGS.inference_interval_epochs == 0: if FLAGS.torch: model.eval() if inference_fn is None: inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, suffix=infer_suffix, sep=sep) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) if FLAGS.log_dir != FLAGS.model_dir: assert FLAGS.log_dir command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir) print(command, file=sys.stderr) os.system(command) command = 'rm -rf %s/latest.pyt.*' % (FLAGS.model_dir) print(command, file=sys.stderr) os.system(command)
def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=None, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, valid_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None, # for epoch only, incase you change batch size eval_loops=1, learning_rate=None, learning_rate_patience=None, learning_rate_decay_factor=None, num_epochs=None, model_path=None, use_horovod=False, ): use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ #is_start = False # force not to evaluate at first step #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step())) timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 if not num_epochs: epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else '' else: epoch_str = 'epoch:%.3f/%d' % ( epoch, num_epochs) if num_steps_per_epoch else '' melt.set_global('epoch', '%.2f' % (epoch)) info = IO() stop = False if eval_names is None: if names: eval_names = ['eval/' + x for x in names] if names: names = ['train/' + x for x in names] if eval_names: eval_names = ['eval/' + x for x in eval_names] is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0 summary_str = [] eval_str = '' if is_eval_step: # deal with summary if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) # if eval ops then should have bee rank 0 if eval_ops: #if deal_eval_results_fn is None and eval_names is not None: # deal_eval_results_fn = lambda x: melt.print_results(x, eval_names) for i in range(eval_loops): eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) # if use horovod let each rant use same sess.run! if not log_dir or train_once.summary_op is None or gezi.env_has( 'EVAL_NO_SUMMARY') or use_horovod: #if not log_dir or train_once.summary_op is None: eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) else: eval_results = sess.run(eval_ops + [train_once.summary_op], feed_dict=eval_feed_dict) summary_str = eval_results[-1] eval_results = eval_results[:-1] eval_loss = gezi.get_singles(eval_results) #timer_.print() eval_stop = False if use_horovod: sess.run(hvd.allreduce(tf.constant(0))) #if not use_horovod or hvd.local_rank() == 0: # @TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') eval_names_ = melt.adjust_names(eval_loss, eval_names) #if not use_horovod or hvd.rank() == 0: # logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_))) eval_str = 'valid:{}'.format( melt.parse_results(eval_loss, eval_names_)) # if deal_eval_results_fn is not None: # eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) if not use_horovod or hvd.rank() == 0: melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != valid_interval_steps: #print() pass metric_evaluate = False # if metric_eval_fn is not None \ # and (is_start \ # or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)): # metric_evaluate = True if metric_eval_fn is not None \ and ((is_start or metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0) or model_path): metric_evaluate = True if 'EVFIRST' in os.environ: if os.environ['EVFIRST'] == '0': if is_start: metric_evaluate = False else: if is_start: metric_evaluate = True if step == 0 or 'QUICK' in os.environ: metric_evaluate = False #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank()) if metric_evaluate: if use_horovod: print('------------metric evaluate step', step, model_path, hvd.rank()) if not model_path or 'model_path' not in inspect.getargspec( metric_eval_fn).args: metric_eval_fn_ = metric_eval_fn else: metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path) try: l = metric_eval_fn_() if isinstance(l, tuple): num_returns = len(l) if num_returns == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok' evaluate_results, evaluate_names, evaluate_summaries = l else: #return dict evaluate_results, evaluate_names = tuple(zip(*dict.items())) evaluate_summaries = None except Exception: logging.info('Do nothing for metric eval fn with exception:\n', traceback.format_exc()) if not use_horovod or hvd.rank() == 0: #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) logging.info2('{} valid_step:{} {}:{}'.format( epoch_str, step, 'valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) if learning_rate is not None and (learning_rate_patience and learning_rate_patience > 0): assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1 valid_loss = evaluate_results[0] if not hasattr(train_once, 'min_valid_loss'): train_once.min_valid_loss = valid_loss train_once.deacy_steps = [] train_once.patience = 0 else: if valid_loss < train_once.min_valid_loss: train_once.min_valid_loss = valid_loss train_once.patience = 0 else: train_once.patience += 1 logging.info2('{} valid_step:{} patience:{}'.format( epoch_str, step, train_once.patience)) if learning_rate_patience and train_once.patience >= learning_rate_patience: lr_op = ops[1] lr = sess.run(lr_op) * learning_rate_decay_factor train_once.deacy_steps.append(step) logging.info2( '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}' .format(epoch_str, step, learning_rate_decay_factor, ','.join(map(str, train_once.deacy_steps)))) sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32))) train_once.patience = 0 train_once.min_valid_loss = valid_loss if ops is not None: #if deal_results_fn is None and names is not None: # deal_results_fn = lambda x: melt.print_results(x, names) feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar #print('---------------ops', ops) if eval_ops is not None or not log_dir or not hasattr( train_once, 'summary_op') or train_once.summary_op is None or use_horovod: feed_dict[K.learning_phase()] = 1 results = sess.run(ops, feed_dict=feed_dict) else: ## TODO why below ? #try: feed_dict[K.learning_phase()] = 1 results = sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) summary_str = results[-1] results = results[:-1] # except Exception: # logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail') # results = sess.run(ops, feed_dict=feed_dict) #print('------------results', results) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op, results[1] to be learning_rate learning_rate = results[1] results = results[2:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) steps_per_second = None instances_per_second = None hours_per_epoch = None #step += 1 #if is_start or interval_steps and step % interval_steps == 0: interval_ok = not use_horovod or hvd.local_rank() == 0 if interval_steps and step % interval_steps == 0 and interval_ok: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.2f} '.format(duration) melt.set_global('duration', '%.2f' % duration) #info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) info.write( 'elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]' .format(elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_))) info.write(' train:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) info.write(eval_str) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) summary_strs = gezi.to_list(summary_str) if metric_evaluate: if evaluate_summaries is not None: summary_strs += evaluate_summaries if step > 1: if is_eval_step: # deal with summary if log_dir: summary = tf.Summary() if eval_ops is None: if train_once.summary_op is not None: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) else: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) suffix = 'valid' if not eval_names else '' # loss/valid melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) if ops is not None: try: # loss/train_avg melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg') except Exception: pass ##optimizer has done this also melt.add_summary(summary, learning_rate, 'learning_rate') melt.add_summary(summary, melt.batch_size(), 'batch_size', prefix='other') melt.add_summary(summary, melt.epoch(), 'epoch', prefix='other') if steps_per_second: melt.add_summary(summary, steps_per_second, 'steps_per_second', prefix='perf') if instances_per_second: melt.add_summary(summary, instances_per_second, 'instances_per_second', prefix='perf') if hours_per_epoch: melt.add_summary(summary, hours_per_epoch, 'hours_per_epoch', prefix='perf') if metric_evaluate: #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') prefix = 'step_eval' if model_path: prefix = 'eval' if not hasattr(train_once, 'epoch_step'): train_once.epoch_step = 1 else: train_once.epoch_step += 1 step = train_once.epoch_step # eval/loss eval/auc .. melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() return stop elif metric_evaluate and log_dir: summary = tf.Summary() for summary_str in summary_strs: train_once.summary_writer.add_summary(summary_str, step) #summary.ParseFromString(evaluate_summaries) summary_writer = train_once.summary_writer prefix = 'step_eval' if model_path: prefix = 'eval' if not hasattr(train_once, 'epoch_step'): ## TODO.. restart will get 1 again.. #epoch_step = tf.Variable(0, trainable=False, name='epoch_step') #epoch_step += 1 #train_once.epoch_step = sess.run(epoch_step) valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass train_once.epoch_step = 1 if melt.epoch() <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) logging.info('train_once epoch start step is', train_once.epoch_step) else: #epoch_step += 1 train_once.epoch_step += 1 step = train_once.epoch_step #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) summary_writer.add_summary(summary, step) summary_writer.flush()
def train_once(sess, step, ops, names=None, gen_feed_dict=None, deal_results=melt.print_results, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict=None, deal_eval_results=melt.print_results, eval_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_function=None, metric_eval_interval_steps=0): timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = step / num_steps_per_epoch if num_steps_per_epoch else -1 epoch_str = 'epoch:%.4f' % (epoch) if num_steps_per_epoch else '' melt.set_global('epoch', '%.4f' % (epoch)) info = BytesIO() stop = False if ops is not None: if deal_results is None and names is not None: deal_results = lambda x: melt.print_results(x, names) if deal_eval_results is None and eval_names is not None: deal_eval_results = lambda x: melt.print_results(x, eval_names) if eval_names is None: eval_names = names feed_dict = {} if gen_feed_dict is None else gen_feed_dict() results = sess.run(ops, feed_dict=feed_dict) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op results = results[1:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() if interval_steps != eval_interval_steps: train_once.avg_loss2 = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) if interval_steps != eval_interval_steps: train_once.avg_loss2.add(loss) if is_start or interval_steps and step % interval_steps == 0: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.3f} '.format(duration) melt.set_global('duration', '%.3f' % duration) info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size * num_gpus / elapsed if num_gpus == 1: info.write( 'elapsed:[{:.3f}] batch_size:[{}] batches/s:[{:.2f}] insts/s:[{:.2f}] ' .format(elapsed, batch_size, steps_per_second, instances_per_second)) else: info.write( 'elapsed:[{:.3f}] batch_size:[{}] gpus:[{}], batches/s:[{:.2f}] insts/s:[{:.2f}] ' .format(elapsed, batch_size, num_gpus, steps_per_second, instances_per_second)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) info.write('train_avg_metrics:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'train_step:%d' % step, info.getvalue())) if deal_results is not None: stop = deal_results(results) metric_evaluate = False # if metric_eval_function is not None \ # and ( (is_start and (step or ops is None))\ # or (step and ((num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)))): # metric_evaluate = True if metric_eval_function is not None \ and (is_start \ or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ or (metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0)): metric_evaluate = True if metric_evaluate: evaluate_results, evaluate_names = metric_eval_function() if is_start or eval_interval_steps and step % eval_interval_steps == 0: if ops is not None: if interval_steps != eval_interval_steps: train_average_loss = train_once.avg_loss2.avg_score() info = BytesIO() names_ = melt.adjust_names(results, names) train_average_loss_str = '' if print_avg_loss and interval_steps != eval_interval_steps: train_average_loss_str = melt.value_name_list_str( train_average_loss, names_) melt.set_global('train_loss', train_average_loss_str) train_average_loss_str = 'train_avg_loss:{} '.format( train_average_loss_str) if interval_steps != eval_interval_steps: #end = '' if eval_ops is None else '\n' #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, train_average_loss_str, end=end) logging.info2('{} eval_step: {} {}'.format( epoch_str, step, train_average_loss_str)) if eval_ops is not None: eval_feed_dict = {} if gen_eval_feed_dict is None else gen_eval_feed_dict( ) #eval_feed_dict.update(feed_dict) #------show how to perf debug ##timer_ = gezi.Timer('sess run generate') ##sess.run(eval_ops[-2], feed_dict=None) ##timer_.print() timer_ = gezi.Timer('sess run eval_ops') eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) timer_.print() if deal_eval_results is not None: #@TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') logging.info2('{} eval_step: {} eval_metrics:'.format( epoch_str, step)) eval_stop = deal_eval_results(eval_results) eval_loss = gezi.get_singles(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != eval_interval_steps: #print() pass if log_dir: #timer_ = gezi.Timer('witting log') if not hasattr(train_once, 'summary_op'): try: train_once.summary_op = tf.summary.merge_all() except Exception: train_once.summary_op = tf.merge_all_summaries() melt.print_summary_ops() try: train_once.summary_train_op = tf.summary.merge_all( key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) except Exception: train_once.summary_train_op = tf.merge_all_summaries( key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.train.SummaryWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) summary = tf.Summary() #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss if train_once.summary_train_op is not None: summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict) train_once.summary_writer.add_summary(summary_str, step) if eval_ops is None: #get train loss, for every batch train if train_once.summary_op is not None: #timer2 = gezi.Timer('sess run') summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) #timer2.print() train_once.summary_writer.add_summary(summary_str, step) else: #get eval loss for every batch eval, then add train loss for eval step average loss summary_str = sess.run( train_once.summary_op, feed_dict=eval_feed_dict ) if train_once.summary_op is not None else '' #all single value results will be add to summary here not using tf.scalar_summary.. summary.ParseFromString(summary_str) melt.add_summarys(summary, eval_results, eval_names_, suffix='eval') melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg%dsteps' % eval_interval_steps) if metric_evaluate: melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='evaluate') train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() #timer_.print() if print_time: full_duration = train_once.eval_timer.elapsed() if metric_evaluate: metric_full_duration = train_once.metric_eval_timer.elapsed() full_duration_str = 'elapsed:{:.3f} '.format(full_duration) #info.write('duration:{:.3f} '.format(timer.elapsed())) duration = timer.elapsed() info.write('duration:{:.3f} '.format(duration)) info.write(full_duration_str) info.write('eval_time_ratio:{:.3f} '.format(duration / full_duration)) if metric_evaluate: info.write('metric_time_ratio:{:.3f} '.format( duration / metric_full_duration)) #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue()) logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d' % step, info.getvalue())) return stop
def train(Dataset, model, loss_fn, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, sep=','): if FLAGS.torch: if torch.cuda.is_available(): model.cuda() input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs batch_size = FLAGS.batch_size num_gpus = melt.num_gpus() if num_gpus > 1: assert False, 'Eager mode train currently not support for num gpus > 1' #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = batch_size if FLAGS.fold is not None: inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', inputs) dataset = Dataset('train') num_examples = dataset.num_examples_per_epoch('train') num_all_examples = num_examples # if FLAGS.fold is not None: # valid_inputs = [x for x in all_inputs if x not in inputs] # else: # valid_inputs = gezi.list_files(FLAGS.valid_input) # logging.info('valid_inputs', valid_inputs) # if valid_inputs: # valid_dataset_ = Dataset('valid') # valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs) # valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True) # else: # valid_datsset = None # valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (len(inputs) / (len(inputs) + 1))) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None # if FLAGS.fold is not None: # if num_examples: # num_valid_examples = int(num_all_examples * (1 / (len(inputs) + 1))) # num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) # else: # num_valid_steps_per_epoch = None # else: # num_valid_examples = valid_dataset_.num_examples_per_epoch('valid') # num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None # test_inputs = gezi.list_files(FLAGS.test_input) # logging.info('test_inputs', test_inputs) # if test_inputs: # test_dataset_ = Dataset('test') # test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) # num_test_examples = test_dataset_.num_examples_per_epoch('test') # num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None # else: # test_dataset = None summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.model_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.model_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.model_dir + '/valid') writer = summary.create_file_writer(FLAGS.model_dir) writer_train = summary.create_file_writer(FLAGS.model_dir) writer_valid = summary.create_file_writer(FLAGS.model_dir) global_step = tf.train.get_or_create_global_step() learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None ckpt_dir = FLAGS.model_dir + '/ckpt' #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) logging.info('Latest checkpoint:', latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') if not FLAGS.torch: optimizer = melt.get_optimizer(FLAGS.optimizer)(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir checkpoint.restore(latest_checkpoint) start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint else 0 else: # TODO torch with learning rate adjust optimizer = torch.optim.Adamax(model.parameters(), lr=FLAGS.learning_rate) if latest_checkpoint: checkpoint = torch.load(latest_checkpoint + '.pyt') start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) model.eval() else: start_epoch = 0 if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) # TODO consider multiple gpu for torch iter = dataset.make_batch(batch_size, inputs, repeat=False, initializable=False) batch = iter.get_next() #x, y = melt.split_batch(batch, batch_size, num_gpus) x_, y_ = batch Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean epoch_loss_avg = Mean() epoch_valid_loss_avg = Mean() sess = melt.get_session(device_count={'GPU': 0}) global_step = 0 for epoch in range(start_epoch, num_epochs): melt.set_global('epoch', '%.4f' % (epoch)) sess.run(iter.initializer) model.train() #..... still OOM... FIXME TODO try: for _ in tqdm(range(num_steps_per_epoch), total=num_steps_per_epoch, ascii=True): x, y = sess.run([x_, y_]) x, y = to_torch(x, y) optimizer.zero_grad() loss = loss_fn(model, x, y) loss.backward() optimizer.step() epoch_loss_avg(loss) if global_step % FLAGS.interval_steps == 0: print(global_step, epoch_loss_avg.result().numpy()) global_step += 1 except tf.errors.OutOfRangeError: print('epoch:%d loss:%f' % (epoch, epoch_loss_avg.result().numpy()))
def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=None, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, eval_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None, # for epoch only, incase you change batch size eval_loops=1, learning_rate=None, learning_rate_patience=None, learning_rate_decay_factor=None, num_epochs=None, model_path=None, ): #is_start = False # force not to evaluate at first step #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step())) timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 if not num_epochs: epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else '' else: epoch_str = 'epoch:%.3f/%d' % ( epoch, num_epochs) if num_steps_per_epoch else '' melt.set_global('epoch', '%.2f' % (epoch)) info = IO() stop = False if eval_names is None: if names: eval_names = ['eval/' + x for x in names] if names: names = ['train/' + x for x in names] if eval_names: eval_names = ['eval/' + x for x in eval_names] is_eval_step = is_start or eval_interval_steps and step % eval_interval_steps == 0 summary_str = [] if is_eval_step: # deal with summary if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) if eval_ops is not None: #if deal_eval_results_fn is None and eval_names is not None: # deal_eval_results_fn = lambda x: melt.print_results(x, eval_names) for i in range(eval_loops): eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) # might use EVAL_NO_SUMMARY if some old code has problem TODO CHECK if not log_dir or train_once.summary_op is None or gezi.env_has( 'EVAL_NO_SUMMARY'): #if not log_dir or train_once.summary_op is None: eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) else: eval_results = sess.run(eval_ops + [train_once.summary_op], feed_dict=eval_feed_dict) summary_str = eval_results[-1] eval_results = eval_results[:-1] eval_loss = gezi.get_singles(eval_results) #timer_.print() eval_stop = False # @TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') eval_names_ = melt.adjust_names(eval_loss, eval_names) logging.info2('{} eval_step:{} eval_metrics:{}'.format( epoch_str, step, melt.parse_results(eval_loss, eval_names_))) # if deal_eval_results_fn is not None: # eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != eval_interval_steps: #print() pass metric_evaluate = False # if metric_eval_fn is not None \ # and (is_start \ # or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)): # metric_evaluate = True if metric_eval_fn is not None \ and ((is_start or metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0) or model_path): metric_evaluate = True #if (is_start or step == 0) and (not 'EVFIRST' in os.environ): if ((step == 0) and (not 'EVFIRST' in os.environ)) or ('QUICK' in os.environ) or ( 'EVFIRST' in os.environ and os.environ['EVFIRST'] == '0'): metric_evaluate = False if metric_evaluate: # TODO better if not model_path or 'model_path' not in inspect.getargspec( metric_eval_fn).args: l = metric_eval_fn() if len(l) == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: evaluate_results, evaluate_names, evaluate_summaries = l else: try: l = metric_eval_fn(model_path=model_path) if len(l) == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: evaluate_results, evaluate_names, evaluate_summaries = l except Exception: logging.info('Do nothing for metric eval fn with exception:\n', traceback.format_exc()) logging.info2('{} valid_step:{} {}:{}'.format( epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) if learning_rate is not None and (learning_rate_patience and learning_rate_patience > 0): assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1 valid_loss = evaluate_results[0] if not hasattr(train_once, 'min_valid_loss'): train_once.min_valid_loss = valid_loss train_once.deacy_steps = [] train_once.patience = 0 else: if valid_loss < train_once.min_valid_loss: train_once.min_valid_loss = valid_loss train_once.patience = 0 else: train_once.patience += 1 logging.info2('{} valid_step:{} patience:{}'.format( epoch_str, step, train_once.patience)) if learning_rate_patience and train_once.patience >= learning_rate_patience: lr_op = ops[1] lr = sess.run(lr_op) * learning_rate_decay_factor train_once.deacy_steps.append(step) logging.info2( '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}' .format(epoch_str, step, learning_rate_decay_factor, ','.join(map(str, train_once.deacy_steps)))) sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32))) train_once.patience = 0 train_once.min_valid_loss = valid_loss if ops is not None: #if deal_results_fn is None and names is not None: # deal_results_fn = lambda x: melt.print_results(x, names) feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar #print('---------------ops', ops) if eval_ops is not None or not log_dir or not hasattr( train_once, 'summary_op') or train_once.summary_op is None: results = sess.run(ops, feed_dict=feed_dict) else: #try: results = sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) summary_str = results[-1] results = results[:-1] # except Exception: # logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail') # results = sess.run(ops, feed_dict=feed_dict) #print('------------results', results) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op, results[1] to be learning_rate learning_rate = results[1] results = results[2:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() if interval_steps != eval_interval_steps: train_once.avg_loss2 = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) if interval_steps != eval_interval_steps: train_once.avg_loss2.add(loss) steps_per_second = None instances_per_second = None hours_per_epoch = None #step += 1 if is_start or interval_steps and step % interval_steps == 0: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.3f} '.format(duration) melt.set_global('duration', '%.3f' % duration) info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600 epoch_time_info = ' 1epoch:[{:.2f}h]'.format( hours_per_epoch) info.write( 'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.8f}]' .format(elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_))) info.write(' train:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) summary_strs = gezi.to_list(summary_str) if metric_evaluate: if evaluate_summaries is not None: summary_strs += evaluate_summaries if step > 1: if is_eval_step: # deal with summary if log_dir: # if not hasattr(train_once, 'summary_op'): # melt.print_summary_ops() # if summary_excls is None: # train_once.summary_op = tf.summary.merge_all() # else: # summary_ops = [] # for op in tf.get_collection(tf.GraphKeys.SUMMARIES): # for summary_excl in summary_excls: # if not summary_excl in op.name: # summary_ops.append(op) # print('filtered summary_ops:') # for op in summary_ops: # print(op) # train_once.summary_op = tf.summary.merge(summary_ops) # print('-------------summary_op', train_once.summary_op) # #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) # train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_once.summary_writer, projector_config) summary = tf.Summary() # #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset # #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss # assert train_once.summary_train_op is None # if train_once.summary_train_op is not None: # summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict) # train_once.summary_writer.add_summary(summary_str, step) if eval_ops is None: # #get train loss, for every batch train # if train_once.summary_op is not None: # #timer2 = gezi.Timer('sess run') # try: # # TODO FIXME so this means one more train batch step without adding to global step counter ?! so should move it earlier # summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) # except Exception: # if not hasattr(train_once, 'num_summary_errors'): # logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) fail') # train_once.num_summary_errors = 1 # logging.warning(traceback.format_exc()) # summary_str = '' # # #timer2.print() if train_once.summary_op is not None: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) else: # #get eval loss for every batch eval, then add train loss for eval step average loss # try: # summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) if train_once.summary_op is not None else '' # except Exception: # if not hasattr(train_once, 'num_summary_errors'): # logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail') # train_once.num_summary_errors = 1 # logging.warning(traceback.format_exc()) # summary_str = '' #all single value results will be add to summary here not using tf.scalar_summary.. #summary.ParseFromString(summary_str) for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) suffix = 'eval' if not eval_names else '' melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) if ops is not None: melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg') ##optimizer has done this also melt.add_summary(summary, learning_rate, 'learning_rate') melt.add_summary(summary, melt.batch_size(), 'batch_size') melt.add_summary(summary, melt.epoch(), 'epoch') if steps_per_second: melt.add_summary(summary, steps_per_second, 'steps_per_second') if instances_per_second: melt.add_summary(summary, instances_per_second, 'instances_per_second') if hours_per_epoch: melt.add_summary(summary, hours_per_epoch, 'hours_per_epoch') if metric_evaluate: #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') prefix = 'step/valid' if model_path: prefix = 'epoch/valid' if not hasattr(train_once, 'epoch_step'): train_once.epoch_step = 1 else: train_once.epoch_step += 1 step = train_once.epoch_step melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() #timer_.print() # if print_time: # full_duration = train_once.eval_timer.elapsed() # if metric_evaluate: # metric_full_duration = train_once.metric_eval_timer.elapsed() # full_duration_str = 'elapsed:{:.3f} '.format(full_duration) # #info.write('duration:{:.3f} '.format(timer.elapsed())) # duration = timer.elapsed() # info.write('duration:{:.3f} '.format(duration)) # info.write(full_duration_str) # info.write('eval_time_ratio:{:.3f} '.format(duration/full_duration)) # if metric_evaluate: # info.write('metric_time_ratio:{:.3f} '.format(duration/metric_full_duration)) # #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue()) # logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d'%step, info.getvalue())) return stop elif metric_evaluate: summary = tf.Summary() for summary_str in summary_strs: train_once.summary_writer.add_summary(summary_str, step) #summary.ParseFromString(evaluate_summaries) summary_writer = train_once.summary_writer prefix = 'step/valid' if model_path: prefix = 'epoch/valid' if not hasattr(train_once, 'epoch_step'): ## TODO.. restart will get 1 again.. #epoch_step = tf.Variable(0, trainable=False, name='epoch_step') #epoch_step += 1 #train_once.epoch_step = sess.run(epoch_step) valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass train_once.epoch_step = 1 if melt.epoch() <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) logging.info('train_once epoch start step is', train_once.epoch_step) else: #epoch_step += 1 train_once.epoch_step += 1 step = train_once.epoch_step #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) summary_writer.add_summary(summary, step) summary_writer.flush()