crop_builder = get_img_builder(config['model']['crop'], args.crop_dir, is_crop=True) use_resnet = crop_builder.is_raw_image() # Load data logger.info('Loading data..') trainset = OracleDataset.load(args.data_dir, "train", image_builder, crop_builder) validset = OracleDataset.load(args.data_dir, "valid", image_builder, crop_builder) testset = OracleDataset.load(args.data_dir, "test", image_builder, crop_builder) # Load dictionary logger.info('Loading dictionary..') tokenizer = GWTokenizer(os.path.join(args.data_dir, args.dict_file)) # Build Network logger.info('Building network..') network = OracleNetwork(config, num_words=tokenizer.no_words) # Build Optimizer logger.info('Building optimizer..') optimizer, outputs = create_optimizer(network, config, finetune=finetune) ############################### # START TRAINING ############################# # create a saver to store/load checkpoint saver = tf.train.Saver() resnet_saver = None # Retrieve only resnet variabes if use_resnet:
# Load dictionary logger.info('Loading dictionary..') tokenizer = GWTokenizer(os.path.join(args.data_dir, args.dict_file)) ############################### # LOAD NETWORKS ############################# logger.info('Building networks..') qgen_network = QGenNetworkLSTM(qgen_config["model"], num_words=tokenizer.no_words, policy_gradient=True) qgen_var = [v for v in tf.global_variables() if "qgen" in v.name and 'rl_baseline' not in v.name] qgen_saver = tf.train.Saver(var_list=qgen_var) oracle_network = OracleNetwork(oracle_config, num_words=tokenizer.no_words) oracle_var = [v for v in tf.global_variables() if "oracle" in v.name] oracle_saver = tf.train.Saver(var_list=oracle_var) guesser_network = GuesserNetwork(guesser_config["model"], num_words=tokenizer.no_words) guesser_var = [v for v in tf.global_variables() if "guesser" in v.name] guesser_saver = tf.train.Saver(var_list=guesser_var) loop_saver = tf.train.Saver(allow_empty=False) ############################### # REINFORCE OPTIMIZER ############################# logger.info('Building optimizer..')