def _expand_inputs_for_generation( input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, attention_mask: torch.LongTensor = None, encoder_outputs: ModelOutput = None, **model_kwargs) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = (torch.arange(input_ids.shape[0]).view( -1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)) input_ids = input_ids.index_select(0, expanded_return_idx) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select( 0, expanded_return_idx) if model_kwargs["token_type_ids"] is not None: model_kwargs["token_type_ids"] = model_kwargs[ "token_type_ids"].index_select(0, expanded_return_idx) if is_encoder_decoder: assert encoder_outputs is not None encoder_outputs[ "last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs
def forward(self, input: Tensor, target: LongTensor) -> Tensor: # type: ignore """ hidden :: [len*bsz x d_proj] target :: [len*bsz] """ input_shape = input.size() input = input.contiguous().view(-1, input_shape[-1]) target = target.contiguous().view(-1) if input.size(0) != target.size(0): raise RuntimeError('Input and target should have the same size ' 'in the batch dimension.') if self.n_clusters == 0: logits = self._compute_logits(input, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0]) nll = F.nll_loss(logits, target, reduction='none') else: weights, biases = self._construct_weights() head_weight, head_bias = weights[0], biases[0] head_proj = self.out_projs[0] if len(self.out_projs) > 0 else None head_logits = self._compute_logits(input, head_weight, head_bias, head_proj) head_log_probs = F.log_softmax(head_logits, dim=1) nonzero_indices: List[torch.ByteTensor] = [ ((target >= l) & (target < r)).nonzero().squeeze() for l, r in zip(self.cutoffs[:-1], self.cutoffs[1:]) ] head_indices: LongTensor = target.clone() for idx, indices in enumerate(nonzero_indices): if indices.numel() == 0: continue index = self.shortlist_size + self.n_clusters - 1 - idx head_indices.index_fill_(0, indices, index) head_nll = F.nll_loss(head_log_probs, head_indices, reduction='none') for idx, indices in enumerate(nonzero_indices): if indices.numel() == 0: continue weight_i, bias_i = weights[idx + 1], biases[idx + 1] proj_i = self.out_projs[idx + 1] if len(self.out_projs) > idx + 1 else None cluster_hidden = input.index_select(0, indices) cluster_target = target.index_select(0, indices) - self.cutoffs[idx] cluster_logits = self._compute_logits(cluster_hidden, weight_i, bias_i, proj_i) cluster_nll = F.cross_entropy(cluster_logits, cluster_target, reduction='none') tail_nll = torch.zeros_like(head_nll) tail_nll.index_copy_(0, indices, cluster_nll) head_nll = head_nll + tail_nll nll = head_nll nll = nll.view(input_shape[:-1]) return nll
def forward(self, input_tokens: torch.LongTensor, input_lengths: List[int], init_hidden: Tuple[torch.Tensor, torch.Tensor], encoded_commands: torch.Tensor, commands_lengths: List[int], encoded_situations: torch.Tensor) -> Tuple[torch.Tensor, List[int], torch.Tensor]: """ Run batch attention decoder forward for a series of steps Each decoder step considers all of the encoder_outputs through attention. Attention retrieval is based on decoder hidden state (not cell state) :param input_tokens: [batch_size, max_length]; padded target sequences :param input_lengths: [batch_size] for sequence length of each padded target sequence :param init_hidden: tuple of tensors [num_layers, batch_size, hidden_size] (for hidden and cell) :param encoded_commands: [max_input_length, batch_size, embedding_dim] :param commands_lengths: [batch_size] sequence length of each encoder sequence (without padding) :param encoded_situations: [batch_size, image_width * image_width, image_features]; encoded image situations. :return: output : unnormalized log-score, [max_length, batch_size, output_size] hidden : current decoder state, tuple with each [num_layers, batch_size, hidden_size] (for hidden and cell) """ batch_size, max_time = input_tokens.size() # Sort the sequences by length in descending order input_lengths = torch.tensor(input_lengths, dtype=torch.long, device=device) input_lengths, perm_idx = torch.sort(input_lengths, descending=True) input_tokens_sorted = input_tokens.index_select(dim=0, index=perm_idx) initial_h, initial_c = init_hidden hidden = (initial_h.index_select(dim=1, index=perm_idx), initial_c.index_select(dim=1, index=perm_idx)) encoded_commands = encoded_commands.index_select(dim=1, index=perm_idx) commands_lengths = torch.tensor(commands_lengths, device=device) commands_lengths = commands_lengths.index_select(dim=0, index=perm_idx) encoded_situations = encoded_situations.index_select(dim=0, index=perm_idx) # For efficiency projected_keys_visual = self.visual_attention.key_layer( encoded_situations) # [batch_size, situation_length, dec_hidden_dim] projected_keys_textual = self.textual_attention.key_layer( encoded_commands) # [max_input_length, batch_size, dec_hidden_dim] all_attention_weights = [] lstm_output = [] for time in range(max_time): input_token = input_tokens_sorted[:, time] (output, hidden, context_situation, attention_weights_commands, attention_weights_situations) = self.forward_step(input_token, hidden, projected_keys_textual, commands_lengths, projected_keys_visual) all_attention_weights.append(attention_weights_situations.unsqueeze(0)) lstm_output.append(output.unsqueeze(0)) lstm_output = torch.cat(lstm_output, dim=0) # [max_time, batch_size, output_size] attention_weights = torch.cat(all_attention_weights, dim=0) # [max_time, batch_size, situation_dim**2] # Reverse the sorting _, unperm_idx = perm_idx.sort(0) lstm_output = lstm_output.index_select(dim=1, index=unperm_idx) # [max_time, batch_size, output_size] seq_len = input_lengths[unperm_idx].tolist() attention_weights = attention_weights.index_select(dim=1, index=unperm_idx) return lstm_output, seq_len, attention_weights.sum(dim=0)