コード例 #1
0
ファイル: train.py プロジェクト: tangqiqi123/hasky
def train_flow(ops,
               names=None,
               gen_feed_dict_fn=None,
               deal_results_fn=melt.print_results,
               eval_ops=None,
               eval_names=None,
               gen_eval_feed_dict_fn=None,
               deal_eval_results_fn=melt.print_results,
               optimizer=None,
               learning_rate=0.1,
               num_steps_per_epoch=None,
               model_dir=None,
               metric_eval_fn=None,
               debug=False,
               summary_excls=None,
               init_fn=None,
               sess=None):

    if sess is None:
        sess = melt.get_session()
    if debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

    logging.info('learning_rate:{}'.format(FLAGS.learning_rate))
    #batch size right now not define here, but in app code like input_app.py
    melt.set_global('batch_size', FLAGS.batch_size)
    melt.set_global('num_gpus', max(FLAGS.num_gpus, 1))

    #NOTICE since melt.__init__.py with from melt.flow import * then you can not
    #use melt.flow.train.train_flow but you can always use
    #from melt.flow.train.train_flow import train_flow

    if optimizer is None:
        optimizer = FLAGS.optimizer
    # Set up the training ops.
    #notice '' only works in tf >= 0.11, for 0.10 will always add OptimeizeLoss scope
    #the diff is 0.10 use variable_op_scope and 0.11 use variable_scope
    optimize_scope = None if FLAGS.optimize_has_scope else ''
    #or judge by FLAGS.num_gpus
    if not isinstance(ops[0], (list, tuple)):
        learning_rate, learning_rate_decay_fn = gen_learning_rate()
        train_op = tf.contrib.layers.optimize_loss(
            loss=ops[0],
            global_step=None,
            learning_rate=learning_rate,
            optimizer=melt.util.get_optimizer(optimizer),
            clip_gradients=FLAGS.clip_gradients,
            learning_rate_decay_fn=learning_rate_decay_fn,
            name=optimize_scope)
    else:
        #---as in cifa10 example, put all but tower loss on cpu, wiki say, that will be faster,
        #but here I find without setting to cpu will be faster..
        #https://github.com/tensorflow/tensorflow/issues/4881
        #I've noticed same thing on cirrascale GPU machines - putting parameters on gpu:0 and using gpu->gpu transfer was a bit faster. I suppose this depends on particular details of hardware -- if you don't have p2p connectivity between your video cards then keeping parameters on CPU:0 gives faster training.
        #err but for my pc no p2p, with PHB connection nvidia-smi topo -m, still hurt by set cpu.. may be should not put cpu here
        #with tf.device('/cpu:0'):
        learning_rate, learning_rate_decay_fn = gen_learning_rate()
        train_op = melt.layers.optimize_loss(
            losses=ops[0],
            num_gpus=FLAGS.num_gpus,
            global_step=None,
            learning_rate=learning_rate,
            optimizer=melt.util.get_optimizer(optimizer),
            clip_gradients=FLAGS.clip_gradients,
            learning_rate_decay_fn=learning_rate_decay_fn,
            name=optimize_scope)
        #set the last tower loss as loss in ops
        ops[0] = ops[0][-1]

    ops.insert(0, train_op)

    #-----------post deal
    save_interval_seconds = FLAGS.save_interval_seconds if FLAGS.save_interval_seconds > 0 \
       else FLAGS.save_interval_hours * 3600

    interval_steps = FLAGS.interval_steps
    eval_interval_steps = FLAGS.eval_interval_steps
    metric_eval_interval_steps = FLAGS.metric_eval_interval_steps
    save_model = FLAGS.save_model
    save_interval_steps = FLAGS.save_interval_steps
    if not save_interval_steps:
        save_interval_steps = 1000000000000

    if FLAGS.work_mode == 'train':
        eval_ops = None
        metric_eval_fn = None
        logging.info('running train only mode')
    elif FLAGS.work_mode == 'train_metric':
        eval_ops = None
        assert metric_eval_fn is not None, 'set metric_eval to 1'
        logging.info('running train+metric mode')
    elif FLAGS.work_mode == 'train_valid':
        metric_eval_fn = None
        logging.info('running train+valid mode')
    elif FLAGS.work_mode == 'test':
        ops = None
        logging.info('running test only mode')
        interval_steps = 0
        eval_interval_steps = 1
        metric_eval_interval_steps /= FLAGS.eval_interval_steps
        save_model = False

    return melt.flow.train_flow(
        ops,
        names=names,
        gen_feed_dict_fn=gen_feed_dict_fn,
        deal_results_fn=deal_results_fn,
        eval_ops=eval_ops,
        eval_names=eval_names,
        gen_eval_feed_dict_fn=gen_eval_feed_dict_fn,
        deal_eval_results_fn=deal_eval_results_fn,
        interval_steps=interval_steps,
        eval_interval_steps=eval_interval_steps,
        num_epochs=FLAGS.num_epochs,
        num_steps=FLAGS.num_steps,
        save_interval_seconds=save_interval_seconds,
        save_interval_steps=save_interval_steps,
        save_model=save_model,
        save_interval_epochs=FLAGS.save_interval_epochs,
        #optimizer=optimizer,
        optimizer=
        None,  #must set None since here we have done choosing optimizer
        learning_rate=learning_rate,
        num_steps_per_epoch=num_steps_per_epoch,
        max_models_keep=FLAGS.max_models_keep,
        model_dir=model_dir,
        restore_from_latest=FLAGS.restore_from_latest,
        metric_eval_fn=metric_eval_fn,
        metric_eval_interval_steps=metric_eval_interval_steps,
        no_log=FLAGS.no_log,
        summary_excls=summary_excls,
        init_fn=init_fn,
        sess=sess)
