Exemple #1
0
def print_img(img, i):
    img_url = FLAGS.image_url_prefix + img if not img.startswith(
        "http://") else img
    logging.info(
        img_html.format(img_url, i, img, melt.epoch(), melt.step(),
                        melt.train_loss(), melt.eval_loss(), melt.duration(),
                        gezi.now_time()))
Exemple #2
0
def print_img(img, i):
  img_url = get_img_url(img)
  logging.info(img_html.format(
    img_url, 
    i, 
    img, 
    melt.epoch(), 
    melt.step(), 
    melt.train_loss(), 
    melt.eval_loss(),
    melt.duration(),
    gezi.now_time()))
Exemple #3
0
def evaluate_scores(predictor, random=False):
    timer = gezi.Timer('evaluate_scores')
    init()
    imgs, img_features = get_image_names_and_features()

    num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs))
    step = FLAGS.metric_eval_batch_size

    if random:
        index = np.random.choice(len(imgs),
                                 num_metric_eval_examples,
                                 replace=False)
        imgs = imgs[index]
        img_features = img_features[index]

    text_max_words = all_distinct_texts.shape[1]
    rank_metrics = gezi.rank_metrics.RecallMetrics()

    print('text_max_words:', text_max_words)
    start = 0
    while start < num_metric_eval_examples:
        end = start + step
        if end > num_metric_eval_examples:
            end = num_metric_eval_examples
        print('predicts start:', start, 'end:', end, file=sys.stderr)
        predicts(imgs[start:end], img_features[start:end], predictor,
                 rank_metrics)
        start = end

    melt.logging_results(
        rank_metrics.get_metrics(),
        rank_metrics.get_names(),
        tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
            melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss()))

    timer.print()

    return rank_metrics.get_metrics(), rank_metrics.get_names()
