Example #1
0
    def command_generation_greedy_generation(self, observation_feats,
                                             task_desc_strings,
                                             previous_dynamics):
        with torch.no_grad():
            batch_size = len(observation_feats)

            aggregated_obs_feat = self.aggregate_feats_seq(observation_feats)
            h_obs = self.online_net.vision_fc(aggregated_obs_feat)
            h_td, td_mask = self.encode(task_desc_strings, use_model="online")
            h_td_mean = self.online_net.masked_mean(h_td, td_mask).unsqueeze(1)
            h_obs = h_obs.to(h_td_mean.device)
            vision_td = torch.cat((h_obs, h_td_mean),
                                  dim=1)  # batch x k boxes x hid
            vision_td_mask = torch.ones(
                (batch_size,
                 h_obs.shape[1] + h_td_mean.shape[1])).to(h_td_mean.device)

            if self.recurrent:
                averaged_vision_td_representation = self.online_net.masked_mean(
                    vision_td, vision_td_mask)
                current_dynamics = self.online_net.rnncell(
                    averaged_vision_td_representation, previous_dynamics
                ) if previous_dynamics is not None else self.online_net.rnncell(
                    averaged_vision_td_representation)
            else:
                current_dynamics = None

            # greedy generation
            input_target_list = [[self.word2id["[CLS]"]]
                                 for i in range(batch_size)]
            eos = np.zeros(batch_size)
            for _ in range(self.max_target_length):

                input_target = copy.deepcopy(input_target_list)
                input_target = pad_sequences(
                    input_target, maxlen=max_len(input_target)).astype('int32')
                input_target = to_pt(input_target, self.use_cuda)
                target_mask = compute_mask(
                    input_target)  # mask of ground truth should be the same
                pred = self.online_net.vision_decode(
                    input_target, target_mask, vision_td, vision_td_mask,
                    current_dynamics)  # batch x target_length x vocab
                # pointer softmax
                pred = to_np(pred[:, -1])  # batch x vocab
                pred = np.argmax(pred, -1)  # batch
                for b in range(batch_size):
                    new_stuff = [pred[b]] if eos[b] == 0 else []
                    input_target_list[b] = input_target_list[b] + new_stuff
                    if pred[b] == self.word2id["[SEP]"]:
                        eos[b] = 1
                if np.sum(eos) == batch_size:
                    break
            res = [self.tokenizer.decode(item) for item in input_target_list]
            res = [
                item.replace("[CLS]", "").replace("[SEP]", "").strip()
                for item in res
            ]
            res = [item.replace(" in / on ", " in/on ") for item in res]
            return res, current_dynamics
Example #2
0
    def get_action_candidate_representations(self,
                                             action_candidate_list,
                                             use_model="online"):
        # in case there are too many candidates in certain data point, we compute their candidate representations by small batches
        batch_size = len(action_candidate_list)
        max_num_candidate = max_len(action_candidate_list)
        res_representations = torch.zeros(batch_size, max_num_candidate,
                                          self.online_net.block_hidden_dim)
        res_mask = torch.zeros(batch_size, max_num_candidate)
        if self.use_cuda:
            res_representations = res_representations.cuda()
            res_mask = res_mask.cuda()

        squeezed_candidate_list, from_which_original_batch = [], []
        for b in range(batch_size):
            squeezed_candidate_list += action_candidate_list[b]
            for i in range(len(action_candidate_list[b])):
                from_which_original_batch.append((b, i))

        tmp_batch_size = 64
        n_tmp_batches = (len(squeezed_candidate_list) + tmp_batch_size -
                         1) // tmp_batch_size
        for tmp_batch_id in range(n_tmp_batches):
            tmp_batch_cand = squeezed_candidate_list[
                tmp_batch_id * tmp_batch_size:(tmp_batch_id + 1) *
                tmp_batch_size]  # tmp_batch of candidates
            tmp_batch_from = from_which_original_batch[tmp_batch_id *
                                                       tmp_batch_size:
                                                       (tmp_batch_id + 1) *
                                                       tmp_batch_size]

            tmp_batch_cand_representation_sequence, tmp_batch_cand_mask = self.encode_text(
                tmp_batch_cand, use_model=use_model
            )  # tmp_batch x num_word x hid, tmp_batch x num_word

            # masked mean the num_word dimension
            _mask = torch.sum(tmp_batch_cand_mask, -1)  # batch
            tmp_batch_cand_representation = torch.sum(
                tmp_batch_cand_representation_sequence, -2)  # batch x hid
            tmp = torch.eq(_mask, 0).float()
            if tmp_batch_cand_representation.is_cuda:
                tmp = tmp.cuda()
            _mask = _mask + tmp
            tmp_batch_cand_representation = tmp_batch_cand_representation / _mask.unsqueeze(
                -1)  # batch x hid
            tmp_batch_cand_mask = tmp_batch_cand_mask.byte().any(
                -1).float()  # batch

            for i in range(len(tmp_batch_from)):
                res_representations[
                    tmp_batch_from[i][0],
                    tmp_batch_from[i][1], :] = tmp_batch_cand_representation[i]
                res_mask[tmp_batch_from[i][0],
                         tmp_batch_from[i][1]] = tmp_batch_cand_mask[i]

        return res_representations, res_mask
Example #3
0
 def get_word_input_from_ids(self, word_id_list):
     input_word = pad_sequences(word_id_list, maxlen=max_len(word_id_list) + 3, dtype='int32')  # 3 --> see layer.DepthwiseSeparableConv.padding
     input_word = to_pt(input_word, self.use_cuda)
     return input_word