Пример #1
0
  def add(self, score):
    import melt.utils.logging as logging

    if not tf.executing_eagerly():
      weight = self.sess.run(self.weight_op)
    else:
      weight = self.weight_op
    #print(weight, score, self.score, self.patience)
    
    if (not self.cmp) and self.score:
      if score > self.score:
        self.cmp = lambda x, y: x > y  
      else:
        self.cmp = lambda x, y: x < y
      logging.info('decay cmp:', self.cmp)

    if not self.score or self.cmp(score, self.score):
      self.score = score 
      self.patience = 0
    else:
      self.patience += 1
      # epoch is set during training loop
      epoch = melt.epoch()
      logging.info('patience:', self.patience)
      if epoch < self.decay_start_epoch:
        return
      if self.patience >= self.max_patience:
        self.count += 1
        self.patience = 0
        self.score = score
        decay = self.decay
        pre_weight = weight
        #weight *= decay
        weight = weight * decay
        
        # decay
        if self.min_weight and weight < self.min_weight:
          weight = self.min_weight
          decay = weight / pre_weight
          if decay >  1.:
            decay = 1.

        logging.info('!decay count:', self.count, self.name, 'now:', weight)
        if not tf.executing_eagerly():
          self.sess.run(tf.assign(self.weight_op, tf.constant(weight, dtype=tf.float32)))
        else:
          self.weight_op = weight
        
        if 'learning_rate' in self.name:
          if not tf.executing_eagerly():
            melt.multiply_learning_rate(tf.constant(decay, dtype=tf.float32), self.sess)
          else:
            # TODO need to test eager mode
            #learning_rate =  tf.get_collection('learning_rate')[-1]
            #if learning_rate * decay > self.min_learning_rate:

            #tf.get_collection('learning_rate')[-1] *= decay
            tf.get_collection('learning_rate')[-1].assign(tf.get_collection('learning_rate')[-1] * decay)

    return weight
Пример #2
0
def print_img(img, i):
    img_url = FLAGS.image_url_prefix + img if not img.startswith(
        "http://") else img
    logging.info(
        img_html.format(img_url, i, img, melt.epoch(), melt.step(),
                        melt.train_loss(), melt.eval_loss(), melt.duration(),
                        gezi.now_time()))
Пример #3
0
  def add(self, scores):
    import melt.utils.logging as logging
    scores = np.array(scores)

    #print(scores.shape, self.scores.shape, len(self.names))
    logging.info('diff:', list(zip(self.names, scores - self.scores)))

    if not tf.executing_eagerly():
      weights = self.sess.run(self.weights_op)
      weights_ = weights
    else:
      weights = self.weights_op
      weights_ = weights.numpy()

    if (not self.cmp) and self.scores:
      if scores[0] > self.scores[0]:
        self.cmp = lambda x, y: x > y  
      else:
        self.cmp = lambda x, y: x < y
      logging.info('decay cmp:', self.cmp)
      
        # epoch is set during training loop
    epoch = melt.epoch()

    for i, score in enumerate(scores):
      if self.scores is None or self.cmp(score, self.scores[i]):
        self.scores[i] = score 
        self.patience[i] = 0
      else:
        self.patience[i] += 1        
        
        logging.info('patience_%s %d' % (self.names[i], self.patience[i]))
        if epoch < self.decay_start_epoch:
          continue

        if self.patience[i] >= self.max_patience:
          self.count[i] += 1
          self.patience[i] = 0
          self.scores[i] = score
          
          decay = self.decay if not isinstance(self.decay, (list, tuple)) else self.decay[i]

          weights_[i] *= decay

          if not self.min_weight:
            if weights_[i] < self.min_weight:
              weights_[i] = self.min_weight

          #logging.info('!%s decay count:%d decay ratio:%f lr ratios now:%f' % (self.names[i], self.count[i], self.decay, weights[i]))
          if not tf.executing_eagerly():
            self.sess.run(tf.assign(self.weights_op, tf.constant(weights_, dtype=tf.float32)))
          else:
            self.weights_op.assign(weights_)

    return weights_
          
Пример #4
0
def print_img(img, i):
  img_url = get_img_url(img)
  logging.info(img_html.format(
    img_url, 
    i, 
    img, 
    melt.epoch(), 
    melt.step(), 
    melt.train_loss(), 
    melt.eval_loss(),
    melt.duration(),
    gezi.now_time()))
Пример #5
0
    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step

        warmup_percent_done = step / self.warmup
        warmup_learning_rate = self.lr * warmup_percent_done

        # decay by eval value ?

        if melt.epoch() >= 9 and FLAGS.num_epochs > 9:
            self.min_lr = self.ori_min_lr * (
                (FLAGS.num_epochs - melt.epoch()) / (FLAGS.num_epochs - 5))
        #print('-----------------', melt.epoch(), melt.epoch() > 9, self.min_lr, self.ori_min_lr)

        is_warmup = step < self.warmup
        learning_rate = lr_poly(self.lr, step, self.num_train_steps,
                                self.min_lr, 1.)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)
        #print('-----------------', is_warmup, warmup_percent_done, warmup_learning_rate, warmup_learning_rate)
        return learning_rate
Пример #6
0
  def unk_aug(self, x, x_mask=None):
    """
    randomly make 10% words as unk
    TODO this works, but should this be rmoved and put it to Dataset so can share for both pyt and tf
    """
    if not self.training or not FLAGS.unk_aug or melt.epoch() < FLAGS.unk_aug_start_epoch:
      return x 

    if x_mask is None:
      x_mask = x > 0
    x_mask = x_mask.long()

    ratio = np.random.uniform(0, FLAGS.unk_aug_max_ratio)
    mask = torch.cuda.FloatTensor(x.size(0), x.size(1)).uniform_() > ratio
    mask = mask.long()
    rmask = FLAGS.unk_id * (1 - mask)

    x = (x * mask + rmask) * x_mask
    return x
Пример #7
0
    def unk_aug(self, x, x_mask=None):
        """
    randomly make some words as unk
    """
        if not self.training or not FLAGS.unk_aug or melt.epoch(
        ) < FLAGS.unk_aug_start_epoch:
            return x

        if x_mask is None:
            x_mask = x > 0
        x_mask = x_mask.long()

        ratio = np.random.uniform(0, FLAGS.unk_aug_max_ratio)
        mask = torch.cuda.FloatTensor(x.size(0), x.size(1)).uniform_() > ratio
        mask = mask.long()
        rmask = FLAGS.unk_id * (1 - mask)

        x = (x * mask + rmask) * x_mask
        return x
Пример #8
0
def evaluate_scores(predictor, random=False):
    timer = gezi.Timer('evaluate_scores')
    init()
    imgs, img_features = get_image_names_and_features()

    num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs))
    step = FLAGS.metric_eval_batch_size

    if random:
        index = np.random.choice(len(imgs),
                                 num_metric_eval_examples,
                                 replace=False)
        imgs = imgs[index]
        img_features = img_features[index]

    text_max_words = all_distinct_texts.shape[1]
    rank_metrics = gezi.rank_metrics.RecallMetrics()

    print('text_max_words:', text_max_words)
    start = 0
    while start < num_metric_eval_examples:
        end = start + step
        if end > num_metric_eval_examples:
            end = num_metric_eval_examples
        print('predicts start:', start, 'end:', end, file=sys.stderr)
        predicts(imgs[start:end], img_features[start:end], predictor,
                 rank_metrics)
        start = end

    melt.logging_results(
        rank_metrics.get_metrics(),
        rank_metrics.get_names(),
        tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
            melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss()))

    timer.print()

    return rank_metrics.get_metrics(), rank_metrics.get_names()
