def sample_negatives(self, y): bsz, fsz, tsz = y.shape y = y.transpose(0, 1) # BCT -> CBT y = y.contiguous().view(fsz, -1) # CBT => C(BxT) cross_high = tsz * bsz high = tsz if self.sample_distance is None else min(tsz, self.sample_distance) assert high > 1 neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz)) with torch.no_grad(): if self.n_negatives > 0: tszs = ( buffered_arange(tsz) .unsqueeze(-1) .expand(-1, self.n_negatives) .flatten() ) neg_idxs = torch.randint( low=0, high=high - 1, size=(bsz, self.n_negatives * tsz) ) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = ( buffered_arange(tsz) .unsqueeze(-1) .expand(-1, self.cross_sample_negatives) .flatten() ) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * tsz), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[..., neg_idxs.view(-1)] negs = negs.view( fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz ).permute( 2, 1, 0, 3 ) # to NxBxCxT return negs
def sample_negatives(self, y, num): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz high = tsz with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = ( buffered_arange(num) .unsqueeze(-1) .expand(-1, self.n_negatives) .flatten() ) neg_idxs = torch.randint( low=0, high=high - 1, size=(bsz, self.n_negatives * num) ) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = ( buffered_arange(num) .unsqueeze(-1) .expand(-1, self.cross_sample_negatives) .flatten() ) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: for i in range(1, bsz): neg_idxs[i] += i * high else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view( bsz, num, self.n_negatives + self.cross_sample_negatives, fsz ).permute( 2, 0, 1, 3 ) # to NxBxTxC return negs, neg_idxs
def sample_negatives(self, y, num, padding_count=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) bsz, tsz, fsz = y.shape y = y.view(-1, fsz) # BTC => (BxT)C # FIXME: what happens if padding_count is specified? cross_high = tsz * bsz high = tsz - (padding_count or 0) with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = ( buffered_arange(num) .unsqueeze(-1) .expand(-1, self.n_negatives) .flatten() ) neg_idxs = torch.randint( low=0, high=high - 1, size=(bsz, self.n_negatives * num) ) neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: tszs = ( buffered_arange(num) .unsqueeze(-1) .expand(-1, self.cross_sample_negatives) .flatten() ) cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, size=(bsz, self.cross_sample_negatives * num), ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 if self.n_negatives > 0: neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) negs = y[neg_idxs.view(-1)] negs = negs.view( bsz, num, self.n_negatives + self.cross_sample_negatives, fsz ).permute( 2, 0, 1, 3 ) # to NxBxTxC return negs, neg_idxs
def convert_padding_direction( src_frames, src_lengths, right_to_left=False, left_to_right=False, ): """Counterpart of :func:`~fairseq.utils.convert_padding_direction`, operating on 3d tensors of size B x T x C. Note that this function is unware of whether it has already been right padded or left padded (since any real value is legal for non-padded elements), so be clear of the actual padding direction before calling this function. """ assert right_to_left ^ left_to_right assert src_frames.size(0) == src_lengths.size(0) max_len = src_frames.size(1) if not src_lengths.eq(max_len).any(): # no padding, return early return src_frames range = utils.buffered_arange(max_len).unsqueeze(-1).expand_as(src_frames) num_pads = (max_len - src_lengths.type_as(range)).unsqueeze(-1).unsqueeze(-1) if right_to_left: index = torch.remainder(range - num_pads, max_len) else: index = torch.remainder(range + num_pads, max_len) return src_frames.gather(1, index)
def _with_sentence_boundaries( self, input: torch.Tensor, ): if not self.add_bos and not self.add_eos: return input zero_block = input.new(input.size(0), 0) bos_block = input.new_full( (input.size(0), 1), self.eos_idx) if self.add_bos else zero_block pad_block = input.new_full( (input.size(0), 1), self.padding_idx) if self.add_eos else zero_block # add eos in the beginning and pad to the end of the sentence input = torch.cat([bos_block, input, pad_block], dim=1) if self.add_eos: num_pads = input.eq(self.padding_idx).long().sum(dim=1, keepdim=True) max_len = input.size(1) # index of the first pad first_pads = buffered_arange(max_len).type_as(input).view( 1, -1).expand(input.size(0), -1).eq(max_len - num_pads) input[first_pads] = self.eos_idx return input
def _with_sentence_boundaries( self, input: torch.Tensor): """ Args: input: the sentence Tensor it's bs * seq_len * num_chars in case of char input and bs*seq_len in case of token input Returns: tuple, 1) processed input, 2) tensor mask for the eos position of each sentence, None if did not add eos """ if not self.add_bos and not self.add_eos: return input, None zero_block = input.new(0, 0) block_size = (input.size(0), 1, input.size(2)) if self.char_inputs else (input.size(0), 1) bos_block = torch.full(block_size, self.eos_idx).type_as(input) if self.add_bos else zero_block pad_block = torch.full(block_size, self.padding_idx).type_as(input) if self.add_eos else zero_block # add eos in the beginning and pad to the end of the sentence input = torch.cat([bos_block, input, pad_block], dim=1) first_pads = None # if not add_eos, then first_pads is not valid, set to None if self.add_eos: index_block = input[:, :, 0] if self.char_inputs else input padding_mask = index_block.eq(self.padding_idx) num_pads = padding_mask.long().sum(dim=1, keepdim=True) max_len = input.size(1) # index of the first pad if self.onnx_trace: first_pads = torch._dim_arange(input, 1).type_as(input).view(1, -1).\ repeat(input.size(0), 1).eq(max_len - num_pads) eos_indices = first_pads if self.char_inputs: eos_indices = eos_indices.unsqueeze(2).repeat(1, 1, input.size(-1)) input = torch.where(eos_indices, torch.Tensor([self.eos_idx]).type_as(input), input) else: first_pads = buffered_arange(max_len).type_as(input).view(1, -1).\ expand(input.size(0), -1).eq(max_len - num_pads) eos_indices = first_pads if self.char_inputs: eos_indices = eos_indices.unsqueeze(2).expand_as(input) input[eos_indices] = self.eos_idx return input, first_pads