Exemple #4
0
def tf_train_flow(
        train_once_fn,
        model_dir='./model',
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=1,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        init_fn=None,
        restore_fn=None,
        restore_scope=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()
    logging.info('tf_train_flow start')
    print('max_models_keep:', max_models_keep, file=sys.stderr)
    print('save_interval_seconds:', save_interval_seconds, file=sys.stderr)

    #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
    #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

    var_list = None if not restore_scope else tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
    if not variables_to_restore:
        variables_to_restore = var_list
    if not variables_to_save:
        variables_to_save = variables_to_restore
    if save_all_scope:
        variables_to_save = None

    if variables_to_restore is None:
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #print(varnames_in_checkpoint)
        variables_to_restore = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)

    #logging.info('variables_to_restore:{}'.format(variables_to_restore))
    loader = tf.train.Saver(var_list=variables_to_restore)

    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)

    ##TODO for safe restore all init will be ok ?
    #if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # else:
    #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    sess.run(init_op)

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if yu change batch size
    model_path = _get_model_path(model_dir, save_model)
    model_dir = gezi.get_dir(
        model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model
    if model_path is not None:
        if not restore_from_latest:
            print('using recent but not latest model', file=sys.stderr)
            model_path = melt.recent_checkpoint(model_dir)
        model_name = os.path.basename(model_path)
        timer = gezi.Timer('Loading and training from existing model [%s]' %
                           model_path)
        if restore_fn is not None:
            restore_fn(sess)
        loader.restore(sess, model_path)
        timer.print()
        pre_step = melt.get_model_step(model_path)
        pre_epoch = melt.get_model_epoch(model_path)
        fixed_pre_step = pre_step
        if pre_epoch is not None:
            #like using batch size 32, then reload train using batch size 64
            if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
                fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
                logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
                  .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
    else:
        print('Train all start step 0', file=sys.stderr)
        #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
        #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
        #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
        #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
        #init_op = tf.group(tf.global_variables_initializer(),
        #                   tf.local_variables_initializer())
        #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

        #sess.run(init_op)

        #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
        #for finetune from loading other model init
        if init_fn is not None:
            init_fn(sess)

    if save_interval_epochs and num_steps_per_epoch:
        epoch_dir = os.path.join(model_dir, 'epoch')
        gezi.try_mkdir(epoch_dir)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    checkpoint_path = os.path.join(model_dir, 'model.ckpt')

    tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1
        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            print('just load and resave then exit', file=sys.stderr)
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            print('just run one step', file=sys.stderr)

        early_stop = True  #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        while not coord.should_stop():
            stop = train_once_fn(sess,
                                 step,
                                 is_start=(step == start),
                                 fixed_step=fixed_step)
            if only_one_step:
                stop = True
            if save_model and step:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer('save model step %d to %s' %
                                       (step, checkpoint_path))
                    saver.save(sess,
                               _get_checkpoint_path(checkpoint_path,
                                                    fixed_step,
                                                    num_steps_per_epoch),
                               global_step=step)
                    timer.print()
                #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0:
                #if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0:
                if save_interval_epochs and num_steps_per_epoch and fixed_step % num_steps_per_epoch == 0:
                    #epoch = step // num_steps_per_epoch
                    epoch = fixed_step // num_steps_per_epoch
                    eval_loss = melt.eval_loss()
                    if eval_loss:
                        #['eval_loss:3.2','eal_accuracy:4.3']
                        eval_loss = float(
                            eval_loss.strip('[]').split(',')[0].strip(
                                "'").split(':')[-1])
                        if os.path.exists(
                                os.path.join(epoch_dir, 'best_eval_loss.txt')):
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt')) as f:
                                best_epoch_eval_loss = float(
                                    f.readline().split()[-1].strip())
                        if eval_loss < best_epoch_eval_loss:
                            best_epoch_eval_loss = eval_loss
                            logging.info(
                                'Now best eval loss is epoch %d eval_loss:%f' %
                                (epoch, eval_loss))
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt'),
                                    'w') as f:
                                f.write('%d %d %f\n' %
                                        (epoch, step, best_epoch_eval_loss))
                            best_epoch_saver.save(
                                sess, os.path.join(epoch_dir,
                                                   'model.ckpt-best'))

                        with open(os.path.join(epoch_dir, 'eval_loss.txt'),
                                  'a') as f:
                            f.write('%d %d %f\n' % (epoch, step, eval_loss))
                        if eval_loss >= pre_epoch_eval_loss:
                            num_bad_epochs += 1
                            if num_bad_epochs > num_allowed_bad_epochs:
                                logging.warning(
                                    'Evaluate loss not decrease for last %d epochs'
                                    % (num_allowed_bad_epochs + 1))
                                if not os.path.exists(
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove')):
                                    best_epoch_saver.save(
                                        sess,
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove'))
                                ##-------well remove it since
                                #if early_stop:
                                #  stop = True
                        else:
                            num_bad_epochs = 0
                        pre_epoch_eval_loss = eval_loss
                    if step % (num_steps_per_epoch *
                               save_interval_epochs) == 0:
                        epoch_saver.save(sess,
                                         os.path.join(epoch_dir,
                                                      'model.ckpt-%d' % epoch),
                                         global_step=step)
                    #--------do not add step
                    # epoch_saver.save(sess,
                    #        os.path.join(epoch_dir,'model.ckpt-%d'%epoch))
            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            if max_num_epochs and num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
            step += 1
            fixed_step += 1
    except tf.errors.OutOfRangeError, e:
        if not (step
                == start) and save_model and step % save_interval_steps != 0:
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
        if only_one_step:
            print('Done one step', file=sys.stderr)
            exit(0)
        if metric_eval_fn is not None:
            metric_eval_fn()
        if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (
                num_steps and (step + 1) == start + num_steps):
            print('Done training for %.3f epochs, %d steps.' %
                  (step / num_steps_per_epoch, step + 1),
                  file=sys.stderr)
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            print('Should not stop, but stopped at epoch: %.3f' %
                  (step / num_steps_per_epoch),
                  file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            raise e
Exemple #5
0
def tf_train_flow(train_once_fn, 
                  model_dir='./model', 
                  max_models_keep=1, 
                  save_interval_seconds=600, 
                  save_interval_steps=1000, 
                  num_epochs=None,
                  num_steps=None, 
                  save_model=True,
                  save_interval_epochs=1, 
                  num_steps_per_epoch=0,
                  restore_from_latest=True,
                  metric_eval_fn=None,
                  init_fn=None,
                  sess=None):
  """
  similary flow as tf_flow, but add model try reload and save
  """
  if sess is None:
    #TODO melt.get_session is global session but may cause
    sess = melt.get_session()
  logging.info('tf_train_flow start')
  print('max_models_keep:', max_models_keep)
  print('save_interval_seconds:', save_interval_seconds)
  
  saver = tf.train.Saver(
    max_to_keep=max_models_keep, 
    keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0)
  
  epoch_saver = tf.train.Saver()
  best_epoch_saver = tf.train.Saver() 
  
  #pre_step means the step last saved, train without pretrained,then -1
  pre_step = -1;
  model_path = _get_model_path(model_dir, save_model)
  model_dir = gezi.get_dir(model_dir) #incase you pass ./model/model-ckpt1000 -> ./model
  if model_path is not None:
    if not restore_from_latest:
      print('using recent but not latest model', file=sys.stderr)
      model_path = melt.recent_checkpoint(model_dir)
    model_name = os.path.basename(model_path)
    timer = gezi.Timer('Loading and training from existing model [%s]'%model_path)
    saver.restore(sess, model_path)
    timer.print()
    pre_step = melt.get_model_step(model_path)
    if 'epoch' in model_name:
      pre_step *= num_steps_per_epoch
    #for non 0 eopochs  without this will be
    #Attempting to use uninitialized value input/input_producer/limit_epochs/epochs
    sess.run(tf.local_variables_initializer())
  else:
    print('Train all start step 0', file=sys.stderr)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    if init_fn is not None:
      init_fn(sess)
  
  if save_interval_epochs and num_steps_per_epoch:
    epoch_dir = os.path.join(model_dir, 'epoch')
    gezi.try_mkdir(epoch_dir)
  
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  checkpoint_path = os.path.join(model_dir, 'model.ckpt')

  tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
  try:
    step = start = pre_step +  1
    #hack just for save one model after load
    if num_steps and num_steps < step:
      print('just load and resave then exit', file=sys.stderr)
      saver.save(sess, 
                 _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), 
                 global_step=step)
      sess.close()
      exit(0)

    early_stop = True #TODO allow config
    num_bad_epochs = 0
    pre_epoch_eval_loss = 1e20
    best_epoch_eval_loss = 1e20
    num_allowed_bad_epochs = 4 #allow 5 non decrease eval loss epochs  before stop
    while not coord.should_stop():
      stop = train_once_fn(sess, step, is_start=(step==start))
      if save_model and step:
        #step 0 is also saved! actually train one step and save
        if step % save_interval_steps == 0:
          timer = gezi.Timer('save model step %d to %s'%(step, checkpoint_path))
          saver.save(sess, 
                     _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), 
                     global_step=step)
          timer.print()
        #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0:
        if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0:
          epoch = step // num_steps_per_epoch
          eval_loss = melt.eval_loss()
          if eval_loss:
            #['eval_loss:3.2','eal_accuracy:4.3']
            eval_loss = float(eval_loss.strip('[]').split(',')[0].strip("'").split(':')[-1])
            if os.path.exists(os.path.join(epoch_dir, 'best_eval_loss.txt')):
              with open(os.path.join(epoch_dir, 'best_eval_loss.txt')) as f:
                best_epoch_eval_loss = float(f.readline().split()[-1].strip())
            if eval_loss < best_epoch_eval_loss:
              best_epoch_eval_loss = eval_loss
              logging.info('Now best eval loss is epoch %d eval_loss:%f' % (epoch, eval_loss))
              with open(os.path.join(epoch_dir, 'best_eval_loss.txt'), 'w') as f:
                f.write('%d %d %f\n'%(epoch, step, best_epoch_eval_loss))
              best_epoch_saver.save(sess, 
                                    os.path.join(epoch_dir,'model.cpkt-best'))

            with open(os.path.join(epoch_dir, 'eval_loss.txt'), 'a') as f:
               f.write('%d %d %f\n'%(epoch, step, eval_loss))
            if eval_loss >= pre_epoch_eval_loss:
              num_bad_epochs += 1
              if num_bad_epochs > num_allowed_bad_epochs:
                logging.warning('Evaluate loss not decrease for last %d epochs'% (num_allowed_bad_epochs + 1))
                if not os.path.exists(os.path.join(epoch_dir,'model.cpkt-noimprove')):
                  best_epoch_saver.save(sess, os.path.join(epoch_dir,'model.cpkt-noimprove'))
                ##-------well remove it since 
                #if early_stop:
                #  stop = True 
            else:
              num_bad_epochs = 0
            pre_epoch_eval_loss = eval_loss
          if step % (num_steps_per_epoch * save_interval_epochs) == 0:
            epoch_saver.save(sess, 
                            os.path.join(epoch_dir,'model.cpkt-%d'%epoch), 
                            global_step=step)
          #--------do not add step
          # epoch_saver.save(sess, 
          #        os.path.join(epoch_dir,'model.cpkt-%d'%epoch))
      if stop is True:
        print('Early stop running %d stpes'%(step), file=sys.stderr)
        raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step))
      if num_steps and (step + 1) == start + num_steps:
        raise tf.errors.OutOfRangeError(None, None,'Reached max num steps')
      #max_num_epochs = 1000
      max_num_epochs = num_epochs
      if num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs:
        raise tf.errors.OutOfRangeError(None, None,'Reached max num epochs of %d'%max_num_epochs)
      step += 1
  except tf.errors.OutOfRangeError, e:
    if not (step==start) and save_model and step % save_interval_steps != 0:
      saver.save(sess, 
                 _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch), 
                 global_step=step)
    if metric_eval_fn is not None:
      metric_eval_fn()
    if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (num_steps and (step + 1) == start + num_steps) :
      print('Done training for %.3f epochs, %d steps.' % (step / num_steps_per_epoch, step + 1), file=sys.stderr)
      #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
      exit(0)
    else:
      print('Should not stop, but stopped at epoch: %.3f'%(step / num_steps_per_epoch), file=sys.stderr)
      print(traceback.format_exc(), file=sys.stderr)
      raise e