コード例 #2
0
ファイル: train.py プロジェクト: moonlight1776/competitions
def train(Dataset,
          model,
          loss_fn,
          evaluate_fn=None,
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          optimizer=None,
          param_groups=None,
          init_fn=None,
          dataset=None,
          valid_dataset=None,
          test_dataset=None,
          sep=','):
    if Dataset is None:
        assert dataset
    if FLAGS.torch:
        # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        model.to(device)

    input_ = FLAGS.train_input
    inputs = gezi.list_files(input_)
    inputs.sort()

    all_inputs = inputs

    #batch_size = FLAGS.batch_size
    batch_size = melt.batch_size()

    num_gpus = melt.num_gpus()

    #batch_size = max(batch_size, 1)
    #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
    batch_size_ = batch_size

    if FLAGS.fold is not None:
        inputs = [
            x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)
            and not x.endswith('%d.tfrecord' % FLAGS.fold)
        ]
        # if FLAGS.valid_input:
        #   inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)]
    logging.info('inputs', len(inputs), inputs[:100])
    num_folds = FLAGS.num_folds or len(inputs) + 1

    train_dataset_ = dataset or Dataset('train')
    train_dataset = train_dataset_.make_batch(batch_size, inputs)
    num_examples = train_dataset_.num_examples_per_epoch('train')
    num_all_examples = num_examples

    valid_inputs = None
    if FLAGS.valid_input:
        valid_inputs = gezi.list_files(FLAGS.valid_input)
    else:
        if FLAGS.fold is not None:
            #valid_inputs = [x for x in all_inputs if x not in inputs]
            if not FLAGS.test_aug:
                valid_inputs = [
                    x for x in all_inputs if not 'aug' in x and x not in inputs
                ]
            else:
                valid_inputs = [
                    x for x in all_inputs if 'aug' in x and x not in inputs
                ]

    logging.info('valid_inputs', valid_inputs)

    if valid_inputs:
        valid_dataset_ = valid_dataset or Dataset('valid')
        valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs)
        valid_dataset2 = valid_dataset_.make_batch(batch_size_,
                                                   valid_inputs,
                                                   repeat=True)
    else:
        valid_datsset = None
        valid_dataset2 = None

    if num_examples:
        if FLAGS.fold is not None:
            num_examples = int(num_examples * (num_folds - 1) / num_folds)
        num_steps_per_epoch = -(-num_examples // batch_size)
    else:
        num_steps_per_epoch = None
    logging.info('num_train_examples:', num_examples)

    num_valid_examples = None
    if FLAGS.valid_input:
        num_valid_examples = valid_dataset_.num_examples_per_epoch('valid')
        num_valid_steps_per_epoch = -(
            -num_valid_examples // batch_size_) if num_valid_examples else None
    else:
        if FLAGS.fold is not None:
            if num_examples:
                num_valid_examples = int(num_all_examples * (1 / num_folds))
                num_valid_steps_per_epoch = -(-num_valid_examples //
                                              batch_size_)
            else:
                num_valid_steps_per_epoch = None
    logging.info('num_valid_examples:', num_valid_examples)

    if FLAGS.test_input:
        test_inputs = gezi.list_files(FLAGS.test_input)
        #test_inputs = [x for x in test_inputs if not 'aug' in x]
        logging.info('test_inputs', test_inputs)
    else:
        test_inputs = None

    num_test_examples = None
    if test_inputs:
        test_dataset_ = test_dataset or Dataset('test')
        test_dataset = test_dataset_.make_batch(batch_size_, test_inputs)
        num_test_examples = test_dataset_.num_examples_per_epoch('test')
        num_test_steps_per_epoch = -(
            -num_test_examples // batch_size_) if num_test_examples else None
    else:
        test_dataset = None
    logging.info('num_test_examples:', num_test_examples)

    summary = tf.contrib.summary
    # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch')
    # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train')
    # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid')
    writer = summary.create_file_writer(FLAGS.log_dir)
    writer_train = summary.create_file_writer(FLAGS.log_dir)
    writer_valid = summary.create_file_writer(FLAGS.log_dir)
    global_step = tf.train.get_or_create_global_step()

    learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")

    tf.add_to_collection('learning_rate', learning_rate)

    learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
    try:
        learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
    except Exception:
        learning_rate_weights = None

    # ckpt dir save models one per epoch
    ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt')
    os.system('mkdir -p %s' % ckpt_dir)
    # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset
    ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2')
    os.system('mkdir -p %s' % ckpt_dir2)

    #TODO FIXME now I just changed tf code so to not by default save only latest 5
    # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
    #     checkpoint, directory=ckpt_dir, max_to_keep=5)
    # latest_checkpoint = manager.latest_checkpoint

    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
    if latest_checkpoint:
        logging.info('Latest checkpoint:', latest_checkpoint)
    else:
        latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2)
        logging.info('Latest checkpoint:', latest_checkpoint)

    if os.path.exists(FLAGS.model_dir + '.index'):
        latest_checkpoint = FLAGS.model_dir

    if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode:
        #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
        latest_checkpoint = FLAGS.model_dir
        #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint)

    checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')
    checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt')

    if not FLAGS.torch:
        try:
            optimizer = optimizer or melt.get_optimizer(
                FLAGS.optimizer)(learning_rate)
        except Exception:
            logging.warning(
                f'Fail to using {FLAGS.optimizer} use adam instead')
            optimizer = melt.get_optimizer('adam')(learning_rate)

        # TODO...
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                model=model,
                optimizer=optimizer,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                model=model,
                optimizer=optimizer,
                global_step=global_step)

        checkpoint.restore(latest_checkpoint)
        checkpoint2 = copy.deepcopy(checkpoint)

        start_epoch = int(
            latest_checkpoint.split('-')
            [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0
    else:
        # TODO torch with learning rate adjust
        if optimizer is None:
            import lele
            is_dynamic_opt = True
            if FLAGS.optimizer == 'noam':
                optimizer = lele.training.optimizers.NoamOpt(
                    128, 2, 4000, torch.optim.Adamax(model.parameters(), lr=0))
            elif FLAGS.optimizer == 'bert':
                num_train_steps = int(
                    num_steps_per_epoch *
                    (FLAGS.num_decay_epochs or FLAGS.num_epochs))
                num_warmup_steps = FLAGS.warmup_steps or int(
                    num_train_steps * FLAGS.warmup_proportion)
                logging.info('num_train_steps', num_train_steps,
                             'num_warmup_steps', num_warmup_steps,
                             'warmup_proportion', FLAGS.warmup_proportion)
                optimizer = lele.training.optimizers.BertOpt(
                    FLAGS.learning_rate, FLAGS.min_learning_rate,
                    num_train_steps, num_warmup_steps,
                    torch.optim.Adamax(model.parameters(), lr=0))
            else:
                is_dynamic_opt = False
                optimizer = torch.optim.Adamax(
                    param_groups if param_groups else model.parameters(),
                    lr=FLAGS.learning_rate)

        start_epoch = 0
        latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(
            FLAGS.model_dir, 'latest.pyt')
        if not os.path.exists(latest_path):
            latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt')
        if os.path.exists(latest_path):
            logging.info('loading torch model from', latest_path)
            checkpoint = torch.load(latest_path)
            if not FLAGS.torch_finetune:
                start_epoch = checkpoint['epoch']
                step = checkpoint['step']
                global_step.assign(step + 1)
            load_torch_model(model, latest_path)
            if FLAGS.torch_load_optimizer:
                optimizer.load_state_dict(checkpoint['optimizer'])

        # TODO by this way restart can not change learning rate..
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                global_step=global_step)

        try:
            checkpoint.restore(latest_checkpoint)
            checkpoint2 = copy.deepcopy(checkpoint)
        except Exception:
            pass

    if FLAGS.torch and is_dynamic_opt:
        optimizer._step = global_step.numpy()

    #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
    #model.save('./weight3.hd5')
    logging.info('optimizer:', optimizer)

    if FLAGS.torch_lr:
        learning_rate.assign(optimizer.rate(1))
    if FLAGS.torch:
        learning_rate.assign(optimizer.param_groups[0]['lr'])
        logging.info('learning rate got from pytorch latest.py as',
                     learning_rate)

    learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor)
    if learning_rate_weights is not None:
        learning_rate_weights.assign(learning_rate_weights *
                                     FLAGS.learning_rate_start_factor)

    # TODO currently not support 0.1 epoch.. like this
    num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024

    will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ
    if start_epoch == 0 and not 'EVFIRST' in os.environ and will_valid:
        will_valid = False

    if start_epoch > 0 and will_valid:
        will_valid = True

    if will_valid:
        logging.info('----------valid')
        if FLAGS.torch:
            model.eval()
        names = None
        if evaluate_fn is not None:
            vals, names = evaluate_fn(model, valid_dataset,
                                      tf.train.latest_checkpoint(ckpt_dir),
                                      num_valid_steps_per_epoch)
        elif eval_fn:
            model_path = None if not write_valid else latest_checkpoint
            names = valid_names if valid_names is not None else [
                infer_names[0]
            ] + [x + '_y' for x in infer_names[1:]
                 ] + infer_names[1:] if infer_names else None

            logging.info('model_path:', model_path, 'model_dir:',
                         FLAGS.model_dir)
            vals, names = evaluate(model,
                                   valid_dataset,
                                   eval_fn,
                                   model_path,
                                   names,
                                   valid_write_fn,
                                   write_streaming,
                                   num_valid_steps_per_epoch,
                                   suffix=valid_suffix,
                                   sep=sep)
        if names:
            logging.info2(
                'epoch:%d/%d' % (start_epoch, num_epochs),
                ['%s:%.5f' % (name, val) for name, val in zip(names, vals)])

    if FLAGS.work_mode == 'valid':
        exit(0)

    if 'test' in FLAGS.work_mode:
        logging.info('--------test/inference')
        if test_dataset:
            if FLAGS.torch:
                model.eval()
            if inference_fn is None:
                # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint
                # logging.info('model_path', model_path)
                assert latest_checkpoint
                inference(model,
                          test_dataset,
                          latest_checkpoint,
                          infer_names,
                          infer_debug_names,
                          infer_write_fn,
                          write_streaming,
                          num_test_steps_per_epoch,
                          suffix=infer_suffix)
            else:
                inference_fn(model, test_dataset,
                             tf.train.latest_checkpoint(ckpt_dir),
                             num_test_steps_per_epoch)
        exit(0)

    if 'SHOW' in os.environ:
        num_epochs = start_epoch + 1

    class PytObj(object):
        def __init__(self, x):
            self.x = x

        def numpy(self):
            return self.x

    class PytMean(object):
        def __init__(self):
            self._val = 0.
            self.count = 0

            self.is_call = True

        def clear(self):
            self._val = 0
            self.count = 0

        def __call__(self, val):
            if not self.is_call:
                self.clear()
                self.is_call = True
            self._val += val.item()
            self.count += 1

        def result(self):
            if self.is_call:
                self.is_call = False
            if not self.count:
                val = 0
            else:
                val = self._val / self.count
            # TODO just for compact with tf ..
            return PytObj(val)

    Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean
    timer = gezi.Timer()
    num_insts = 0

    if FLAGS.learning_rate_decay_factor > 0:
        #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?'
        #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay
        #since global step / decay_steps will not be correct epoch as num_steps per epoch changed
        #so if if you change batch set you have to reset global step as fixed step
        assert FLAGS.num_steps_per_decay or (
            FLAGS.num_epochs_per_decay and num_steps_per_epoch
        ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch'
        decay_steps = FLAGS.num_steps_per_decay or int(
            num_steps_per_epoch * FLAGS.num_epochs_per_decay)
        decay_start_step = FLAGS.decay_start_step or int(
            num_steps_per_epoch * FLAGS.decay_start_epoch)
        # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
        logging.info(
            'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'
            .format(FLAGS.learning_rate_decay_factor,
                    FLAGS.num_epochs_per_decay, decay_steps,
                    FLAGS.decay_start_epoch, decay_start_step))

    for epoch in range(start_epoch, num_epochs):
        melt.set_global('epoch', '%.4f' % (epoch))

        if FLAGS.torch:
            model.train()

        epoch_loss_avg = Mean()
        epoch_valid_loss_avg = Mean()

        #for i, (x, y) in tqdm(enumerate(train_dataset), total=num_steps_per_epoch, ascii=True):
        for i, (x, y) in enumerate(train_dataset):
            if FLAGS.torch:
                x, y = to_torch(x, y)
                if is_dynamic_opt:
                    learning_rate.assign(optimizer.rate())

            #print(x, y)

            if not FLAGS.torch:
                loss, grads = melt.eager.grad(model, x, y, loss_fn)
                grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients)
                optimizer.apply_gradients(zip(grads, model.variables))
            else:
                optimizer.zero_grad()
                if 'training' in inspect.getargspec(loss_fn).args:
                    loss = loss_fn(model, x, y, training=True)
                else:
                    loss = loss_fn(model, x, y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               FLAGS.clip_gradients)
                optimizer.step()

            global_step.assign_add(1)

            epoch_loss_avg(loss)  # add current batch loss

            if FLAGS.torch:
                del loss

            batch_size_ = list(
                x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type(
                    {}) else x.shape[FLAGS.batch_size_dim]
            num_insts += int(batch_size_)
            if global_step.numpy() % FLAGS.interval_steps == 0:
                #checkpoint.save(checkpoint_prefix)
                elapsed = timer.elapsed()
                steps_per_second = FLAGS.interval_steps / elapsed
                instances_per_second = num_insts / elapsed
                num_insts = 0

                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)

                if valid_dataset2:
                    try:
                        x, y = next(iter(valid_dataset2))
                    except Exception:
                        # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset
                        x, y = next(iter(valid_dataset2))

                    if FLAGS.torch:
                        x, y = to_torch(x, y)
                        model.eval()
                    valid_loss = loss_fn(model, x, y)
                    epoch_valid_loss_avg(valid_loss)
                    if FLAGS.torch:
                        model.train()

                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' %
                        elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' %
                        num_gpus, 'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second, '%s' %
                        epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(),
                        'valid_loss:[%.4f]' %
                        epoch_valid_loss_avg.result().numpy())
                    if global_step.numpy() % FLAGS.eval_interval_steps == 0:
                        with writer_valid.as_default(
                        ), summary.always_record_summaries():
                            #summary.scalar('step/loss', epoch_valid_loss_avg.result().numpy())
                            summary.scalar(
                                'loss/eval',
                                epoch_valid_loss_avg.result().numpy())
                            writer_valid.flush()
                else:
                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' %
                        elapsed, 'batch_size:[%d]' % batch_size_,
                        'gpus:[%d]' % num_gpus,
                        'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second,
                        '%s' % epoch_time_info,
                        'lr:[%.8f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % epoch_loss_avg.result().numpy())

                if global_step.numpy() % FLAGS.eval_interval_steps == 0:
                    with writer_train.as_default(
                    ), summary.always_record_summaries():
                        #summary.scalar('step/loss', epoch_loss_avg.result().numpy())
                        summary.scalar('loss/train_avg',
                                       epoch_loss_avg.result().numpy())
                        summary.scalar('learning_rate', learning_rate.numpy())
                        summary.scalar('batch_size', batch_size_)
                        summary.scalar('epoch', melt.epoch())
                        summary.scalar('steps_per_second', steps_per_second)
                        summary.scalar('instances_per_second',
                                       instances_per_second)
                        writer_train.flush()

                    if FLAGS.log_dir != FLAGS.model_dir:
                        assert FLAGS.log_dir
                        command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir,
                                                              FLAGS.model_dir)
                        print(command, file=sys.stderr)
                        os.system(command)

            if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy(
            ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0:
                if FLAGS.torch:
                    model.eval()
                vals, names = None, None
                if evaluate_fn is not None:
                    vals, names = evaluate_fn(model, valid_dataset, None,
                                              num_valid_steps_per_epoch)
                elif eval_fn:
                    names = valid_names if valid_names is not None else [
                        infer_names[0]
                    ] + [x + '_y' for x in infer_names[1:]
                         ] + infer_names[1:] if infer_names else None
                    vals, names = evaluate(model,
                                           valid_dataset,
                                           eval_fn,
                                           None,
                                           names,
                                           valid_write_fn,
                                           write_streaming,
                                           num_valid_steps_per_epoch,
                                           sep=sep)
                if vals and names:
                    with writer_valid.as_default(
                    ), summary.always_record_summaries():
                        for name, val in zip(names, vals):
                            summary.scalar(f'step/valid/{name}', val)
                        writer_valid.flush()

                if FLAGS.torch:
                    if not FLAGS.torch_lr:
                        # control learning rate by tensorflow learning rate
                        for param_group in optimizer.param_groups:
                            # important learning rate decay
                            param_group['lr'] = learning_rate.numpy()

                    model.train()

                if names and vals:
                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'valid_step:%d' % global_step.numpy(), 'valid_metrics',
                        [
                            '%s:%.5f' % (name, val)
                            for name, val in zip(names, vals)
                        ])

            # if i == 5:
            #   print(i, '---------------------save')
            #   print(len(model.trainable_variables))
            ## TODO FIXME seems save weighs value not ok... not the same as checkpoint save
            #   model.save_weights(os.path.join(ckpt_dir, 'weights'))
            #   checkpoint.save(checkpoint_prefix)
            #   exit(0)

            if global_step.numpy() % FLAGS.save_interval_steps == 0:
                if FLAGS.torch:
                    state = {
                        'epoch':
                        epoch,
                        'step':
                        global_step.numpy(),
                        'state_dict':
                        model.state_dict() if not hasattr(model, 'module') else
                        model.module.state_dict(),
                        'optimizer':
                        optimizer.state_dict(),
                    }
                    torch.save(state,
                               os.path.join(FLAGS.model_dir, 'latest.pyt'))

            # TODO fixme why if both checpoint2 and chekpoint used... not ok..
            if FLAGS.save_interval_epochs and FLAGS.save_interval_epochs < 1 and global_step.numpy(
            ) % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
                #if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
                checkpoint2.save(checkpoint_prefix2)
                if FLAGS.torch:
                    state = {
                        'epoch':
                        epoch,
                        'step':
                        global_step.numpy(),
                        'state_dict':
                        model.state_dict() if not hasattr(model, 'module') else
                        model.module.state_dict(),
                        'optimizer':
                        optimizer.state_dict(),
                    }
                    torch.save(state,
                               tf.train.latest_checkpoint(ckpt_dir2) + '.pyt')

            if FLAGS.learning_rate_decay_factor > 0:
                if global_step.numpy(
                ) >= decay_start_step and global_step.numpy(
                ) % decay_steps == 0:
                    lr = max(
                        learning_rate.numpy() *
                        FLAGS.learning_rate_decay_factor,
                        FLAGS.min_learning_rate)
                    if lr < learning_rate.numpy():
                        learning_rate.assign(lr)
                        if FLAGS.torch:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate.numpy()

            if epoch == start_epoch and i == 0:
                try:
                    if not FLAGS.torch:
                        logging.info(model.summary())
                except Exception:
                    traceback.print_exc()
                    logging.info(
                        'Fail to do model.summary() may be you have layer define in init but not used in call'
                    )
                if 'SHOW' in os.environ:
                    exit(0)

        logging.info2(
            'epoch:%d/%d' % (epoch + 1, num_epochs),
            'step:%d' % global_step.numpy(), 'batch_size:[%d]' % batch_size,
            'gpus:[%d]' % num_gpus, 'lr:[%.8f]' % learning_rate.numpy(),
            'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(),
            'valid_loss::[%.4f]' % epoch_valid_loss_avg.result().numpy())

        timer = gezi.Timer(
            f'save model to {checkpoint_prefix}-{checkpoint.save_counter.numpy() + 1}',
            False)
        checkpoint.save(checkpoint_prefix)
        if FLAGS.torch and FLAGS.save_interval_epochs == 1:
            state = {
                'epoch':
                epoch + 1,
                'step':
                global_step.numpy(),
                'state_dict':
                model.state_dict()
                if not hasattr(model, 'module') else model.module.state_dict(),
                'optimizer':
                optimizer.state_dict(),
            }
            torch.save(state, tf.train.latest_checkpoint(ckpt_dir) + '.pyt')

        timer.print_elapsed()

        if valid_dataset and (epoch + 1) % FLAGS.valid_interval_epochs == 0:
            if FLAGS.torch:
                model.eval()

            vals, names = None, None
            if evaluate_fn is not None:
                vals, names = evaluate_fn(model, valid_dataset,
                                          tf.train.latest_checkpoint(ckpt_dir),
                                          num_valid_steps_per_epoch)
            elif eval_fn:
                model_path = None if not write_valid else tf.train.latest_checkpoint(
                    ckpt_dir)
                names = valid_names if valid_names is not None else [
                    infer_names[0]
                ] + [x + '_y' for x in infer_names[1:]
                     ] + infer_names[1:] if infer_names else None

                vals, names = evaluate(model,
                                       valid_dataset,
                                       eval_fn,
                                       model_path,
                                       names,
                                       valid_write_fn,
                                       write_streaming,
                                       num_valid_steps_per_epoch,
                                       suffix=valid_suffix,
                                       sep=sep)

            if vals and names:
                logging.info2('epoch:%d/%d' % (epoch + 1, num_epochs),
                              'step:%d' % global_step.numpy(),
                              'epoch_valid_metrics', [
                                  '%s:%.5f' % (name, val)
                                  for name, val in zip(names, vals)
                              ])

        with writer.as_default(), summary.always_record_summaries():
            temp = global_step.value()
            global_step.assign(epoch + 1)
            summary.scalar('epoch/train/loss', epoch_loss_avg.result().numpy())
            if valid_dataset:
                if FLAGS.torch:
                    model.eval()
                if vals and names:
                    for name, val in zip(names, vals):
                        summary.scalar(f'epoch/valid/{name}', val)
            writer.flush()
            global_step.assign(temp)

        if test_dataset and (epoch + 1) % FLAGS.inference_interval_epochs == 0:
            if FLAGS.torch:
                model.eval()
            if inference_fn is None:
                inference(model,
                          test_dataset,
                          tf.train.latest_checkpoint(ckpt_dir),
                          infer_names,
                          infer_debug_names,
                          infer_write_fn,
                          write_streaming,
                          num_test_steps_per_epoch,
                          suffix=infer_suffix,
                          sep=sep)
            else:
                inference_fn(model, test_dataset,
                             tf.train.latest_checkpoint(ckpt_dir),
                             num_test_steps_per_epoch)

    if FLAGS.log_dir != FLAGS.model_dir:
        assert FLAGS.log_dir
        command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir)
        print(command, file=sys.stderr)
        os.system(command)
        command = 'rm -rf %s/latest.pyt.*' % (FLAGS.model_dir)
        print(command, file=sys.stderr)
        os.system(command)
