Exemplo n.º 1
0
    def update(self, replay_batch_size, history_size, update_from=0, discount_gamma=0.0):

        if len(self.replay_memory) < replay_batch_size:
            return None
        transitions = self.replay_memory.get_batch(replay_batch_size, history_size + 1)  # list (history_size + 1) of list (batch) of tuples
        # last transitions is just for computing the last Q function
        if transitions is None:
            return None
        sequences = [Transition(*zip(*batch)) for batch in transitions]

        losses = []
        prev_ras_hidden, prev_ras_cell = None, None  # ras: recurrent action scorer
        observation_id_list = pad_sequences(sequences[0].observation_id_list, maxlen=max_len(sequences[0].observation_id_list), padding='post').astype('int32')
        input_observation = to_pt(observation_id_list, self.use_cuda)
        v_idx = torch.stack(sequences[0].v_idx, 0)  # batch x 1
        n_idx = torch.stack(sequences[0].n_idx, 0)  # batch x 1
        verb_rank, noun_rank, curr_ras_hidden, curr_ras_cell = self.get_ranks(input_observation, prev_ras_hidden, prev_ras_cell)
        v_qvalue, n_qvalue = verb_rank.gather(1, v_idx.unsqueeze(-1)).squeeze(-1), noun_rank.gather(1, n_idx.unsqueeze(-1)).squeeze(-1)  # batch
        prev_qvalue = torch.mean(torch.stack([v_qvalue, n_qvalue], -1), -1)  # batch
        if update_from > 0:
            prev_qvalue, curr_ras_hidden, curr_ras_cell = prev_qvalue.detach(), curr_ras_hidden.detach(), curr_ras_cell.detach()

        for i in range(1, len(sequences)):
            observation_id_list = pad_sequences(sequences[i].observation_id_list,
                                                maxlen=max_len(sequences[i].observation_id_list),
                                                padding='post').astype('int32')
            input_observation = to_pt(observation_id_list, self.use_cuda)
            v_idx = torch.stack(sequences[i].v_idx, 0)  # batch x 1
            n_idx = torch.stack(sequences[i].n_idx, 0)  # batch x 1

            verb_rank, noun_rank, curr_ras_hidden, curr_ras_cell = self.get_ranks(input_observation,
                                                                                  curr_ras_hidden,
                                                                                  curr_ras_cell)

            v_qvalue_max, _, n_qvalue_max, _ = self.choose_maxQ_command(verb_rank, noun_rank)
            q_value_max = torch.mean(torch.stack([v_qvalue_max, n_qvalue_max], -1), -1)  # batch
            q_value_max = q_value_max.detach()
            v_qvalue, n_qvalue = verb_rank.gather(1, v_idx.unsqueeze(-1)).squeeze(-1), \
                                 noun_rank.gather(1, n_idx.unsqueeze(-1)).squeeze(-1)  # batch
            q_value = torch.mean(torch.stack([v_qvalue, n_qvalue], -1), -1)  # batch
            if i < update_from or i == len(sequences) - 1:
                q_value, curr_ras_hidden, curr_ras_cell = q_value.detach(), curr_ras_hidden.detach(), \
                                                          curr_ras_cell.detach()
            if i > update_from:
                prev_rewards = torch.stack(sequences[i - 1].reward)  # batch
                prev_not_done = 1.0 - np.array(sequences[i - 1].done, dtype='float32')  # batch
                prev_not_done = to_pt(prev_not_done, self.use_cuda, type='float')
                prev_rewards = prev_rewards + prev_not_done * q_value_max * discount_gamma  # batch
                prev_mask = torch.stack(sequences[i - 1].mask)  # batch
                prev_loss = F.smooth_l1_loss(prev_qvalue * prev_mask, prev_rewards * prev_mask)  # huber_loss
                losses.append(prev_loss)
            prev_qvalue = q_value

        return torch.stack(losses).mean()
