def test_fwd(self):
        lexicon = [(0, 0), (0, 1), (1, 0), (1, 1)]
        blank_idx = 2
        kernel_size = 5
        stride = 3
        convTrans = transducer.ConvTransduce1D(lexicon, kernel_size, stride,
                                               blank_idx)

        B = 2
        C = 3
        # Zero length inputs not allowed
        inputs = torch.randn(B, 0, C)
        with self.assertRaises(ValueError):
            convTrans(inputs)
        # Other inputs should be padded to be larger than kernel_size
        for Tin in [1, 2, 3, 4]:
            inputs = torch.randn(B, Tin, C)
            convTrans(inputs)

        Tin = (1, 3, 4, 6, 7, 8)
        Tout = (1, 1, 2, 2, 3, 3)
        for Ti, To in zip(Tin, Tout):
            inputs = torch.randn(B, Ti, C)
            outputs = convTrans(inputs)
            self.assertEqual(outputs.shape, (B, To, len(lexicon)))
Exemple #2
0
    def __init__(
        self,
        input_size,
        output_size,
        tokens,
        kernel_size,
        stride,
        tds1,
        tds2,
        wfst=True,
        **kwargs,
    ):
        super(TDS2dTransducer, self).__init__()
        # TDS2d -> ConvTransducer -> TDS2d

        # Setup lexicon for transducer layer:
        with open(tokens, 'r') as fid:
            output_tokens = [l.strip() for l in fid]
        input_tokens = set(t for token in output_tokens for t in token)
        input_tokens = {t: e for e, t in enumerate(sorted(input_tokens))}
        lexicon = [
            tuple(input_tokens[t] for t in token) for token in output_tokens
        ]
        in_token_size = len(input_tokens) + 1
        blank_idx = len(input_tokens)

        # output size of tds1 is number of input tokens + 1 for blank
        self.tds1 = TDS2d(input_size, in_token_size, **tds1)
        stride_h = np.prod([grp["stride"][0] for grp in tds1["tds_groups"]])
        inner_size = input_size // stride_h

        # output size of conv is the size of the lexicon
        if wfst:
            self.conv = transducer.ConvTransduce1D(lexicon, kernel_size,
                                                   stride, blank_idx, **kwargs)
        else:
            # For control, use "dumb" conv with the same parameters as the WFST conv:
            self.conv = torch.nn.Conv1d(in_channels=in_token_size,
                                        out_channels=len(lexicon),
                                        kernel_size=kernel_size,
                                        padding=kernel_size // 2,
                                        stride=stride)
        self.wfst = wfst

        # in_channels should be set to out_channels of prevous tds group * depth
        in_channels = tds1["tds_groups"][-1]["channels"] * tds1["depth"]
        tds2["in_channels"] = in_channels
        self.linear = torch.nn.Linear(len(lexicon), in_channels * inner_size)
        self.tds2 = TDS2d(inner_size, output_size, **tds2)
    def test_bwd(self):
        lexicon = [(0, 0), (0, 1), (1, 0), (1, 1)]
        blank_idx = 2
        kernel_size = 5
        stride = 3
        convTrans = transducer.ConvTransduce1D(lexicon, kernel_size, stride,
                                               blank_idx)

        B = 2
        C = 3
        Tin = (1, 3, 4, 6, 7, 8)
        Tout = (1, 1, 2, 2, 3, 3)
        for Ti, To in zip(Tin, Tout):
            inputs = torch.randn(B, Ti, C, requires_grad=True)
            outputs = convTrans(inputs)
            outputs.backward(torch.ones_like(outputs))