コード例 #3
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    valid_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
    use_horovod=False,
):
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0
    summary_str = []

    eval_str = ''
    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        # if eval ops then should have bee rank 0

        if eval_ops:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # if use horovod let each rant use same sess.run!
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY') or use_horovod:
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False
                if use_horovod:
                    sess.run(hvd.allreduce(tf.constant(0)))

                #if not use_horovod or  hvd.local_rank() == 0:
                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                #if not use_horovod or hvd.rank() == 0:
                #  logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_)))
                eval_str = 'valid:{}'.format(
                    melt.parse_results(eval_loss, eval_names_))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                if not use_horovod or hvd.rank() == 0:
                    melt.set_global('eval_loss',
                                    melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != valid_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    if 'EVFIRST' in os.environ:
        if os.environ['EVFIRST'] == '0':
            if is_start:
                metric_evaluate = False
        else:
            if is_start:
                metric_evaluate = True

    if step == 0 or 'QUICK' in os.environ:
        metric_evaluate = False

    #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank())
    if metric_evaluate:
        if use_horovod:
            print('------------metric evaluate step', step, model_path,
                  hvd.rank())
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            metric_eval_fn_ = metric_eval_fn
        else:
            metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path)

        try:
            l = metric_eval_fn_()
            if isinstance(l, tuple):
                num_returns = len(l)
                if num_returns == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok'
                    evaluate_results, evaluate_names, evaluate_summaries = l
            else:  #return dict
                evaluate_results, evaluate_names = tuple(zip(*dict.items()))
                evaluate_summaries = None
        except Exception:
            logging.info('Do nothing for metric eval fn with exception:\n',
                         traceback.format_exc())

        if not use_horovod or hvd.rank() == 0:
            #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names)))
            logging.info2('{} valid_step:{} {}:{}'.format(
                epoch_str, step, 'valid_metrics',
                melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once,
                'summary_op') or train_once.summary_op is None or use_horovod:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            ## TODO why below ?
            #try:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        #if is_start or interval_steps and step % interval_steps == 0:
        interval_ok = not use_horovod or hvd.local_rank() == 0
        if interval_steps and step % interval_steps == 0 and interval_ok:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.2f} '.format(duration)
                melt.set_global('duration', '%.2f' % duration)
                #info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))
            info.write(eval_str)
            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                summary = tf.Summary()
                if eval_ops is None:
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'valid' if not eval_names else ''
                    # loss/valid
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    try:
                        # loss/train_avg
                        melt.add_summarys(summary,
                                          train_average_loss,
                                          names_,
                                          suffix='train_avg')
                    except Exception:
                        pass
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary,
                                     melt.batch_size(),
                                     'batch_size',
                                     prefix='other')
                    melt.add_summary(summary,
                                     melt.epoch(),
                                     'epoch',
                                     prefix='other')
                    if steps_per_second:
                        melt.add_summary(summary,
                                         steps_per_second,
                                         'steps_per_second',
                                         prefix='perf')
                    if instances_per_second:
                        melt.add_summary(summary,
                                         instances_per_second,
                                         'instances_per_second',
                                         prefix='perf')
                    if hours_per_epoch:
                        melt.add_summary(summary,
                                         hours_per_epoch,
                                         'hours_per_epoch',
                                         prefix='perf')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step_eval'
                    if model_path:
                        prefix = 'eval'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step
                    # eval/loss eval/auc ..
                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()
            return stop
        elif metric_evaluate and log_dir:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step_eval'
            if model_path:
                prefix = 'eval'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()