Exemplo n.º 2
0
 def choose_maxQ_command(self, verb_rank, noun_rank):
     batch_size = verb_rank.size(0)
     vr, nr = to_np(verb_rank), to_np(noun_rank)
     v_idx = np.argmax(vr, -1)
     n_idx = np.argmax(nr, -1)
     v_qvalue, n_qvalue = [], []
     for i in range(batch_size):
         v_qvalue.append(verb_rank[i][v_idx[i]])
         n_qvalue.append(noun_rank[i][n_idx[i]])
     v_qvalue, n_qvalue = torch.stack(v_qvalue), torch.stack(n_qvalue)
     v_idx, n_idx = to_pt(v_idx, self.use_cuda), to_pt(n_idx, self.use_cuda)
     return v_qvalue, v_idx, n_qvalue, n_idx
Exemplo n.º 3
0
    def choose_random_command(self, verb_rank, noun_rank):
        batch_size = verb_rank.size(0)
        vr, nr = to_np(verb_rank), to_np(noun_rank)

        v_idx, n_idx = [], []
        for i in range(batch_size):
            v_idx.append(np.random.choice(len(vr[i]), 1)[0])
            n_idx.append(np.random.choice(len(nr[i]), 1)[0])
        v_qvalue, n_qvalue = [], []
        for i in range(batch_size):
            v_qvalue.append(verb_rank[i][v_idx[i]])
            n_qvalue.append(noun_rank[i][n_idx[i]])
        v_qvalue, n_qvalue = torch.stack(v_qvalue), torch.stack(n_qvalue)
        v_idx, n_idx = to_pt(np.array(v_idx), self.use_cuda), to_pt(np.array(n_idx), self.use_cuda)
        return v_qvalue, v_idx, n_qvalue, n_idx
Exemplo n.º 4
0
 def generate_one_command(self, input_description, prev_hidden=None,
                          prev_cell=None, epsilon=0.2, return_att=False, att_mask=None):
     verb_rank, noun_rank, curr_hidden, curr_cell = \
         self.get_ranks(input_description, prev_hidden,
                        prev_cell, return_att=return_att, att_mask=att_mask)  # batch x n_verb, batch x n_noun
     curr_hidden = curr_hidden.detach()
     curr_cell = curr_cell.detach()
     v_qvalue_maxq, v_idx_maxq, n_qvalue_maxq, n_idx_maxq = self.choose_maxQ_command(verb_rank, noun_rank)
     v_qvalue_random, v_idx_random, n_qvalue_random, n_idx_random = self.choose_random_command(verb_rank, noun_rank)
     # random number for epsilon greedy
     rand_num = np.random.uniform(low=0.0, high=1.0, size=(input_description.size(0),))
     less_than_epsilon = (rand_num < epsilon).astype("float32")  # batch
     greater_than_epsilon = 1.0 - less_than_epsilon
     less_than_epsilon = to_pt(less_than_epsilon, self.use_cuda, type='float')
     greater_than_epsilon = to_pt(greater_than_epsilon, self.use_cuda, type='float')
     less_than_epsilon, greater_than_epsilon = less_than_epsilon.long(), greater_than_epsilon.long()
     v_idx = less_than_epsilon * v_idx_random + greater_than_epsilon * v_idx_maxq
     n_idx = less_than_epsilon * n_idx_random + greater_than_epsilon * n_idx_maxq
     v_idx, n_idx = v_idx.detach(), n_idx.detach()
     chosen_strings = self.get_chosen_strings(v_idx, n_idx)
     return v_idx, n_idx, chosen_strings, curr_hidden, curr_cell
Exemplo n.º 5
0
    def compute_reward(self, revisit_counting_lambda=0.0, revisit_counting=True):
        if len(self.dones) == 1:
            mask = [1.0 for _ in self.dones[-1]]
        else:
            assert len(self.dones) > 1
            mask = [1.0 if not self.dones[-2][i] else 0.0 for i in range(len(self.dones[-1]))]
        mask = np.array(mask, dtype='float32')
        mask_pt = to_pt(mask, self.use_cuda, type='float')

        # self.rewards: list of list, max_game_length x batch_size
        rewards = np.array(self.rewards[-1], dtype='float32')  # batch
        if revisit_counting:
            # rewards += np.array(self.intermediate_rewards[-1], dtype='float32')
            if len(self.revisit_counting_rewards) > 0:
                rewards = rewards + np.array(self.revisit_counting_rewards[-1], dtype='float32') * revisit_counting_lambda
        rewards_pt = to_pt(rewards, self.use_cuda, type='float')
        # memory mask: play one more step after done
        if len(self.dones) < 3:
            memory_mask = [1.0 for _ in self.dones[-1]]
        else:
            memory_mask = [1.0 if mask[i] == 1 or ((not self.dones[-3][i]) and self.dones[-2][i]) 
                           else 0.0 for i in range(len(self.dones[-1]))]
        return rewards, rewards_pt, mask, mask_pt, memory_mask
