def start_session(self): """ Launch the tensorflow session and start the GuessWhat loop """ with tf.Session(config=self.tf_config) as sess: guesser_network = GuesserNetwork(self.guesser_config['model'], num_words=self.tokenizer.no_words) guesser_var = [ v for v in tf.global_variables() if 'guesser' in v.name ] guesser_saver = tf.train.Saver(var_list=guesser_var) guesser_saver.restore(sess, GUESS_NTW_PATH) guesser_wrapper = GuesserROSWrapper(guesser_network) qgen_network = QGenNetworkLSTM(self.qgen_config['model'], num_words=self.tokenizer.no_words, policy_gradient=False) qgen_var = [v for v in tf.global_variables() if 'qgen' in v.name] qgen_saver = tf.train.Saver(var_list=qgen_var) qgen_saver.restore(sess, QGEN_NTW_PATH) qgen_network.build_sampling_graph( self.qgen_config['model'], tokenizer=self.tokenizer, max_length=self.eval_config['loop']['max_depth']) qgen_wrapper = QGenWrapper( qgen_network, self.tokenizer, max_length=self.eval_config['loop']['max_depth'], k_best=self.eval_config['loop']['beam_k_best']) oracle_wrapper = OracleROSWrapper(self.tokenizer) self.loop(sess, guesser_wrapper, qgen_wrapper, oracle_wrapper)
logger.info('Loading data..') trainset = Dataset(args.data_dir, "train", image_loader, crop_loader) validset = Dataset(args.data_dir, "valid", image_loader, crop_loader) testset = Dataset(args.data_dir, "test", image_loader, crop_loader) # Load dictionary logger.info('Loading dictionary..') tokenizer = GWTokenizer(os.path.join(args.data_dir, args.dict_file)) ############################### # LOAD NETWORKS ############################# logger.info('Building networks..') qgen_network = QGenNetworkLSTM(qgen_config["model"], num_words=tokenizer.no_words, policy_gradient=True) qgen_var = [v for v in tf.global_variables() if "qgen" in v.name and 'rl_baseline' not in v.name] qgen_saver = tf.train.Saver(var_list=qgen_var) oracle_network = OracleNetwork(oracle_config, num_words=tokenizer.no_words) oracle_var = [v for v in tf.global_variables() if "oracle" in v.name] oracle_saver = tf.train.Saver(var_list=oracle_var) guesser_network = GuesserNetwork(guesser_config["model"], num_words=tokenizer.no_words) guesser_var = [v for v in tf.global_variables() if "guesser" in v.name] guesser_saver = tf.train.Saver(var_list=guesser_var) loop_saver = tf.train.Saver(allow_empty=False) ############################### # REINFORCE OPTIMIZER
crop_loader = None # get_img_loader(config, 'crop', args.image_dir) # Load data logger.info('Loading data..') trainset = Dataset(args.data_dir, "train", image_loader, crop_loader) validset = Dataset(args.data_dir, "valid", image_loader, crop_loader) testset = Dataset(args.data_dir, "test", image_loader, crop_loader) # Load dictionary logger.info('Loading dictionary..') tokenizer = GWTokenizer(os.path.join(args.data_dir, args.dict_file)) # Build Network logger.info('Building network..') network = QGenNetworkLSTM(config["model"], num_words=tokenizer.no_words, policy_gradient=False) # Build Optimizer logger.info('Building optimizer..') optimizer, outputs = create_optimizer(network, config) ############################### # START TRAINING ############################# # Load config batch_size = config['optimizer']['batch_size'] no_epoch = config["optimizer"]["no_epoch"] # create a saver to store/load checkpoint
os.path.join(args.networks_dir, 'guesser', args.guesser_identifier, 'params.ckpt')) guesser_wrapper = GuesserWrapper(guesser_network) else: guesser_wrapper = GuesserUserWrapper(tokenizer, img_raw_dir=args.img_raw_dir) logger.info("No Guesser was registered >>> use user input") if args.qgen_identifier is not None: qgen_config = get_config_from_xp( os.path.join(args.networks_dir, "qgen"), args.qgen_identifier) qgen_network = QGenNetworkLSTM(qgen_config["model"], num_words=tokenizer.no_words, policy_gradient=False) qgen_var = [v for v in tf.global_variables() if "qgen" in v.name] # and 'rl_baseline' not in v.name qgen_saver = tf.train.Saver(var_list=qgen_var) qgen_saver.restore( sess, os.path.join(args.networks_dir, 'qgen', args.qgen_identifier, 'params.ckpt')) qgen_network.build_sampling_graph( qgen_config["model"], tokenizer=tokenizer, max_length=eval_config['loop']['max_depth']) qgen_wrapper = QGenWrapper(