コード例 #1
0
    def __init__(self, config, oracle, qgen, guesser, tokenizer):
        self.storage = []

        self.tokenizer = tokenizer

        self.batch_size = config["optimizer"]["batch_size"]

        self.max_no_question = config['loop']['max_question']
        self.max_depth = config['loop']['max_depth']
        self.k_best = config['loop']['beam_k_best']

        self.oracle = OracleWrapper(oracle, tokenizer)
        self.guesser = GuesserWrapper(guesser)
        self.qgen = QGenWrapper(qgen, tokenizer, max_length=self.max_depth, k_best=self.k_best)
コード例 #2
0
            os.path.join(guesser_dir, guesser_checkpoint, 'best',
                         'params.ckpt'))
        qgen_saver.restore(
            sess_loop, os.path.join(xp_manager.dir_best_ckpt, 'params.ckpt'))

        oracle_split_mode = 1
        oracle_batchifier = oracle_batchifier_cstor(
            tokenizer,
            sources=oracle_network.get_sources(sess_loop),
            split_mode=oracle_split_mode)
        oracle_wrapper = OracleWrapper(oracle_network, oracle_batchifier,
                                       tokenizer)

        guesser_batchifier = guesser_batchifier_cstor(
            tokenizer, sources=guesser_network.get_sources(sess_loop))
        guesser_wrapper = GuesserWrapper(guesser_network, guesser_batchifier,
                                         tokenizer, guesser_listener)

        qgen_batchifier = qgen_batchifier_cstor(
            tokenizer,
            sources=qgen_network.get_sources(sess_loop),
            generate=True)
        qgen_wrapper = QGenWrapper(qgen_network,
                                   qgen_batchifier,
                                   tokenizer,
                                   max_length=12,
                                   k_best=20)

        xp_manager.configure_score_tracking("valid_accuracy", max_is_best=True)

        loop_config = {}  # fake config
        loop_config['loop'] = {}
