예제 #1
0
파일: train.py 프로젝트: tangqiqi123/hasky
def train_once(sess, step, input_text, text, model, optimizer):
    if not hasattr(train_once, 'train_loss'):
        train_once.train_loss = 0.

    if not hasattr(train_once, 'summary_writter'):
        log_dir = FLAGS.model_dir
        train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    summary = tf.Summary()

    pred = model(input_text, text, feed_previous=False)

    total_loss = 0.
    total_words = 0
    batch_size = len(text)
    time_steps = text.size()[1]
    for time_step in xrange(time_steps - 1):
        y_pred = pred[time_step]
        target = text[:, time_step + 1]
        loss = criterion(y_pred, target)
        total_loss += loss
        total_words += target.data.ne(0).sum()

    total_loss /= total_words
    #total_loss /= batch_size
    optimizer.zero_grad()
    #print('loss', total_loss)
    total_loss.backward()
    optimizer.step()
    #NOTICE! must be .data[0] other wise will consume more and more gpu mem, see
    #https://discuss.pytorch.org/t/cuda-memory-continuously-increases-when-net-images-called-in-every-iteration/501
    #https://discuss.pytorch.org/t/understanding-graphs-and-state/224/1
    train_once.train_loss += total_loss.data[0]

    steps = FLAGS.interval_steps
    if step % steps == 0:
        avg_loss = train_once.train_loss if step is 0 else train_once.train_loss / steps
        print('step:', step, 'train_loss:', avg_loss)
        train_once.train_loss = 0.

        names = melt.adjust_names([avg_loss], None)
        melt.add_summarys(summary, [avg_loss],
                          names,
                          suffix='train_avg%dsteps' % steps)
        train_once.summary_writer.add_summary(summary, step)
        train_once.summary_writer.flush()
예제 #2
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    valid_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
    use_horovod=False,
):
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0
    summary_str = []

    eval_str = ''
    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        # if eval ops then should have bee rank 0

        if eval_ops:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # if use horovod let each rant use same sess.run!
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY') or use_horovod:
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False
                if use_horovod:
                    sess.run(hvd.allreduce(tf.constant(0)))

                #if not use_horovod or  hvd.local_rank() == 0:
                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                #if not use_horovod or hvd.rank() == 0:
                #  logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_)))
                eval_str = 'valid:{}'.format(
                    melt.parse_results(eval_loss, eval_names_))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                if not use_horovod or hvd.rank() == 0:
                    melt.set_global('eval_loss',
                                    melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != valid_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    if 'EVFIRST' in os.environ:
        if os.environ['EVFIRST'] == '0':
            if is_start:
                metric_evaluate = False
        else:
            if is_start:
                metric_evaluate = True

    if step == 0 or 'QUICK' in os.environ:
        metric_evaluate = False

    #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank())
    if metric_evaluate:
        if use_horovod:
            print('------------metric evaluate step', step, model_path,
                  hvd.rank())
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            metric_eval_fn_ = metric_eval_fn
        else:
            metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path)

        try:
            l = metric_eval_fn_()
            if isinstance(l, tuple):
                num_returns = len(l)
                if num_returns == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok'
                    evaluate_results, evaluate_names, evaluate_summaries = l
            else:  #return dict
                evaluate_results, evaluate_names = tuple(zip(*dict.items()))
                evaluate_summaries = None
        except Exception:
            logging.info('Do nothing for metric eval fn with exception:\n',
                         traceback.format_exc())

        if not use_horovod or hvd.rank() == 0:
            #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names)))
            logging.info2('{} valid_step:{} {}:{}'.format(
                epoch_str, step, 'valid_metrics',
                melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once,
                'summary_op') or train_once.summary_op is None or use_horovod:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            ## TODO why below ?
            #try:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        #if is_start or interval_steps and step % interval_steps == 0:
        interval_ok = not use_horovod or hvd.local_rank() == 0
        if interval_steps and step % interval_steps == 0 and interval_ok:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.2f} '.format(duration)
                melt.set_global('duration', '%.2f' % duration)
                #info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))
            info.write(eval_str)
            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                summary = tf.Summary()
                if eval_ops is None:
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'valid' if not eval_names else ''
                    # loss/valid
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    try:
                        # loss/train_avg
                        melt.add_summarys(summary,
                                          train_average_loss,
                                          names_,
                                          suffix='train_avg')
                    except Exception:
                        pass
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary,
                                     melt.batch_size(),
                                     'batch_size',
                                     prefix='other')
                    melt.add_summary(summary,
                                     melt.epoch(),
                                     'epoch',
                                     prefix='other')
                    if steps_per_second:
                        melt.add_summary(summary,
                                         steps_per_second,
                                         'steps_per_second',
                                         prefix='perf')
                    if instances_per_second:
                        melt.add_summary(summary,
                                         instances_per_second,
                                         'instances_per_second',
                                         prefix='perf')
                    if hours_per_epoch:
                        melt.add_summary(summary,
                                         hours_per_epoch,
                                         'hours_per_epoch',
                                         prefix='perf')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step_eval'
                    if model_path:
                        prefix = 'eval'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step
                    # eval/loss eval/auc ..
                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()
            return stop
        elif metric_evaluate and log_dir:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step_eval'
            if model_path:
                prefix = 'eval'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()