Exemple #6
0
def evaluate_scores(predictor, random=False):
    timer = gezi.Timer('evaluate_scores')
    init()
    if FLAGS.eval_img2text:
        imgs, img_features = get_image_names_and_features()
        num_metric_eval_examples = min(FLAGS.num_metric_eval_examples,
                                       len(imgs))
        step = FLAGS.metric_eval_batch_size

        if random:
            index = np.random.choice(len(imgs),
                                     num_metric_eval_examples,
                                     replace=False)
            imgs = imgs[index]
            img_features = img_features[index]

        rank_metrics = gezi.rank_metrics.RecallMetrics()

        start = 0
        while start < num_metric_eval_examples:
            end = start + step
            if end > num_metric_eval_examples:
                end = num_metric_eval_examples
            print('predicts image start:', start, 'end:', end, file=sys.stderr)
            predicts(imgs[start:end], img_features[start:end], predictor,
                     rank_metrics)
            start = end

        melt.logging_results(
            rank_metrics.get_metrics(),
            rank_metrics.get_names(),
            tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
                melt.epoch(), melt.step(), melt.train_loss(),
                melt.eval_loss()))

    if FLAGS.eval_text2img:
        num_metric_eval_examples = min(FLAGS.num_metric_eval_examples,
                                       len(all_distinct_texts))
        if random:
            index = np.random.choice(len(all_distinct_texts),
                                     num_metric_eval_examples,
                                     replace=False)
            text_strs = all_distinct_text_strs[index]
            texts = all_distinct_texts[index]

        rank_metrics2 = gezi.rank_metrics.RecallMetrics()

        start = 0
        while start < num_metric_eval_examples:
            end = start + step
            if end > num_metric_eval_examples:
                end = num_metric_eval_examples
            print('predicts start:', start, 'end:', end, file=sys.stderr)
            predicts_txt2im(text_strs[start:end], texts[start:end], predictor,
                            rank_metrics2)
            start = end

        melt.logging_results(rank_metrics2.get_metrics(),
                             ['t2i' + x for x in rank_metrics2.get_names()],
                             tag='text2img')

    timer.print()

    if FLAGS.eval_img2text and FLAGS.eval_text2img:
        return rank_metrics.get_metrics() + rank_metrics2.get_metrics(
        ), rank_metrics.get_names() + [
            't2i' + x for x in rank_metrics2.get_names()
        ]
    elif FLAGS.eval_img2text:
        return rank_metrics.get_metrics(), rank_metrics.get_names()
    else:
        return rank_metrics2.get_metrics(), rank_metrics2.get_names()
