Exemple #1
0
class OracleWrapper(object):
    def __init__(self, oracle, tokenizer):

        self.oracle = oracle
        self.evaluator = None
        self.tokenizer = tokenizer

    def initialize(self, sess):
        self.evaluator = Evaluator(self.oracle.get_sources(sess),
                                   self.oracle.scope_name)

    def answer_question(self, sess, question, seq_length, game_data):

        game_data["question"] = question
        game_data["seq_length"] = seq_length

        # convert dico name to fit oracle constraint
        game_data["category"] = game_data.get("targets_category", None)
        game_data["spatial"] = game_data.get("targets_spatial", None)

        # sample
        answers_indices = self.evaluator.execute(sess,
                                                 output=self.oracle.best_pred,
                                                 batch=game_data)

        # Decode the answers token  ['<yes>', '<no>', '<n/a>'] WARNING magic order... TODO move this order into tokenizer
        answer_dico = [
            self.tokenizer.yes_token, self.tokenizer.no_token,
            self.tokenizer.non_applicable_token
        ]
        answers = [answer_dico[a]
                   for a in answers_indices]  # turn indices into tokenizer_id

        return answers
class QGenSamplingWrapper(object):
    def __init__(self, qgen, tokenizer, max_length):

        self.qgen = qgen

        self.tokenizer = tokenizer
        self.max_length = max_length

        self.evaluator = None

        # Track the hidden state of LSTM
        self.state_c = None
        self.state_h = None
        self.state_size = int(qgen.decoder_zero_state_c.get_shape()[1])

    def initialize(self, sess):
        self.evaluator = Evaluator(self.qgen.get_sources(sess),
                                   self.qgen.scope_name)

    def reset(self, batch_size):
        # reset state
        self.state_c = np.zeros((batch_size, self.state_size))
        self.state_h = np.zeros((batch_size, self.state_size))

    def sample_next_question(self, sess, prev_answers, game_data, greedy):

        game_data["dialogues"] = prev_answers
        game_data["seq_length"] = [1] * len(prev_answers)
        game_data["state_c"] = self.state_c
        game_data["state_h"] = self.state_h
        game_data["greedy"] = greedy

        # sample
        res = self.evaluator.execute(sess, self.qgen.samples, game_data)

        self.state_c = res[0]
        self.state_h = res[1]
        transpose_questions = res[2]
        seq_length = res[3]

        # Get questions
        padded_questions = transpose_questions.transpose([1, 0])
        padded_questions = padded_questions[:, 1:]  # ignore first token

        for i, l in enumerate(seq_length):
            padded_questions[i, l:] = self.tokenizer.padding_token

        questions = [q[:l] for q, l in zip(padded_questions, seq_length)]

        return padded_questions, questions, seq_length
class OracleWrapper(object):
    def __init__(self, oracle, batchifier, tokenizer):

        self.oracle = oracle
        self.evaluator = None

        self.tokenizer = tokenizer
        self.batchifier = batchifier

    def initialize(self, sess):
        self.evaluator = Evaluator(self.oracle.get_sources(sess),
                                   self.oracle.scope_name)

    def answer_question(self, sess, games):

        # create the training batch #TODO: hack -> to remove
        oracle_games = []
        if self.batchifier.split_mode == 1:
            for game in games:
                g = copy.copy(game)
                g.questions = [game.questions[-1]]
                g.question_ids = [game.question_ids[-1]]
                oracle_games.append(g)
        else:
            oracle_games = games

        batch = self.batchifier.apply(oracle_games, skip_targets=True)
        batch["is_training"] = False

        # Sample
        answers_index = self.evaluator.execute(sess,
                                               output=self.oracle.prediction,
                                               batch=batch)

        # Update game
        new_games = []
        for game, answer in zip(games, answers_index):
            if not game.user_data[
                    "has_stop_token"]:  # stop adding answer if dialogue is over
                game.answers.append(
                    self.tokenizer.decode_oracle_answer(answer, sparse=True))
            new_games.append(game)

        return new_games
Exemple #4
0
class GuesserWrapper(object):

    def __init__(self, guesser):
        self.guesser = guesser
        self.evaluator = None

    def initialize(self, sess):
        self.evaluator = Evaluator(self.guesser.get_sources(sess), self.guesser.scope_name)

    def find_object(self, sess, dialogues, seq_length, game_data):
        game_data["dialogues"] = dialogues
        game_data["seq_length"] = seq_length

        # sample
        selected_object, softmax = self.evaluator.execute(sess, output=[self.guesser.selected_object, self.guesser.softmax], batch=game_data)

        found = (selected_object == game_data["targets_index"])

        return found, softmax, selected_object