Exemplo n.º 6
0
    def get_game_step_info(self, obs, infos, prev_actions=None, prune=False,
                           ret_desc=False, teacher_actions=None):
        # concat d/i/q/f together as one string
        if prune:
            inventory_strings = [self.bs_obj.prune_state(info["inventory"], teacher_actions[k], add_prefix=False) for k, info in enumerate(infos)]
        else:
            inventory_strings = [info["inventory"] for info in infos]
        inventory_token_list = [preproc(item, str_type='inventory', lower_case=True) for item in inventory_strings]
        inventory_id_list = [_words_to_ids(tokens, self.word2id) for tokens in inventory_token_list]
        if prune:
            feedback_strings = [self.bs_obj.prune_state(info["command_feedback"], teacher_actions[k], add_prefix=False)
                                 for k, info in enumerate(infos)]
        else:
            feedback_strings = [info["command_feedback"] for info in infos]

        feedback_token_list = [preproc(item, str_type='feedback', lower_case=True) for item in feedback_strings]
        feedback_id_list = [_words_to_ids(tokens, self.word2id) for tokens in feedback_token_list]

        orig_quest_string = [info["objective"] for info in infos]
        if prune:
            quest_strings = [self.bs_obj.prune_state(info["objective"], teacher_actions[k], add_prefix=False) for k, info in enumerate(infos)]
        else:
            quest_strings = [info["objective"] for info in infos]
        quest_token_list = [preproc(item, str_type='None', lower_case=True) for item in quest_strings]
        quest_id_list = [_words_to_ids(tokens, self.word2id) for tokens in quest_token_list]

        prev_actions = prev_actions
        if prev_actions is not None:
            prev_action_token_list = [preproc(item, str_type='None', lower_case=True) for item in prev_actions]
            prev_action_id_list = [_words_to_ids(tokens, self.word2id) for tokens in prev_action_token_list]
        else:
            prev_action_token_list = [[] for _ in infos]
            prev_action_id_list = [[] for _ in infos]

        if prune:
            description_strings = [self.bs_obj.prune_state(info["description"], teacher_actions[k]) for k, info in enumerate(infos)]
        else:
            description_strings = [info["description"] for info in infos]

        description_token_list = [preproc(item, str_type='description', lower_case=True) for item in description_strings]
        for i, d in enumerate(description_token_list):
            if len(d) == 0:
                description_token_list[i] = ["end"]  # hack here, if empty description, insert word "end"
        description_id_list = [_words_to_ids(tokens, self.word2id) for tokens in description_token_list]
    
        description_id_list = [_d + _i + _q + _f + _pa for (_d, _i, _q, _f, _pa) in
                            zip(description_id_list, inventory_id_list, quest_id_list,
                                feedback_id_list, prev_action_id_list)]
        description_str_list = [_d + _i + _q + _f + _pa for (_d, _i, _q, _f, _pa) in
                                zip(description_token_list, inventory_token_list,
                                    quest_token_list, feedback_token_list, prev_action_token_list)]
        
        self.observation_cache.push(description_id_list)
        description_with_history_id_list = self.observation_cache.get_all()
        input_description = pad_sequences(description_with_history_id_list,
                                          maxlen=max_len(description_with_history_id_list),
                                          padding='post').astype('int32')
        input_description = to_pt(input_description, self.use_cuda)
        if ret_desc:
            return input_description, description_with_history_id_list, description_str_list, orig_quest_string
        else:
            return input_description, description_with_history_id_list