def test_fwd(self): lexicon = [(0, 0), (0, 1), (1, 0), (1, 1)] blank_idx = 2 kernel_size = 5 stride = 3 convTrans = transducer.ConvTransduce1D(lexicon, kernel_size, stride, blank_idx) B = 2 C = 3 # Zero length inputs not allowed inputs = torch.randn(B, 0, C) with self.assertRaises(ValueError): convTrans(inputs) # Other inputs should be padded to be larger than kernel_size for Tin in [1, 2, 3, 4]: inputs = torch.randn(B, Tin, C) convTrans(inputs) Tin = (1, 3, 4, 6, 7, 8) Tout = (1, 1, 2, 2, 3, 3) for Ti, To in zip(Tin, Tout): inputs = torch.randn(B, Ti, C) outputs = convTrans(inputs) self.assertEqual(outputs.shape, (B, To, len(lexicon)))
def __init__( self, input_size, output_size, tokens, kernel_size, stride, tds1, tds2, wfst=True, **kwargs, ): super(TDS2dTransducer, self).__init__() # TDS2d -> ConvTransducer -> TDS2d # Setup lexicon for transducer layer: with open(tokens, 'r') as fid: output_tokens = [l.strip() for l in fid] input_tokens = set(t for token in output_tokens for t in token) input_tokens = {t: e for e, t in enumerate(sorted(input_tokens))} lexicon = [ tuple(input_tokens[t] for t in token) for token in output_tokens ] in_token_size = len(input_tokens) + 1 blank_idx = len(input_tokens) # output size of tds1 is number of input tokens + 1 for blank self.tds1 = TDS2d(input_size, in_token_size, **tds1) stride_h = np.prod([grp["stride"][0] for grp in tds1["tds_groups"]]) inner_size = input_size // stride_h # output size of conv is the size of the lexicon if wfst: self.conv = transducer.ConvTransduce1D(lexicon, kernel_size, stride, blank_idx, **kwargs) else: # For control, use "dumb" conv with the same parameters as the WFST conv: self.conv = torch.nn.Conv1d(in_channels=in_token_size, out_channels=len(lexicon), kernel_size=kernel_size, padding=kernel_size // 2, stride=stride) self.wfst = wfst # in_channels should be set to out_channels of prevous tds group * depth in_channels = tds1["tds_groups"][-1]["channels"] * tds1["depth"] tds2["in_channels"] = in_channels self.linear = torch.nn.Linear(len(lexicon), in_channels * inner_size) self.tds2 = TDS2d(inner_size, output_size, **tds2)
def test_bwd(self): lexicon = [(0, 0), (0, 1), (1, 0), (1, 1)] blank_idx = 2 kernel_size = 5 stride = 3 convTrans = transducer.ConvTransduce1D(lexicon, kernel_size, stride, blank_idx) B = 2 C = 3 Tin = (1, 3, 4, 6, 7, 8) Tout = (1, 1, 2, 2, 3, 3) for Ti, To in zip(Tin, Tout): inputs = torch.randn(B, Ti, C, requires_grad=True) outputs = convTrans(inputs) outputs.backward(torch.ones_like(outputs))