コード例 #1
0
 def mask(self, tensor):
     _, half_dim, dim = tensor.size()
     if self.onnx_trace:
         # triu and tril are not supported in onnx
         a = torch._dim_arange(tensor, 2).unsqueeze(0).repeat(half_dim, 1)
         b = torch._dim_arange(tensor, 1).unsqueeze(1).repeat(1, dim)
         mask = (a > b + half_dim).float() + (a < b).float()
         mask = torch.where(mask > 0,
                            torch.Tensor([0]).type_as(tensor),
                            torch.Tensor([float("-Inf")]).type_as(tensor))
     else:
         ones = tensor.new_ones(half_dim, dim).bool()
         mask = ones.triu(half_dim + 1) + ones.tril(-1)
         mask = utils.fill_with_neg_inf(tensor.new(
             mask.size())).masked_fill_(mask, 0)
     return mask
コード例 #2
0
ファイル: utils.py プロジェクト: fyabc/fairseq
def make_positions(tensor, padding_idx, left_pad, onnx_trace=False):
    """Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1.

    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    """
    if onnx_trace:
        range_buf = torch._dim_arange(like=tensor, dim=1) + padding_idx + 1
        mask = tensor.ne(padding_idx)
        positions = range_buf.expand_as(tensor)
        if left_pad:
            positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
        return positions * mask.long() + positions * (1 - mask.long())

    max_pos = padding_idx + 1 + tensor.size(1)
    if not hasattr(make_positions, 'range_buf'):
        make_positions.range_buf = tensor.new()
    make_positions.range_buf = make_positions.range_buf.type_as(tensor)
    if make_positions.range_buf.numel() < max_pos:
        torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
    mask = tensor.ne(padding_idx)
    positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
    if left_pad:
        positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
    return tensor.clone().masked_scatter_(mask, positions[mask])
コード例 #3
0
def make_positions(tensor, padding_idx, left_pad, onnx_trace=False):
    """Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1.

    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    """
    if onnx_trace:
        range_buf = torch._dim_arange(like=tensor, dim=1) + padding_idx + 1
        mask = tensor.ne(padding_idx)
        positions = range_buf.expand_as(tensor)
        if left_pad:
            positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
        return positions * mask.long() + padding_idx * (1 - mask.long())

    max_pos = padding_idx + 1 + tensor.size(1)
    #if not hasattr(make_positions, 'range_buf'):
    make_positions.range_buf = tensor.new()
    make_positions.range_buf = make_positions.range_buf.type_as(tensor)
    if make_positions.range_buf.numel() < max_pos:
        torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
    mask = tensor.ne(padding_idx)
    positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
    if left_pad:
        positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
    return tensor.clone().masked_scatter_(mask, positions[mask])
コード例 #4
0
 def buffered_past_mask(self, tensor):
     dim = tensor.size(0)
     if self.onnx_trace:
         a = torch._dim_arange(tensor, 0).unsqueeze(0).repeat(dim, 1)
         b = torch._dim_arange(tensor, 0).unsqueeze(1).repeat(1, dim)
         past_mask = a < b
         past_mask_neg_inf = torch.where(past_mask,
                                         torch.Tensor([float("-Inf")]),
                                         torch.Tensor([0])).type_as(tensor)
         return past_mask_neg_inf
     if not hasattr(
             self, '_past_mask'
     ) or self._past_mask is None or self._past_mask.device != tensor.device:
         self._past_mask = torch.tril(
             utils.fill_with_neg_inf(tensor.new(dim, dim)), -1)
     if self._past_mask.size(0) < dim:
         self._past_mask = torch.tril(
             utils.fill_with_neg_inf(self._past_mask.resize_(dim, dim)), -1)
     return self._past_mask[:dim, :dim]