Пример #9
0
def train(Dataset,
          model,
          loss_fn,
          evaluate_fn=None,
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          optimizer=None,
          param_groups=None,
          init_fn=None,
          dataset=None,
          valid_dataset=None,
          test_dataset=None,
          sep=','):
    if Dataset is None:
        assert dataset
    if FLAGS.torch:
        # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        model.to(device)

    input_ = FLAGS.train_input
    inputs = gezi.list_files(input_)
    inputs.sort()

    all_inputs = inputs

    #batch_size = FLAGS.batch_size
    batch_size = melt.batch_size()

    num_gpus = melt.num_gpus()

    #batch_size = max(batch_size, 1)
    #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
    batch_size_ = batch_size

    if FLAGS.fold is not None:
        inputs = [
            x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)
            and not x.endswith('%d.tfrecord' % FLAGS.fold)
        ]
        # if FLAGS.valid_input:
        #   inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)]
    logging.info('inputs', len(inputs), inputs[:100])
    num_folds = FLAGS.num_folds or len(inputs) + 1

    train_dataset_ = dataset or Dataset('train')
    train_dataset = train_dataset_.make_batch(batch_size, inputs)
    num_examples = train_dataset_.num_examples_per_epoch('train')
    num_all_examples = num_examples

    valid_inputs = None
    if FLAGS.valid_input:
        valid_inputs = gezi.list_files(FLAGS.valid_input)
    else:
        if FLAGS.fold is not None:
            #valid_inputs = [x for x in all_inputs if x not in inputs]
            if not FLAGS.test_aug:
                valid_inputs = [
                    x for x in all_inputs if not 'aug' in x and x not in inputs
                ]
            else:
                valid_inputs = [
                    x for x in all_inputs if 'aug' in x and x not in inputs
                ]

    logging.info('valid_inputs', valid_inputs)

    if valid_inputs:
        valid_dataset_ = valid_dataset or Dataset('valid')
        valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs)
        valid_dataset2 = valid_dataset_.make_batch(batch_size_,
                                                   valid_inputs,
                                                   repeat=True)
    else:
        valid_datsset = None
        valid_dataset2 = None

    if num_examples:
        if FLAGS.fold is not None:
            num_examples = int(num_examples * (num_folds - 1) / num_folds)
        num_steps_per_epoch = -(-num_examples // batch_size)
    else:
        num_steps_per_epoch = None
    logging.info('num_train_examples:', num_examples)

    num_valid_examples = None
    if FLAGS.valid_input:
        num_valid_examples = valid_dataset_.num_examples_per_epoch('valid')
        num_valid_steps_per_epoch = -(
            -num_valid_examples // batch_size_) if num_valid_examples else None
    else:
        if FLAGS.fold is not None:
            if num_examples:
                num_valid_examples = int(num_all_examples * (1 / num_folds))
                num_valid_steps_per_epoch = -(-num_valid_examples //
                                              batch_size_)
            else:
                num_valid_steps_per_epoch = None
    logging.info('num_valid_examples:', num_valid_examples)

    if FLAGS.test_input:
        test_inputs = gezi.list_files(FLAGS.test_input)
        #test_inputs = [x for x in test_inputs if not 'aug' in x]
        logging.info('test_inputs', test_inputs)
    else:
        test_inputs = None

    num_test_examples = None
    if test_inputs:
        test_dataset_ = test_dataset or Dataset('test')
        test_dataset = test_dataset_.make_batch(batch_size_, test_inputs)
        num_test_examples = test_dataset_.num_examples_per_epoch('test')
        num_test_steps_per_epoch = -(
            -num_test_examples // batch_size_) if num_test_examples else None
    else:
        test_dataset = None
    logging.info('num_test_examples:', num_test_examples)

    summary = tf.contrib.summary
    # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch')
    # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train')
    # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid')
    writer = summary.create_file_writer(FLAGS.log_dir)
    writer_train = summary.create_file_writer(FLAGS.log_dir)
    writer_valid = summary.create_file_writer(FLAGS.log_dir)
    global_step = tf.train.get_or_create_global_step()

    learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")

    tf.add_to_collection('learning_rate', learning_rate)

    learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
    try:
        learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
    except Exception:
        learning_rate_weights = None

    # ckpt dir save models one per epoch
    ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt')
    os.system('mkdir -p %s' % ckpt_dir)
    # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset
    ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2')
    os.system('mkdir -p %s' % ckpt_dir2)

    #TODO FIXME now I just changed tf code so to not by default save only latest 5
    # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
    #     checkpoint, directory=ckpt_dir, max_to_keep=5)
    # latest_checkpoint = manager.latest_checkpoint

    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
    if latest_checkpoint:
        logging.info('Latest checkpoint:', latest_checkpoint)
    else:
        latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2)
        logging.info('Latest checkpoint:', latest_checkpoint)

    if os.path.exists(FLAGS.model_dir + '.index'):
        latest_checkpoint = FLAGS.model_dir

    if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode:
        #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
        latest_checkpoint = FLAGS.model_dir
        #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint)

    checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')
    checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt')

    if not FLAGS.torch:
        try:
            optimizer = optimizer or melt.get_optimizer(
                FLAGS.optimizer)(learning_rate)
        except Exception:
            logging.warning(
                f'Fail to using {FLAGS.optimizer} use adam instead')
            optimizer = melt.get_optimizer('adam')(learning_rate)

        # TODO...
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                model=model,
                optimizer=optimizer,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                model=model,
                optimizer=optimizer,
                global_step=global_step)

        checkpoint.restore(latest_checkpoint)
        checkpoint2 = copy.deepcopy(checkpoint)

        start_epoch = int(
            latest_checkpoint.split('-')
            [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0
    else:
        # TODO torch with learning rate adjust
        if optimizer is None:
            import lele
            is_dynamic_opt = True
            if FLAGS.optimizer == 'noam':
                optimizer = lele.training.optimizers.NoamOpt(
                    128, 2, 4000, torch.optim.Adamax(model.parameters(), lr=0))
            elif FLAGS.optimizer == 'bert':
                num_train_steps = int(
                    num_steps_per_epoch *
                    (FLAGS.num_decay_epochs or FLAGS.num_epochs))
                num_warmup_steps = FLAGS.warmup_steps or int(
                    num_train_steps * FLAGS.warmup_proportion)
                logging.info('num_train_steps', num_train_steps,
                             'num_warmup_steps', num_warmup_steps,
                             'warmup_proportion', FLAGS.warmup_proportion)
                optimizer = lele.training.optimizers.BertOpt(
                    FLAGS.learning_rate, FLAGS.min_learning_rate,
                    num_train_steps, num_warmup_steps,
                    torch.optim.Adamax(model.parameters(), lr=0))
            else:
                is_dynamic_opt = False
                optimizer = torch.optim.Adamax(
                    param_groups if param_groups else model.parameters(),
                    lr=FLAGS.learning_rate)

        start_epoch = 0
        latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(
            FLAGS.model_dir, 'latest.pyt')
        if not os.path.exists(latest_path):
            latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt')
        if os.path.exists(latest_path):
            logging.info('loading torch model from', latest_path)
            checkpoint = torch.load(latest_path)
            if not FLAGS.torch_finetune:
                start_epoch = checkpoint['epoch']
                step = checkpoint['step']
                global_step.assign(step + 1)
            load_torch_model(model, latest_path)
            if FLAGS.torch_load_optimizer:
                optimizer.load_state_dict(checkpoint['optimizer'])

        # TODO by this way restart can not change learning rate..
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                global_step=global_step)

        try:
            checkpoint.restore(latest_checkpoint)
            checkpoint2 = copy.deepcopy(checkpoint)
        except Exception:
            pass

    if FLAGS.torch and is_dynamic_opt:
        optimizer._step = global_step.numpy()

    #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
    #model.save('./weight3.hd5')
    logging.info('optimizer:', optimizer)

    if FLAGS.torch_lr:
        learning_rate.assign(optimizer.rate(1))
    if FLAGS.torch:
        learning_rate.assign(optimizer.param_groups[0]['lr'])
        logging.info('learning rate got from pytorch latest.py as',
                     learning_rate)

    learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor)
    if learning_rate_weights is not None:
        learning_rate_weights.assign(learning_rate_weights *
                                     FLAGS.learning_rate_start_factor)

    # TODO currently not support 0.1 epoch.. like this
    num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024

    will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ
    if start_epoch == 0 and not 'EVFIRST' in os.environ and will_valid:
        will_valid = False

    if start_epoch > 0 and will_valid:
        will_valid = True

    if will_valid:
        logging.info('----------valid')
        if FLAGS.torch:
            model.eval()
        names = None
        if evaluate_fn is not None:
            vals, names = evaluate_fn(model, valid_dataset,
                                      tf.train.latest_checkpoint(ckpt_dir),
                                      num_valid_steps_per_epoch)
        elif eval_fn:
            model_path = None if not write_valid else latest_checkpoint
            names = valid_names if valid_names is not None else [
                infer_names[0]
            ] + [x + '_y' for x in infer_names[1:]
                 ] + infer_names[1:] if infer_names else None

            logging.info('model_path:', model_path, 'model_dir:',
                         FLAGS.model_dir)
            vals, names = evaluate(model,
                                   valid_dataset,
                                   eval_fn,
                                   model_path,
                                   names,
                                   valid_write_fn,
                                   write_streaming,
                                   num_valid_steps_per_epoch,
                                   suffix=valid_suffix,
                                   sep=sep)
        if names:
            logging.info2(
                'epoch:%d/%d' % (start_epoch, num_epochs),
                ['%s:%.5f' % (name, val) for name, val in zip(names, vals)])

    if FLAGS.work_mode == 'valid':
        exit(0)

    if 'test' in FLAGS.work_mode:
        logging.info('--------test/inference')
        if test_dataset:
            if FLAGS.torch:
                model.eval()
            if inference_fn is None:
                # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint
                # logging.info('model_path', model_path)
                assert latest_checkpoint
                inference(model,
                          test_dataset,
                          latest_checkpoint,
                          infer_names,
                          infer_debug_names,
                          infer_write_fn,
                          write_streaming,
                          num_test_steps_per_epoch,
                          suffix=infer_suffix)
            else:
                inference_fn(model, test_dataset,
                             tf.train.latest_checkpoint(ckpt_dir),
                             num_test_steps_per_epoch)
        exit(0)

    if 'SHOW' in os.environ:
        num_epochs = start_epoch + 1

    class PytObj(object):
        def __init__(self, x):
            self.x = x

        def numpy(self):
            return self.x

    class PytMean(object):
        def __init__(self):
            self._val = 0.
            self.count = 0

            self.is_call = True

        def clear(self):
            self._val = 0
            self.count = 0

        def __call__(self, val):
            if not self.is_call:
                self.clear()
                self.is_call = True
            self._val += val.item()
            self.count += 1

        def result(self):
            if self.is_call:
                self.is_call = False
            if not self.count:
                val = 0
            else:
                val = self._val / self.count
            # TODO just for compact with tf ..
            return PytObj(val)

    Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean
    timer = gezi.Timer()
    num_insts = 0

    if FLAGS.learning_rate_decay_factor > 0:
        #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?'
        #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay
        #since global step / decay_steps will not be correct epoch as num_steps per epoch changed
        #so if if you change batch set you have to reset global step as fixed step
        assert FLAGS.num_steps_per_decay or (
            FLAGS.num_epochs_per_decay and num_steps_per_epoch
        ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch'
        decay_steps = FLAGS.num_steps_per_decay or int(
            num_steps_per_epoch * FLAGS.num_epochs_per_decay)
        decay_start_step = FLAGS.decay_start_step or int(
            num_steps_per_epoch * FLAGS.decay_start_epoch)
        # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
        logging.info(
            'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'
            .format(FLAGS.learning_rate_decay_factor,
                    FLAGS.num_epochs_per_decay, decay_steps,
                    FLAGS.decay_start_epoch, decay_start_step))

    for epoch in range(start_epoch, num_epochs):
        melt.set_global('epoch', '%.4f' % (epoch))

        if FLAGS.torch:
            model.train()

        epoch_loss_avg = Mean()
        epoch_valid_loss_avg = Mean()

        #for i, (x, y) in tqdm(enumerate(train_dataset), total=num_steps_per_epoch, ascii=True):
        for i, (x, y) in enumerate(train_dataset):
            if FLAGS.torch:
                x, y = to_torch(x, y)
                if is_dynamic_opt:
                    learning_rate.assign(optimizer.rate())

            #print(x, y)

            if not FLAGS.torch:
                loss, grads = melt.eager.grad(model, x, y, loss_fn)
                grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients)
                optimizer.apply_gradients(zip(grads, model.variables))
            else:
                optimizer.zero_grad()
                if 'training' in inspect.getargspec(loss_fn).args:
                    loss = loss_fn(model, x, y, training=True)
                else:
                    loss = loss_fn(model, x, y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               FLAGS.clip_gradients)
                optimizer.step()

            global_step.assign_add(1)

            epoch_loss_avg(loss)  # add current batch loss

            if FLAGS.torch:
                del loss

            batch_size_ = list(
                x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type(
                    {}) else x.shape[FLAGS.batch_size_dim]
            num_insts += int(batch_size_)
            if global_step.numpy() % FLAGS.interval_steps == 0:
                #checkpoint.save(checkpoint_prefix)
                elapsed = timer.elapsed()
                steps_per_second = FLAGS.interval_steps / elapsed
                instances_per_second = num_insts / elapsed
                num_insts = 0

                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)

                if valid_dataset2:
                    try:
                        x, y = next(iter(valid_dataset2))
                    except Exception:
                        # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset
                        x, y = next(iter(valid_dataset2))

                    if FLAGS.torch:
                        x, y = to_torch(x, y)
                        model.eval()
                    valid_loss = loss_fn(model, x, y)
                    epoch_valid_loss_avg(valid_loss)
                    if FLAGS.torch:
                        model.train()

                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' %
                        elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' %
                        num_gpus, 'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second, '%s' %
                        epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(),
                        'valid_loss:[%.4f]' %
                        epoch_valid_loss_avg.result().numpy())
                    if global_step.numpy() % FLAGS.eval_interval_steps == 0:
                        with writer_valid.as_default(
                        ), summary.always_record_summaries():
                            #summary.scalar('step/loss', epoch_valid_loss_avg.result().numpy())
                            summary.scalar(
                                'loss/eval',
                                epoch_valid_loss_avg.result().numpy())
                            writer_valid.flush()
                else:
                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' %
                        elapsed, 'batch_size:[%d]' % batch_size_,
                        'gpus:[%d]' % num_gpus,
                        'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second,
                        '%s' % epoch_time_info,
                        'lr:[%.8f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % epoch_loss_avg.result().numpy())

                if global_step.numpy() % FLAGS.eval_interval_steps == 0:
                    with writer_train.as_default(
                    ), summary.always_record_summaries():
                        #summary.scalar('step/loss', epoch_loss_avg.result().numpy())
                        summary.scalar('loss/train_avg',
                                       epoch_loss_avg.result().numpy())
                        summary.scalar('learning_rate', learning_rate.numpy())
                        summary.scalar('batch_size', batch_size_)
                        summary.scalar('epoch', melt.epoch())
                        summary.scalar('steps_per_second', steps_per_second)
                        summary.scalar('instances_per_second',
                                       instances_per_second)
                        writer_train.flush()

                    if FLAGS.log_dir != FLAGS.model_dir:
                        assert FLAGS.log_dir
                        command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir,
                                                              FLAGS.model_dir)
                        print(command, file=sys.stderr)
                        os.system(command)

            if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy(
            ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0:
                if FLAGS.torch:
                    model.eval()
                vals, names = None, None
                if evaluate_fn is not None:
                    vals, names = evaluate_fn(model, valid_dataset, None,
                                              num_valid_steps_per_epoch)
                elif eval_fn:
                    names = valid_names if valid_names is not None else [
                        infer_names[0]
                    ] + [x + '_y' for x in infer_names[1:]
                         ] + infer_names[1:] if infer_names else None
                    vals, names = evaluate(model,
                                           valid_dataset,
                                           eval_fn,
                                           None,
                                           names,
                                           valid_write_fn,
                                           write_streaming,
                                           num_valid_steps_per_epoch,
                                           sep=sep)
                if vals and names:
                    with writer_valid.as_default(
                    ), summary.always_record_summaries():
                        for name, val in zip(names, vals):
                            summary.scalar(f'step/valid/{name}', val)
                        writer_valid.flush()

                if FLAGS.torch:
                    if not FLAGS.torch_lr:
                        # control learning rate by tensorflow learning rate
                        for param_group in optimizer.param_groups:
                            # important learning rate decay
                            param_group['lr'] = learning_rate.numpy()

                    model.train()

                if names and vals:
                    logging.info2(
                        'epoch:%.3f/%d' %
                        ((epoch + i / num_steps_per_epoch), num_epochs),
                        'valid_step:%d' % global_step.numpy(), 'valid_metrics',
                        [
                            '%s:%.5f' % (name, val)
                            for name, val in zip(names, vals)
                        ])

            # if i == 5:
            #   print(i, '---------------------save')
            #   print(len(model.trainable_variables))
            ## TODO FIXME seems save weighs value not ok... not the same as checkpoint save
            #   model.save_weights(os.path.join(ckpt_dir, 'weights'))
            #   checkpoint.save(checkpoint_prefix)
            #   exit(0)

            if global_step.numpy() % FLAGS.save_interval_steps == 0:
                if FLAGS.torch:
                    state = {
                        'epoch':
                        epoch,
                        'step':
                        global_step.numpy(),
                        'state_dict':
                        model.state_dict() if not hasattr(model, 'module') else
                        model.module.state_dict(),
                        'optimizer':
                        optimizer.state_dict(),
                    }
                    torch.save(state,
                               os.path.join(FLAGS.model_dir, 'latest.pyt'))

            # TODO fixme why if both checpoint2 and chekpoint used... not ok..
            if FLAGS.save_interval_epochs and FLAGS.save_interval_epochs < 1 and global_step.numpy(
            ) % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
                #if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
                checkpoint2.save(checkpoint_prefix2)
                if FLAGS.torch:
                    state = {
                        'epoch':
                        epoch,
                        'step':
                        global_step.numpy(),
                        'state_dict':
                        model.state_dict() if not hasattr(model, 'module') else
                        model.module.state_dict(),
                        'optimizer':
                        optimizer.state_dict(),
                    }
                    torch.save(state,
                               tf.train.latest_checkpoint(ckpt_dir2) + '.pyt')

            if FLAGS.learning_rate_decay_factor > 0:
                if global_step.numpy(
                ) >= decay_start_step and global_step.numpy(
                ) % decay_steps == 0:
                    lr = max(
                        learning_rate.numpy() *
                        FLAGS.learning_rate_decay_factor,
                        FLAGS.min_learning_rate)
                    if lr < learning_rate.numpy():
                        learning_rate.assign(lr)
                        if FLAGS.torch:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate.numpy()

            if epoch == start_epoch and i == 0:
                try:
                    if not FLAGS.torch:
                        logging.info(model.summary())
                except Exception:
                    traceback.print_exc()
                    logging.info(
                        'Fail to do model.summary() may be you have layer define in init but not used in call'
                    )
                if 'SHOW' in os.environ:
                    exit(0)

        logging.info2(
            'epoch:%d/%d' % (epoch + 1, num_epochs),
            'step:%d' % global_step.numpy(), 'batch_size:[%d]' % batch_size,
            'gpus:[%d]' % num_gpus, 'lr:[%.8f]' % learning_rate.numpy(),
            'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(),
            'valid_loss::[%.4f]' % epoch_valid_loss_avg.result().numpy())

        timer = gezi.Timer(
            f'save model to {checkpoint_prefix}-{checkpoint.save_counter.numpy() + 1}',
            False)
        checkpoint.save(checkpoint_prefix)
        if FLAGS.torch and FLAGS.save_interval_epochs == 1:
            state = {
                'epoch':
                epoch + 1,
                'step':
                global_step.numpy(),
                'state_dict':
                model.state_dict()
                if not hasattr(model, 'module') else model.module.state_dict(),
                'optimizer':
                optimizer.state_dict(),
            }
            torch.save(state, tf.train.latest_checkpoint(ckpt_dir) + '.pyt')

        timer.print_elapsed()

        if valid_dataset and (epoch + 1) % FLAGS.valid_interval_epochs == 0:
            if FLAGS.torch:
                model.eval()

            vals, names = None, None
            if evaluate_fn is not None:
                vals, names = evaluate_fn(model, valid_dataset,
                                          tf.train.latest_checkpoint(ckpt_dir),
                                          num_valid_steps_per_epoch)
            elif eval_fn:
                model_path = None if not write_valid else tf.train.latest_checkpoint(
                    ckpt_dir)
                names = valid_names if valid_names is not None else [
                    infer_names[0]
                ] + [x + '_y' for x in infer_names[1:]
                     ] + infer_names[1:] if infer_names else None

                vals, names = evaluate(model,
                                       valid_dataset,
                                       eval_fn,
                                       model_path,
                                       names,
                                       valid_write_fn,
                                       write_streaming,
                                       num_valid_steps_per_epoch,
                                       suffix=valid_suffix,
                                       sep=sep)

            if vals and names:
                logging.info2('epoch:%d/%d' % (epoch + 1, num_epochs),
                              'step:%d' % global_step.numpy(),
                              'epoch_valid_metrics', [
                                  '%s:%.5f' % (name, val)
                                  for name, val in zip(names, vals)
                              ])

        with writer.as_default(), summary.always_record_summaries():
            temp = global_step.value()
            global_step.assign(epoch + 1)
            summary.scalar('epoch/train/loss', epoch_loss_avg.result().numpy())
            if valid_dataset:
                if FLAGS.torch:
                    model.eval()
                if vals and names:
                    for name, val in zip(names, vals):
                        summary.scalar(f'epoch/valid/{name}', val)
            writer.flush()
            global_step.assign(temp)

        if test_dataset and (epoch + 1) % FLAGS.inference_interval_epochs == 0:
            if FLAGS.torch:
                model.eval()
            if inference_fn is None:
                inference(model,
                          test_dataset,
                          tf.train.latest_checkpoint(ckpt_dir),
                          infer_names,
                          infer_debug_names,
                          infer_write_fn,
                          write_streaming,
                          num_test_steps_per_epoch,
                          suffix=infer_suffix,
                          sep=sep)
            else:
                inference_fn(model, test_dataset,
                             tf.train.latest_checkpoint(ckpt_dir),
                             num_test_steps_per_epoch)

    if FLAGS.log_dir != FLAGS.model_dir:
        assert FLAGS.log_dir
        command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir)
        print(command, file=sys.stderr)
        os.system(command)
        command = 'rm -rf %s/latest.pyt.*' % (FLAGS.model_dir)
        print(command, file=sys.stderr)
        os.system(command)
