Example #1
0
    def forward(self, xs, xlens, batch_first=True):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, F]` or `[T, B, F]`
            xlens (IntTensor): `[B]` (on CPU)
            batch_first (bool): operate batch-first tensor
        Returns:
            xs (FloatTensor): `[B, T', F']` or `[T', B, F']`
            xlens (IntTensor): `[B]` (on CPU)

        """
        if self.factor == 1:
            return xs, xlens

        if batch_first:
            xs = self.conv1d(xs.transpose(2, 1))
            xs = xs.transpose(2, 1).contiguous()
        else:
            xs = self.conv1d(xs.permute(1, 2, 0))
            xs = xs.permute(2, 0, 1).contiguous()
        xs = torch.relu(xs)

        xlens = update_lens_1d(xlens, self.conv1d)
        return xs, xlens
Example #2
0
    def forward(self, xs, xlens):
        if self.factor == 1:
            return xs, xlens

        xs = self.pool(xs.transpose(2, 1)).transpose(2, 1).contiguous()

        xlens = update_lens_1d(xlens, self.pool)
        return xs, xlens
Example #3
0
    def forward(self, xs, xlens):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, F]`
            xlens (IntTensor): `[B]` (on CPU)
        Returns:
            xs (FloatTensor): `[B, T', F']`
            xlens (IntTensor): `[B]` (on CPU)

        """
        if self.subsampling_factor == 1:
            return xs, xlens

        xs = self.pool(xs.transpose(2, 1)).transpose(2, 1).contiguous()

        xlens = update_lens_1d(xlens, self.pool)
        return xs, xlens
Example #4
0
    def forward(self, xs, xlens):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, C_i, T, F]`
            xlens (IntTensor): `[B]`
        Returns:
            xs (FloatTensor): `[B, C_o, T, F]`
            xlens (IntTensor): `[B]`

        """
        B, C, T, F = xs.size()
        xs = xs.transpose(3, 2).view(B, C * F, T)
        xs = self.dropout(torch.relu(self.conv1d(xs)))
        xs = xs.view(B, self.C_out, F, -1).transpose(3, 2)
        xs = self.norm(xs)

        xlens = update_lens_1d(xlens, self.conv1d)
        return xs, xlens