Exemple #7
0
def evaluate_translation(predictor, random=False, index=None):
  timer = gezi.Timer('evaluate_translation')

  refs = prepare_refs()

  imgs, img_features = get_image_names_and_features()
  num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs))
  if num_metric_eval_examples <= 0:
    num_metric_eval_examples = len(imgs)
  if num_metric_eval_examples == len(imgs):
    random = False

  step = FLAGS.metric_eval_batch_size

  if random:
    if index is None:
      index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False)
    imgs = imgs[index]
    img_features = img_features[index]
  else:
    img_features = img_features[:num_metric_eval_examples]

  results = {}

  start = 0
  while start < num_metric_eval_examples:
    end = start + step
    if end > num_metric_eval_examples:
      end = num_metric_eval_examples
    print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r')
    translation_predicts(imgs[start: end], img_features[start: end], predictor, results)
    start = end
    
  scorers = [
            (Bleu(4), ["bleu_1", "bleu_2", "bleu_3", "bleu_4"]),
            (Meteor(),"meteor"),
            (Rouge(), "rouge_l"),
            (Cider(), "cider")
        ]

  score_list = []
  metric_list = []

  selected_refs = {}
  selected_results = {}
  #by doing this can force same .keys()
  for key in results:
    selected_refs[key] = refs[key]
    selected_results[key] = results[key]
    assert len(selected_results[key]) == 1, selected_results[key]
  assert selected_results.keys() == selected_refs.keys(), '%d %d'%(len(selected_results.keys()), len(selected_refs.keys())) 

  if FLAGS.eval_translation_reseg:
    print('tokenization...', file=sys.stderr)
    global tokenizer
    if tokenizer is None:
      tokenizer = PTBTokenizer()
    selected_refs  = tokenizer.tokenize(selected_refs)
    selected_results = tokenizer.tokenize(selected_results)

  logging.info('predict&label:{}{}{}'.format('|'.join(selected_results.items()[0][1]), '---', '|'.join(selected_refs.items()[0][1])))

  for scorer, method in scorers:
    print('computing %s score...'%(scorer.method()), file=sys.stderr)
    score, scores = scorer.compute_score(selected_refs, selected_results)
    if type(method) == list:
      for sc, scs, m in zip(score, scores, method):
        score_list.append(sc)
        metric_list.append(m)
        if FLAGS.eval_result_dir:
          out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w')
          for i, sc in enumerate(scs):
            key = selected_results.keys()[i]
            result = selected_results[key]
            refs = '\x01'.join(selected_refs[key])
            print(key, result, refs, sc, sep='\t', file=out)
    else:
      score_list.append(score)
      metric_list.append(method)
      if FLAGS.eval_result_dir:
        out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w')
        for i, sc in enumerate(scores):
          key = selected_results.keys()[i]
          result = selected_results[key]
          refs = '\x01'.join(selected_refs[key])
          print(key, result, refs, sc, sep='\t', file=out)
  
  #exclude "bleu_1", "bleu_2", "bleu_3"
  score_list, metric_list = score_list[3:], metric_list[3:]
  assert(len(score_list) == 4)

  avg_score = sum(score_list) / len(score_list)
  score_list.append(avg_score)
  metric_list.append('avg')
  metric_list = ['trans_' + x for x in metric_list]

  melt.logging_results(
    score_list,
    metric_list,
    tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
      melt.epoch(), 
      melt.step(),
      melt.train_loss(),
      melt.eval_loss()))

  timer.print()

  return score_list, metric_list