Пример #10
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    valid_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
    use_horovod=False,
):
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0
    summary_str = []

    eval_str = ''
    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        # if eval ops then should have bee rank 0

        if eval_ops:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # if use horovod let each rant use same sess.run!
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY') or use_horovod:
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False
                if use_horovod:
                    sess.run(hvd.allreduce(tf.constant(0)))

                #if not use_horovod or  hvd.local_rank() == 0:
                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                #if not use_horovod or hvd.rank() == 0:
                #  logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_)))
                eval_str = 'valid:{}'.format(
                    melt.parse_results(eval_loss, eval_names_))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                if not use_horovod or hvd.rank() == 0:
                    melt.set_global('eval_loss',
                                    melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != valid_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    if 'EVFIRST' in os.environ:
        if os.environ['EVFIRST'] == '0':
            if is_start:
                metric_evaluate = False
        else:
            if is_start:
                metric_evaluate = True

    if step == 0 or 'QUICK' in os.environ:
        metric_evaluate = False

    #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank())
    if metric_evaluate:
        if use_horovod:
            print('------------metric evaluate step', step, model_path,
                  hvd.rank())
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            metric_eval_fn_ = metric_eval_fn
        else:
            metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path)

        try:
            l = metric_eval_fn_()
            if isinstance(l, tuple):
                num_returns = len(l)
                if num_returns == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok'
                    evaluate_results, evaluate_names, evaluate_summaries = l
            else:  #return dict
                evaluate_results, evaluate_names = tuple(zip(*dict.items()))
                evaluate_summaries = None
        except Exception:
            logging.info('Do nothing for metric eval fn with exception:\n',
                         traceback.format_exc())

        if not use_horovod or hvd.rank() == 0:
            #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names)))
            logging.info2('{} valid_step:{} {}:{}'.format(
                epoch_str, step, 'valid_metrics',
                melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once,
                'summary_op') or train_once.summary_op is None or use_horovod:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            ## TODO why below ?
            #try:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        #if is_start or interval_steps and step % interval_steps == 0:
        interval_ok = not use_horovod or hvd.local_rank() == 0
        if interval_steps and step % interval_steps == 0 and interval_ok:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.2f} '.format(duration)
                melt.set_global('duration', '%.2f' % duration)
                #info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))
            info.write(eval_str)
            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                summary = tf.Summary()
                if eval_ops is None:
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'valid' if not eval_names else ''
                    # loss/valid
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    try:
                        # loss/train_avg
                        melt.add_summarys(summary,
                                          train_average_loss,
                                          names_,
                                          suffix='train_avg')
                    except Exception:
                        pass
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary,
                                     melt.batch_size(),
                                     'batch_size',
                                     prefix='other')
                    melt.add_summary(summary,
                                     melt.epoch(),
                                     'epoch',
                                     prefix='other')
                    if steps_per_second:
                        melt.add_summary(summary,
                                         steps_per_second,
                                         'steps_per_second',
                                         prefix='perf')
                    if instances_per_second:
                        melt.add_summary(summary,
                                         instances_per_second,
                                         'instances_per_second',
                                         prefix='perf')
                    if hours_per_epoch:
                        melt.add_summary(summary,
                                         hours_per_epoch,
                                         'hours_per_epoch',
                                         prefix='perf')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step_eval'
                    if model_path:
                        prefix = 'eval'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step
                    # eval/loss eval/auc ..
                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()
            return stop
        elif metric_evaluate and log_dir:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step_eval'
            if model_path:
                prefix = 'eval'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()
