Beispiel #1
0
    def test_single_block(self):
        in_planes = 8
        out_planes = 16
        blocks = [self.contrust_jasper_block(inplanes=in_planes, planes=out_planes)]

        block = jasper.ParallelBlock(blocks)
        x = torch.randn(1, in_planes, 140)
        xlen = torch.tensor([131])
        y, ylen = block(([x], xlen))

        assert y[0].shape == torch.Size([1, out_planes, 140])
        assert ylen[0] == 131
Beispiel #2
0
    def test_blocks_with_different_input_output_channels_sum_residual(self):
        blocks = []
        in_planes = 8
        out_planes = 16
        for _ in range(2):
            blocks.append(self.contrust_jasper_block(inplanes=in_planes, planes=out_planes))

        block = jasper.ParallelBlock(blocks, residual_mode='sum')
        x = torch.randn(1, in_planes, 140)
        xlen = torch.tensor([131])

        with pytest.raises(RuntimeError):
            block(([x], xlen))
Beispiel #3
0
    def test_tower_dropout(self):
        blocks = []
        in_planes = 8
        out_planes = 8
        for _ in range(2):
            blocks.append(self.contrust_jasper_block(inplanes=in_planes, planes=out_planes))

        block = jasper.ParallelBlock(blocks, aggregation_mode='dropout', block_dropout_prob=1.0)
        x = torch.randn(1, in_planes, 140)
        xlen = torch.tensor([131])
        y, _ = block(([x], xlen))

        # Tower dropout is 1.0, meaning that all towers have to be dropped, so only residual remains.
        torch.testing.assert_allclose(y[0], x)
Beispiel #4
0
    def test_blocks_with_different_input_output_channels_conv_residual(self):
        blocks = []
        in_planes = 8
        out_planes = 16
        for _ in range(2):
            blocks.append(self.contrust_jasper_block(inplanes=in_planes, planes=out_planes))

        block = jasper.ParallelBlock(blocks, residual_mode='conv', in_filters=in_planes, out_filters=out_planes)
        x = torch.randn(1, in_planes, 140)
        xlen = torch.tensor([131])
        y, ylen = block(([x], xlen))

        assert y[0].shape == torch.Size([1, out_planes, 140])
        assert ylen[0] == 131