コード例 #4
0
def train_once(sess,
               step,
               ops,
               names=None,
               gen_feed_dict=None,
               deal_results=melt.print_results,
               interval_steps=100,
               eval_ops=None,
               eval_names=None,
               gen_eval_feed_dict=None,
               deal_eval_results=melt.print_results,
               eval_interval_steps=100,
               print_time=True,
               print_avg_loss=True,
               model_dir=None,
               log_dir=None,
               is_start=False,
               num_steps_per_epoch=None,
               metric_eval_function=None,
               metric_eval_interval_steps=0):

    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = step / num_steps_per_epoch if num_steps_per_epoch else -1
    epoch_str = 'epoch:%.4f' % (epoch) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.4f' % (epoch))

    info = BytesIO()
    stop = False

    if ops is not None:
        if deal_results is None and names is not None:
            deal_results = lambda x: melt.print_results(x, names)
        if deal_eval_results is None and eval_names is not None:
            deal_eval_results = lambda x: melt.print_results(x, eval_names)

        if eval_names is None:
            eval_names = names

        feed_dict = {} if gen_feed_dict is None else gen_feed_dict()

        results = sess.run(ops, feed_dict=feed_dict)

        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op
        results = results[1:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
                if interval_steps != eval_interval_steps:
                    train_once.avg_loss2 = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)
            if interval_steps != eval_interval_steps:
                train_once.avg_loss2.add(loss)

        if is_start or interval_steps and step % interval_steps == 0:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.3f} '.format(duration)
                melt.set_global('duration', '%.3f' % duration)
                info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size * num_gpus / elapsed
                if num_gpus == 1:
                    info.write(
                        'elapsed:[{:.3f}] batch_size:[{}] batches/s:[{:.2f}] insts/s:[{:.2f}] '
                        .format(elapsed, batch_size, steps_per_second,
                                instances_per_second))
                else:
                    info.write(
                        'elapsed:[{:.3f}] batch_size:[{}] gpus:[{}], batches/s:[{:.2f}] insts/s:[{:.2f}] '
                        .format(elapsed, batch_size, num_gpus,
                                steps_per_second, instances_per_second))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                info.write('train_avg_metrics:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))

            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'train_step:%d' % step,
                                            info.getvalue()))

            if deal_results is not None:
                stop = deal_results(results)

    metric_evaluate = False
    # if metric_eval_function is not None \
    #   and ( (is_start and (step or ops is None))\
    #     or (step and ((num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #            or (metric_eval_interval_steps \
    #                and step % metric_eval_interval_steps == 0)))):
    #     metric_evaluate = True
    if metric_eval_function is not None \
      and (is_start \
        or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
             or (metric_eval_interval_steps \
                 and step % metric_eval_interval_steps == 0)):
        metric_evaluate = True

    if metric_evaluate:
        evaluate_results, evaluate_names = metric_eval_function()

    if is_start or eval_interval_steps and step % eval_interval_steps == 0:
        if ops is not None:
            if interval_steps != eval_interval_steps:
                train_average_loss = train_once.avg_loss2.avg_score()

            info = BytesIO()

            names_ = melt.adjust_names(results, names)

            train_average_loss_str = ''
            if print_avg_loss and interval_steps != eval_interval_steps:
                train_average_loss_str = melt.value_name_list_str(
                    train_average_loss, names_)
                melt.set_global('train_loss', train_average_loss_str)
                train_average_loss_str = 'train_avg_loss:{} '.format(
                    train_average_loss_str)

            if interval_steps != eval_interval_steps:
                #end = '' if eval_ops is None else '\n'
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, train_average_loss_str, end=end)
                logging.info2('{} eval_step: {} {}'.format(
                    epoch_str, step, train_average_loss_str))

        if eval_ops is not None:
            eval_feed_dict = {} if gen_eval_feed_dict is None else gen_eval_feed_dict(
            )
            #eval_feed_dict.update(feed_dict)

            #------show how to perf debug
            ##timer_ = gezi.Timer('sess run generate')
            ##sess.run(eval_ops[-2], feed_dict=None)
            ##timer_.print()

            timer_ = gezi.Timer('sess run eval_ops')
            eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
            timer_.print()
            if deal_eval_results is not None:
                #@TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                logging.info2('{} eval_step: {} eval_metrics:'.format(
                    epoch_str, step))
                eval_stop = deal_eval_results(eval_results)

            eval_loss = gezi.get_singles(eval_results)
            assert len(eval_loss) > 0
            if eval_stop is True: stop = True
            eval_names_ = melt.adjust_names(eval_loss, eval_names)

            melt.set_global('eval_loss',
                            melt.parse_results(eval_loss, eval_names_))
        elif interval_steps != eval_interval_steps:
            #print()
            pass

        if log_dir:
            #timer_ = gezi.Timer('witting log')

            if not hasattr(train_once, 'summary_op'):
                try:
                    train_once.summary_op = tf.summary.merge_all()
                except Exception:
                    train_once.summary_op = tf.merge_all_summaries()

                melt.print_summary_ops()

                try:
                    train_once.summary_train_op = tf.summary.merge_all(
                        key=melt.MonitorKeys.TRAIN)
                    train_once.summary_writer = tf.summary.FileWriter(
                        log_dir, sess.graph)
                except Exception:
                    train_once.summary_train_op = tf.merge_all_summaries(
                        key=melt.MonitorKeys.TRAIN)
                    train_once.summary_writer = tf.train.SummaryWriter(
                        log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

            summary = tf.Summary()
            #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset
            #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss
            if train_once.summary_train_op is not None:
                summary_str = sess.run(train_once.summary_train_op,
                                       feed_dict=feed_dict)
                train_once.summary_writer.add_summary(summary_str, step)

            if eval_ops is None:
                #get train loss, for every batch train
                if train_once.summary_op is not None:
                    #timer2 = gezi.Timer('sess run')
                    summary_str = sess.run(train_once.summary_op,
                                           feed_dict=feed_dict)
                    #timer2.print()
                    train_once.summary_writer.add_summary(summary_str, step)
            else:
                #get eval loss for every batch eval, then add train loss for eval step average loss
                summary_str = sess.run(
                    train_once.summary_op, feed_dict=eval_feed_dict
                ) if train_once.summary_op is not None else ''
                #all single value results will be add to summary here not using tf.scalar_summary..
                summary.ParseFromString(summary_str)
                melt.add_summarys(summary,
                                  eval_results,
                                  eval_names_,
                                  suffix='eval')

            melt.add_summarys(summary,
                              train_average_loss,
                              names_,
                              suffix='train_avg%dsteps' % eval_interval_steps)

            if metric_evaluate:
                melt.add_summarys(summary,
                                  evaluate_results,
                                  evaluate_names,
                                  prefix='evaluate')

            train_once.summary_writer.add_summary(summary, step)
            train_once.summary_writer.flush()

            #timer_.print()

        if print_time:
            full_duration = train_once.eval_timer.elapsed()
            if metric_evaluate:
                metric_full_duration = train_once.metric_eval_timer.elapsed()
            full_duration_str = 'elapsed:{:.3f} '.format(full_duration)
            #info.write('duration:{:.3f} '.format(timer.elapsed()))
            duration = timer.elapsed()
            info.write('duration:{:.3f} '.format(duration))
            info.write(full_duration_str)
            info.write('eval_time_ratio:{:.3f} '.format(duration /
                                                        full_duration))
            if metric_evaluate:
                info.write('metric_time_ratio:{:.3f} '.format(
                    duration / metric_full_duration))
        #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue())
        logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d' % step,
                                        info.getvalue()))

        return stop