Пример #11
0
def get_num_finetune_words():
    if not FLAGS.dynamic_finetune:
        return FLAGS.num_finetune_words
    else:
        return min(int(melt.epoch() * 1000), FLAGS.num_finetune_words)
Пример #12
0
def evaluate_scores(predictor, random=False):
    timer = gezi.Timer('evaluate_scores')
    init()
    if FLAGS.eval_img2text:
        imgs, img_features = get_image_names_and_features()
        num_metric_eval_examples = min(FLAGS.num_metric_eval_examples,
                                       len(imgs))
        step = FLAGS.metric_eval_batch_size

        if random:
            index = np.random.choice(len(imgs),
                                     num_metric_eval_examples,
                                     replace=False)
            imgs = imgs[index]
            img_features = img_features[index]

        rank_metrics = gezi.rank_metrics.RecallMetrics()

        start = 0
        while start < num_metric_eval_examples:
            end = start + step
            if end > num_metric_eval_examples:
                end = num_metric_eval_examples
            print('predicts image start:', start, 'end:', end, file=sys.stderr)
            predicts(imgs[start:end], img_features[start:end], predictor,
                     rank_metrics)
            start = end

        melt.logging_results(
            rank_metrics.get_metrics(),
            rank_metrics.get_names(),
            tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
                melt.epoch(), melt.step(), melt.train_loss(),
                melt.eval_loss()))

    if FLAGS.eval_text2img:
        num_metric_eval_examples = min(FLAGS.num_metric_eval_examples,
                                       len(all_distinct_texts))
        if random:
            index = np.random.choice(len(all_distinct_texts),
                                     num_metric_eval_examples,
                                     replace=False)
            text_strs = all_distinct_text_strs[index]
            texts = all_distinct_texts[index]

        rank_metrics2 = gezi.rank_metrics.RecallMetrics()

        start = 0
        while start < num_metric_eval_examples:
            end = start + step
            if end > num_metric_eval_examples:
                end = num_metric_eval_examples
            print('predicts start:', start, 'end:', end, file=sys.stderr)
            predicts_txt2im(text_strs[start:end], texts[start:end], predictor,
                            rank_metrics2)
            start = end

        melt.logging_results(rank_metrics2.get_metrics(),
                             ['t2i' + x for x in rank_metrics2.get_names()],
                             tag='text2img')

    timer.print()

    if FLAGS.eval_img2text and FLAGS.eval_text2img:
        return rank_metrics.get_metrics() + rank_metrics2.get_metrics(
        ), rank_metrics.get_names() + [
            't2i' + x for x in rank_metrics2.get_names()
        ]
    elif FLAGS.eval_img2text:
        return rank_metrics.get_metrics(), rank_metrics.get_names()
    else:
        return rank_metrics2.get_metrics(), rank_metrics2.get_names()
