Ejemplo n.º 1
0
    def _expand_inputs_for_generation(
            input_ids: torch.LongTensor,
            expand_size: int = 1,
            is_encoder_decoder: bool = False,
            attention_mask: torch.LongTensor = None,
            encoder_outputs: ModelOutput = None,
            **model_kwargs) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (torch.arange(input_ids.shape[0]).view(
            -1, 1).repeat(1, expand_size).view(-1).to(input_ids.device))
        input_ids = input_ids.index_select(0, expanded_return_idx)

        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(
                0, expanded_return_idx)

        if model_kwargs["token_type_ids"] is not None:
            model_kwargs["token_type_ids"] = model_kwargs[
                "token_type_ids"].index_select(0, expanded_return_idx)

        if is_encoder_decoder:
            assert encoder_outputs is not None
            encoder_outputs[
                "last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
                    0, expanded_return_idx)
            model_kwargs["encoder_outputs"] = encoder_outputs

        return input_ids, model_kwargs
Ejemplo n.º 2
0
Archivo: utils.py Proyecto: neulab/lrlm
    def forward(self, input: Tensor, target: LongTensor) -> Tensor:  # type: ignore
        """
        hidden :: [len*bsz x d_proj]
        target :: [len*bsz]
        """
        input_shape = input.size()
        input = input.contiguous().view(-1, input_shape[-1])
        target = target.contiguous().view(-1)

        if input.size(0) != target.size(0):
            raise RuntimeError('Input and target should have the same size '
                               'in the batch dimension.')

        if self.n_clusters == 0:
            logits = self._compute_logits(input, self.out_layers[0].weight,
                                          self.out_layers[0].bias, self.out_projs[0])
            nll = F.nll_loss(logits, target, reduction='none')
        else:
            weights, biases = self._construct_weights()

            head_weight, head_bias = weights[0], biases[0]
            head_proj = self.out_projs[0] if len(self.out_projs) > 0 else None

            head_logits = self._compute_logits(input, head_weight, head_bias, head_proj)
            head_log_probs = F.log_softmax(head_logits, dim=1)

            nonzero_indices: List[torch.ByteTensor] = [
                ((target >= l) & (target < r)).nonzero().squeeze()
                for l, r in zip(self.cutoffs[:-1], self.cutoffs[1:])
            ]
            head_indices: LongTensor = target.clone()
            for idx, indices in enumerate(nonzero_indices):
                if indices.numel() == 0:
                    continue
                index = self.shortlist_size + self.n_clusters - 1 - idx
                head_indices.index_fill_(0, indices, index)

            head_nll = F.nll_loss(head_log_probs, head_indices, reduction='none')

            for idx, indices in enumerate(nonzero_indices):
                if indices.numel() == 0:
                    continue

                weight_i, bias_i = weights[idx + 1], biases[idx + 1]
                proj_i = self.out_projs[idx + 1] if len(self.out_projs) > idx + 1 else None

                cluster_hidden = input.index_select(0, indices)
                cluster_target = target.index_select(0, indices) - self.cutoffs[idx]

                cluster_logits = self._compute_logits(cluster_hidden, weight_i, bias_i, proj_i)
                cluster_nll = F.cross_entropy(cluster_logits, cluster_target, reduction='none')

                tail_nll = torch.zeros_like(head_nll)
                tail_nll.index_copy_(0, indices, cluster_nll)
                head_nll = head_nll + tail_nll

            nll = head_nll

        nll = nll.view(input_shape[:-1])
        return nll
Ejemplo n.º 3
0
    def forward(self, input_tokens: torch.LongTensor, input_lengths: List[int],
                init_hidden: Tuple[torch.Tensor, torch.Tensor], encoded_commands: torch.Tensor,
                commands_lengths: List[int], encoded_situations: torch.Tensor) -> Tuple[torch.Tensor, List[int],
                                                                                        torch.Tensor]:
        """
        Run batch attention decoder forward for a series of steps
         Each decoder step considers all of the encoder_outputs through attention.
         Attention retrieval is based on decoder hidden state (not cell state)

        :param input_tokens: [batch_size, max_length];  padded target sequences
        :param input_lengths: [batch_size] for sequence length of each padded target sequence
        :param init_hidden: tuple of tensors [num_layers, batch_size, hidden_size] (for hidden and cell)
        :param encoded_commands: [max_input_length, batch_size, embedding_dim]
        :param commands_lengths: [batch_size] sequence length of each encoder sequence (without padding)
        :param encoded_situations: [batch_size, image_width * image_width, image_features]; encoded image situations.
        :return: output : unnormalized log-score, [max_length, batch_size, output_size]
          hidden : current decoder state, tuple with each [num_layers, batch_size, hidden_size] (for hidden and cell)
        """
        batch_size, max_time = input_tokens.size()

        # Sort the sequences by length in descending order
        input_lengths = torch.tensor(input_lengths, dtype=torch.long, device=device)
        input_lengths, perm_idx = torch.sort(input_lengths, descending=True)
        input_tokens_sorted = input_tokens.index_select(dim=0, index=perm_idx)
        initial_h, initial_c = init_hidden
        hidden = (initial_h.index_select(dim=1, index=perm_idx),
                  initial_c.index_select(dim=1, index=perm_idx))
        encoded_commands = encoded_commands.index_select(dim=1, index=perm_idx)
        commands_lengths = torch.tensor(commands_lengths, device=device)
        commands_lengths = commands_lengths.index_select(dim=0, index=perm_idx)
        encoded_situations = encoded_situations.index_select(dim=0, index=perm_idx)

        # For efficiency
        projected_keys_visual = self.visual_attention.key_layer(
            encoded_situations)  # [batch_size, situation_length, dec_hidden_dim]
        projected_keys_textual = self.textual_attention.key_layer(
            encoded_commands)  # [max_input_length, batch_size, dec_hidden_dim]

        all_attention_weights = []
        lstm_output = []
        for time in range(max_time):
            input_token = input_tokens_sorted[:, time]
            (output, hidden, context_situation, attention_weights_commands,
             attention_weights_situations) = self.forward_step(input_token, hidden, projected_keys_textual,
                                                               commands_lengths,
                                                               projected_keys_visual)
            all_attention_weights.append(attention_weights_situations.unsqueeze(0))
            lstm_output.append(output.unsqueeze(0))
        lstm_output = torch.cat(lstm_output, dim=0)  # [max_time, batch_size, output_size]
        attention_weights = torch.cat(all_attention_weights, dim=0)  # [max_time, batch_size, situation_dim**2]

        # Reverse the sorting
        _, unperm_idx = perm_idx.sort(0)
        lstm_output = lstm_output.index_select(dim=1, index=unperm_idx)  # [max_time, batch_size, output_size]
        seq_len = input_lengths[unperm_idx].tolist()
        attention_weights = attention_weights.index_select(dim=1, index=unperm_idx)

        return lstm_output, seq_len, attention_weights.sum(dim=0)