コード例 #5
0
def train(Dataset, 
          model, 
          loss_fn, 
          evaluate_fn=None, 
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          sep=','):
  if FLAGS.torch:
    if torch.cuda.is_available():
      model.cuda()
  
  input_ =  FLAGS.train_input 
  inputs = gezi.list_files(input_)
  inputs.sort()

  all_inputs = inputs

  batch_size = FLAGS.batch_size

  num_gpus = melt.num_gpus()
  if num_gpus > 1:
    assert False, 'Eager mode train currently not support for num gpus > 1'

  #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
  batch_size_ = batch_size

  if FLAGS.fold is not None:
    inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)]

  logging.info('inputs', inputs)

  dataset = Dataset('train')
  num_examples = dataset.num_examples_per_epoch('train') 
  num_all_examples = num_examples

  # if FLAGS.fold is not None:
  #   valid_inputs = [x for x in all_inputs if x not in inputs]
  # else:
  #   valid_inputs = gezi.list_files(FLAGS.valid_input)
  
  # logging.info('valid_inputs', valid_inputs)

  # if valid_inputs:
  #   valid_dataset_ = Dataset('valid')
  #   valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs)
  #   valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True)
  # else:
  #   valid_datsset = None
  #   valid_dataset2 = None

  if num_examples:
    if FLAGS.fold is not None:
      num_examples = int(num_examples * (len(inputs) / (len(inputs) + 1)))
    num_steps_per_epoch = -(-num_examples // batch_size)
  else:
    num_steps_per_epoch = None

  # if FLAGS.fold is not None:
  #   if num_examples:
  #     num_valid_examples = int(num_all_examples * (1 / (len(inputs) + 1)))
  #     num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_)
  #   else:
  #     num_valid_steps_per_epoch = None
  # else:
  #   num_valid_examples = valid_dataset_.num_examples_per_epoch('valid')
  #   num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None

  # test_inputs = gezi.list_files(FLAGS.test_input)
  # logging.info('test_inputs', test_inputs)
  
  # if test_inputs:
  #   test_dataset_ = Dataset('test')
  #   test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) 
  #   num_test_examples = test_dataset_.num_examples_per_epoch('test')
  #   num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None
  # else:
  #   test_dataset = None
  
  summary = tf.contrib.summary
  # writer = summary.create_file_writer(FLAGS.model_dir + '/epoch')
  # writer_train = summary.create_file_writer(FLAGS.model_dir + '/train')
  # writer_valid = summary.create_file_writer(FLAGS.model_dir + '/valid')
  writer = summary.create_file_writer(FLAGS.model_dir)
  writer_train = summary.create_file_writer(FLAGS.model_dir)
  writer_valid = summary.create_file_writer(FLAGS.model_dir)
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")
  tf.add_to_collection('learning_rate', learning_rate)

  learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
  try:
    learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
  except Exception:
    learning_rate_weights = None

  ckpt_dir = FLAGS.model_dir + '/ckpt'

  #TODO FIXME now I just changed tf code so to not by default save only latest 5
  # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
  #     checkpoint, directory=ckpt_dir, max_to_keep=5)
  # latest_checkpoint = manager.latest_checkpoint

  latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
  logging.info('Latest checkpoint:', latest_checkpoint)
  checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')

  if not FLAGS.torch:
    optimizer = melt.get_optimizer(FLAGS.optimizer)(learning_rate)
    
    # TODO...
    if  learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
      
    if os.path.exists(FLAGS.model_dir + '.index'):
      latest_checkpoint = FLAGS.model_dir   

    checkpoint.restore(latest_checkpoint)

    start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint else 0
  else:
    # TODO torch with learning rate adjust
    optimizer = torch.optim.Adamax(model.parameters(), lr=FLAGS.learning_rate)

    if latest_checkpoint:
      checkpoint = torch.load(latest_checkpoint + '.pyt')
      start_epoch = checkpoint['epoch']
      model.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      model.eval()
    else:
      start_epoch = 0

    if learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
          learning_rate=learning_rate, 
          learning_rate_weight=learning_rate_weight,
          global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            global_step=global_step)

  #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
  #model.save('./weight3.hd5')

  # TODO currently not support 0.1 epoch.. like this
  num_epochs = FLAGS.num_epochs
  
 
  class PytObj(object):
    def __init__(self, x):
      self.x = x
    def numpy(self):
      return self.x

  class PytMean(object):
    def __init__(self):
      self._val = 0. 
      self.count = 0

      self.is_call = True

    def clear(self):
      self._val = 0
      self.count = 0

    def __call__(self, val):
      if not self.is_call:
        self.clear()
        self.is_call = True
      self._val += val.item()
      self.count += 1

    def result(self):
      if self.is_call:
        self.is_call = False
      if not self.count:
        val = 0
      else:
        val = self._val / self.count
      # TODO just for compact with tf ..
      return PytObj(val)
      
  # TODO consider multiple gpu for torch 

  iter = dataset.make_batch(batch_size, inputs, repeat=False, initializable=False)
  batch = iter.get_next()
  #x, y = melt.split_batch(batch, batch_size, num_gpus)
  x_, y_ = batch
  
  Mean =  tfe.metrics.Mean if not FLAGS.torch else PytMean
  epoch_loss_avg = Mean()
  epoch_valid_loss_avg = Mean()

  sess = melt.get_session(device_count={'GPU': 0})
  global_step = 0
  for epoch in range(start_epoch, num_epochs):
    melt.set_global('epoch', '%.4f' % (epoch))
    sess.run(iter.initializer)

    model.train()

    #..... still OOM... FIXME TODO
    try:
      for _ in tqdm(range(num_steps_per_epoch), total=num_steps_per_epoch, ascii=True):
        x, y = sess.run([x_, y_])
        x, y = to_torch(x, y)
        
        optimizer.zero_grad()
        loss = loss_fn(model, x, y)
        loss.backward()
        optimizer.step()

        epoch_loss_avg(loss) 

        if global_step % FLAGS.interval_steps == 0:
          print(global_step, epoch_loss_avg.result().numpy())

        global_step += 1
    except tf.errors.OutOfRangeError:
      print('epoch:%d loss:%f' % (epoch, epoch_loss_avg.result().numpy()))
