예제 #1
0
def test_correlation_block():
    """ perception test """

    for d0, d1 in [([100, 16], [100, 16]), ([100, 5, 16], [100, 5, 16]),
                   ([100, 5, 16], [100, 16]), ([100, 16], [100, 5, 16]),
                   ([100, 25, 16], [100, 10, 16])]:
        in_dict = build_multi_input_dict(dims=[d0, d1])

        # with out reduction
        net = CorrelationBlock(in_keys=["in_key_0", "in_key_1"],
                               out_keys="correlation",
                               in_shapes=[d0[1:], d1[1:]],
                               reduce=False)
        str(net)
        out_dict_no_reduce = net(in_dict)

        assert isinstance(out_dict_no_reduce, Dict)
        assert out_dict_no_reduce["correlation"].shape[-1] == 16
        assert net.out_shapes() == [
            out_dict_no_reduce["correlation"].shape[1:]
        ]

        # with reduction
        net = CorrelationBlock(in_keys=["in_key_0", "in_key_1"],
                               out_keys="correlation",
                               in_shapes=[d0[1:], d1[1:]],
                               reduce=True)
        str(net)
        out_dict_reduce = net(in_dict)

        assert isinstance(out_dict_reduce, Dict)
        assert (out_dict_reduce["correlation"].ndim +
                1) == out_dict_no_reduce["correlation"].ndim
예제 #2
0
def perform_masked_global_pooling_test_dim3_pool_3(feature_dim_1: int,
                                                   feature_dim_2: int,
                                                   feature_dim_3: int,
                                                   use_masking: bool,
                                                   pooling_func_name: str):
    batch_dim = 4
    in_dict = build_multi_input_dict(
        dims=[[batch_dim, feature_dim_1, feature_dim_2, feature_dim_3],
              [batch_dim, feature_dim_1, feature_dim_2]])

    net: MaskedGlobalPoolingBlock = MaskedGlobalPoolingBlock(
        in_keys=["in_key_0"] if not use_masking else ['in_key_0', 'in_key_1'],
        out_keys="out_key",
        in_shapes=[(feature_dim_1, feature_dim_2,
                    feature_dim_3)] if not use_masking else [(feature_dim_1,
                                                              feature_dim_2,
                                                              feature_dim_3),
                                                             (feature_dim_1,
                                                              feature_dim_2)],
        pooling_func=pooling_func_name,
        pooling_dim=-1)

    out_dict = net(
        in_dict if use_masking else {'in_key_0': in_dict['in_key_0']})
    str(net)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (batch_dim, feature_dim_1,
                                               feature_dim_2)
예제 #3
0
def test_torch_model_block():
    """ perception test """
    with torch.no_grad():
        custom_model = CustomPytorchModel()
        torch_model_block = TorchModelBlock(
            in_keys=['in_key_0', 'in_key_1'],
            out_keys=['out_key_0', 'out_key_1'],
            in_shapes=[[3, 5, 5], [16]],
            in_num_dims=[4, 2],
            out_num_dims=[4, 2],
            net=custom_model)

        in_tensor_dict = build_multi_input_dict(dims=[[3, 5, 5], [16]])
        out_dict = torch_model_block(in_tensor_dict)
        assert out_dict['out_key_0'].numpy().shape == (4, 5, 5)
        assert out_dict['out_key_1'].numpy().shape == (32, )

        in_tensor_dict = build_multi_input_dict(dims=[[8, 3, 5, 5], [8, 16]])
        out_dict = torch_model_block(in_tensor_dict)
        assert out_dict['out_key_0'].numpy().shape == (8, 4, 5, 5)
        assert out_dict['out_key_1'].numpy().shape == (8, 32)
