def encode(self, xs, task='all', streaming=False, lookback=False, lookahead=False): """Encode acoustic or text features. Args: xs (list): A list of length `[B]`, which contains Tensor of size `[T, input_dim]` task (str): all/ys*/ys_sub1*/ys_sub2* streaming (bool): streaming encoding lookback (bool): truncate leftmost frames for lookback in CNN context lookahead (bool): truncate rightmost frames for lookahead in CNN context Returns: eout_dict (dict): """ if self.input_type == 'speech': # Frame stacking if self.n_stacks > 1: xs = [stack_frame(x, self.n_stacks, self.n_skips) for x in xs] # Splicing if self.n_splices > 1: xs = [splice(x, self.n_splices, self.n_stacks) for x in xs] xlens = torch.IntTensor([len(x) for x in xs]) xs = pad_list([np2tensor(x, self.device).float() for x in xs], 0.) # SpecAugment if self.specaug is not None and self.training: xs = self.specaug(xs) # Weight noise injection if self.weight_noise_std > 0 and self.training: self.add_weight_noise(std=self.weight_noise_std) # Input Gaussian noise injection if self.input_noise_std > 0 and self.training: xs = add_input_noise(xs, std=self.input_noise_std) # Sequence summary network if self.ssn is not None: xs = self.ssn(xs, xlens) elif self.input_type == 'text': xlens = torch.IntTensor([len(x) for x in xs]) xs = [np2tensor(np.fromiter(x, dtype=np.int64), self.device) for x in xs] xs = pad_list(xs, self.pad) xs = self.dropout_emb(self.embed(xs)) # TODO(hirofumi): fix for Transformer # encoder eout_dict = self.enc(xs, xlens, task.split('.')[0], streaming, lookback, lookahead) if self.main_weight < 1 and self.enc_type in ['conv', 'tds', 'gated_conv']: for sub in ['sub1', 'sub2']: eout_dict['ys_' + sub]['xs'] = eout_dict['ys']['xs'].clone() eout_dict['ys_' + sub]['xlens'] = eout_dict['ys']['xlens'][:] return eout_dict
def test_forward(): batch_size = 4 xmax = 40 input_dim = 80 device = "cpu" xs = np.random.randn(batch_size, xmax, input_dim).astype(np.float32) xs = pad_list([np2tensor(x, device).float() for x in xs], 0.) out = add_input_noise(xs, std=0.075) assert out.size() == xs.size()