class QGenWrapper(object):
    def __init__(self, qgen, batchifier, tokenizer, max_length, k_best):

        self.qgen = qgen

        self.batchifier = batchifier
        self.tokenizer = tokenizer

        self.ops = dict()
        self.ops["sampling"] = qgen.create_sampling_graph(
            start_token=tokenizer.start_token,
            stop_token=tokenizer.stop_token,
            max_tokens=max_length)

        self.ops["greedy"] = qgen.create_greedy_graph(
            start_token=tokenizer.start_token,
            stop_token=tokenizer.stop_token,
            max_tokens=max_length)

        beam_predicted_ids, seq_length, att = qgen.create_beam_graph(
            start_token=tokenizer.start_token,
            stop_token=tokenizer.stop_token,
            max_tokens=max_length,
            k_best=k_best)
        # print('b',beam_predicted_ids)
        # print('s',seq_length)
        # Only keep best beam
        self.ops[
            "beam"] = beam_predicted_ids[:,
                                         0, :], seq_length[:,
                                                           0], beam_predicted_ids[:,
                                                                                  0, :] * 0, att

        self.evaluator = None

    def initialize(self, sess):
        self.evaluator = Evaluator(self.qgen.get_sources(sess),
                                   self.qgen.scope_name,
                                   network=self.qgen,
                                   tokenizer=self.tokenizer)

    def policy_update(self, sess, games, optimizer):

        # ugly hack... to allow training on RL
        batchifier = copy.copy(self.batchifier)
        batchifier.generate = False
        batchifier.supervised = False

        iterator = BasicIterator(games,
                                 batch_size=len(games),
                                 batchifier=batchifier)

        # Check whether the gradient is accumulated
        if isinstance(optimizer, AccOptimizer):
            sess.run(optimizer.zero)  # reset gradient
            local_optimizer = optimizer.accumulate
        else:
            local_optimizer = optimizer

        # Compute the gradient
        self.evaluator.process(sess,
                               iterator,
                               outputs=[local_optimizer],
                               show_progress=False)

        if isinstance(optimizer, AccOptimizer):
            sess.run(optimizer.update)  # Apply accumulated gradient

    def sample_next_question(self, sess, games, att_dict, beta_dict, mode):

        # ugly hack... to allow training on RL
        batchifier = copy.copy(self.batchifier)
        batchifier.generate = True
        batchifier.supervised = False

        # create the training batch
        batch = batchifier.apply(games, skip_targets=True)
        batch["is_training"] = False
        batch["is_dynamic"] = True

        # Sample
        tokens, seq_length, state_values, atts = self.evaluator.execute(
            sess, output=self.ops[mode], batch=batch)
        # tokens, seq_length, state_values, atts, betas = self.evaluator.execute(sess, output=self.ops[mode], batch=batch)

        # Update game
        new_games = []
        for game, question_tokens, l, state_value, att in zip(
                games, tokens, seq_length, state_values, atts):
            # for game, question_tokens, l, state_value, att, beta in zip(games, tokens, seq_length, state_values, atts, betas):

            if not game.user_data[
                    "has_stop_token"]:  # stop adding question if dialogue is over

                # clean tokens after stop_dialogue_tokens
                if self.tokenizer.stop_dialogue in question_tokens:
                    game.user_data["has_stop_token"] = True
                    l = np.nonzero(
                        question_tokens == self.tokenizer.stop_dialogue
                    )[0][0] + 1  # find the first stop_dialogue occurrence
                # Append the newly generated question
                game.questions.append(
                    self.tokenizer.decode(question_tokens[:l]))
                game.question_ids.append(len(game.question_ids))

                game.user_data["state_values"] = game.user_data.get(
                    "state_values", [])
                game.user_data["state_values"].append(state_value[:l].tolist())
            att = att.tolist()
            att_i = np.argsort(att).tolist()
            att_3 = np.sort(att).tolist()
            if game.dialogue_id not in att_dict:
                att_dict[game.dialogue_id] = []
                att_dict[game.dialogue_id].append((att_i, att_3))
            else:
                att_dict[game.dialogue_id].append((att_i, att_3))

            # beta = beta.tolist()
            # if game.dialogue_id not in beta_dict:
            #     beta_dict[game.dialogue_id] = []
            #     beta_dict[game.dialogue_id].append(beta)
            # else:
            #     beta_dict[game.dialogue_id].append(beta)

            new_games.append(game)

        return new_games, att_dict  #, beta_dict