예제 #4
0
def test_attention_sequential_2():
    """test_attention_sequential"""
    in_dict = build_multi_input_dict(dims=[(2, 10), (2, 7), (2, 9)])

    self_attn_block = MultiHeadAttentionBlock(in_keys=['in_key_0', 'in_key_1', 'in_key_2'],
                                              out_keys='self_attention', in_shapes=[(10,), (7,), (9,)],
                                              num_heads=10, dropout=0.0, bias=False,
                                              add_input_to_output=True, add_bias_kv=False, add_zero_attn=False,
                                              kdim=7, vdim=9, use_key_padding_mask=False)
    str(self_attn_block)
    out_dict = self_attn_block(in_dict)
    assert self_attn_block.get_num_of_parameters() == 361
    assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 1
    assert out_dict[self_attn_block.out_keys[0]].shape == (2, 10)
예제 #5
0
def test_concat_block():
    """ perception test """

    for d0, d1 in [([100, 16], [100, 16]), ([100, 1, 16], [100, 1, 16])]:
        in_dict = build_multi_input_dict(dims=[d0, d1])
        net = ConcatenationBlock(in_keys=["in_key_0", "in_key_1"],
                                 out_keys="concat",
                                 in_shapes=[d0[1:], d1[1:]],
                                 concat_dim=-1)
        str(net)
        out_dict = net(in_dict)

        assert isinstance(out_dict, Dict)
        assert out_dict["concat"].shape[-1] == 32
        assert net.out_shapes() == [out_dict["concat"].shape[1:]]
예제 #6
0
def test_attention_1d():
    """test_attention_1d"""
    in_dict = build_multi_input_dict(dims=[(2, 10), (2, 10), (2, 10)])

    self_attn_block = MultiHeadAttentionBlock(in_keys=['in_key_0', 'in_key_1', 'in_key_2'],
                                              out_keys=['added_attention', 'attention'], in_shapes=[(10,), (10,),
                                                                                                    (10,)],
                                              num_heads=10, dropout=0.0, bias=False,
                                              add_input_to_output=True, add_bias_kv=False, add_zero_attn=False,
                                              kdim=None, vdim=None, use_key_padding_mask=False)

    out_dict = self_attn_block(in_dict)
    assert self_attn_block.get_num_of_parameters() == 401
    assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 2
    assert out_dict[self_attn_block.out_keys[0]].shape == (2, 10)
예제 #7
0
def test_functional_block_multi_arg_lambda():
    """ perception test """
    in_dict = build_multi_input_dict(dims=[[100, 64, 1], [100, 64, 1]])

    net: FunctionalBlock = FunctionalBlock(
        in_keys=["in_key_0", 'in_key_1'],
        out_keys="out_key",
        in_shapes=[(100, 64, 1), (100, 64, 1)],
        func=lambda in_key_0, in_key_1: torch.cat(
            (in_key_0, in_key_1), dim=-1))
    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (100, 64, 2)
예제 #8
0
def test_self_attention_sequential_masked():
    """test_self_attention_sequential"""
    in_dict = build_multi_input_dict(dims=[(2, 7, 10), (2, 7, 7)])
    in_dict['in_key_1'] = in_dict['in_key_1'] != 0

    self_attn_block = SelfAttentionSeqBlock(in_keys=['in_key_0', 'in_key_1'],
                                            out_keys='self_attention',
                                            in_shapes=[(7, 10), (7, 7)],
                                            num_heads=10,
                                            dropout=0.0,
                                            bias=False,
                                            add_input_to_output=False)
    str(self_attn_block)
    out_dict = self_attn_block(in_dict)
    assert self_attn_block.get_num_of_parameters() == 401
    assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 1
    assert out_dict[self_attn_block.out_keys[0]].shape == (2, 7, 10)
예제 #9
0
def test_functional_block_multi_arg_order():
    """ perception test """
    in_dict = build_multi_input_dict(dims=[[100, 64], [100, 64, 1]])

    def my_func(in_key_1, in_key_0):
        squeeze_in_1 = torch.squeeze(in_key_1, dim=-1)
        return torch.cat((in_key_0, squeeze_in_1), dim=-1)

    net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'],
                                           out_keys="out_key",
                                           in_shapes=[(100, 64), (100, 64, 1)],
                                           func=my_func)
    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (100, 128)
