Example #1
0
    def sample_negatives(self, y):
        bsz, fsz, tsz = y.shape

        y = y.transpose(0, 1)  # BCT -> CBT
        y = y.contiguous().view(fsz, -1)  # CBT => C(BxT)

        cross_high = tsz * bsz
        high = tsz if self.sample_distance is None else min(tsz, self.sample_distance)
        assert high > 1

        neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz))

        with torch.no_grad():
            if self.n_negatives > 0:
                tszs = (
                    buffered_arange(tsz)
                    .unsqueeze(-1)
                    .expand(-1, self.n_negatives)
                    .flatten()
                )

                neg_idxs = torch.randint(
                    low=0, high=high - 1, size=(bsz, self.n_negatives * tsz)
                )
                neg_idxs[neg_idxs >= tszs] += 1

            if self.cross_sample_negatives > 0:
                tszs = (
                    buffered_arange(tsz)
                    .unsqueeze(-1)
                    .expand(-1, self.cross_sample_negatives)
                    .flatten()
                )

                cross_neg_idxs = torch.randint(
                    low=0,
                    high=cross_high - 1,
                    size=(bsz, self.cross_sample_negatives * tsz),
                )
                cross_neg_idxs[cross_neg_idxs >= tszs] += 1

        if self.n_negatives > 0:
            for i in range(1, bsz):
                neg_idxs[i] += i * high
        else:
            neg_idxs = cross_neg_idxs

        if self.cross_sample_negatives > 0 and self.n_negatives > 0:
            neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)

        negs = y[..., neg_idxs.view(-1)]
        negs = negs.view(
            fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz
        ).permute(
            2, 1, 0, 3
        )  # to NxBxCxT

        return negs
Example #2
0
    def sample_negatives(self, y, num):

        if self.n_negatives == 0 and self.cross_sample_negatives == 0:
            return y.new(0)

        bsz, tsz, fsz = y.shape
        y = y.view(-1, fsz)  # BTC => (BxT)C

        cross_high = tsz * bsz
        high = tsz
        with torch.no_grad():
            assert high > 1, f"{bsz,tsz,fsz}"

            if self.n_negatives > 0:
                tszs = (
                    buffered_arange(num)
                    .unsqueeze(-1)
                    .expand(-1, self.n_negatives)
                    .flatten()
                )

                neg_idxs = torch.randint(
                    low=0, high=high - 1, size=(bsz, self.n_negatives * num)
                )
                neg_idxs[neg_idxs >= tszs] += 1

            if self.cross_sample_negatives > 0:
                tszs = (
                    buffered_arange(num)
                    .unsqueeze(-1)
                    .expand(-1, self.cross_sample_negatives)
                    .flatten()
                )

                cross_neg_idxs = torch.randint(
                    low=0,
                    high=cross_high - 1,
                    size=(bsz, self.cross_sample_negatives * num),
                )
                cross_neg_idxs[cross_neg_idxs >= tszs] += 1

        if self.n_negatives > 0:
            for i in range(1, bsz):
                neg_idxs[i] += i * high
        else:
            neg_idxs = cross_neg_idxs

        if self.cross_sample_negatives > 0 and self.n_negatives > 0:
            neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)

        negs = y[neg_idxs.view(-1)]
        negs = negs.view(
            bsz, num, self.n_negatives + self.cross_sample_negatives, fsz
        ).permute(
            2, 0, 1, 3
        )  # to NxBxTxC
        return negs, neg_idxs
Example #3
0
    def sample_negatives(self, y, num, padding_count=None):

        if self.n_negatives == 0 and self.cross_sample_negatives == 0:
            return y.new(0)

        bsz, tsz, fsz = y.shape
        y = y.view(-1, fsz)  # BTC => (BxT)C

        # FIXME: what happens if padding_count is specified?
        cross_high = tsz * bsz
        high = tsz - (padding_count or 0)
        with torch.no_grad():
            assert high > 1, f"{bsz,tsz,fsz}"

            if self.n_negatives > 0:
                tszs = (
                    buffered_arange(num)
                    .unsqueeze(-1)
                    .expand(-1, self.n_negatives)
                    .flatten()
                )

                neg_idxs = torch.randint(
                    low=0, high=high - 1, size=(bsz, self.n_negatives * num)
                )
                neg_idxs[neg_idxs >= tszs] += 1

            if self.cross_sample_negatives > 0:
                tszs = (
                    buffered_arange(num)
                    .unsqueeze(-1)
                    .expand(-1, self.cross_sample_negatives)
                    .flatten()
                )

                cross_neg_idxs = torch.randint(
                    low=0,
                    high=cross_high - 1,
                    size=(bsz, self.cross_sample_negatives * num),
                )
                cross_neg_idxs[cross_neg_idxs >= tszs] += 1

        if self.n_negatives > 0:
            neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high)
        else:
            neg_idxs = cross_neg_idxs

        if self.cross_sample_negatives > 0 and self.n_negatives > 0:
            neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)

        negs = y[neg_idxs.view(-1)]
        negs = negs.view(
            bsz, num, self.n_negatives + self.cross_sample_negatives, fsz
        ).permute(
            2, 0, 1, 3
        )  # to NxBxTxC
        return negs, neg_idxs