Пример #13
0
def train(model, 
          loss_fn,
          Dataset=None,  
          dataset=None,
          valid_dataset=None,
          valid_dataset2=None,
          test_dataset=None,
          evaluate_fn=None, 
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          optimizer=None,
          param_groups=None,
          init_fn=None,
          sep=','):
  use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

  if Dataset is None:
    assert dataset
  logging.info('Dataset', Dataset, 'dataset', dataset, 'valid_dataset', valid_dataset, 'test_dataset', test_dataset, loss_fn)

  if FLAGS.torch:
    torch.manual_seed(FLAGS.seed or 0)
    if torch.cuda.device_count():
      torch.cuda.manual_seed(FLAGS.seed or 0)
    if use_horovod:
      pass
      # import horovod.torch as hvd
      # hvd.init()
      # #print('-----------------', hvd, hvd.size())
      # assert hvd.mpi_threads_supported()
      # assert hvd.size() == comm.Get_size()
      # torch.cuda.set_device(hvd.local_rank())
    # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
    else:
      if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(device)
    
  input_ =  FLAGS.train_input 
  inputs = gezi.list_files(input_)
  inputs.sort()

  all_inputs = inputs

  #batch_size = FLAGS.batch_size
  batch_size = melt.batch_size()

  num_gpus = melt.num_gpus()

  #batch_size = max(batch_size, 1)
  #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
  batch_size_ = FLAGS.eval_batch_size or batch_size

  if dataset is None:
    if FLAGS.fold is not None:
      inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold)]
      # if FLAGS.valid_input:
      #   inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)]
    logging.info('inputs', len(inputs), inputs[:100])
  num_folds = FLAGS.num_folds or len(inputs) + 1

  if dataset is None:
    dataset = Dataset('train')
    assert len(inputs) > 0
    train_dataset = dataset.make_batch(batch_size, inputs, simple_parse=FLAGS.simple_parse)
    num_examples = dataset.num_examples_per_epoch('train') 
  else:
    assert FLAGS.torch_only, 'only torch only currently support input dataset not Dataset class type, because we do not have len function there'
    train_dataset = dataset
    num_examples = len(train_dataset.dataset)

  num_all_examples = num_examples

  if valid_dataset is None:
    valid_inputs = None
    if FLAGS.valid_input:
      valid_inputs = gezi.list_files(FLAGS.valid_input)
    else:
      if FLAGS.fold is not None:
        #valid_inputs = [x for x in all_inputs if x not in inputs]
        if not FLAGS.test_aug:
          valid_inputs = [x for x in all_inputs if not 'aug' in x and x not in inputs]
        else:
          valid_inputs = [x for x in all_inputs if 'aug' in x and x not in inputs]

    logging.info('valid_inputs', valid_inputs)

  num_valid_examples = None
  if valid_dataset is not None:
    num_valid_examples = len(valid_dataset.dataset)
    num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None   
    valid_dataset2_iter = iter(valid_dataset2)
  else:
    if valid_inputs:
      valid_dataset = dataset.make_batch(batch_size_, valid_inputs, subset='valid', hvd_shard=FLAGS.horovod_eval )
      valid_dataset2 = dataset.make_batch(batch_size, valid_inputs, subset='valid', repeat=True, initializable=False, hvd_shard=False)
      valid_dataset2_iter = iter(valid_dataset2)
    else:
      valid_datsset = None
      valid_dataset2 = None

  if num_examples:
    if FLAGS.fold is not None:
      num_examples = int(num_examples * (num_folds - 1) / num_folds)
    num_steps_per_epoch = -(-num_examples // batch_size)
  else:
    num_steps_per_epoch = None
  logging.info('num_train_examples:', num_examples)
  if use_horovod and num_examples:
    num_steps_per_epoch = -(-num_examples // (batch_size * hvd.size()))

  if num_valid_examples is None:
    if FLAGS.valid_input:
      num_valid_examples = dataset.num_examples_per_epoch('valid')
      num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None   
    else:
      if FLAGS.fold is not None:
        if num_examples:
          num_valid_examples = int(num_all_examples * (1 / num_folds))
          num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_)
        else:
          num_valid_steps_per_epoch = None
  if use_horovod and FLAGS.horovod_eval and num_valid_examples:
      num_valid_steps_per_epoch = -(-num_valid_examples // (batch_size_ * hvd.size()))
  logging.info('num_valid_examples:', num_valid_examples)

  if test_dataset is None:
    if FLAGS.test_input:
      test_inputs = gezi.list_files(FLAGS.test_input)
      #test_inputs = [x for x in test_inputs if not 'aug' in x]
      logging.info('test_inputs', test_inputs)
    else:
      test_inputs = None
  
  num_test_examples = None
  if test_dataset is not None:
    num_test_examples = len(test_dataset.dataset)
  else:
    if test_inputs:
      test_dataset = dataset.make_batch(batch_size_, test_inputs, subset='test') 
      num_test_examples = dataset.num_examples_per_epoch('test')
    else:
      test_dataset = None
  num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None
  if use_horovod and FLAGS.horovod_eval and num_test_examples:
      num_test_steps_per_epoch = -(-num_test_examples // (batch_size_ * hvd.size()))
  logging.info('num_test_examples:', num_test_examples)
  
  summary = tf.contrib.summary
  # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch')
  # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train')
  # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid')
  writer = summary.create_file_writer(FLAGS.log_dir)
  writer_train = summary.create_file_writer(FLAGS.log_dir)
  writer_valid = summary.create_file_writer(FLAGS.log_dir)
  global_step = tf.train.get_or_create_global_step()
  ## RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead.
  #logger = gezi.SummaryWriter(FLAGS.log_dir)

  learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")
  
  tf.add_to_collection('learning_rate', learning_rate)

  learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
  try:
    learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
  except Exception:
    learning_rate_weights = None

  # ckpt dir save models one per epoch
  ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt')
  os.system('mkdir -p %s' % ckpt_dir)
  # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset
  ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2')
  os.system('mkdir -p %s' % ckpt_dir2)

  #TODO FIXME now I just changed tf code so to not by default save only latest 5
  # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
  #     checkpoint, directory=ckpt_dir, max_to_keep=5)
  # latest_checkpoint = manager.latest_checkpoint

  latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
  if latest_checkpoint:
    logging.info('Latest checkpoint:', latest_checkpoint)
  else:
    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2)
    logging.info('Latest checkpoint:', latest_checkpoint)
  
  if os.path.exists(FLAGS.model_dir + '.index'):
    latest_checkpoint = FLAGS.model_dir  

  if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode:
    #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
    latest_checkpoint = FLAGS.model_dir
    #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint)

  checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')
  checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt')

  if not FLAGS.torch:
    try:
      optimizer = optimizer or melt.get_optimizer(FLAGS.optimizer)(learning_rate)
    except Exception:
      logging.warning(f'Fail to using {FLAGS.optimizer} use adam instead')
      optimizer = melt.get_optimizer('adam')(learning_rate)
    
    # TODO...
    if  learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            model=model,
            optimizer=optimizer,
            global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            model=model,
            optimizer=optimizer,
            global_step=global_step) 

    checkpoint.restore(latest_checkpoint)
    checkpoint2 = copy.deepcopy(checkpoint)

    start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0
    start_step = 0 # TODO
  else:
    # TODO torch with learning rate adjust
      # https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py
  # TODO full support for pytorch now not work
    
    if optimizer is None:
      import lele
      is_dynamic_opt = True
      if FLAGS.optimizer == 'noam':
        optimizer_ = torch.optim.Adamax(model.parameters(), lr=0)
        if use_horovod:
          optimizer_ = hvd.DistributedOptimizer(optimizer_)
        optimizer = lele.training.optimizers.NoamOpt(128, 2, 4000, optimzier_)
      elif FLAGS.optimizer == 'bert':
        num_train_steps = int(num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs))
        if FLAGS.warmup_steps and use_horovod:
          FLAGS.warmup_steps = max(int(FLAGS.warmup_steps / hvd.size()), 1)
        num_warmup_steps = FLAGS.warmup_steps or int(num_steps_per_epoch * FLAGS.warmup_epochs) or int(num_train_steps * FLAGS.warmup_proportion) 
        logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion)
        optimizer_ = torch.optim.Adamax(model.parameters(), lr=0)
        if use_horovod:
          optimizer_ = hvd.DistributedOptimizer(optimizer_)
        optimizer = lele.training.optimizers.BertOpt(
                            FLAGS.learning_rate, 
                            FLAGS.min_learning_rate,
                            num_train_steps,
                            num_warmup_steps,
                            optimizer_
                            )
      else:
        is_dynamic_opt = False
        optimizer = torch.optim.Adamax(param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate)
        if use_horovod:
          optimizer = hvd.DistributedOptimizer(optimizer)
          optimizer_ = optimizer
    else:
      if use_horovod:
        optimizer = hvd.DistributedOptimizer(optimizer)
        optimizer_ = optimizer
    
    start_epoch = 0  
    latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(FLAGS.model_dir, 'latest.pyt')
    if not os.path.exists(latest_path):
      latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt')
    if os.path.exists(latest_path):
      logging.info('loading torch model from', latest_path)
      checkpoint = torch.load(latest_path)
      if not FLAGS.torch_finetune:
        start_epoch = checkpoint['epoch']
        step = checkpoint['step']
        global_step.assign(step + 1)
      load_torch_model(model, latest_path)
      if FLAGS.torch_load_optimizer:
        optimizer.load_state_dict(checkpoint['optimizer'])

    # TODO by this way restart can not change learning rate..
    if learning_rate_weights is None:
      checkpoint = tf.train.Checkpoint(
          learning_rate=learning_rate, 
          learning_rate_weight=learning_rate_weight,
          global_step=global_step)
    else:
      checkpoint = tf.train.Checkpoint(
            learning_rate=learning_rate, 
            learning_rate_weight=learning_rate_weight,
            learning_rate_weights=learning_rate_weights,
            global_step=global_step)

    try:
      checkpoint.restore(latest_checkpoint)
      checkpoint2 = copy.deepcopy(checkpoint)
    except Exception:
      pass

  if FLAGS.torch and is_dynamic_opt:
    optimizer._step = global_step.numpy()
    
  #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
  #model.save('./weight3.hd5')
  logging.info('optimizer:', optimizer)

  if FLAGS.torch_lr:
    learning_rate.assign(optimizer.rate(1))
  if FLAGS.torch:
    learning_rate.assign(optimizer.param_groups[0]['lr'])
    logging.info('learning rate got from pytorch latest.py as', learning_rate.numpy())

  learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor)
  if learning_rate_weights is not None:
    learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor)

  # TODO currently not support 0.1 epoch.. like this
  num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024

  will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ
  if global_step.numpy() == 0 :
    will_valid = False

  if gezi.get_env('EVFIRST') == '1':
    will_valid = True
  
  if gezi.get_env('EVFIRST') == '0':
    will_valid = False

  if will_valid:
    logging.info('----------valid')
    if hasattr(model, 'eval'):
      model.eval()
    names = None 
    if evaluate_fn is not None:
      vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch)
    elif eval_fn:
      model_path = None if not write_valid else latest_checkpoint
      names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None

      logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir)
      vals, names = evaluate(model, valid_dataset, eval_fn, model_path, 
                             names, valid_write_fn, write_streaming,
                             num_valid_steps_per_epoch, num_valid_examples,
                             suffix=valid_suffix, sep=sep)
    if names:
      logging.info2('epoch:%.2f/%d step:%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs, global_step.numpy()), 
                    ['%s:%.4f' % (name, val) for name, val in zip(names, vals)])
  
    if FLAGS.work_mode == 'valid' or gezi.get_env('METRIC') == '1':
      exit(0)

  if 'test' in FLAGS.work_mode or gezi.get_env('TEST') == '1' or gezi.get_env('INFER') == '1':
    logging.info('--------test/inference')
    if test_dataset:
      if hasattr(model, eval):
        model.eval()
      if inference_fn is None:
        # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint
        # logging.info('model_path', model_path)
        assert latest_checkpoint
        inference(model, test_dataset, latest_checkpoint, 
                  infer_names, infer_debug_names, infer_write_fn, write_streaming,
                  num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix)
      else:
        inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch)
    exit(0)
  
  if 'SHOW' in os.environ:
    num_epochs = start_epoch + 1

  class PytObj(object):
    def __init__(self, x):
      self.x = x
    def numpy(self):
      return self.x

  class PytMean(object):
    def __init__(self):
      self._val = 0. 
      self.count = 0

      self.is_call = True

    def clear(self):
      self._val = 0
      self.count = 0

    def __call__(self, val):
      if not self.is_call:
        self.clear()
        self.is_call = True
      self._val += val.item()
      self.count += 1

    def result(self):
      if self.is_call:
        self.is_call = False
      if not self.count:
        val = 0
      else:
        val = self._val / self.count
      # TODO just for compact with tf ..
      return PytObj(val)
      
  Mean =  tfe.metrics.Mean if not FLAGS.torch else PytMean
  
  num_insts = 0

  if FLAGS.learning_rate_decay_factor > 0:
    #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?'
    #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay
    #since global step / decay_steps will not be correct epoch as num_steps per epoch changed
    #so if if you change batch set you have to reset global step as fixed step
    assert FLAGS.num_steps_per_decay or (FLAGS.num_epochs_per_decay and num_steps_per_epoch), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch'
    decay_steps = FLAGS.num_steps_per_decay or int(num_steps_per_epoch * FLAGS.num_epochs_per_decay)    
    decay_start_step = FLAGS.decay_start_step or int(num_steps_per_epoch * FLAGS.decay_start_epoch)
    # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
    logging.info('learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'.format(
        FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step))

  #-------------------------start training
  if hasattr(model, 'train'):
    model.train()

  if use_horovod:
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer_, root_rank=0)

  timer = gezi.Timer()
  loss_avg = Mean()
  valid_loss_avg = Mean()

  num_epochs = num_epochs if num_epochs else 0
  loops = min(num_epochs, 1) if FLAGS.torch_only else 1
  for _ in range(loops):
    for i, (x, y) in enumerate(train_dataset):
      if FLAGS.torch:
        x, y = to_torch(x, y)
        if is_dynamic_opt:
          learning_rate.assign(optimizer.rate())

      def loss_fn_(x, y):
        if not FLAGS.torch and 'training' in inspect.getargspec(model.call).args:
          y_ = model(x, training=True)
        else:
          y_ = model(x)
        if not FLAGS.torch:
          return loss_fn(y, y_)
        else:
          return loss_fn(y_, y)
      
      if not FLAGS.torch:
        loss, grads = melt.eager.grad(model, x, y, loss_fn)
        grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients)
        #optimizer.apply_gradients(zip(grads, model.variables))
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        # https://github.com/horovod/horovod/blob/master/examples/tensorflow_mnist_eager.py
        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        # Note: broadcast should be done after the first gradient step to ensure optimizer
        # initialization.
        # TODO check eager mode
        if use_horovod and epoch == start_epoch and i == 0:
          hvd.broadcast_variables(model.variables, root_rank=0)
          hvd.broadcast_variables(optimizier.variables(), root_rank=0)
      else:
        optimizer.zero_grad()
        loss = loss_fn_(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                        FLAGS.clip_gradients)
        optimizer.step()

      global_step.assign_add(1)
      loss_avg(loss)
    
      ## https://discuss.pytorch.org/t/calling-loss-backward-reduce-memory-usage/2735
      # if FLAGS.torch:
      #   del loss

      batch_size_ = list(x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type({}) else x.shape[FLAGS.batch_size_dim]
      num_insts += int(batch_size_)
      if global_step.numpy() % FLAGS.interval_steps == 0:
        #checkpoint.save(checkpoint_prefix)
        elapsed = timer.elapsed()
        steps_per_second = FLAGS.interval_steps / elapsed
        instances_per_second = num_insts / elapsed
        num_insts = 0

        if num_steps_per_epoch is None:
          epoch_time_info = ''
        else:
          hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600
          epoch_time_info = '1epoch:[{:.2f}h]'.format(hours_per_epoch)

        if valid_dataset2:
          # try:
          #   x, y = next(iter(valid_dataset2))
          # except Exception:
          #   # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset 
          #   x, y = next(iter(valid_dataset2))
          ## valid dataset2 is repeated
          ## NOTICE will always the first batch ... as below
          #x, y = next(iter(valid_dataset2))
          x, y = next(valid_dataset2_iter)
          #print(x['id'][0])
          if FLAGS.torch:
            x, y = to_torch(x, y)
          if hasattr(model, 'eval'):  
            model.eval()
          valid_loss = loss_fn_(x, y)
          valid_loss = valid_loss.numpy() if not FLAGS.torch else valid_loss.item()
          if hasattr(model, 'train'):
            model.train()

          if not use_horovod or hvd.rank() == 0:
                        # 'train_loss:[%.4f]' % loss_avg.result().numpy(),
                        # 'valid_loss:[%.4f]' % valid_loss_avg.result().numpy()
            logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 
                        'step:%d' % global_step.numpy(), 
                        'elapsed:[%.2f]' % elapsed,
                        'batch_size:[%d]' % batch_size_,
                        'gpus:[%d]' % num_gpus, 
                        'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second,
                        '%s' % epoch_time_info,
                        'lr:[%.6f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % loss_avg.result().numpy(),
                        'valid_loss:[%.4f]' % valid_loss
                        )
            if global_step.numpy() % FLAGS.valid_interval_steps == 0:
              with writer_valid.as_default(), summary.always_record_summaries():
                summary.scalar('loss/valid', valid_loss)
                writer_valid.flush()
        else:
          if not use_horovod or hvd.rank() == 0:
            #'train_loss:[%.4f]' % loss_avg.result().numpy()
            logging.info2('epoch:%.2f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 
                        'step:%d' % global_step.numpy(), 
                        'elapsed:[%.2f]' % elapsed,
                        'batch_size:[%d]' % batch_size_,
                        'gpus:[%d]' % num_gpus, 
                        'batches/s:[%.2f]' % steps_per_second,
                        'insts/s:[%d]' % instances_per_second,
                        '%s' % epoch_time_info,
                        'lr:[%.6f]' % learning_rate.numpy(),
                        'train_loss:[%.4f]' % loss_avg.result().numpy()
                        )      

        if not use_horovod or hvd.rank() == 0:
          if global_step.numpy() % FLAGS.valid_interval_steps == 0:
            with writer_train.as_default(), summary.always_record_summaries():
              summary.scalar('loss/train_avg', loss_avg.result().numpy())
              summary.scalar('learning_rate', learning_rate.numpy())
              summary.scalar('other/batch_size', batch_size_)
              summary.scalar('other/epoch', melt.epoch())
              summary.scalar('perf/steps_per_second', steps_per_second)
              summary.scalar('perf/instances_per_second', instances_per_second)
              writer_train.flush()

      if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy() and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0:
        if hasattr(model, eval):
          model.eval()
        vals, names = None, None
        if evaluate_fn is not None:
          vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch)
        elif eval_fn:
          names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None
          vals, names = evaluate(model, valid_dataset, eval_fn, None, 
                                  names, valid_write_fn, write_streaming,
                                  num_valid_steps_per_epoch, num_valid_examples, sep=sep)
        if not use_horovod or hvd.rank() == 0:
          if vals and names:
            with writer_valid.as_default(), summary.always_record_summaries():
              for name, val in zip(names, vals):
                summary.scalar(f'step_eval/{name}', val)
              writer_valid.flush()
      
        if FLAGS.torch:
          if not FLAGS.torch_lr:
            # control learning rate by tensorflow learning rate
            for param_group in optimizer.param_groups:
              # important learning rate decay
              param_group['lr'] = learning_rate.numpy()
        if hasattr(model, 'train'):  
          model.train()
        if not use_horovod or hvd.rank() == 0:
          if names and vals:
            logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs),  
                          'valid_step:%d' % global_step.numpy(),
                          'valid_metrics',
                          ['%s:%.5f' % (name, val) for name, val in zip(names, vals)])
      
      if not use_horovod or hvd.rank() == 0:
      # TODO save ok ?
        if global_step.numpy() % FLAGS.save_interval_steps == 0:
          if FLAGS.torch:
            state = {
                    'epoch': int(global_step.numpy() / num_steps_per_epoch),
                    'step': global_step.numpy(),
                    'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                  }
            torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt'))     

        # TODO fixme why if both checpoint2 and chekpoint used... not ok..
        if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
          checkpoint2.save(checkpoint_prefix2) 
          if FLAGS.torch:
            state = {
                    'epoch': int(global_step.numpy() / num_steps_per_epoch),
                    'step': global_step.numpy(),
                    'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                  }
            torch.save(state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt')

      if FLAGS.learning_rate_decay_factor > 0:
        if global_step.numpy() >= decay_start_step and global_step.numpy() % decay_steps == 0:
          lr = max(learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate)
          if lr < learning_rate.numpy():
            learning_rate.assign(lr)
            if FLAGS.torch:
              for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate.numpy()

      if i == 0:
        try:
          if not FLAGS.torch:
            logging.info(model.summary())
            # #tf.keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='TB')
            # import keras
            # keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='LR', expand_nested=True, dpi=96)
          else:
            logging.info(model)
        except Exception:
          traceback.print_exc()
          logging.info('Fail to do model.summary() may be you have layer define in init but not used in call')
        if 'SHOW' in os.environ:
          exit(0)
      
      if valid_dataset and  global_step.numpy() % int(num_steps_per_epoch * FLAGS.valid_interval_epochs) == 0:
        if hasattr(model, 'eval'):
          model.eval()

        vals, names = None, None
        if evaluate_fn is not None:
          vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch)
        elif eval_fn:
          model_path = None if not write_valid else tf.train.latest_checkpoint(ckpt_dir)
          print('---------metric evaluate step', global_step.numpy(), 'model_path:', model_path)
          names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None

          vals, names = evaluate(model, valid_dataset, eval_fn, model_path, 
                                names, valid_write_fn, write_streaming,
                                num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep)

        if not use_horovod or hvd.rank() == 0:
          if vals and names:
            logging.info2('epoch:%.2f/%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs), 
                          'step:%d' % global_step.numpy(),
                          'valid_metrics',
                          ['%s:%.5f' % (name, val) for name, val in zip(names, vals)])

        if not use_horovod or hvd.rank() == 0:
          with writer.as_default(), summary.always_record_summaries():
            temp = global_step.value()
            global_step.assign(int(global_step.numpy() / int(num_steps_per_epoch * FLAGS.valid_interval_epochs)))
            if valid_dataset:
              if hasattr(model, 'eval'):
                model.eval()
              if vals and names:
                for name, val in zip(names, vals):
                  summary.scalar(f'eval/{name}', val)
            writer.flush()
            global_step.assign(temp)

      if test_dataset and global_step.numpy() % int(num_steps_per_epoch * FLAGS.inference_interval_epochs) == 0:
        if hasattr(model, 'eval'):
          model.eval()
        if inference_fn is None:
          inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), 
                    infer_names, infer_debug_names, infer_write_fn, write_streaming,
                    num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix, sep=sep)
        else:
          inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch)

      if num_epochs and (global_step.numpy() % num_steps_per_epoch) == 0 and int(global_step.numpy() / num_steps_per_epoch) == num_epochs:
        logging.info(f'Finshed training of {num_epochs} epochs')
        exit(0)