예제 #3
0
def train_once(sess,
               step,
               ops,
               names=None,
               gen_feed_dict=None,
               deal_results=melt.print_results,
               interval_steps=100,
               eval_ops=None,
               eval_names=None,
               gen_eval_feed_dict=None,
               deal_eval_results=melt.print_results,
               eval_interval_steps=100,
               print_time=True,
               print_avg_loss=True,
               model_dir=None,
               log_dir=None,
               is_start=False,
               num_steps_per_epoch=None,
               metric_eval_function=None,
               metric_eval_interval_steps=0):

    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = step / num_steps_per_epoch if num_steps_per_epoch else -1
    epoch_str = 'epoch:%.4f' % (epoch) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.4f' % (epoch))

    info = BytesIO()
    stop = False

    if ops is not None:
        if deal_results is None and names is not None:
            deal_results = lambda x: melt.print_results(x, names)
        if deal_eval_results is None and eval_names is not None:
            deal_eval_results = lambda x: melt.print_results(x, eval_names)

        if eval_names is None:
            eval_names = names

        feed_dict = {} if gen_feed_dict is None else gen_feed_dict()

        results = sess.run(ops, feed_dict=feed_dict)

        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op
        results = results[1:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
                if interval_steps != eval_interval_steps:
                    train_once.avg_loss2 = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)
            if interval_steps != eval_interval_steps:
                train_once.avg_loss2.add(loss)

        if is_start or interval_steps and step % interval_steps == 0:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.3f} '.format(duration)
                melt.set_global('duration', '%.3f' % duration)
                info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size * num_gpus / elapsed
                if num_gpus == 1:
                    info.write(
                        'elapsed:[{:.3f}] batch_size:[{}] batches/s:[{:.2f}] insts/s:[{:.2f}] '
                        .format(elapsed, batch_size, steps_per_second,
                                instances_per_second))
                else:
                    info.write(
                        'elapsed:[{:.3f}] batch_size:[{}] gpus:[{}], batches/s:[{:.2f}] insts/s:[{:.2f}] '
                        .format(elapsed, batch_size, num_gpus,
                                steps_per_second, instances_per_second))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                info.write('train_avg_metrics:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))

            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'train_step:%d' % step,
                                            info.getvalue()))

            if deal_results is not None:
                stop = deal_results(results)

    metric_evaluate = False
    # if metric_eval_function is not None \
    #   and ( (is_start and (step or ops is None))\
    #     or (step and ((num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #            or (metric_eval_interval_steps \
    #                and step % metric_eval_interval_steps == 0)))):
    #     metric_evaluate = True
    if metric_eval_function is not None \
      and (is_start \
        or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
             or (metric_eval_interval_steps \
                 and step % metric_eval_interval_steps == 0)):
        metric_evaluate = True

    if metric_evaluate:
        evaluate_results, evaluate_names = metric_eval_function()

    if is_start or eval_interval_steps and step % eval_interval_steps == 0:
        if ops is not None:
            if interval_steps != eval_interval_steps:
                train_average_loss = train_once.avg_loss2.avg_score()

            info = BytesIO()

            names_ = melt.adjust_names(results, names)

            train_average_loss_str = ''
            if print_avg_loss and interval_steps != eval_interval_steps:
                train_average_loss_str = melt.value_name_list_str(
                    train_average_loss, names_)
                melt.set_global('train_loss', train_average_loss_str)
                train_average_loss_str = 'train_avg_loss:{} '.format(
                    train_average_loss_str)

            if interval_steps != eval_interval_steps:
                #end = '' if eval_ops is None else '\n'
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, train_average_loss_str, end=end)
                logging.info2('{} eval_step: {} {}'.format(
                    epoch_str, step, train_average_loss_str))

        if eval_ops is not None:
            eval_feed_dict = {} if gen_eval_feed_dict is None else gen_eval_feed_dict(
            )
            #eval_feed_dict.update(feed_dict)

            #------show how to perf debug
            ##timer_ = gezi.Timer('sess run generate')
            ##sess.run(eval_ops[-2], feed_dict=None)
            ##timer_.print()

            timer_ = gezi.Timer('sess run eval_ops')
            eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
            timer_.print()
            if deal_eval_results is not None:
                #@TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                logging.info2('{} eval_step: {} eval_metrics:'.format(
                    epoch_str, step))
                eval_stop = deal_eval_results(eval_results)

            eval_loss = gezi.get_singles(eval_results)
            assert len(eval_loss) > 0
            if eval_stop is True: stop = True
            eval_names_ = melt.adjust_names(eval_loss, eval_names)

            melt.set_global('eval_loss',
                            melt.parse_results(eval_loss, eval_names_))
        elif interval_steps != eval_interval_steps:
            #print()
            pass

        if log_dir:
            #timer_ = gezi.Timer('witting log')

            if not hasattr(train_once, 'summary_op'):
                try:
                    train_once.summary_op = tf.summary.merge_all()
                except Exception:
                    train_once.summary_op = tf.merge_all_summaries()

                melt.print_summary_ops()

                try:
                    train_once.summary_train_op = tf.summary.merge_all(
                        key=melt.MonitorKeys.TRAIN)
                    train_once.summary_writer = tf.summary.FileWriter(
                        log_dir, sess.graph)
                except Exception:
                    train_once.summary_train_op = tf.merge_all_summaries(
                        key=melt.MonitorKeys.TRAIN)
                    train_once.summary_writer = tf.train.SummaryWriter(
                        log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

            summary = tf.Summary()
            #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset
            #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss
            if train_once.summary_train_op is not None:
                summary_str = sess.run(train_once.summary_train_op,
                                       feed_dict=feed_dict)
                train_once.summary_writer.add_summary(summary_str, step)

            if eval_ops is None:
                #get train loss, for every batch train
                if train_once.summary_op is not None:
                    #timer2 = gezi.Timer('sess run')
                    summary_str = sess.run(train_once.summary_op,
                                           feed_dict=feed_dict)
                    #timer2.print()
                    train_once.summary_writer.add_summary(summary_str, step)
            else:
                #get eval loss for every batch eval, then add train loss for eval step average loss
                summary_str = sess.run(
                    train_once.summary_op, feed_dict=eval_feed_dict
                ) if train_once.summary_op is not None else ''
                #all single value results will be add to summary here not using tf.scalar_summary..
                summary.ParseFromString(summary_str)
                melt.add_summarys(summary,
                                  eval_results,
                                  eval_names_,
                                  suffix='eval')

            melt.add_summarys(summary,
                              train_average_loss,
                              names_,
                              suffix='train_avg%dsteps' % eval_interval_steps)

            if metric_evaluate:
                melt.add_summarys(summary,
                                  evaluate_results,
                                  evaluate_names,
                                  prefix='evaluate')

            train_once.summary_writer.add_summary(summary, step)
            train_once.summary_writer.flush()

            #timer_.print()

        if print_time:
            full_duration = train_once.eval_timer.elapsed()
            if metric_evaluate:
                metric_full_duration = train_once.metric_eval_timer.elapsed()
            full_duration_str = 'elapsed:{:.3f} '.format(full_duration)
            #info.write('duration:{:.3f} '.format(timer.elapsed()))
            duration = timer.elapsed()
            info.write('duration:{:.3f} '.format(duration))
            info.write(full_duration_str)
            info.write('eval_time_ratio:{:.3f} '.format(duration /
                                                        full_duration))
            if metric_evaluate:
                info.write('metric_time_ratio:{:.3f} '.format(
                    duration / metric_full_duration))
        #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue())
        logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d' % step,
                                        info.getvalue()))

        return stop