Example #4
0
def convert_padding_direction(
    src_frames,
    src_lengths,
    right_to_left=False,
    left_to_right=False,
):
    """Counterpart of :func:`~fairseq.utils.convert_padding_direction`,
    operating on 3d tensors of size B x T x C. Note that this function is unware
    of whether it has already been right padded or left padded (since any real
    value is legal for non-padded elements), so be clear of the actual padding
    direction before calling this function.
    """
    assert right_to_left ^ left_to_right
    assert src_frames.size(0) == src_lengths.size(0)
    max_len = src_frames.size(1)
    if not src_lengths.eq(max_len).any():
        # no padding, return early
        return src_frames
    range = utils.buffered_arange(max_len).unsqueeze(-1).expand_as(src_frames)
    num_pads = (max_len -
                src_lengths.type_as(range)).unsqueeze(-1).unsqueeze(-1)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_frames.gather(1, index)
    def _with_sentence_boundaries(
        self,
        input: torch.Tensor,
    ):
        if not self.add_bos and not self.add_eos:
            return input

        zero_block = input.new(input.size(0), 0)
        bos_block = input.new_full(
            (input.size(0), 1), self.eos_idx) if self.add_bos else zero_block
        pad_block = input.new_full(
            (input.size(0),
             1), self.padding_idx) if self.add_eos else zero_block

        # add eos in the beginning and pad to the end of the sentence
        input = torch.cat([bos_block, input, pad_block], dim=1)

        if self.add_eos:
            num_pads = input.eq(self.padding_idx).long().sum(dim=1,
                                                             keepdim=True)
            max_len = input.size(1)

            # index of the first pad
            first_pads = buffered_arange(max_len).type_as(input).view(
                1, -1).expand(input.size(0), -1).eq(max_len - num_pads)
            input[first_pads] = self.eos_idx

        return input
Example #6
0
    def _with_sentence_boundaries(
            self,
            input: torch.Tensor):
        """
        Args:
            input: the sentence Tensor
            it's bs * seq_len * num_chars in case of char input and bs*seq_len in case of token input

        Returns:
            tuple,
                1) processed input,
                2) tensor mask for the eos position of each sentence,
                    None if did not add eos
        """
        if not self.add_bos and not self.add_eos:
            return input, None

        zero_block = input.new(0, 0)
        block_size = (input.size(0), 1, input.size(2)) if self.char_inputs else (input.size(0), 1)
        bos_block = torch.full(block_size, self.eos_idx).type_as(input) if self.add_bos else zero_block
        pad_block = torch.full(block_size, self.padding_idx).type_as(input) if self.add_eos else zero_block

        # add eos in the beginning and pad to the end of the sentence
        input = torch.cat([bos_block, input, pad_block], dim=1)

        first_pads = None  # if not add_eos, then first_pads is not valid, set to None
        if self.add_eos:
            index_block = input[:, :, 0] if self.char_inputs else input
            padding_mask = index_block.eq(self.padding_idx)
            num_pads = padding_mask.long().sum(dim=1, keepdim=True)
            max_len = input.size(1)

            # index of the first pad
            if self.onnx_trace:
                first_pads = torch._dim_arange(input, 1).type_as(input).view(1, -1).\
                    repeat(input.size(0), 1).eq(max_len - num_pads)
                eos_indices = first_pads
                if self.char_inputs:
                    eos_indices = eos_indices.unsqueeze(2).repeat(1, 1, input.size(-1))
                input = torch.where(eos_indices, torch.Tensor([self.eos_idx]).type_as(input), input)
            else:
                first_pads = buffered_arange(max_len).type_as(input).view(1, -1).\
                    expand(input.size(0), -1).eq(max_len - num_pads)
                eos_indices = first_pads
                if self.char_inputs:
                    eos_indices = eos_indices.unsqueeze(2).expand_as(input)
                input[eos_indices] = self.eos_idx

        return input, first_pads