示例#1
0
    def start_session(self):
        """ Launch the tensorflow session and start the GuessWhat loop """
        with tf.Session(config=self.tf_config) as sess:
            guesser_network = GuesserNetwork(self.guesser_config['model'],
                                             num_words=self.tokenizer.no_words)
            guesser_var = [
                v for v in tf.global_variables() if 'guesser' in v.name
            ]
            guesser_saver = tf.train.Saver(var_list=guesser_var)
            guesser_saver.restore(sess, GUESS_NTW_PATH)
            guesser_wrapper = GuesserROSWrapper(guesser_network)

            qgen_network = QGenNetworkLSTM(self.qgen_config['model'],
                                           num_words=self.tokenizer.no_words,
                                           policy_gradient=False)
            qgen_var = [v for v in tf.global_variables() if 'qgen' in v.name]
            qgen_saver = tf.train.Saver(var_list=qgen_var)
            qgen_saver.restore(sess, QGEN_NTW_PATH)
            qgen_network.build_sampling_graph(
                self.qgen_config['model'],
                tokenizer=self.tokenizer,
                max_length=self.eval_config['loop']['max_depth'])
            qgen_wrapper = QGenWrapper(
                qgen_network,
                self.tokenizer,
                max_length=self.eval_config['loop']['max_depth'],
                k_best=self.eval_config['loop']['beam_k_best'])

            oracle_wrapper = OracleROSWrapper(self.tokenizer)

            self.loop(sess, guesser_wrapper, qgen_wrapper, oracle_wrapper)
    logger.info('Loading data..')
    trainset = Dataset(args.data_dir, "train", image_loader, crop_loader)
    validset = Dataset(args.data_dir, "valid", image_loader, crop_loader)
    testset = Dataset(args.data_dir, "test", image_loader, crop_loader)

    # Load dictionary
    logger.info('Loading dictionary..')
    tokenizer = GWTokenizer(os.path.join(args.data_dir, args.dict_file))

    ###############################
    #  LOAD NETWORKS
    #############################

    logger.info('Building networks..')

    qgen_network = QGenNetworkLSTM(qgen_config["model"], num_words=tokenizer.no_words, policy_gradient=True)
    qgen_var = [v for v in tf.global_variables() if "qgen" in v.name and 'rl_baseline' not in v.name]
    qgen_saver = tf.train.Saver(var_list=qgen_var)

    oracle_network = OracleNetwork(oracle_config, num_words=tokenizer.no_words)
    oracle_var = [v for v in tf.global_variables() if "oracle" in v.name]
    oracle_saver = tf.train.Saver(var_list=oracle_var)

    guesser_network = GuesserNetwork(guesser_config["model"], num_words=tokenizer.no_words)
    guesser_var = [v for v in tf.global_variables() if "guesser" in v.name]
    guesser_saver = tf.train.Saver(var_list=guesser_var)

    loop_saver = tf.train.Saver(allow_empty=False)

    ###############################
    #  REINFORCE OPTIMIZER
    crop_loader = None  # get_img_loader(config, 'crop', args.image_dir)

    # Load data
    logger.info('Loading data..')
    trainset = Dataset(args.data_dir, "train", image_loader, crop_loader)
    validset = Dataset(args.data_dir, "valid", image_loader, crop_loader)
    testset = Dataset(args.data_dir, "test", image_loader, crop_loader)

    # Load dictionary
    logger.info('Loading dictionary..')
    tokenizer = GWTokenizer(os.path.join(args.data_dir, args.dict_file))

    # Build Network
    logger.info('Building network..')
    network = QGenNetworkLSTM(config["model"],
                              num_words=tokenizer.no_words,
                              policy_gradient=False)

    # Build Optimizer
    logger.info('Building optimizer..')
    optimizer, outputs = create_optimizer(network, config)

    ###############################
    #  START TRAINING
    #############################

    # Load config
    batch_size = config['optimizer']['batch_size']
    no_epoch = config["optimizer"]["no_epoch"]

    # create a saver to store/load checkpoint
示例#4
0
                os.path.join(args.networks_dir, 'guesser',
                             args.guesser_identifier, 'params.ckpt'))

            guesser_wrapper = GuesserWrapper(guesser_network)

        else:
            guesser_wrapper = GuesserUserWrapper(tokenizer,
                                                 img_raw_dir=args.img_raw_dir)
            logger.info("No Guesser was registered >>> use user input")

        if args.qgen_identifier is not None:
            qgen_config = get_config_from_xp(
                os.path.join(args.networks_dir, "qgen"), args.qgen_identifier)

            qgen_network = QGenNetworkLSTM(qgen_config["model"],
                                           num_words=tokenizer.no_words,
                                           policy_gradient=False)
            qgen_var = [v for v in tf.global_variables()
                        if "qgen" in v.name]  # and 'rl_baseline' not in v.name
            qgen_saver = tf.train.Saver(var_list=qgen_var)

            qgen_saver.restore(
                sess,
                os.path.join(args.networks_dir, 'qgen', args.qgen_identifier,
                             'params.ckpt'))

            qgen_network.build_sampling_graph(
                qgen_config["model"],
                tokenizer=tokenizer,
                max_length=eval_config['loop']['max_depth'])
            qgen_wrapper = QGenWrapper(