def gen_train_input(self, inputs, decode_fn): #--------------------- train logging.info('train_input: %s' % FLAGS.train_input) trainset = list_files(FLAGS.train_input) logging.info('trainset:{} {}'.format(len(trainset), trainset[:2])) assert len(trainset) >= FLAGS.min_records, '%d %d' % ( len(trainset), FLAGS.min_records) if FLAGS.num_records > 0: assert len(trainset) == FLAGS.num_records, len(trainset) num_records = gezi.read_int_from(FLAGS.num_records_file) logging.info('num_records:{}'.format(num_records)) logging.info('batch_size:{}'.format(FLAGS.batch_size)) logging.info('FLAGS.num_gpus:{}'.format(FLAGS.num_gpus)) num_gpus = max(FLAGS.num_gpus, 1) num_steps_per_epoch = num_records // (FLAGS.batch_size * num_gpus) logging.info('num_steps_per_epoch:{}'.format(num_steps_per_epoch)) self.num_records = num_records self.num_steps_per_epoch = num_steps_per_epoch image_name, image_feature, text, text_str = inputs( trainset, decode_fn=decode_fn, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs, #seed=seed, num_threads=FLAGS.num_threads, batch_join=FLAGS.batch_join, shuffle_files=FLAGS.shuffle_files, fix_sequence=FLAGS.fix_sequence, num_prefetch_batches=FLAGS.num_prefetch_batches, min_after_dequeue=FLAGS.min_after_dequeue, name=self.input_train_name) if FLAGS.feed_dict: self.text_place = text_placeholder('text_place') self.text_str = text_str text = self.text_place self.image_feature_place = image_feature_placeholder( 'image_feature_place') self.image_feature = image_feature image_feature = self.image_feature_place if FLAGS.monitor_level > 1: lengths = melt.length(text) melt.scalar_summary("text/batch_min", tf.reduce_min(lengths)) melt.scalar_summary("text/batch_max", tf.reduce_max(lengths)) melt.scalar_summary("text/batch_mean", tf.reduce_mean(lengths)) return (image_name, image_feature, text, text_str), trainset
def monitor_text_length(text): lengths = melt.length(text) melt.scalar_summary("text/batch_min", tf.reduce_min(lengths)) melt.scalar_summary("text/batch_max", tf.reduce_max(lengths)) melt.scalar_summary("text/batch_mean", tf.reduce_mean(lengths))