Пример #1
0
  def gen_valid_input(self, inputs, decode):
    #---------------------- valid  
    validset = list_files(FLAGS.valid_input)
    logging.info('validset:{} {}'.format(len(validset), validset[:2]))
    eval_image_name, eval_text, eval_text_str, eval_input_text, eval_input_text_str = inputs(
      validset, 
      decode=decode,
      batch_size=FLAGS.eval_batch_size,
      num_threads=FLAGS.num_threads,
      batch_join=FLAGS.batch_join,
      shuffle=FLAGS.eval_shuffle,
      seed=FLAGS.eval_seed,
      fix_random=FLAGS.eval_fix_random,
      num_prefetch_batches=FLAGS.num_prefetch_batches,
      min_after_dequeue=FLAGS.min_after_dequeue,
      fix_sequence=FLAGS.fix_sequence,
      name=self.input_valid_name)

    eval_batch_size = FLAGS.eval_batch_size
   
    eval_result = eval_image_name, eval_text, eval_text_str, eval_input_text, eval_input_text_str
    eval_show_result = None
    
    if FLAGS.show_eval:
      eval_fixed = bool(FLAGS.fixed_valid_input)
      self.eval_fixed = eval_fixed
      if eval_fixed:
        assert FLAGS.fixed_eval_batch_size >= FLAGS.num_fixed_evaluate_examples, '%d %d'%(FLAGS.fixed_eval_batch_size, FLAGS.num_fixed_evaluate_examples)
        logging.info('fixed_eval_batch_size:{}'.format(FLAGS.fixed_eval_batch_size))
        logging.info('num_fixed_evaluate_examples:{}'.format(FLAGS.num_fixed_evaluate_examples))
        logging.info('num_evaluate_examples:{}'.format(FLAGS.num_evaluate_examples))
        #------------------- fixed valid
        fixed_validset = list_files(FLAGS.fixed_valid_input)
        logging.info('fixed_validset:{} {}'.format(len(fixed_validset), fixed_validset[:2]))
        fixed_image_name, fixed_text, fixed_text_str, fixed_input_text, fixed_input_text_str = inputs(
          fixed_validset, 
          decode=decode,
          batch_size=FLAGS.fixed_eval_batch_size,
          fix_sequence=True,
          num_prefetch_batches=FLAGS.num_prefetch_batches,
          min_after_dequeue=FLAGS.min_after_dequeue,
          name=self.fixed_input_valid_name)

        #-------------shrink fixed image input as input batch might large then what we want to show, only choose top num_fixed_evaluate_examples
        fixed_image_name = melt.first_nrows(fixed_image_name, FLAGS.num_fixed_evaluate_examples)
        fixed_text = melt.first_nrows(fixed_text, FLAGS.num_fixed_evaluate_examples)
        fixed_text_str = melt.first_nrows(fixed_text_str, FLAGS.num_fixed_evaluate_examples)
        fixed_input_text = melt.first_nrows(fixed_input_text, FLAGS.num_fixed_evaluate_examples)
        fixed_input_text_str = melt.first_nrows(fixed_input_text_str, FLAGS.num_fixed_evaluate_examples)

        #notice read data always be FLAGS.fixed_eval_batch_size, if only 5 tests then will wrapp the data 
        eval_image_name = tf.concat([fixed_image_name, eval_image_name], axis=0)

        #melt.align only need if you use dynamic batch/padding
        if FLAGS.dynamic_batch_length:
          fixed_text, eval_text = melt.align_col_padding2d(fixed_text, eval_text)
          fixed_input_text, eval_input_text = melt.align_col_padding2d(fixed_input_text, eval_input_text)
        eval_text = tf.concat([fixed_text, eval_text], axis=0)
        eval_text_str = tf.concat([fixed_text_str, eval_text_str], axis=0)
        eval_input_text = tf.concat([fixed_input_text, eval_input_text], axis=0)
        eval_input_text_str = tf.concat([fixed_input_text_str, eval_input_text_str], axis=0)
        eval_batch_size = FLAGS.num_fixed_evaluate_examples + FLAGS.eval_batch_size 
      
      #should aways be FLAGS.num_fixed_evaluate_examples + FLAGS.num_evaluate_examples
      num_evaluate_examples = min(eval_batch_size, FLAGS.num_fixed_evaluate_examples + FLAGS.num_evaluate_examples)
      print('----num_evaluate_examples', num_evaluate_examples)
      print('----eval_batch_size', eval_batch_size)

      #------------combine valid and fixed valid 
      evaluate_image_name = melt.first_nrows(eval_image_name, num_evaluate_examples)
      evaluate_text_str = melt.first_nrows(eval_text_str, num_evaluate_examples)
      evaluate_text = melt.first_nrows(eval_text, num_evaluate_examples)
      evaluate_input_text_str = melt.first_nrows(eval_input_text_str, num_evaluate_examples)
      evaluate_input_text = melt.first_nrows(eval_input_text, num_evaluate_examples)

      self.num_evaluate_examples = num_evaluate_examples
      
      eval_result = eval_image_name, eval_text, eval_text_str, eval_input_text, eval_input_text_str
      eval_show_result = evaluate_image_name, evaluate_text, evaluate_text_str, evaluate_input_text, evaluate_input_text_str
    
    return eval_result, eval_show_result, eval_batch_size