Пример #14
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    eval_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
):

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or eval_interval_steps and step % eval_interval_steps == 0
    summary_str = []

    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        if eval_ops is not None:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # might use EVAL_NO_SUMMARY if some old code has problem TODO CHECK
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY'):
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False

                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                logging.info2('{} eval_step:{} eval_metrics:{}'.format(
                    epoch_str, step,
                    melt.parse_results(eval_loss, eval_names_)))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                melt.set_global('eval_loss',
                                melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != eval_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    #if (is_start or step == 0) and (not 'EVFIRST' in os.environ):
    if ((step == 0) and
        (not 'EVFIRST' in os.environ)) or ('QUICK' in os.environ) or (
            'EVFIRST' in os.environ and os.environ['EVFIRST'] == '0'):
        metric_evaluate = False

    if metric_evaluate:
        # TODO better
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            l = metric_eval_fn()
            if len(l) == 2:
                evaluate_results, evaluate_names = l
                evaluate_summaries = None
            else:
                evaluate_results, evaluate_names, evaluate_summaries = l
        else:
            try:
                l = metric_eval_fn(model_path=model_path)
                if len(l) == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    evaluate_results, evaluate_names, evaluate_summaries = l
            except Exception:
                logging.info('Do nothing for metric eval fn with exception:\n',
                             traceback.format_exc())

        logging.info2('{} valid_step:{} {}:{}'.format(
            epoch_str, step,
            'valid_metrics' if model_path is None else 'epoch_valid_metrics',
            melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once, 'summary_op') or train_once.summary_op is None:
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            #try:
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
                if interval_steps != eval_interval_steps:
                    train_once.avg_loss2 = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)
            if interval_steps != eval_interval_steps:
                train_once.avg_loss2.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        if is_start or interval_steps and step % interval_steps == 0:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.3f} '.format(duration)
                melt.set_global('duration', '%.3f' % duration)
                info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = ' 1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                info.write(
                    'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.8f}]'
                    .format(elapsed, batch_size, gpu_info, steps_per_second,
                            instances_per_second, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))

            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                # if not hasattr(train_once, 'summary_op'):
                #   melt.print_summary_ops()
                #   if summary_excls is None:
                #     train_once.summary_op = tf.summary.merge_all()
                #   else:
                #     summary_ops = []
                #     for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                #       for summary_excl in summary_excls:
                #         if not summary_excl in op.name:
                #           summary_ops.append(op)
                #     print('filtered summary_ops:')
                #     for op in summary_ops:
                #       print(op)
                #     train_once.summary_op = tf.summary.merge(summary_ops)

                #   print('-------------summary_op', train_once.summary_op)

                #   #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                #   train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

                #   tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_once.summary_writer, projector_config)

                summary = tf.Summary()
                # #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset
                # #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss
                # assert train_once.summary_train_op is None
                # if train_once.summary_train_op is not None:
                #   summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict)
                #   train_once.summary_writer.add_summary(summary_str, step)

                if eval_ops is None:
                    # #get train loss, for every batch train
                    # if train_once.summary_op is not None:
                    #   #timer2 = gezi.Timer('sess run')
                    #   try:
                    #     # TODO FIXME so this means one more train batch step without adding to global step counter ?! so should move it earlier
                    #     summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict)
                    #   except Exception:
                    #     if not hasattr(train_once, 'num_summary_errors'):
                    #       logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) fail')
                    #       train_once.num_summary_errors = 1
                    #       logging.warning(traceback.format_exc())
                    #     summary_str = ''
                    #   # #timer2.print()
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    # #get eval loss for every batch eval, then add train loss for eval step average loss
                    # try:
                    #   summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) if train_once.summary_op is not None else ''
                    # except Exception:
                    #   if not hasattr(train_once, 'num_summary_errors'):
                    #     logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail')
                    #     train_once.num_summary_errors = 1
                    #     logging.warning(traceback.format_exc())
                    #   summary_str = ''
                    #all single value results will be add to summary here not using tf.scalar_summary..
                    #summary.ParseFromString(summary_str)
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'eval' if not eval_names else ''
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    melt.add_summarys(summary,
                                      train_average_loss,
                                      names_,
                                      suffix='train_avg')
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary, melt.batch_size(), 'batch_size')
                    melt.add_summary(summary, melt.epoch(), 'epoch')
                    if steps_per_second:
                        melt.add_summary(summary, steps_per_second,
                                         'steps_per_second')
                    if instances_per_second:
                        melt.add_summary(summary, instances_per_second,
                                         'instances_per_second')
                    if hours_per_epoch:
                        melt.add_summary(summary, hours_per_epoch,
                                         'hours_per_epoch')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step/valid'
                    if model_path:
                        prefix = 'epoch/valid'
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step

                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()

                #timer_.print()

            # if print_time:
            #   full_duration = train_once.eval_timer.elapsed()
            #   if metric_evaluate:
            #     metric_full_duration = train_once.metric_eval_timer.elapsed()
            #   full_duration_str = 'elapsed:{:.3f} '.format(full_duration)
            #   #info.write('duration:{:.3f} '.format(timer.elapsed()))
            #   duration = timer.elapsed()
            #   info.write('duration:{:.3f} '.format(duration))
            #   info.write(full_duration_str)
            #   info.write('eval_time_ratio:{:.3f} '.format(duration/full_duration))
            #   if metric_evaluate:
            #     info.write('metric_time_ratio:{:.3f} '.format(duration/metric_full_duration))
            # #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue())
            # logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d'%step, info.getvalue()))
            return stop
        elif metric_evaluate:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step/valid'
            if model_path:
                prefix = 'epoch/valid'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()