コード例 #5
0
    def _with_sentence_boundaries(
            self,
            input: torch.Tensor):
        """
        Args:
            input: the sentence Tensor
            it's bs * seq_len * num_chars in case of char input and bs*seq_len in case of token input

        Returns:
            tuple,
                1) processed input,
                2) tensor mask for the eos position of each sentence,
                    None if did not add eos
        """
        if not self.add_bos and not self.add_eos:
            return input, None

        zero_block = input.new(0, 0)
        block_size = (input.size(0), 1, input.size(2)) if self.char_inputs else (input.size(0), 1)
        bos_block = torch.full(block_size, self.eos_idx).type_as(input) if self.add_bos else zero_block
        pad_block = torch.full(block_size, self.padding_idx).type_as(input) if self.add_eos else zero_block

        # add eos in the beginning and pad to the end of the sentence
        input = torch.cat([bos_block, input, pad_block], dim=1)

        first_pads = None  # if not add_eos, then first_pads is not valid, set to None
        if self.add_eos:
            index_block = input[:, :, 0] if self.char_inputs else input
            padding_mask = index_block.eq(self.padding_idx)
            num_pads = padding_mask.long().sum(dim=1, keepdim=True)
            max_len = input.size(1)

            # index of the first pad
            if self.onnx_trace:
                first_pads = torch._dim_arange(input, 1).type_as(input).view(1, -1).\
                    repeat(input.size(0), 1).eq(max_len - num_pads)
                eos_indices = first_pads
                if self.char_inputs:
                    eos_indices = eos_indices.unsqueeze(2).repeat(1, 1, input.size(-1))
                input = torch.where(eos_indices, torch.Tensor([self.eos_idx]).type_as(input), input)
            else:
                first_pads = buffered_arange(max_len).type_as(input).view(1, -1).\
                    expand(input.size(0), -1).eq(max_len - num_pads)
                eos_indices = first_pads
                if self.char_inputs:
                    eos_indices = eos_indices.unsqueeze(2).expand_as(input)
                input[eos_indices] = self.eos_idx

        return input, first_pads
コード例 #6
0
ファイル: positional.py プロジェクト: shlomota/newscaptioning
def make_positions(X, padding_idx, left_pad, onnx_trace=False):
    """Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1.

    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    """
    max_seq_len = X.shape[1]
    # torch._dim_arange is a temporary hack to allow tracing of arange like
    # constructs with dynamic bounds on arange.  Normal arange is not traceable
    # because it does not take any tensor inputs; if the range you need is
    # based on another tensor, calling this function directly will preserve
    # tracing.  Get rid of this when arange can directly take tensors for
    # bounds (so that it can be traced directly).
    if onnx_trace:
        range_buf = torch._dim_arange(like=X, dim=1) + padding_idx + 1
        mask = X.ne(padding_idx)
        positions = range_buf.expand_as(X)
        if left_pad:
            offsets = max_seq_len - mask.long().sum(dim=1).unsqueeze(1)
            positions = positions - offsets
        return positions * mask.long() + padding_idx * (1 - mask.long())

    max_pos = padding_idx + 1 + X.size(1)

    # Function attributes are used for caching
    if not hasattr(make_positions, 'range_buf'):
        make_positions.range_buf = X.new()
    make_positions.range_buf = make_positions.range_buf.type_as(X)
    if make_positions.range_buf.numel() < max_pos:
        torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
    mask = X.ne(padding_idx)
    positions = make_positions.range_buf[:X.size(1)].expand_as(X)
    if left_pad:
        offsets = max_seq_len - mask.long().sum(dim=1).unsqueeze(1)
        positions = positions - offsets
    return X.clone().masked_scatter_(mask, positions[mask])
コード例 #7
0
 def forward(self, input):
     return torch._dim_arange(input, 1)