def main(_):
  prediction_file = FLAGS.prediction_file or sys.argv[1]

  assert prediction_file

  log_dir = os.path.dirname(prediction_file)
  log_dir = log_dir or './'
  print('prediction_file', prediction_file, 'log_dir', log_dir, file=sys.stderr)
  logging.set_logging_path(log_dir)

  sess = tf.Session()
  summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

  refs = prepare_refs()
  tokenizer = prepare_tokenization()
  ##TODO some problem running tokenizer..
  #refs = tokenizer.tokenize(refs)

  min_len = 10000
  min_len_image = None
  min_len_caption = None
  max_len = 0
  max_len_image = None
  max_len_caption = None
  sum_len = 0

  min_words = 10000
  min_words_image = None
  min_words_caption = None
  max_words = 0
  max_words_image = None
  max_words_caption = None
  sum_words = 0

  caption_metrics_file = FLAGS.caption_metrics_file or prediction_file.replace('evaluate-inference', 'caption_metrics')
  imgs = []
  captions = []
  infos = {}
  for line in open(prediction_file):
    l = line.strip().split('\t')
    img, caption, all_caption, all_score = l[0], l[1], l[-2], l[-1]
    img = img.replace('.jpg', '')
    img += '.jpg'
    imgs.append(img)

    infos[img] = '%s %s' % (all_caption.replace(' ', '|'), all_score.replace(' ', '|'))

    caption = caption.replace(' ', '').replace('\t', '')
    caption_words = [x.encode('utf-8') for x in jieba.cut(caption)]
    caption_str = ' '.join(caption_words)
    captions.append([caption_str])

    caption_len = len(gezi.get_single_cns(caption))
    num_words = len(caption_words)

    if caption_len < min_len:
      min_len = caption_len
      min_len_image = img 
      min_len_caption = caption
    if caption_len > max_len:
      max_len = caption_len
      max_len_image = img 
      max_len_caption = caption
    sum_len += caption_len

    if num_words < min_words:
      min_words = num_words
      min_words_image = img
      min_words_caption = caption_str
    if num_words > max_words:
      max_words = num_words
      max_words_image = img
      max_words_caption = caption_str
    sum_words += num_words

  results = dict(zip(imgs, captions))
  
  #results = tokenizer.tokenize(results)

  selected_results, selected_refs = translation_reorder_keys(results, refs)

  scorers = [
            (Bleu(4), ["bleu_1", "bleu_2", "bleu_3", "bleu_4"]),
            (Cider(), "cider"),
            (Meteor(), "meteor"),
            (Rouge(), "rouge_l")
        ]

  score_list = []
  metric_list = []
  scores_list = []

  print('img&predict&label:{}:{}{}{}'.format(selected_results.items()[0][0], '|'.join(selected_results.items()[0][1]), '---', '|'.join(selected_refs.items()[0][1])), file=sys.stderr)
  #print('avg_len:', sum_len / len(refs), 'min_len:', min_len, min_len_image, min_len_caption, 'max_len:', max_len, max_len_image, max_len_caption, file=sys.stderr)
  print('avg_len:', sum_len / refs_len, 'min_len:', min_len, min_len_image, min_len_caption, 'max_len:', max_len, max_len_image, max_len_caption, file=sys.stderr)
  print('avg_words', sum_words / refs_len, 'min_words:', min_words, min_words_image, min_words_caption, 'max_words:', max_words, max_words_image, max_words_caption, file=sys.stderr)
  
  for scorer, method in scorers:
    print('computing %s score...' % (scorer.method()), file=sys.stderr)
    score, scores = scorer.compute_score(selected_refs, selected_results)
    if type(method) == list:
      for i in range(len(score)):
        score_list.append(score[i])
        metric_list.append(method[i])
        scores_list.append(scores[i])
        print(method[i], score[i], file=sys.stderr)
    else:
      score_list.append(score)
      metric_list.append(method)
      scores_list.append(scores)
      print(method, score, file=sys.stderr)

  assert(len(score_list) == 7)

  avg_score = np.mean(np.array(score_list[3:]))
  score_list.insert(0, avg_score)
  metric_list.insert(0, 'avg')

  if caption_metrics_file:
    out = open(caption_metrics_file, 'w')
    print('image_id', 'caption', 'ref', '\t'.join(metric_list), 'infos', sep='\t', file=out)
    for i in range(len(selected_results)):
      key = selected_results.keys()[i] 
      result = selected_results[key][0]
      refs = '|'.join(selected_refs[key])
      bleu_1 = scores_list[0][i]
      bleu_2 = scores_list[1][i]
      bleu_3 = scores_list[2][i]
      bleu_4 = scores_list[3][i]
      cider = scores_list[4][i]
      meteor = scores_list[5][i]
      rouge_l = scores_list[6][i]
      avg = (bleu_4 + cider + meteor + rouge_l) / 4.
      print(key.split('.')[0], result, refs, avg, bleu_1, bleu_2, bleu_3, bleu_4, cider, meteor, rouge_l, infos[key], sep='\t', file=out)

  metric_list = ['trans_' + x for x in metric_list]
  metric_score_str = '\t'.join('%s:[%.4f]' % (name, result) for name, result in zip(metric_list, score_list))
  logging.info('%s\t%s'%(metric_score_str, os.path.basename(prediction_file)))

  print(key.split('.')[0], 'None', 'None', '\t'.join(map(str, score_list)), 'None', sep='\t', file=out)

  summary = tf.Summary()
  if score_list and 'ckpt' in prediction_file:
    try:
      epoch = float(os.path.basename(prediction_file).split('-')[1])
      #for float epoch like 0.01 0.02 turn it to 1, 2, notice it make epoch 1 to 100 
      epoch = int(epoch * 100)
      step = int(float(os.path.basename(prediction_file).split('-')[2].split('.')[0]))
      prefix = 'step' if FLAGS.write_step else 'epoch'
      melt.add_summarys(summary, score_list, metric_list, prefix=prefix)
      step = epoch if not FLAGS.write_step else step
      summary_writer.add_summary(summary, step)
      summary_writer.flush()
    except Exception:
      print(traceback.format_exc(), file=sys.stderr)