Пример #15
0
def evaluate_translation(predictor, random=False, index=None):
  timer = gezi.Timer('evaluate_translation')

  refs = prepare_refs()

  imgs, img_features = get_image_names_and_features()
  num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs))
  if num_metric_eval_examples <= 0:
    num_metric_eval_examples = len(imgs)
  if num_metric_eval_examples == len(imgs):
    random = False

  step = FLAGS.metric_eval_batch_size

  if random:
    if index is None:
      index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False)
    imgs = imgs[index]
    img_features = img_features[index]
  else:
    img_features = img_features[:num_metric_eval_examples]

  results = {}

  start = 0
  while start < num_metric_eval_examples:
    end = start + step
    if end > num_metric_eval_examples:
      end = num_metric_eval_examples
    print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r')
    translation_predicts(imgs[start: end], img_features[start: end], predictor, results)
    start = end
    
  scorers = [
            (Bleu(4), ["bleu_1", "bleu_2", "bleu_3", "bleu_4"]),
            (Meteor(),"meteor"),
            (Rouge(), "rouge_l"),
            (Cider(), "cider")
        ]

  score_list = []
  metric_list = []

  selected_refs = {}
  selected_results = {}
  #by doing this can force same .keys()
  for key in results:
    selected_refs[key] = refs[key]
    selected_results[key] = results[key]
    assert len(selected_results[key]) == 1, selected_results[key]
  assert selected_results.keys() == selected_refs.keys(), '%d %d'%(len(selected_results.keys()), len(selected_refs.keys())) 

  if FLAGS.eval_translation_reseg:
    print('tokenization...', file=sys.stderr)
    global tokenizer
    if tokenizer is None:
      tokenizer = PTBTokenizer()
    selected_refs  = tokenizer.tokenize(selected_refs)
    selected_results = tokenizer.tokenize(selected_results)

  logging.info('predict&label:{}{}{}'.format('|'.join(selected_results.items()[0][1]), '---', '|'.join(selected_refs.items()[0][1])))

  for scorer, method in scorers:
    print('computing %s score...'%(scorer.method()), file=sys.stderr)
    score, scores = scorer.compute_score(selected_refs, selected_results)
    if type(method) == list:
      for sc, scs, m in zip(score, scores, method):
        score_list.append(sc)
        metric_list.append(m)
        if FLAGS.eval_result_dir:
          out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w')
          for i, sc in enumerate(scs):
            key = selected_results.keys()[i]
            result = selected_results[key]
            refs = '\x01'.join(selected_refs[key])
            print(key, result, refs, sc, sep='\t', file=out)
    else:
      score_list.append(score)
      metric_list.append(method)
      if FLAGS.eval_result_dir:
        out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w')
        for i, sc in enumerate(scores):
          key = selected_results.keys()[i]
          result = selected_results[key]
          refs = '\x01'.join(selected_refs[key])
          print(key, result, refs, sc, sep='\t', file=out)
  
  #exclude "bleu_1", "bleu_2", "bleu_3"
  score_list, metric_list = score_list[3:], metric_list[3:]
  assert(len(score_list) == 4)

  avg_score = sum(score_list) / len(score_list)
  score_list.append(avg_score)
  metric_list.append('avg')
  metric_list = ['trans_' + x for x in metric_list]

  melt.logging_results(
    score_list,
    metric_list,
    tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
      melt.epoch(), 
      melt.step(),
      melt.train_loss(),
      melt.eval_loss()))

  timer.print()

  return score_list, metric_list