コード例 #6
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    eval_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
):

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or eval_interval_steps and step % eval_interval_steps == 0
    summary_str = []

    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        if eval_ops is not None:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # might use EVAL_NO_SUMMARY if some old code has problem TODO CHECK
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY'):
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False

                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                logging.info2('{} eval_step:{} eval_metrics:{}'.format(
                    epoch_str, step,
                    melt.parse_results(eval_loss, eval_names_)))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                melt.set_global('eval_loss',
                                melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != eval_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    #if (is_start or step == 0) and (not 'EVFIRST' in os.environ):
    if ((step == 0) and
        (not 'EVFIRST' in os.environ)) or ('QUICK' in os.environ) or (
            'EVFIRST' in os.environ and os.environ['EVFIRST'] == '0'):
        metric_evaluate = False

    if metric_evaluate:
        # TODO better
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            l = metric_eval_fn()
            if len(l) == 2:
                evaluate_results, evaluate_names = l
                evaluate_summaries = None
            else:
                evaluate_results, evaluate_names, evaluate_summaries = l
        else:
            try:
                l = metric_eval_fn(model_path=model_path)
                if len(l) == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    evaluate_results, evaluate_names, evaluate_summaries = l
            except Exception:
                logging.info('Do nothing for metric eval fn with exception:\n',
                             traceback.format_exc())

        logging.info2('{} valid_step:{} {}:{}'.format(
            epoch_str, step,
            'valid_metrics' if model_path is None else 'epoch_valid_metrics',
            melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once, 'summary_op') or train_once.summary_op is None:
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            #try:
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
                if interval_steps != eval_interval_steps:
                    train_once.avg_loss2 = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)
            if interval_steps != eval_interval_steps:
                train_once.avg_loss2.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        if is_start or interval_steps and step % interval_steps == 0:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.3f} '.format(duration)
                melt.set_global('duration', '%.3f' % duration)
                info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = ' 1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.8f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))

            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                # if not hasattr(train_once, 'summary_op'):
                #   melt.print_summary_ops()
                #   if summary_excls is None:
                #     train_once.summary_op = tf.summary.merge_all()
                #   else:
                #     summary_ops = []
                #     for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                #       for summary_excl in summary_excls:
                #         if not summary_excl in op.name:
                #           summary_ops.append(op)
                #     print('filtered summary_ops:')
                #     for op in summary_ops:
                #       print(op)
                #     train_once.summary_op = tf.summary.merge(summary_ops)

                #   print('-------------summary_op', train_once.summary_op)

                #   #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                #   train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

                #   tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_once.summary_writer, projector_config)

                summary = tf.Summary()
                # #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset
                # #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss
                # assert train_once.summary_train_op is None
                # if train_once.summary_train_op is not None:
                #   summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict)
                #   train_once.summary_writer.add_summary(summary_str, step)

                if eval_ops is None:
                    # #get train loss, for every batch train
                    # if train_once.summary_op is not None:
                    #   #timer2 = gezi.Timer('sess run')
                    #   try:
                    #     # TODO FIXME so this means one more train batch step without adding to global step counter ?! so should move it earlier
                    #     summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict)
                    #   except Exception:
                    #     if not hasattr(train_once, 'num_summary_errors'):
                    #       logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) fail')
                    #       train_once.num_summary_errors = 1
                    #       logging.warning(traceback.format_exc())
                    #     summary_str = ''
                    #   # #timer2.print()
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    # #get eval loss for every batch eval, then add train loss for eval step average loss
                    # try:
                    #   summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) if train_once.summary_op is not None else ''
                    # except Exception:
                    #   if not hasattr(train_once, 'num_summary_errors'):
                    #     logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail')
                    #     train_once.num_summary_errors = 1
                    #     logging.warning(traceback.format_exc())
                    #   summary_str = ''
                    #all single value results will be add to summary here not using tf.scalar_summary..
                    #summary.ParseFromString(summary_str)
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'eval' if not eval_names else ''
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    melt.add_summarys(summary,
                                      train_average_loss,
                                      names_,
                                      suffix='train_avg')
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary, melt.batch_size(), 'batch_size')
                    melt.add_summary(summary, melt.epoch(), 'epoch')
                    if steps_per_second:
                        melt.add_summary(summary, steps_per_second,
                                         'steps_per_second')
                    if instances_per_second:
                        melt.add_summary(summary, instances_per_second,
                                         'instances_per_second')
                    if hours_per_epoch:
                        melt.add_summary(summary, hours_per_epoch,
                                         'hours_per_epoch')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step/valid'
                    if model_path:
                        prefix = 'epoch/valid'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step

                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()

                #timer_.print()

            # if print_time:
            #   full_duration = train_once.eval_timer.elapsed()
            #   if metric_evaluate:
            #     metric_full_duration = train_once.metric_eval_timer.elapsed()
            #   full_duration_str = 'elapsed:{:.3f} '.format(full_duration)
            #   #info.write('duration:{:.3f} '.format(timer.elapsed()))
            #   duration = timer.elapsed()
            #   info.write('duration:{:.3f} '.format(duration))
            #   info.write(full_duration_str)
            #   info.write('eval_time_ratio:{:.3f} '.format(duration/full_duration))
            #   if metric_evaluate:
            #     info.write('metric_time_ratio:{:.3f} '.format(duration/metric_full_duration))
            # #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue())
            # logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d'%step, info.getvalue()))
            return stop
        elif metric_evaluate:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step/valid'
            if model_path:
                prefix = 'epoch/valid'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()