Exemple #8
0
def evaluate_scores(predictor, random=False, index=None, exact_predictor=None, exact_ratio=1.):
  """
  actually this is rank metrics evaluation, by default recall@1,2,5,10,50
  """
  timer = gezi.Timer('evaluate_scores')
  init()
  if FLAGS.eval_img2text:
    imgs, img_features = get_image_names_and_features()
    num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) 
    if num_metric_eval_examples <= 0:
      num_metric_eval_examples = len(imgs)
    if num_metric_eval_examples == len(imgs):
      random = False

    step = FLAGS.metric_eval_batch_size

    if random:
      if index is None:
        index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False)
      imgs = imgs[index]
      img_features = img_features[index]
    else:
      img_features = img_features[:num_metric_eval_examples]

    rank_metrics = gezi.rank_metrics.RecallMetrics()

    start = 0
    while start < num_metric_eval_examples:
      end = start + step
      if end > num_metric_eval_examples:
        end = num_metric_eval_examples
      print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r')
      predicts(imgs[start: end], img_features[start: end], predictor, rank_metrics, 
               exact_predictor=exact_predictor, exact_ratio=exact_ratio)
      start = end
      
    melt.logging_results(
      rank_metrics.get_metrics(), 
      rank_metrics.get_names(), 
      tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format(
        melt.epoch(), 
        melt.step(),
        melt.train_loss(),
        melt.eval_loss()))

  if FLAGS.eval_text2img:
    num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(all_distinct_texts))

    if random:
      index = np.random.choice(len(all_distinct_texts), num_metric_eval_examples, replace=False)
      text_strs = all_distinct_text_strs[index]
      texts = all_distinct_texts[index]
    else:
      text_strs = all_distinct_text_strs
      texts = all_distinct_texts

    rank_metrics2 = gezi.rank_metrics.RecallMetrics()

    start = 0
    while start < num_metric_eval_examples:
      end = start + step
      if end > num_metric_eval_examples:
        end = num_metric_eval_examples
      print('predicts start:', start, 'end:', end, file=sys.stderr, end='\r')
      predicts_txt2im(text_strs[start: end], texts[start: end], predictor, rank_metrics2, exact_predictor=exact_predictor)
      start = end
    
    melt.logging_results(
      rank_metrics2.get_metrics(), 
      ['t2i' + x for x in rank_metrics2.get_names()],
      tag='text2img')

  timer.print()

  if FLAGS.eval_img2text and FLAGS.eval_text2img:
    return rank_metrics.get_metrics() + rank_metrics2.get_metrics(), rank_metrics.get_names() + ['t2i' + x for x in rank_metrics2.get_names()]
  elif FLAGS.eval_img2text:
    return rank_metrics.get_metrics(), rank_metrics.get_names()
  else:
    return rank_metrics2.get_metrics(), rank_metrics2.get_names()