def train(): model_config = ModelConfig() training_config = TrainConfig() # Get model model_fn = get_model_creation_fn(FLAGS.model_type) reader_fn = create_reader('VAQ-Var', phase='train') env = MixReward() env.diversity_reward.mode = 'winner_take_all' env.set_cider_state(False) env.set_language_thresh(0.2) # Create training directory. train_dir = FLAGS.train_dir % (FLAGS.version, FLAGS.model_type) if not tf.gfile.IsDirectory(train_dir): tf.logging.info("Creating training directory: %s", train_dir) tf.gfile.MakeDirs(train_dir) g = tf.Graph() with g.as_default(): # Build the model. model = model_fn(model_config, 'train') model.build() # Set up the learning rate.u learning_rate = tf.constant(training_config.initial_learning_rate * 0.1) def _learning_rate_decay_fn(learn_rate, global_step): return tf.train.exponential_decay( learn_rate, global_step, decay_steps=training_config.decay_step, decay_rate=training_config.decay_factor, staircase=False) learning_rate_decay_fn = _learning_rate_decay_fn train_op = tf.contrib.layers.optimize_loss( loss=model.loss, global_step=model.global_step, learning_rate=learning_rate, optimizer=training_config.optimizer, clip_gradients=training_config.clip_gradients, learning_rate_decay_fn=learning_rate_decay_fn) # Set up the Saver for saving and restoring model checkpoints. saver = tf.train.Saver( max_to_keep=training_config.max_checkpoints_to_keep) # Setup summaries summary_op = tf.summary.merge_all() # Setup language model lm = LanguageModel() lm.build() env.set_language_model(lm) # create reader reader = reader_fn( batch_size=16, subset='kprestval', # 'kptrain' version=FLAGS.version) # Run training. training_util.train(train_op, train_dir, log_every_n_steps=FLAGS.log_every_n_steps, graph=g, global_step=model.global_step, number_of_steps=FLAGS.number_of_steps, init_fn=model.init_fn, saver=saver, reader=reader, model=model, summary_op=summary_op, env=env)
def train(): model_config = ModelConfig() training_config = TrainConfig() # Get model model_fn = get_model_creation_fn(FLAGS.model_type) reader_fn = create_reader('VAQ-VIS', phase='test') env = MixReward(attention_vqa=True) env.diversity_reward.mode = 'winner_take_all' # env.set_language_thresh(0.1) env.set_language_thresh(0.2) env.set_cider_state(use_cider=False) env.set_replay_buffer( insert_thresh=0.1, sv_dir='vqa_replay_buffer/tmp') # if 0.5, already fooled others # Create training directory. train_dir = FLAGS.train_dir % (FLAGS.version, FLAGS.model_type) if not tf.gfile.IsDirectory(train_dir): tf.logging.info("Creating training directory: %s", train_dir) tf.gfile.MakeDirs(train_dir) ckpt_suffix = train_dir.split('/')[-1] g = tf.Graph() with g.as_default(): # Build the model. model = model_fn(model_config, 'train') # model.set_init_ckpt('model/v1_var_att_noimage_cache_restval_VAQ-VarRL/model.ckpt-230000') # model.set_init_ckpt('model/v1_var_att_lowthresh_cache_restval_VAQ-VarRL/model.ckpt-1072000') model.build() # Set up the learning rate.u learning_rate = tf.constant(training_config.initial_learning_rate * 0.0) def _learning_rate_decay_fn(learn_rate, global_step): return tf.train.exponential_decay( learn_rate, global_step, decay_steps=training_config.decay_step, decay_rate=training_config.decay_factor, staircase=False) learning_rate_decay_fn = _learning_rate_decay_fn train_op = tf.contrib.layers.optimize_loss( loss=model.loss, global_step=model.global_step, learning_rate=learning_rate, optimizer=training_config.optimizer, clip_gradients=training_config.clip_gradients, learning_rate_decay_fn=learning_rate_decay_fn) # Set up the Saver for saving and restoring model checkpoints. saver = tf.train.Saver( max_to_keep=training_config.max_checkpoints_to_keep) # Setup summaries summary_op = tf.summary.merge_all() # Setup language model lm = LanguageModel() lm.build() lm.set_cache_dir(ckpt_suffix) env.set_language_model(lm) # create reader reader = reader_fn( batch_size=1, subset='kpval', # 'kptrain' version=FLAGS.version) # Run training. training_util.train(train_op, train_dir, log_every_n_steps=FLAGS.log_every_n_steps, graph=g, global_step=model.global_step, number_of_steps=FLAGS.number_of_steps, init_fn=model.init_fn, saver=saver, reader=reader, model=model, summary_op=summary_op, env=env)