예제 #5
0
def main(_):
    print('eval_rank:', FLAGS.eval_rank, 'eval_translation:',
          FLAGS.eval_translation)
    epoch_dir = os.path.join(FLAGS.model_dir, 'epoch')

    logging.set_logging_path(gezi.get_dir(epoch_dir))

    log_dir = epoch_dir
    sess = tf.Session()
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    Predictor = TextPredictor

    image_model = None
    if FLAGS.image_checkpoint_file:
        #feature_name = None, since in show and tell predictor will use gen_features not gen_feature
        image_model = melt.image.ImageModel(FLAGS.image_checkpoint_file,
                                            FLAGS.image_model_name,
                                            feature_name=None)

    evaluator.init(image_model)

    visited_path = os.path.join(epoch_dir, 'visited.pkl')
    if not os.path.exists(visited_path):
        visited_checkpoints = set()
    else:
        visited_checkpoints = pickle.load(open(visited_path, 'rb'))

    visited_checkpoints = set([x.split('/')[-1] for x in visited_checkpoints])

    while True:
        suffix = '.data-00000-of-00001'
        files = glob.glob(
            os.path.join(epoch_dir, 'model.ckpt*.data-00000-of-00001'))
        #from epoch 1, 2, ..
        files.sort(key=os.path.getmtime)
        files = [file.replace(suffix, '') for file in files]
        for i, file in enumerate(files):
            if 'best' in file:
                continue
            if FLAGS.start_epoch and i + 1 < FLAGS.start_epoch:
                continue
            file_ = file.split('/')[-1]
            if file_ not in visited_checkpoints:
                visited_checkpoints.add(file_)
                epoch = int(file_.split('-')[-2])
                logging.info('mointor_epoch:%d from %d model files' %
                             (epoch, len(visited_checkpoints)))
                #will use predict_text in eval_translation , predict in eval_rank
                predictor = Predictor(file,
                                      image_model=image_model,
                                      feature_name=melt.get_features_name(
                                          FLAGS.image_model_name))
                summary = tf.Summary()
                scores, metrics = evaluator.evaluate(
                    predictor,
                    eval_rank=FLAGS.eval_rank,
                    eval_translation=FLAGS.eval_translation)
                melt.add_summarys(summary, scores, metrics)
                summary_writer.add_summary(summary, epoch)
                summary_writer.flush()
                pickle.dump(visited_checkpoints, open(visited_path, 'wb'))
        time.sleep(5)
