Пример #1
0
class ClevrEnv(gym.Env):
    """Clevr Env"""
    metadata = {'render.modes': ['human']}

    def __init__(self,
                 data_path,
                 max_len,
                 reward_type="levenshtein",
                 reward_path=None,
                 max_samples=None,
                 debug=False,
                 mode="train",
                 num_questions=10):
        super(ClevrEnv, self).__init__()
        self.mode = mode
        self.data_path = data_path
        h5_questions_path = os.path.join(data_path,
                                         '{}_questions.h5'.format(self.mode))
        h5_feats_path = os.path.join(data_path,
                                     '{}_features.h5'.format(self.mode))
        vocab_path = os.path.join(data_path, 'vocab.json')
        # self.debug_true_questions = torch.randint(0,debug_len_vocab, (2,))
        self.debug = debug
        self.num_questions = num_questions
        self.clevr_dataset = CLEVR_Dataset(h5_questions_path=h5_questions_path,
                                           h5_feats_path=h5_feats_path,
                                           vocab_path=vocab_path,
                                           max_samples=max_samples)

        # num_tokens = self.clevr_dataset.len_vocab
        # feats_shape = self.clevr_dataset.feats_shape
        SOS_idx = self.clevr_dataset.vocab_questions["<SOS>"]
        EOS_idx = self.clevr_dataset.vocab_questions["<EOS>"]

        Special_Tokens = namedtuple('Special_Tokens', ('SOS_idx', 'EOS_idx'))
        self.special_tokens = Special_Tokens(SOS_idx, EOS_idx)
        self.State = namedtuple('State', ('text', 'img'))
        self.Episode = namedtuple('Episode',
                                  ('img_idx', 'closest_question', 'dialog',
                                   'rewards', 'valid_actions'))
        self.max_len = max_len
        # self.ref_questions = torch.randint(0, self.debug_len_vocab,
        #                                  (3, self.max_len)) if self.debug_len_vocab is not None else None
        # self.reset()

        self.reward_func = rewards[reward_type](reward_path)
        self.step_idx = 0
        self.state, self.dialog = None, None
        self.ref_questions, self.ref_questions_decoded = None, None
        self.img_idx, self.img_feats = None, None

    def step(self, action):
        action = torch.tensor(action).view(1, 1)
        self.state = self.State(torch.cat([self.state.text, action], dim=1),
                                self.state.img)
        question = self.clevr_dataset.idx2word(self.state.text.numpy()[0])
        done = True if action.item(
        ) == self.special_tokens.EOS_idx or self.step_idx == (self.max_len -
                                                              1) else False
        # question = preprocess_final_state(state_text=self.state.text, dataset=self.clevr_dataset,
        #                               EOS_idx=self.special_tokens.EOS_idx)
        reward, closest_question = self.reward_func.get(
            question=question,
            ep_questions_decoded=self.ref_questions_decoded) if done else (
                0, None)
        self.step_idx += 1
        if done:
            self.dialog = question
            logging.info(question)
        return self.state, (reward, closest_question), done, {}

    def reset(self):
        self.img_idx = np.random.randint(
            0, self.clevr_dataset.all_feats.shape[0]
        ) if not self.debug else np.random.randint(0, self.debug)
        # self.img_idx = 0
        self.ref_questions = self.clevr_dataset.get_questions_from_img_idx(
            self.img_idx)[:, :self.max_len]  # shape (10, 45)
        #if self.debug > 0:
        self.ref_questions = self.ref_questions[0:self.num_questions]
        # if self.debug:
        # self.ref_questions = torch.tensor([[7, 8, 10, 12, 14]])
        self.ref_questions_decoded = [
            self.clevr_dataset.idx2word(question, clean=True)
            for question in self.ref_questions.numpy()
        ]
        logging.info("Questions for image {} : {}".format(
            self.img_idx, self.ref_questions_decoded))
        # self.ref_questions_decoded = [self.ref_questions_decoded[0]]  # FOR DEBUGGING.
        self.img_feats = self.clevr_dataset.get_feats_from_img_idx(
            self.img_idx)  # shape (1024, 14, 14)
        self.state = self.State(
            torch.LongTensor([self.special_tokens.SOS_idx]).view(1, 1),
            self.img_feats.unsqueeze(0))
        self.step_idx = 0
        self.dialog = None
        self.current_episode = self.Episode(self.img_idx, None, None, None,
                                            None)

        return self.state

    def decode_current_episode(self):
        valid_actions = self.current_episode.valid_actions
        assert valid_actions is not None
        valid_actions_decoded = [
            self.clevr_dataset.idx2word(actions, delim=',')
            for actions in valid_actions
        ]
        # dialog_split = [self.current_episode.dialog.split()[:i] for i in range(valid_actions)]
        # return dict(zip(dialog_split, valid_actions_decoded))
        return valid_actions_decoded

    def clean_ref_questions(self):
        questions_decoded = [
            tokens.replace('<PAD>', '')
            for tokens in self.ref_questions_decoded
        ]
        questions_decoded = [q.strip() for q in questions_decoded]
        self.ref_questions_decoded = questions_decoded

    def get_reduced_action_space(self):
        assert self.ref_questions_decoded is not None
        reduced_vocab = [q.split() for q in self.ref_questions_decoded]
        reduced_vocab = [i for l in reduced_vocab for i in l]
        reduced_vocab = list(set(reduced_vocab))
        unique_tokens = self.clevr_dataset.word2idx(seq_tokens=reduced_vocab)
        dict_tokens = dict(
            zip([i for i in range(len(unique_tokens))], unique_tokens))
        return dict_tokens, reduced_vocab

    def render(self, mode='human', close=False):
        pass
Пример #2
0
    test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), drop_last=True,
                             num_workers=args.num_workers)
    out_file_top_k_words = os.path.join(args.out_path,
                                        'generate_top_k_words_k_{}_seed_{}.json'.format(args.top_k, args.seed))
    out_file_log = os.path.join(args.out_path, 'eval_log.log')
    logger = create_logger(out_file_log)
    log_interval = int(args.words / 10)

    ###############################################################################
    # generate words
    ###############################################################################
    indexes = [i for i in range(5)]
    for index in indexes:
        img_feats = test_dataset.get_feats_from_img_idx(index).unsqueeze(0)
        input = torch.LongTensor([SOS_idx]).view(1,1).to(device)
        input_word = test_dataset.idx2word([input[0].item()], delim='')
        for temp in args.temperature:
            logger.info("generating text with temperature: {}".format(temp))
            out_file_generate = os.path.join(args.out_path, 'generate_words_temp_{}_img_{}.txt'.format(temp, index))
            with open(out_file_generate, 'w') as f:
                f.write(input_word + '\n')
                with torch.no_grad():
                    for i in range(args.words):
                        output, hidden, _ = model(input, img_feats)  # output (1, num_tokens)
                        if temp is not None:
                            word_weights = output.squeeze().div(temp).exp()  # (exp(1/temp * log_sofmax)) = (p_i^(1/T))
                            word_weights = word_weights / word_weights.sum(dim=-1).cpu()
                            word_idx = torch.multinomial(word_weights, num_samples=1)[0]  # [0] to have a scalar tensor.
                        else:
                            word_idx = output.squeeze().argmax()
                        input.fill_(word_idx)