def command_generation_greedy_generation(self, observation_feats, task_desc_strings, previous_dynamics): with torch.no_grad(): batch_size = len(observation_feats) aggregated_obs_feat = self.aggregate_feats_seq(observation_feats) h_obs = self.online_net.vision_fc(aggregated_obs_feat) h_td, td_mask = self.encode(task_desc_strings, use_model="online") h_td_mean = self.online_net.masked_mean(h_td, td_mask).unsqueeze(1) h_obs = h_obs.to(h_td_mean.device) vision_td = torch.cat((h_obs, h_td_mean), dim=1) # batch x k boxes x hid vision_td_mask = torch.ones( (batch_size, h_obs.shape[1] + h_td_mean.shape[1])).to(h_td_mean.device) if self.recurrent: averaged_vision_td_representation = self.online_net.masked_mean( vision_td, vision_td_mask) current_dynamics = self.online_net.rnncell( averaged_vision_td_representation, previous_dynamics ) if previous_dynamics is not None else self.online_net.rnncell( averaged_vision_td_representation) else: current_dynamics = None # greedy generation input_target_list = [[self.word2id["[CLS]"]] for i in range(batch_size)] eos = np.zeros(batch_size) for _ in range(self.max_target_length): input_target = copy.deepcopy(input_target_list) input_target = pad_sequences( input_target, maxlen=max_len(input_target)).astype('int32') input_target = to_pt(input_target, self.use_cuda) target_mask = compute_mask( input_target) # mask of ground truth should be the same pred = self.online_net.vision_decode( input_target, target_mask, vision_td, vision_td_mask, current_dynamics) # batch x target_length x vocab # pointer softmax pred = to_np(pred[:, -1]) # batch x vocab pred = np.argmax(pred, -1) # batch for b in range(batch_size): new_stuff = [pred[b]] if eos[b] == 0 else [] input_target_list[b] = input_target_list[b] + new_stuff if pred[b] == self.word2id["[SEP]"]: eos[b] = 1 if np.sum(eos) == batch_size: break res = [self.tokenizer.decode(item) for item in input_target_list] res = [ item.replace("[CLS]", "").replace("[SEP]", "").strip() for item in res ] res = [item.replace(" in / on ", " in/on ") for item in res] return res, current_dynamics
def get_action_candidate_representations(self, action_candidate_list, use_model="online"): # in case there are too many candidates in certain data point, we compute their candidate representations by small batches batch_size = len(action_candidate_list) max_num_candidate = max_len(action_candidate_list) res_representations = torch.zeros(batch_size, max_num_candidate, self.online_net.block_hidden_dim) res_mask = torch.zeros(batch_size, max_num_candidate) if self.use_cuda: res_representations = res_representations.cuda() res_mask = res_mask.cuda() squeezed_candidate_list, from_which_original_batch = [], [] for b in range(batch_size): squeezed_candidate_list += action_candidate_list[b] for i in range(len(action_candidate_list[b])): from_which_original_batch.append((b, i)) tmp_batch_size = 64 n_tmp_batches = (len(squeezed_candidate_list) + tmp_batch_size - 1) // tmp_batch_size for tmp_batch_id in range(n_tmp_batches): tmp_batch_cand = squeezed_candidate_list[ tmp_batch_id * tmp_batch_size:(tmp_batch_id + 1) * tmp_batch_size] # tmp_batch of candidates tmp_batch_from = from_which_original_batch[tmp_batch_id * tmp_batch_size: (tmp_batch_id + 1) * tmp_batch_size] tmp_batch_cand_representation_sequence, tmp_batch_cand_mask = self.encode_text( tmp_batch_cand, use_model=use_model ) # tmp_batch x num_word x hid, tmp_batch x num_word # masked mean the num_word dimension _mask = torch.sum(tmp_batch_cand_mask, -1) # batch tmp_batch_cand_representation = torch.sum( tmp_batch_cand_representation_sequence, -2) # batch x hid tmp = torch.eq(_mask, 0).float() if tmp_batch_cand_representation.is_cuda: tmp = tmp.cuda() _mask = _mask + tmp tmp_batch_cand_representation = tmp_batch_cand_representation / _mask.unsqueeze( -1) # batch x hid tmp_batch_cand_mask = tmp_batch_cand_mask.byte().any( -1).float() # batch for i in range(len(tmp_batch_from)): res_representations[ tmp_batch_from[i][0], tmp_batch_from[i][1], :] = tmp_batch_cand_representation[i] res_mask[tmp_batch_from[i][0], tmp_batch_from[i][1]] = tmp_batch_cand_mask[i] return res_representations, res_mask
def get_word_input_from_ids(self, word_id_list): input_word = pad_sequences(word_id_list, maxlen=max_len(word_id_list) + 3, dtype='int32') # 3 --> see layer.DepthwiseSeparableConv.padding input_word = to_pt(input_word, self.use_cuda) return input_word