コード例 #3
0
class BasicLooper(object):
    def __init__(self, config, oracle, qgen, guesser, tokenizer):
        self.storage = []

        self.tokenizer = tokenizer

        self.batch_size = config["optimizer"]["batch_size"]

        self.max_no_question = config['loop']['max_question']
        self.max_depth = config['loop']['max_depth']
        self.k_best = config['loop']['beam_k_best']

        self.oracle = OracleWrapper(oracle, tokenizer)
        self.guesser = GuesserWrapper(guesser)
        self.qgen = QGenWrapper(qgen, tokenizer, max_length=self.max_depth, k_best=self.k_best)

    def process(self, sess, iterator, mode, optimizer=list(), store_games=False):

        # initialize the wrapper
        self.qgen.initialize(sess)
        self.oracle.initialize(sess)
        self.guesser.initialize(sess)

        self.storage = []
        score, total_elem = 0, 0
        for game_data in tqdm(iterator):

            # initialize the dialogue
            full_dialogues = [np.array([self.tokenizer.start_token]) for _ in range(self.batch_size)]
            prev_answers = full_dialogues

            no_elem = len(game_data["raw"])
            total_elem += no_elem

            # Step 1: generate question/answer
            self.qgen.reset(batch_size=no_elem)
            for no_question in range(self.max_no_question):

                # Step 1.1: Generate new question
                padded_questions, questions, seq_length = \
                    self.qgen.sample_next_question(sess, prev_answers, game_data=game_data, mode=mode)

                # Step 1.2: Answer the question
                answers = self.oracle.answer_question(sess,
                                                      question=padded_questions,
                                                      seq_length=seq_length,
                                                      game_data=game_data)

                # Step 1.3: store the full dialogues
                for i in range(self.batch_size):
                    full_dialogues[i] = np.concatenate((full_dialogues[i], questions[i], [answers[i]]))

                # Step 1.4 set new input tokens
                prev_answers = [[a]for a in answers]

            # Step 2 : clear question after <stop_dialogue>
            full_dialogues, _ = clear_after_stop_dialogue(full_dialogues, self.tokenizer)
            padded_dialogue, seq_length = list_to_padded_tokens(full_dialogues, self.tokenizer)

            # Step 3 : Find the object
            found_object, softmax, guess_objects = self.guesser.find_object(sess, padded_dialogue, seq_length, game_data)
            score += np.sum(found_object)

            if store_games:
                for d, g, t, f, go in zip(full_dialogues, game_data["raw"], game_data["targets"], found_object, guess_objects):
                    self.storage.append({"dialogue": d, "game": g, "object_id": g.objects[t].id, "success": f, "guess_object_id": g.objects[go].id})

            if len(optimizer) > 0:
                final_reward = found_object + 0  # +1 if found otherwise 0

                self.apply_policy_gradient(sess,
                                           final_reward=final_reward,
                                           padded_dialogue=padded_dialogue,
                                           seq_length=seq_length,
                                           game_data=game_data,
                                           optimizer=optimizer)

        score = 1.0 * score / iterator.n_examples

        return score

    def get_storage(self):
        return self.storage

    def apply_policy_gradient(self, sess, final_reward, padded_dialogue, seq_length, game_data, optimizer):

        # Compute cumulative reward TODO: move into an external function
        cum_rewards = np.zeros_like(padded_dialogue, dtype=np.float32)
        for i, (end_of_dialogue, r) in enumerate(zip(seq_length, final_reward)):
            cum_rewards[i, :(end_of_dialogue - 1)] = r  # gamma = 1

        # Create answer mask to ignore the reward for yes/no tokens
        answer_mask = np.ones_like(padded_dialogue)  # quick and dirty mask -> TODO to improve
        answer_mask[padded_dialogue == self.tokenizer.yes_token] = 0
        answer_mask[padded_dialogue == self.tokenizer.no_token] = 0
        answer_mask[padded_dialogue == self.tokenizer.non_applicable_token] = 0

        # Create padding mask to ignore the reward after <stop_dialogue>
        padding_mask = np.ones_like(padded_dialogue)
        padding_mask[padded_dialogue == self.tokenizer.padding_token] = 0
        # for i in range(np.max(seq_length)): print(cum_rewards[0][i], answer_mask[0][i],self.tokenizer.decode([padded_dialogue[0][i]]))

        # Step 4.4: optim step
        qgen = self.qgen.qgen  # retrieve qgen from wrapper (dirty)

        sess.run(optimizer,
                 feed_dict={
                     qgen.images: game_data["images"],
                     qgen.dialogues: padded_dialogue,
                     qgen.seq_length: seq_length,
                     qgen.padding_mask: padding_mask,
                     qgen.answer_mask: answer_mask,
                     qgen.cum_rewards: cum_rewards,
                 })
コード例 #4
0
        # create training tools
        loop_sources = qgen_network.get_sources(sess)
        logger.info("Sources: " + ', '.join(loop_sources))

        evaluator = Evaluator(loop_sources,
                              qgen_network.scope_name,
                              network=qgen_network,
                              tokenizer=tokenizer)

        train_batchifier = LooperBatchifier(tokenizer, generate_new_games=True)
        eval_batchifier = LooperBatchifier(tokenizer, generate_new_games=False)

        # Initialize the looper to eval/train the game-simulation

        oracle_wrapper = OracleWrapper(oracle_network, tokenizer)
        guesser_wrapper = GuesserWrapper(guesser_network)
        qgen_network.build_sampling_graph(
            qgen_config["model"],
            tokenizer=tokenizer,
            max_length=loop_config['loop']['max_depth'])
        qgen_wrapper = QGenWrapper(qgen_network,
                                   tokenizer,
                                   max_length=loop_config['loop']['max_depth'],
                                   k_best=loop_config['loop']['beam_k_best'])

        looper_evaluator = BasicLooper(
            loop_config,
            oracle_wrapper=oracle_wrapper,
            guesser_wrapper=guesser_wrapper,
            qgen_wrapper=qgen_wrapper,
            tokenizer=tokenizer,