Пример #16
0
def evaluate_scores(predictor, random=False, index=None, exact_predictor=None, exact_ratio=1.):
  """
  actually this is rank metrics evaluation, by default recall@1,2,5,10,50
  """
  timer = gezi.Timer('evaluate_scores')
  init()
  if FLAGS.eval_img2text:
    imgs, img_features = get_image_names_and_features()
    num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) 
    if num_metric_eval_examples <= 0:
      num_metric_eval_examples = len(imgs)
    if num_metric_eval_examples == len(imgs):
      random = False

    step = FLAGS.metric_eval_batch_size

    if random:
      if index is None:
        index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False)
      imgs = imgs[index]
      img_features = img_features[index]
    else:
      img_features = img_features[:num_metric_eval_examples]

    rank_metrics = gezi.rank_metrics.RecallMetrics()

    start = 0
    while start < num_metric_eval_examples:
      end = start + step
      if end > num_metric_eval_examples:
        end = num_metric_eval_examples
      print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r')
      predicts(imgs[start: end], img_features[start: end], predictor, rank_metrics, 
               exact_predictor=exact_predictor, exact_ratio=exact_ratio)
      start = end
      
    melt.logging_results(
      rank_metrics.get_metrics(), 
      rank_metrics.get_names(), 
      tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
        melt.epoch(), 
        melt.step(),
        melt.train_loss(),
        melt.eval_loss()))

  if FLAGS.eval_text2img:
    num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(all_distinct_texts))

    if random:
      index = np.random.choice(len(all_distinct_texts), num_metric_eval_examples, replace=False)
      text_strs = all_distinct_text_strs[index]
      texts = all_distinct_texts[index]
    else:
      text_strs = all_distinct_text_strs
      texts = all_distinct_texts

    rank_metrics2 = gezi.rank_metrics.RecallMetrics()

    start = 0
    while start < num_metric_eval_examples:
      end = start + step
      if end > num_metric_eval_examples:
        end = num_metric_eval_examples
      print('predicts start:', start, 'end:', end, file=sys.stderr, end='\r')
      predicts_txt2im(text_strs[start: end], texts[start: end], predictor, rank_metrics2, exact_predictor=exact_predictor)
      start = end
    
    melt.logging_results(
      rank_metrics2.get_metrics(), 
      ['t2i' + x for x in rank_metrics2.get_names()],
      tag='text2img')

  timer.print()

  if FLAGS.eval_img2text and FLAGS.eval_text2img:
    return rank_metrics.get_metrics() + rank_metrics2.get_metrics(), rank_metrics.get_names() + ['t2i' + x for x in rank_metrics2.get_names()]
  elif FLAGS.eval_img2text:
    return rank_metrics.get_metrics(), rank_metrics.get_names()
  else:
    return rank_metrics2.get_metrics(), rank_metrics2.get_names()