コード例 #8
0
    def forward(self, input_token, target_token, timestep, *inputs):
        """
        Decoder step inputs correspond one-to-one to encoder outputs.
        """
        log_probs_per_model = []
        state_outputs = []

        next_state_input = len(self.models)

        # underlying assumption is each model has same vocab_reduction_module
        vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
        if vocab_reduction_module is not None:
            possible_translation_tokens = inputs[len(self.models)]
            next_state_input += 1
        else:
            possible_translation_tokens = None

        for i, model in enumerate(self.models):
            encoder_output = inputs[i]
            prev_hiddens = []
            prev_cells = []

            for _ in range(len(model.decoder.layers)):
                prev_hiddens.append(inputs[next_state_input])
                prev_cells.append(inputs[next_state_input + 1])
                next_state_input += 2
            prev_input_feed = inputs[next_state_input].view(1, -1)
            next_state_input += 1

            # no batching, we only care about care about "max" length
            src_length_int = int(encoder_output.size()[0])
            src_length = torch.LongTensor(np.array([src_length_int]))

            # notional, not actually used for decoder computation
            src_tokens = torch.LongTensor(np.array([[0] * src_length_int]))
            src_embeddings = encoder_output.new_zeros(encoder_output.shape)

            encoder_out = (
                encoder_output,
                prev_hiddens,
                prev_cells,
                src_length,
                src_tokens,
                src_embeddings,
            )

            # store cached states, use evaluation mode
            model.decoder._is_incremental_eval = True
            model.eval()

            # placeholder
            incremental_state = {}

            # cache previous state inputs
            utils.set_incremental_state(
                model.decoder,
                incremental_state,
                "cached_state",
                (prev_hiddens, prev_cells, prev_input_feed),
            )

            decoder_output = model.decoder(
                input_token.view(1, 1),
                encoder_out,
                incremental_state=incremental_state,
                possible_translation_tokens=possible_translation_tokens,
            )
            logits, _, _ = decoder_output

            log_probs = F.log_softmax(logits, dim=2)

            log_probs_per_model.append(log_probs)

            (next_hiddens, next_cells,
             next_input_feed) = utils.get_incremental_state(
                 model.decoder, incremental_state, "cached_state")

            for h, c in zip(next_hiddens, next_cells):
                state_outputs.extend([h, c])
            state_outputs.append(next_input_feed)

        average_log_probs = torch.mean(torch.cat(log_probs_per_model, dim=0),
                                       dim=0,
                                       keepdim=True)

        if possible_translation_tokens is not None:
            reduced_indices = torch.zeros(self.vocab_size).long().fill_(
                self.unk_token)
            # ONNX-exportable arange (ATen op)
            possible_translation_token_range = torch._dim_arange(
                like=possible_translation_tokens, dim=0)
            reduced_indices[
                possible_translation_tokens] = possible_translation_token_range
            reduced_index = reduced_indices.index_select(dim=0,
                                                         index=target_token)
            score = average_log_probs.view(
                (-1, )).index_select(dim=0, index=reduced_index)
        else:
            score = average_log_probs.view(
                (-1, )).index_select(dim=0, index=target_token)

        word_reward = self.word_rewards.index_select(0, target_token)
        score += word_reward

        self.input_names = ["prev_token", "target_token", "timestep"]
        for i in range(len(self.models)):
            self.input_names.append(f"fixed_input_{i}")

        if possible_translation_tokens is not None:
            self.input_names.append("possible_translation_tokens")

        outputs = [score]
        self.output_names = ["score"]

        for i in range(len(self.models)):
            self.output_names.append(f"fixed_input_{i}")
            outputs.append(inputs[i])

        if possible_translation_tokens is not None:
            self.output_names.append("possible_translation_tokens")
            outputs.append(possible_translation_tokens)

        for i, state in enumerate(state_outputs):
            outputs.append(state)
            self.output_names.append(f"state_output_{i}")
            self.input_names.append(f"state_input_{i}")

        return tuple(outputs)
コード例 #9
0
 def forward(self, x):
     return torch._dim_arange(x, 1)
コード例 #10
0
def reverse_sort(x: torch.Tensor, dim: int) -> torch.Tensor:
    new_indices = torch.empty_like(x)
    new_size = [1] * dim + [x.size(dim)] + [1] * (x.ndimension() - dim - 1)
    arange = torch._dim_arange(x, dim=dim).reshape(new_size).expand_as(x)
    new_indices.scatter_(dim=dim, index=x, src=arange)
    return new_indices