예제 #1
0
 def num_examples_per_epoch(subset='train', dir=None):
     default_value = None
     if subset == 'train':
         file = (dir or gezi.dirname(
             FLAGS.train_input.split(',')[0])) + '/num_records.txt'
         return gezi.read_int_from(file, default_value)
     elif subset == 'valid':
         file = (dir
                 or gezi.dirname(FLAGS.valid_input)) + '/num_records.txt'
         return gezi.read_int_from(file, default_value)
     elif subset == 'test':
         file = (dir or gezi.dirname(FLAGS.test_input)) + '/num_records.txt'
         return gezi.read_int_from(file, default_value)
     else:
         raise ValueError('Invalid data subset "%s"' % subset)
예제 #2
0
    def gen_train_input(self, inputs, decode_fn):
        #--------------------- train
        logging.info('train_input: %s' % FLAGS.train_input)
        trainset = list_files(FLAGS.train_input)
        logging.info('trainset:{} {}'.format(len(trainset), trainset[:2]))

        assert len(trainset) >= FLAGS.min_records, '%d %d' % (
            len(trainset), FLAGS.min_records)
        if FLAGS.num_records > 0:
            assert len(trainset) == FLAGS.num_records, len(trainset)

        num_records = gezi.read_int_from(FLAGS.num_records_file)
        logging.info('num_records:{}'.format(num_records))
        logging.info('batch_size:{}'.format(FLAGS.batch_size))
        logging.info('FLAGS.num_gpus:{}'.format(FLAGS.num_gpus))
        num_gpus = max(FLAGS.num_gpus, 1)
        num_steps_per_epoch = num_records // (FLAGS.batch_size * num_gpus)
        logging.info('num_steps_per_epoch:{}'.format(num_steps_per_epoch))
        self.num_records = num_records
        self.num_steps_per_epoch = num_steps_per_epoch

        image_name, image_feature, text, text_str = inputs(
            trainset,
            decode_fn=decode_fn,
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            #seed=seed,
            num_threads=FLAGS.num_threads,
            batch_join=FLAGS.batch_join,
            shuffle_files=FLAGS.shuffle_files,
            fix_sequence=FLAGS.fix_sequence,
            num_prefetch_batches=FLAGS.num_prefetch_batches,
            min_after_dequeue=FLAGS.min_after_dequeue,
            name=self.input_train_name)

        if FLAGS.feed_dict:
            self.text_place = text_placeholder('text_place')
            self.text_str = text_str
            text = self.text_place

            self.image_feature_place = image_feature_placeholder(
                'image_feature_place')
            self.image_feature = image_feature
            image_feature = self.image_feature_place

        if FLAGS.monitor_level > 1:
            lengths = melt.length(text)
            melt.scalar_summary("text/batch_min", tf.reduce_min(lengths))
            melt.scalar_summary("text/batch_max", tf.reduce_max(lengths))
            melt.scalar_summary("text/batch_mean", tf.reduce_mean(lengths))
        return (image_name, image_feature, text, text_str), trainset
예제 #3
0
    def gen_train_input(self, inputs, decode_fn):
        #--------------------- train
        logging.info('train_input: %s' % FLAGS.train_input)
        trainset = list_files(FLAGS.train_input)
        logging.info('trainset:{} {}'.format(len(trainset), trainset[:2]))

        assert len(trainset) >= FLAGS.min_records, len(trainset)
        if FLAGS.num_records > 0:
            assert len(trainset) == FLAGS.num_records, len(trainset)

        num_records = gezi.read_int_from(FLAGS.num_records_file)
        logging.info('num_records:{}'.format(num_records))
        logging.info('batch_size:{}'.format(FLAGS.batch_size))
        logging.info('FLAGS.num_gpus:{}'.format(FLAGS.num_gpus))
        num_gpus = max(FLAGS.num_gpus, 1)
        num_steps_per_epoch = num_records // (FLAGS.batch_size * num_gpus)
        logging.info('num_steps_per_epoch:{}'.format(num_steps_per_epoch))
        self.num_records = num_records
        self.num_steps_per_epoch = num_steps_per_epoch

        image_name, text, text_str, input_text, input_text_str = inputs(
            trainset,
            decode_fn=decode_fn,
            batch_size=FLAGS.batch_size,
            num_epochs=FLAGS.num_epochs,
            #seed=seed,
            num_threads=FLAGS.num_threads,
            batch_join=FLAGS.batch_join,
            shuffle_files=FLAGS.shuffle_files,
            fix_sequence=FLAGS.fix_sequence,
            num_prefetch_batches=FLAGS.num_prefetch_batches,
            min_after_dequeue=FLAGS.min_after_dequeue,
            name=self.input_train_name)

        if FLAGS.monitor_level > 1:
            monitor_text_length(text)
            monitor_text_length(input_text)

        return (image_name, text, text_str, input_text,
                input_text_str), trainset