Пример #2
0
    def gen_valid_input(self, inputs, decode_fn):
        #---------------------- valid
        validset = list_files(FLAGS.valid_input)
        logging.info('validset:{} {}'.format(len(validset), validset[:2]))
        eval_image_name, eval_image_feature, eval_text, eval_text_str, eval_input_text, eval_input_text_str = inputs(
            validset,
            decode_fn=decode_fn,
            batch_size=FLAGS.eval_batch_size,
            num_threads=FLAGS.num_threads,
            batch_join=FLAGS.batch_join,
            shuffle_files=FLAGS.eval_shuffle_files,
            seed=FLAGS.eval_seed,
            fix_random=FLAGS.eval_fix_random,
            num_prefetch_batches=FLAGS.num_prefetch_batches,
            min_after_dequeue=FLAGS.min_after_dequeue,
            fix_sequence=FLAGS.fix_sequence,
            name=self.input_valid_name)

        eval_batch_size = FLAGS.eval_batch_size

        eval_result = eval_image_name, eval_text, eval_text_str, eval_input_text, eval_input_text_str
        eval_show_result = None

        if FLAGS.show_eval:
            eval_fixed = bool(FLAGS.fixed_valid_input)
            self.eval_fixed = eval_fixed
            if eval_fixed:
                assert FLAGS.fixed_eval_batch_size >= FLAGS.num_fixed_evaluate_examples, '%d %d' % (
                    FLAGS.fixed_eval_batch_size,
                    FLAGS.num_fixed_evaluate_examples)
                logging.info('fixed_eval_batch_size:{}'.format(
                    FLAGS.fixed_eval_batch_size))
                logging.info('num_fixed_evaluate_examples:{}'.format(
                    FLAGS.num_fixed_evaluate_examples))
                logging.info('num_evaluate_examples:{}'.format(
                    FLAGS.num_evaluate_examples))
                #------------------- fixed valid
                fixed_validset = list_files(FLAGS.fixed_valid_input)
                logging.info('fixed_validset:{} {}'.format(
                    len(fixed_validset), fixed_validset[:2]))
                fixed_image_name, fixed_image_feature, fixed_text, fixed_text_str, fixed_input_text, fixed_input_text_str = inputs(
                    fixed_validset,
                    decode_fn=decode_fn,
                    batch_size=FLAGS.fixed_eval_batch_size,
                    fix_sequence=True,
                    num_prefetch_batches=FLAGS.num_prefetch_batches,
                    min_after_dequeue=FLAGS.min_after_dequeue,
                    name=self.fixed_input_valid_name)

                #-------------shrink fixed image input as input batch might large then what we want to show, only choose top num_fixed_evaluate_examples
                fixed_image_name = melt.first_nrows(
                    fixed_image_name, FLAGS.num_fixed_evaluate_examples)
                fixed_image_feature = melt.first_nrows(
                    fixed_image_feature, FLAGS.num_fixed_evaluate_examples)
                fixed_text = melt.first_nrows(
                    fixed_text, FLAGS.num_fixed_evaluate_examples)
                fixed_text = melt.make_batch_compat(fixed_text)
                fixed_text_str = melt.first_nrows(
                    fixed_text_str, FLAGS.num_fixed_evaluate_examples)
                fixed_input_text = melt.first_nrows(
                    fixed_input_text, FLAGS.num_fixed_evaluate_examples)
                fixed_input_text = melt.make_batch_compat(fixed_input_text)
                fixed_input_text_str = melt.first_nrows(
                    fixed_input_text_str, FLAGS.num_fixed_evaluate_examples)

                #notice read data always be FLAGS.fixed_eval_batch_size, if only 5 tests then will wrapp the data
                eval_image_name = tf.concat(
                    [fixed_image_name, eval_image_name], axis=0)
                eval_image_feature = tf.concat(
                    [fixed_image_feature, eval_image_feature], axis=0)

                #melt.align only need if you use dynamic batch/padding
                #work well before but for imtxt2txt seems has bug here, strange  TODO FIXME in imtxt2txt sh ./train/fixme.sh will reproduce bug
                #Maybe write sequence exmaple will be fine ?
                if FLAGS.dynamic_batch_length:
                    fixed_text, eval_text = melt.align_col_padding2d(
                        fixed_text, eval_text)
                    fixed_input_text, eval_input_text = melt.align_col_padding2d(
                        fixed_input_text, eval_input_text)

                eval_text = tf.concat([fixed_text, eval_text], axis=0)
                eval_text_str = tf.concat([fixed_text_str, eval_text_str],
                                          axis=0)
                eval_input_text = tf.concat(
                    [fixed_input_text, eval_input_text], axis=0)
                eval_input_text_str = tf.concat(
                    [fixed_input_text_str, eval_input_text_str], axis=0)
                eval_batch_size = FLAGS.num_fixed_evaluate_examples + FLAGS.eval_batch_size

                tf.add_to_collection('fixed_input_text', fixed_input_text)
                tf.add_to_collection('fixed_text', fixed_text)
                tf.add_to_collection('eval_text', eval_text)

            #should aways be FLAGS.num_fixed_evaluate_examples + FLAGS.num_evaluate_examples
            num_evaluate_examples = min(
                eval_batch_size, FLAGS.num_fixed_evaluate_examples +
                FLAGS.num_evaluate_examples)
            print('----num_evaluate_examples', num_evaluate_examples)
            print('----eval_batch_size', eval_batch_size)

            #------------combine valid and fixed valid
            evaluate_image_name = melt.first_nrows(eval_image_name,
                                                   num_evaluate_examples)
            evaluate_image_feature = melt.first_nrows(eval_image_feature,
                                                      num_evaluate_examples)
            evaluate_text_str = melt.first_nrows(eval_text_str,
                                                 num_evaluate_examples)
            evaluate_text = melt.first_nrows(eval_text, num_evaluate_examples)
            evaluate_input_text_str = melt.first_nrows(eval_input_text_str,
                                                       num_evaluate_examples)
            evaluate_input_text = melt.first_nrows(eval_input_text,
                                                   num_evaluate_examples)

            self.num_evaluate_examples = num_evaluate_examples

            eval_result = eval_image_name, eval_image_feature, eval_text, eval_text_str, eval_input_text, eval_input_text_str
            eval_show_result = evaluate_image_name, evaluate_image_feature, evaluate_text, evaluate_text_str, evaluate_input_text, evaluate_input_text_str

        return eval_result, eval_show_result, eval_batch_size