예제 #10
0
def test_functional_block_multi_arg_multi_out():
    """ perception test """
    in_dict = build_multi_input_dict(dims=[[100, 64, 32, 1], [100, 64, 1]])

    def my_func(in_key_1, in_key_0):
        return torch.squeeze(in_key_0), torch.squeeze(in_key_1)

    net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'],
                                           out_keys=["out_key_0", 'out_key_1'],
                                           in_shapes=[(100, 64, 32, 1),
                                                      (100, 64, 1)],
                                           func=my_func)
    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (100, 64, 32)
    assert out_dict[net.out_keys[1]].shape == (100, 64)
예제 #11
0
def test_self_attention_2d():
    """test_self_attention_2d"""
    in_dict = build_multi_input_dict(dims=[(2, 16, 5, 5), (2, 25, 25)])
    in_dict['in_key_1'] = in_dict['in_key_1'] != 0

    self_attn_block = SelfAttentionConvBlock(
        in_keys=['in_key_0', 'in_key_1'],
        out_keys=['self_attention', 'attention'],
        in_shapes=[(16, 5, 5), (25, 25)],
        embed_dim=2,
        add_input_to_output=True,
        bias=True,
        dropout=None)
    str(self_attn_block)
    out_dict = self_attn_block(in_dict)
    assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 2
    assert out_dict[
        self_attn_block.out_keys[0]].shape == in_dict['in_key_0'].shape
    assert out_dict[self_attn_block.out_keys[1]].shape == (2, 25, 25)
예제 #12
0
def test_repeat_block():
    """ perception test """
    in_dict = build_multi_input_dict(dims=[[4, 1, 2], [4, 100, 2]])
    net: RepeatToMatchBlock = RepeatToMatchBlock(
        in_keys=["in_key_0", 'in_key_1'],
        out_keys="out_key",
        in_shapes=[(4, 1, 2), (4, 100, 2)],
        repeat_at_idx=-2)

    str(net)
    out_dict = net(in_dict)

    assert isinstance(out_dict, Dict)
    assert set(net.out_keys).issubset(set(out_dict.keys()))
    assert out_dict[net.out_keys[0]].shape == (4, 100, 2)
    for i in range(out_dict[net.out_keys[0]].shape[0]):
        for j in range(out_dict[net.out_keys[0]].shape[1]):
            assert all(
                torch.eq(out_dict[net.out_keys[0]][i][j],
                         in_dict['in_key_0'][i][0]))
예제 #13
0
def build_perception_dict():
    """ helper function """
    in_dict = build_multi_input_dict(dims=[[100, 1, 16], [100, 1, 8]])

    perception_dict = dict()
    for in_key, in_tensor in in_dict.items():
        # compile network block
        net = DenseBlock(in_keys=in_key,
                         out_keys=f"{in_key}_feat",
                         in_shapes=[in_tensor.shape[-1:]],
                         hidden_units=[32, 32],
                         non_lin=nn.ReLU)
        perception_dict[f"{in_key}_feat"] = net

    net = ConcatenationBlock(in_keys=list(perception_dict.keys()),
                             out_keys="concat",
                             in_shapes=[(32, ), (32, )],
                             concat_dim=-1)
    perception_dict["concat"] = net

    return in_dict, perception_dict
예제 #14
0
def test_mlp_and_concat():
    """ perception test """

    in_dict = build_multi_input_dict(dims=[[100, 1, 16], [100, 1, 8]])

    feat_dict = dict()
    for in_key, in_tensor in in_dict.items():
        # compile network block
        net = DenseBlock(in_keys=in_key,
                         out_keys=f"{in_key}_feat",
                         in_shapes=(in_tensor.shape[-1], ),
                         hidden_units=[32, 32],
                         non_lin=nn.ReLU)

        # update output dictionary
        feat_dict.update(net(in_dict))

    net = ConcatenationBlock(in_keys=list(feat_dict.keys()),
                             out_keys="concat",
                             in_shapes=[(32, ), (32, )],
                             concat_dim=-1)
    out_dict = net(feat_dict)
    assert out_dict["concat"].ndim == 3
    assert out_dict["concat"].shape[-1] == 64