def train_once(sess, step, input_text, text, model, optimizer): if not hasattr(train_once, 'train_loss'): train_once.train_loss = 0. if not hasattr(train_once, 'summary_writter'): log_dir = FLAGS.model_dir train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph) summary = tf.Summary() pred = model(input_text, text, feed_previous=False) total_loss = 0. total_words = 0 batch_size = len(text) time_steps = text.size()[1] for time_step in xrange(time_steps - 1): y_pred = pred[time_step] target = text[:, time_step + 1] loss = criterion(y_pred, target) total_loss += loss total_words += target.data.ne(0).sum() total_loss /= total_words #total_loss /= batch_size optimizer.zero_grad() #print('loss', total_loss) total_loss.backward() optimizer.step() #NOTICE! must be .data[0] other wise will consume more and more gpu mem, see #https://discuss.pytorch.org/t/cuda-memory-continuously-increases-when-net-images-called-in-every-iteration/501 #https://discuss.pytorch.org/t/understanding-graphs-and-state/224/1 train_once.train_loss += total_loss.data[0] steps = FLAGS.interval_steps if step % steps == 0: avg_loss = train_once.train_loss if step is 0 else train_once.train_loss / steps print('step:', step, 'train_loss:', avg_loss) train_once.train_loss = 0. names = melt.adjust_names([avg_loss], None) melt.add_summarys(summary, [avg_loss], names, suffix='train_avg%dsteps' % steps) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush()
def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=None, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, valid_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None, # for epoch only, incase you change batch size eval_loops=1, learning_rate=None, learning_rate_patience=None, learning_rate_decay_factor=None, num_epochs=None, model_path=None, use_horovod=False, ): use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ #is_start = False # force not to evaluate at first step #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step())) timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 if not num_epochs: epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else '' else: epoch_str = 'epoch:%.3f/%d' % ( epoch, num_epochs) if num_steps_per_epoch else '' melt.set_global('epoch', '%.2f' % (epoch)) info = IO() stop = False if eval_names is None: if names: eval_names = ['eval/' + x for x in names] if names: names = ['train/' + x for x in names] if eval_names: eval_names = ['eval/' + x for x in eval_names] is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0 summary_str = [] eval_str = '' if is_eval_step: # deal with summary if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) # if eval ops then should have bee rank 0 if eval_ops: #if deal_eval_results_fn is None and eval_names is not None: # deal_eval_results_fn = lambda x: melt.print_results(x, eval_names) for i in range(eval_loops): eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) # if use horovod let each rant use same sess.run! if not log_dir or train_once.summary_op is None or gezi.env_has( 'EVAL_NO_SUMMARY') or use_horovod: #if not log_dir or train_once.summary_op is None: eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) else: eval_results = sess.run(eval_ops + [train_once.summary_op], feed_dict=eval_feed_dict) summary_str = eval_results[-1] eval_results = eval_results[:-1] eval_loss = gezi.get_singles(eval_results) #timer_.print() eval_stop = False if use_horovod: sess.run(hvd.allreduce(tf.constant(0))) #if not use_horovod or hvd.local_rank() == 0: # @TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') eval_names_ = melt.adjust_names(eval_loss, eval_names) #if not use_horovod or hvd.rank() == 0: # logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_))) eval_str = 'valid:{}'.format( melt.parse_results(eval_loss, eval_names_)) # if deal_eval_results_fn is not None: # eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) if not use_horovod or hvd.rank() == 0: melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != valid_interval_steps: #print() pass metric_evaluate = False # if metric_eval_fn is not None \ # and (is_start \ # or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)): # metric_evaluate = True if metric_eval_fn is not None \ and ((is_start or metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0) or model_path): metric_evaluate = True if 'EVFIRST' in os.environ: if os.environ['EVFIRST'] == '0': if is_start: metric_evaluate = False else: if is_start: metric_evaluate = True if step == 0 or 'QUICK' in os.environ: metric_evaluate = False #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank()) if metric_evaluate: if use_horovod: print('------------metric evaluate step', step, model_path, hvd.rank()) if not model_path or 'model_path' not in inspect.getargspec( metric_eval_fn).args: metric_eval_fn_ = metric_eval_fn else: metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path) try: l = metric_eval_fn_() if isinstance(l, tuple): num_returns = len(l) if num_returns == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok' evaluate_results, evaluate_names, evaluate_summaries = l else: #return dict evaluate_results, evaluate_names = tuple(zip(*dict.items())) evaluate_summaries = None except Exception: logging.info('Do nothing for metric eval fn with exception:\n', traceback.format_exc()) if not use_horovod or hvd.rank() == 0: #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) logging.info2('{} valid_step:{} {}:{}'.format( epoch_str, step, 'valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) if learning_rate is not None and (learning_rate_patience and learning_rate_patience > 0): assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1 valid_loss = evaluate_results[0] if not hasattr(train_once, 'min_valid_loss'): train_once.min_valid_loss = valid_loss train_once.deacy_steps = [] train_once.patience = 0 else: if valid_loss < train_once.min_valid_loss: train_once.min_valid_loss = valid_loss train_once.patience = 0 else: train_once.patience += 1 logging.info2('{} valid_step:{} patience:{}'.format( epoch_str, step, train_once.patience)) if learning_rate_patience and train_once.patience >= learning_rate_patience: lr_op = ops[1] lr = sess.run(lr_op) * learning_rate_decay_factor train_once.deacy_steps.append(step) logging.info2( '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}' .format(epoch_str, step, learning_rate_decay_factor, ','.join(map(str, train_once.deacy_steps)))) sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32))) train_once.patience = 0 train_once.min_valid_loss = valid_loss if ops is not None: #if deal_results_fn is None and names is not None: # deal_results_fn = lambda x: melt.print_results(x, names) feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar #print('---------------ops', ops) if eval_ops is not None or not log_dir or not hasattr( train_once, 'summary_op') or train_once.summary_op is None or use_horovod: feed_dict[K.learning_phase()] = 1 results = sess.run(ops, feed_dict=feed_dict) else: ## TODO why below ? #try: feed_dict[K.learning_phase()] = 1 results = sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) summary_str = results[-1] results = results[:-1] # except Exception: # logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail') # results = sess.run(ops, feed_dict=feed_dict) #print('------------results', results) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op, results[1] to be learning_rate learning_rate = results[1] results = results[2:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) steps_per_second = None instances_per_second = None hours_per_epoch = None #step += 1 #if is_start or interval_steps and step % interval_steps == 0: interval_ok = not use_horovod or hvd.local_rank() == 0 if interval_steps and step % interval_steps == 0 and interval_ok: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.2f} '.format(duration) melt.set_global('duration', '%.2f' % duration) #info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) info.write( 'elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]' .format(elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_))) info.write(' train:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) info.write(eval_str) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) summary_strs = gezi.to_list(summary_str) if metric_evaluate: if evaluate_summaries is not None: summary_strs += evaluate_summaries if step > 1: if is_eval_step: # deal with summary if log_dir: summary = tf.Summary() if eval_ops is None: if train_once.summary_op is not None: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) else: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) suffix = 'valid' if not eval_names else '' # loss/valid melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) if ops is not None: try: # loss/train_avg melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg') except Exception: pass ##optimizer has done this also melt.add_summary(summary, learning_rate, 'learning_rate') melt.add_summary(summary, melt.batch_size(), 'batch_size', prefix='other') melt.add_summary(summary, melt.epoch(), 'epoch', prefix='other') if steps_per_second: melt.add_summary(summary, steps_per_second, 'steps_per_second', prefix='perf') if instances_per_second: melt.add_summary(summary, instances_per_second, 'instances_per_second', prefix='perf') if hours_per_epoch: melt.add_summary(summary, hours_per_epoch, 'hours_per_epoch', prefix='perf') if metric_evaluate: #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') prefix = 'step_eval' if model_path: prefix = 'eval' if not hasattr(train_once, 'epoch_step'): train_once.epoch_step = 1 else: train_once.epoch_step += 1 step = train_once.epoch_step # eval/loss eval/auc .. melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() return stop elif metric_evaluate and log_dir: summary = tf.Summary() for summary_str in summary_strs: train_once.summary_writer.add_summary(summary_str, step) #summary.ParseFromString(evaluate_summaries) summary_writer = train_once.summary_writer prefix = 'step_eval' if model_path: prefix = 'eval' if not hasattr(train_once, 'epoch_step'): ## TODO.. restart will get 1 again.. #epoch_step = tf.Variable(0, trainable=False, name='epoch_step') #epoch_step += 1 #train_once.epoch_step = sess.run(epoch_step) valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass train_once.epoch_step = 1 if melt.epoch() <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) logging.info('train_once epoch start step is', train_once.epoch_step) else: #epoch_step += 1 train_once.epoch_step += 1 step = train_once.epoch_step #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) summary_writer.add_summary(summary, step) summary_writer.flush()
def train_once(sess, step, ops, names=None, gen_feed_dict=None, deal_results=melt.print_results, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict=None, deal_eval_results=melt.print_results, eval_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_function=None, metric_eval_interval_steps=0): timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = step / num_steps_per_epoch if num_steps_per_epoch else -1 epoch_str = 'epoch:%.4f' % (epoch) if num_steps_per_epoch else '' melt.set_global('epoch', '%.4f' % (epoch)) info = BytesIO() stop = False if ops is not None: if deal_results is None and names is not None: deal_results = lambda x: melt.print_results(x, names) if deal_eval_results is None and eval_names is not None: deal_eval_results = lambda x: melt.print_results(x, eval_names) if eval_names is None: eval_names = names feed_dict = {} if gen_feed_dict is None else gen_feed_dict() results = sess.run(ops, feed_dict=feed_dict) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op results = results[1:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() if interval_steps != eval_interval_steps: train_once.avg_loss2 = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) if interval_steps != eval_interval_steps: train_once.avg_loss2.add(loss) if is_start or interval_steps and step % interval_steps == 0: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.3f} '.format(duration) melt.set_global('duration', '%.3f' % duration) info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size * num_gpus / elapsed if num_gpus == 1: info.write( 'elapsed:[{:.3f}] batch_size:[{}] batches/s:[{:.2f}] insts/s:[{:.2f}] ' .format(elapsed, batch_size, steps_per_second, instances_per_second)) else: info.write( 'elapsed:[{:.3f}] batch_size:[{}] gpus:[{}], batches/s:[{:.2f}] insts/s:[{:.2f}] ' .format(elapsed, batch_size, num_gpus, steps_per_second, instances_per_second)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) info.write('train_avg_metrics:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'train_step:%d' % step, info.getvalue())) if deal_results is not None: stop = deal_results(results) metric_evaluate = False # if metric_eval_function is not None \ # and ( (is_start and (step or ops is None))\ # or (step and ((num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)))): # metric_evaluate = True if metric_eval_function is not None \ and (is_start \ or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ or (metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0)): metric_evaluate = True if metric_evaluate: evaluate_results, evaluate_names = metric_eval_function() if is_start or eval_interval_steps and step % eval_interval_steps == 0: if ops is not None: if interval_steps != eval_interval_steps: train_average_loss = train_once.avg_loss2.avg_score() info = BytesIO() names_ = melt.adjust_names(results, names) train_average_loss_str = '' if print_avg_loss and interval_steps != eval_interval_steps: train_average_loss_str = melt.value_name_list_str( train_average_loss, names_) melt.set_global('train_loss', train_average_loss_str) train_average_loss_str = 'train_avg_loss:{} '.format( train_average_loss_str) if interval_steps != eval_interval_steps: #end = '' if eval_ops is None else '\n' #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, train_average_loss_str, end=end) logging.info2('{} eval_step: {} {}'.format( epoch_str, step, train_average_loss_str)) if eval_ops is not None: eval_feed_dict = {} if gen_eval_feed_dict is None else gen_eval_feed_dict( ) #eval_feed_dict.update(feed_dict) #------show how to perf debug ##timer_ = gezi.Timer('sess run generate') ##sess.run(eval_ops[-2], feed_dict=None) ##timer_.print() timer_ = gezi.Timer('sess run eval_ops') eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) timer_.print() if deal_eval_results is not None: #@TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') logging.info2('{} eval_step: {} eval_metrics:'.format( epoch_str, step)) eval_stop = deal_eval_results(eval_results) eval_loss = gezi.get_singles(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != eval_interval_steps: #print() pass if log_dir: #timer_ = gezi.Timer('witting log') if not hasattr(train_once, 'summary_op'): try: train_once.summary_op = tf.summary.merge_all() except Exception: train_once.summary_op = tf.merge_all_summaries() melt.print_summary_ops() try: train_once.summary_train_op = tf.summary.merge_all( key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) except Exception: train_once.summary_train_op = tf.merge_all_summaries( key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.train.SummaryWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) summary = tf.Summary() #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss if train_once.summary_train_op is not None: summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict) train_once.summary_writer.add_summary(summary_str, step) if eval_ops is None: #get train loss, for every batch train if train_once.summary_op is not None: #timer2 = gezi.Timer('sess run') summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) #timer2.print() train_once.summary_writer.add_summary(summary_str, step) else: #get eval loss for every batch eval, then add train loss for eval step average loss summary_str = sess.run( train_once.summary_op, feed_dict=eval_feed_dict ) if train_once.summary_op is not None else '' #all single value results will be add to summary here not using tf.scalar_summary.. summary.ParseFromString(summary_str) melt.add_summarys(summary, eval_results, eval_names_, suffix='eval') melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg%dsteps' % eval_interval_steps) if metric_evaluate: melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='evaluate') train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() #timer_.print() if print_time: full_duration = train_once.eval_timer.elapsed() if metric_evaluate: metric_full_duration = train_once.metric_eval_timer.elapsed() full_duration_str = 'elapsed:{:.3f} '.format(full_duration) #info.write('duration:{:.3f} '.format(timer.elapsed())) duration = timer.elapsed() info.write('duration:{:.3f} '.format(duration)) info.write(full_duration_str) info.write('eval_time_ratio:{:.3f} '.format(duration / full_duration)) if metric_evaluate: info.write('metric_time_ratio:{:.3f} '.format( duration / metric_full_duration)) #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue()) logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d' % step, info.getvalue())) return stop
def main(_): prediction_file = FLAGS.prediction_file or sys.argv[1] assert prediction_file log_dir = os.path.dirname(prediction_file) log_dir = log_dir or './' print('prediction_file', prediction_file, 'log_dir', log_dir, file=sys.stderr) logging.set_logging_path(log_dir) sess = tf.Session() summary_writer = tf.summary.FileWriter(log_dir, sess.graph) refs = prepare_refs() tokenizer = prepare_tokenization() ##TODO some problem running tokenizer.. #refs = tokenizer.tokenize(refs) min_len = 10000 min_len_image = None min_len_caption = None max_len = 0 max_len_image = None max_len_caption = None sum_len = 0 min_words = 10000 min_words_image = None min_words_caption = None max_words = 0 max_words_image = None max_words_caption = None sum_words = 0 caption_metrics_file = FLAGS.caption_metrics_file or prediction_file.replace('evaluate-inference', 'caption_metrics') imgs = [] captions = [] infos = {} for line in open(prediction_file): l = line.strip().split('\t') img, caption, all_caption, all_score = l[0], l[1], l[-2], l[-1] img = img.replace('.jpg', '') img += '.jpg' imgs.append(img) infos[img] = '%s %s' % (all_caption.replace(' ', '|'), all_score.replace(' ', '|')) caption = caption.replace(' ', '').replace('\t', '') caption_words = [x.encode('utf-8') for x in jieba.cut(caption)] caption_str = ' '.join(caption_words) captions.append([caption_str]) caption_len = len(gezi.get_single_cns(caption)) num_words = len(caption_words) if caption_len < min_len: min_len = caption_len min_len_image = img min_len_caption = caption if caption_len > max_len: max_len = caption_len max_len_image = img max_len_caption = caption sum_len += caption_len if num_words < min_words: min_words = num_words min_words_image = img min_words_caption = caption_str if num_words > max_words: max_words = num_words max_words_image = img max_words_caption = caption_str sum_words += num_words results = dict(zip(imgs, captions)) #results = tokenizer.tokenize(results) selected_results, selected_refs = translation_reorder_keys(results, refs) scorers = [ (Bleu(4), ["bleu_1", "bleu_2", "bleu_3", "bleu_4"]), (Cider(), "cider"), (Meteor(), "meteor"), (Rouge(), "rouge_l") ] score_list = [] metric_list = [] scores_list = [] print('img&predict&label:{}:{}{}{}'.format(selected_results.items()[0][0], '|'.join(selected_results.items()[0][1]), '---', '|'.join(selected_refs.items()[0][1])), file=sys.stderr) #print('avg_len:', sum_len / len(refs), 'min_len:', min_len, min_len_image, min_len_caption, 'max_len:', max_len, max_len_image, max_len_caption, file=sys.stderr) print('avg_len:', sum_len / refs_len, 'min_len:', min_len, min_len_image, min_len_caption, 'max_len:', max_len, max_len_image, max_len_caption, file=sys.stderr) print('avg_words', sum_words / refs_len, 'min_words:', min_words, min_words_image, min_words_caption, 'max_words:', max_words, max_words_image, max_words_caption, file=sys.stderr) for scorer, method in scorers: print('computing %s score...' % (scorer.method()), file=sys.stderr) score, scores = scorer.compute_score(selected_refs, selected_results) if type(method) == list: for i in range(len(score)): score_list.append(score[i]) metric_list.append(method[i]) scores_list.append(scores[i]) print(method[i], score[i], file=sys.stderr) else: score_list.append(score) metric_list.append(method) scores_list.append(scores) print(method, score, file=sys.stderr) assert(len(score_list) == 7) avg_score = np.mean(np.array(score_list[3:])) score_list.insert(0, avg_score) metric_list.insert(0, 'avg') if caption_metrics_file: out = open(caption_metrics_file, 'w') print('image_id', 'caption', 'ref', '\t'.join(metric_list), 'infos', sep='\t', file=out) for i in range(len(selected_results)): key = selected_results.keys()[i] result = selected_results[key][0] refs = '|'.join(selected_refs[key]) bleu_1 = scores_list[0][i] bleu_2 = scores_list[1][i] bleu_3 = scores_list[2][i] bleu_4 = scores_list[3][i] cider = scores_list[4][i] meteor = scores_list[5][i] rouge_l = scores_list[6][i] avg = (bleu_4 + cider + meteor + rouge_l) / 4. print(key.split('.')[0], result, refs, avg, bleu_1, bleu_2, bleu_3, bleu_4, cider, meteor, rouge_l, infos[key], sep='\t', file=out) metric_list = ['trans_' + x for x in metric_list] metric_score_str = '\t'.join('%s:[%.4f]' % (name, result) for name, result in zip(metric_list, score_list)) logging.info('%s\t%s'%(metric_score_str, os.path.basename(prediction_file))) print(key.split('.')[0], 'None', 'None', '\t'.join(map(str, score_list)), 'None', sep='\t', file=out) summary = tf.Summary() if score_list and 'ckpt' in prediction_file: try: epoch = float(os.path.basename(prediction_file).split('-')[1]) #for float epoch like 0.01 0.02 turn it to 1, 2, notice it make epoch 1 to 100 epoch = int(epoch * 100) step = int(float(os.path.basename(prediction_file).split('-')[2].split('.')[0])) prefix = 'step' if FLAGS.write_step else 'epoch' melt.add_summarys(summary, score_list, metric_list, prefix=prefix) step = epoch if not FLAGS.write_step else step summary_writer.add_summary(summary, step) summary_writer.flush() except Exception: print(traceback.format_exc(), file=sys.stderr)
def main(_): print('eval_rank:', FLAGS.eval_rank, 'eval_translation:', FLAGS.eval_translation) epoch_dir = os.path.join(FLAGS.model_dir, 'epoch') logging.set_logging_path(gezi.get_dir(epoch_dir)) log_dir = epoch_dir sess = tf.Session() summary_writer = tf.summary.FileWriter(log_dir, sess.graph) Predictor = TextPredictor image_model = None if FLAGS.image_checkpoint_file: #feature_name = None, since in show and tell predictor will use gen_features not gen_feature image_model = melt.image.ImageModel(FLAGS.image_checkpoint_file, FLAGS.image_model_name, feature_name=None) evaluator.init(image_model) visited_path = os.path.join(epoch_dir, 'visited.pkl') if not os.path.exists(visited_path): visited_checkpoints = set() else: visited_checkpoints = pickle.load(open(visited_path, 'rb')) visited_checkpoints = set([x.split('/')[-1] for x in visited_checkpoints]) while True: suffix = '.data-00000-of-00001' files = glob.glob( os.path.join(epoch_dir, 'model.ckpt*.data-00000-of-00001')) #from epoch 1, 2, .. files.sort(key=os.path.getmtime) files = [file.replace(suffix, '') for file in files] for i, file in enumerate(files): if 'best' in file: continue if FLAGS.start_epoch and i + 1 < FLAGS.start_epoch: continue file_ = file.split('/')[-1] if file_ not in visited_checkpoints: visited_checkpoints.add(file_) epoch = int(file_.split('-')[-2]) logging.info('mointor_epoch:%d from %d model files' % (epoch, len(visited_checkpoints))) #will use predict_text in eval_translation , predict in eval_rank predictor = Predictor(file, image_model=image_model, feature_name=melt.get_features_name( FLAGS.image_model_name)) summary = tf.Summary() scores, metrics = evaluator.evaluate( predictor, eval_rank=FLAGS.eval_rank, eval_translation=FLAGS.eval_translation) melt.add_summarys(summary, scores, metrics) summary_writer.add_summary(summary, epoch) summary_writer.flush() pickle.dump(visited_checkpoints, open(visited_path, 'wb')) time.sleep(5)
def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=None, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, eval_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None, # for epoch only, incase you change batch size eval_loops=1, learning_rate=None, learning_rate_patience=None, learning_rate_decay_factor=None, num_epochs=None, model_path=None, ): #is_start = False # force not to evaluate at first step #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step())) timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 if not num_epochs: epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else '' else: epoch_str = 'epoch:%.3f/%d' % ( epoch, num_epochs) if num_steps_per_epoch else '' melt.set_global('epoch', '%.2f' % (epoch)) info = IO() stop = False if eval_names is None: if names: eval_names = ['eval/' + x for x in names] if names: names = ['train/' + x for x in names] if eval_names: eval_names = ['eval/' + x for x in eval_names] is_eval_step = is_start or eval_interval_steps and step % eval_interval_steps == 0 summary_str = [] if is_eval_step: # deal with summary if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) if eval_ops is not None: #if deal_eval_results_fn is None and eval_names is not None: # deal_eval_results_fn = lambda x: melt.print_results(x, eval_names) for i in range(eval_loops): eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) # might use EVAL_NO_SUMMARY if some old code has problem TODO CHECK if not log_dir or train_once.summary_op is None or gezi.env_has( 'EVAL_NO_SUMMARY'): #if not log_dir or train_once.summary_op is None: eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) else: eval_results = sess.run(eval_ops + [train_once.summary_op], feed_dict=eval_feed_dict) summary_str = eval_results[-1] eval_results = eval_results[:-1] eval_loss = gezi.get_singles(eval_results) #timer_.print() eval_stop = False # @TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') eval_names_ = melt.adjust_names(eval_loss, eval_names) logging.info2('{} eval_step:{} eval_metrics:{}'.format( epoch_str, step, melt.parse_results(eval_loss, eval_names_))) # if deal_eval_results_fn is not None: # eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != eval_interval_steps: #print() pass metric_evaluate = False # if metric_eval_fn is not None \ # and (is_start \ # or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)): # metric_evaluate = True if metric_eval_fn is not None \ and ((is_start or metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0) or model_path): metric_evaluate = True #if (is_start or step == 0) and (not 'EVFIRST' in os.environ): if ((step == 0) and (not 'EVFIRST' in os.environ)) or ('QUICK' in os.environ) or ( 'EVFIRST' in os.environ and os.environ['EVFIRST'] == '0'): metric_evaluate = False if metric_evaluate: # TODO better if not model_path or 'model_path' not in inspect.getargspec( metric_eval_fn).args: l = metric_eval_fn() if len(l) == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: evaluate_results, evaluate_names, evaluate_summaries = l else: try: l = metric_eval_fn(model_path=model_path) if len(l) == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: evaluate_results, evaluate_names, evaluate_summaries = l except Exception: logging.info('Do nothing for metric eval fn with exception:\n', traceback.format_exc()) logging.info2('{} valid_step:{} {}:{}'.format( epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) if learning_rate is not None and (learning_rate_patience and learning_rate_patience > 0): assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1 valid_loss = evaluate_results[0] if not hasattr(train_once, 'min_valid_loss'): train_once.min_valid_loss = valid_loss train_once.deacy_steps = [] train_once.patience = 0 else: if valid_loss < train_once.min_valid_loss: train_once.min_valid_loss = valid_loss train_once.patience = 0 else: train_once.patience += 1 logging.info2('{} valid_step:{} patience:{}'.format( epoch_str, step, train_once.patience)) if learning_rate_patience and train_once.patience >= learning_rate_patience: lr_op = ops[1] lr = sess.run(lr_op) * learning_rate_decay_factor train_once.deacy_steps.append(step) logging.info2( '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}' .format(epoch_str, step, learning_rate_decay_factor, ','.join(map(str, train_once.deacy_steps)))) sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32))) train_once.patience = 0 train_once.min_valid_loss = valid_loss if ops is not None: #if deal_results_fn is None and names is not None: # deal_results_fn = lambda x: melt.print_results(x, names) feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar #print('---------------ops', ops) if eval_ops is not None or not log_dir or not hasattr( train_once, 'summary_op') or train_once.summary_op is None: results = sess.run(ops, feed_dict=feed_dict) else: #try: results = sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) summary_str = results[-1] results = results[:-1] # except Exception: # logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail') # results = sess.run(ops, feed_dict=feed_dict) #print('------------results', results) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op, results[1] to be learning_rate learning_rate = results[1] results = results[2:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() if interval_steps != eval_interval_steps: train_once.avg_loss2 = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) if interval_steps != eval_interval_steps: train_once.avg_loss2.add(loss) steps_per_second = None instances_per_second = None hours_per_epoch = None #step += 1 if is_start or interval_steps and step % interval_steps == 0: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.3f} '.format(duration) melt.set_global('duration', '%.3f' % duration) info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600 epoch_time_info = ' 1epoch:[{:.2f}h]'.format( hours_per_epoch) info.write( 'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.8f}]' .format(elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_))) info.write(' train:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) summary_strs = gezi.to_list(summary_str) if metric_evaluate: if evaluate_summaries is not None: summary_strs += evaluate_summaries if step > 1: if is_eval_step: # deal with summary if log_dir: # if not hasattr(train_once, 'summary_op'): # melt.print_summary_ops() # if summary_excls is None: # train_once.summary_op = tf.summary.merge_all() # else: # summary_ops = [] # for op in tf.get_collection(tf.GraphKeys.SUMMARIES): # for summary_excl in summary_excls: # if not summary_excl in op.name: # summary_ops.append(op) # print('filtered summary_ops:') # for op in summary_ops: # print(op) # train_once.summary_op = tf.summary.merge(summary_ops) # print('-------------summary_op', train_once.summary_op) # #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) # train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_once.summary_writer, projector_config) summary = tf.Summary() # #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset # #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss # assert train_once.summary_train_op is None # if train_once.summary_train_op is not None: # summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict) # train_once.summary_writer.add_summary(summary_str, step) if eval_ops is None: # #get train loss, for every batch train # if train_once.summary_op is not None: # #timer2 = gezi.Timer('sess run') # try: # # TODO FIXME so this means one more train batch step without adding to global step counter ?! so should move it earlier # summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) # except Exception: # if not hasattr(train_once, 'num_summary_errors'): # logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) fail') # train_once.num_summary_errors = 1 # logging.warning(traceback.format_exc()) # summary_str = '' # # #timer2.print() if train_once.summary_op is not None: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) else: # #get eval loss for every batch eval, then add train loss for eval step average loss # try: # summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) if train_once.summary_op is not None else '' # except Exception: # if not hasattr(train_once, 'num_summary_errors'): # logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail') # train_once.num_summary_errors = 1 # logging.warning(traceback.format_exc()) # summary_str = '' #all single value results will be add to summary here not using tf.scalar_summary.. #summary.ParseFromString(summary_str) for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) suffix = 'eval' if not eval_names else '' melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) if ops is not None: melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg') ##optimizer has done this also melt.add_summary(summary, learning_rate, 'learning_rate') melt.add_summary(summary, melt.batch_size(), 'batch_size') melt.add_summary(summary, melt.epoch(), 'epoch') if steps_per_second: melt.add_summary(summary, steps_per_second, 'steps_per_second') if instances_per_second: melt.add_summary(summary, instances_per_second, 'instances_per_second') if hours_per_epoch: melt.add_summary(summary, hours_per_epoch, 'hours_per_epoch') if metric_evaluate: #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') prefix = 'step/valid' if model_path: prefix = 'epoch/valid' if not hasattr(train_once, 'epoch_step'): train_once.epoch_step = 1 else: train_once.epoch_step += 1 step = train_once.epoch_step melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() #timer_.print() # if print_time: # full_duration = train_once.eval_timer.elapsed() # if metric_evaluate: # metric_full_duration = train_once.metric_eval_timer.elapsed() # full_duration_str = 'elapsed:{:.3f} '.format(full_duration) # #info.write('duration:{:.3f} '.format(timer.elapsed())) # duration = timer.elapsed() # info.write('duration:{:.3f} '.format(duration)) # info.write(full_duration_str) # info.write('eval_time_ratio:{:.3f} '.format(duration/full_duration)) # if metric_evaluate: # info.write('metric_time_ratio:{:.3f} '.format(duration/metric_full_duration)) # #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue()) # logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d'%step, info.getvalue())) return stop elif metric_evaluate: summary = tf.Summary() for summary_str in summary_strs: train_once.summary_writer.add_summary(summary_str, step) #summary.ParseFromString(evaluate_summaries) summary_writer = train_once.summary_writer prefix = 'step/valid' if model_path: prefix = 'epoch/valid' if not hasattr(train_once, 'epoch_step'): ## TODO.. restart will get 1 again.. #epoch_step = tf.Variable(0, trainable=False, name='epoch_step') #epoch_step += 1 #train_once.epoch_step = sess.run(epoch_step) valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass train_once.epoch_step = 1 if melt.epoch() <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) logging.info('train_once epoch start step is', train_once.epoch_step) else: #epoch_step += 1 train_once.epoch_step += 1 step = train_once.epoch_step #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) summary_writer.add_summary(summary, step) summary_writer.flush()