예제 #6
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    eval_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
):

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or eval_interval_steps and step % eval_interval_steps == 0
    summary_str = []

    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        if eval_ops is not None:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # might use EVAL_NO_SUMMARY if some old code has problem TODO CHECK
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY'):
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False

                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                logging.info2('{} eval_step:{} eval_metrics:{}'.format(
                    epoch_str, step,
                    melt.parse_results(eval_loss, eval_names_)))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                melt.set_global('eval_loss',
                                melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != eval_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    #if (is_start or step == 0) and (not 'EVFIRST' in os.environ):
    if ((step == 0) and
        (not 'EVFIRST' in os.environ)) or ('QUICK' in os.environ) or (
            'EVFIRST' in os.environ and os.environ['EVFIRST'] == '0'):
        metric_evaluate = False

    if metric_evaluate:
        # TODO better
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            l = metric_eval_fn()
            if len(l) == 2:
                evaluate_results, evaluate_names = l
                evaluate_summaries = None
            else:
                evaluate_results, evaluate_names, evaluate_summaries = l
        else:
            try:
                l = metric_eval_fn(model_path=model_path)
                if len(l) == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    evaluate_results, evaluate_names, evaluate_summaries = l
            except Exception:
                logging.info('Do nothing for metric eval fn with exception:\n',
                             traceback.format_exc())

        logging.info2('{} valid_step:{} {}:{}'.format(
            epoch_str, step,
            'valid_metrics' if model_path is None else 'epoch_valid_metrics',
            melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once, 'summary_op') or train_once.summary_op is None:
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            #try:
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
                if interval_steps != eval_interval_steps:
                    train_once.avg_loss2 = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)
            if interval_steps != eval_interval_steps:
                train_once.avg_loss2.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        if is_start or interval_steps and step % interval_steps == 0:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.3f} '.format(duration)
                melt.set_global('duration', '%.3f' % duration)
                info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = ' 1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.8f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))

            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                # if not hasattr(train_once, 'summary_op'):
                #   melt.print_summary_ops()
                #   if summary_excls is None:
                #     train_once.summary_op = tf.summary.merge_all()
                #   else:
                #     summary_ops = []
                #     for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                #       for summary_excl in summary_excls:
                #         if not summary_excl in op.name:
                #           summary_ops.append(op)
                #     print('filtered summary_ops:')
                #     for op in summary_ops:
                #       print(op)
                #     train_once.summary_op = tf.summary.merge(summary_ops)

                #   print('-------------summary_op', train_once.summary_op)

                #   #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                #   train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

                #   tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_once.summary_writer, projector_config)

                summary = tf.Summary()
                # #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset
                # #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss
                # assert train_once.summary_train_op is None
                # if train_once.summary_train_op is not None:
                #   summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict)
                #   train_once.summary_writer.add_summary(summary_str, step)

                if eval_ops is None:
                    # #get train loss, for every batch train
                    # if train_once.summary_op is not None:
                    #   #timer2 = gezi.Timer('sess run')
                    #   try:
                    #     # TODO FIXME so this means one more train batch step without adding to global step counter ?! so should move it earlier
                    #     summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict)
                    #   except Exception:
                    #     if not hasattr(train_once, 'num_summary_errors'):
                    #       logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) fail')
                    #       train_once.num_summary_errors = 1
                    #       logging.warning(traceback.format_exc())
                    #     summary_str = ''
                    #   # #timer2.print()
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    # #get eval loss for every batch eval, then add train loss for eval step average loss
                    # try:
                    #   summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) if train_once.summary_op is not None else ''
                    # except Exception:
                    #   if not hasattr(train_once, 'num_summary_errors'):
                    #     logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail')
                    #     train_once.num_summary_errors = 1
                    #     logging.warning(traceback.format_exc())
                    #   summary_str = ''
                    #all single value results will be add to summary here not using tf.scalar_summary..
                    #summary.ParseFromString(summary_str)
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'eval' if not eval_names else ''
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    melt.add_summarys(summary,
                                      train_average_loss,
                                      names_,
                                      suffix='train_avg')
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary, melt.batch_size(), 'batch_size')
                    melt.add_summary(summary, melt.epoch(), 'epoch')
                    if steps_per_second:
                        melt.add_summary(summary, steps_per_second,
                                         'steps_per_second')
                    if instances_per_second:
                        melt.add_summary(summary, instances_per_second,
                                         'instances_per_second')
                    if hours_per_epoch:
                        melt.add_summary(summary, hours_per_epoch,
                                         'hours_per_epoch')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step/valid'
                    if model_path:
                        prefix = 'epoch/valid'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step

                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()

                #timer_.print()

            # if print_time:
            #   full_duration = train_once.eval_timer.elapsed()
            #   if metric_evaluate:
            #     metric_full_duration = train_once.metric_eval_timer.elapsed()
            #   full_duration_str = 'elapsed:{:.3f} '.format(full_duration)
            #   #info.write('duration:{:.3f} '.format(timer.elapsed()))
            #   duration = timer.elapsed()
            #   info.write('duration:{:.3f} '.format(duration))
            #   info.write(full_duration_str)
            #   info.write('eval_time_ratio:{:.3f} '.format(duration/full_duration))
            #   if metric_evaluate:
            #     info.write('metric_time_ratio:{:.3f} '.format(duration/metric_full_duration))
            # #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue())
            # logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d'%step, info.getvalue()))
            return stop
        elif metric_evaluate:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step/valid'
            if model_path:
                prefix = 'epoch/valid'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()