def test_correlation_block(): """ perception test """ for d0, d1 in [([100, 16], [100, 16]), ([100, 5, 16], [100, 5, 16]), ([100, 5, 16], [100, 16]), ([100, 16], [100, 5, 16]), ([100, 25, 16], [100, 10, 16])]: in_dict = build_multi_input_dict(dims=[d0, d1]) # with out reduction net = CorrelationBlock(in_keys=["in_key_0", "in_key_1"], out_keys="correlation", in_shapes=[d0[1:], d1[1:]], reduce=False) str(net) out_dict_no_reduce = net(in_dict) assert isinstance(out_dict_no_reduce, Dict) assert out_dict_no_reduce["correlation"].shape[-1] == 16 assert net.out_shapes() == [ out_dict_no_reduce["correlation"].shape[1:] ] # with reduction net = CorrelationBlock(in_keys=["in_key_0", "in_key_1"], out_keys="correlation", in_shapes=[d0[1:], d1[1:]], reduce=True) str(net) out_dict_reduce = net(in_dict) assert isinstance(out_dict_reduce, Dict) assert (out_dict_reduce["correlation"].ndim + 1) == out_dict_no_reduce["correlation"].ndim
def perform_masked_global_pooling_test_dim3_pool_3(feature_dim_1: int, feature_dim_2: int, feature_dim_3: int, use_masking: bool, pooling_func_name: str): batch_dim = 4 in_dict = build_multi_input_dict( dims=[[batch_dim, feature_dim_1, feature_dim_2, feature_dim_3], [batch_dim, feature_dim_1, feature_dim_2]]) net: MaskedGlobalPoolingBlock = MaskedGlobalPoolingBlock( in_keys=["in_key_0"] if not use_masking else ['in_key_0', 'in_key_1'], out_keys="out_key", in_shapes=[(feature_dim_1, feature_dim_2, feature_dim_3)] if not use_masking else [(feature_dim_1, feature_dim_2, feature_dim_3), (feature_dim_1, feature_dim_2)], pooling_func=pooling_func_name, pooling_dim=-1) out_dict = net( in_dict if use_masking else {'in_key_0': in_dict['in_key_0']}) str(net) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (batch_dim, feature_dim_1, feature_dim_2)
def test_torch_model_block(): """ perception test """ with torch.no_grad(): custom_model = CustomPytorchModel() torch_model_block = TorchModelBlock( in_keys=['in_key_0', 'in_key_1'], out_keys=['out_key_0', 'out_key_1'], in_shapes=[[3, 5, 5], [16]], in_num_dims=[4, 2], out_num_dims=[4, 2], net=custom_model) in_tensor_dict = build_multi_input_dict(dims=[[3, 5, 5], [16]]) out_dict = torch_model_block(in_tensor_dict) assert out_dict['out_key_0'].numpy().shape == (4, 5, 5) assert out_dict['out_key_1'].numpy().shape == (32, ) in_tensor_dict = build_multi_input_dict(dims=[[8, 3, 5, 5], [8, 16]]) out_dict = torch_model_block(in_tensor_dict) assert out_dict['out_key_0'].numpy().shape == (8, 4, 5, 5) assert out_dict['out_key_1'].numpy().shape == (8, 32)
def test_attention_sequential_2(): """test_attention_sequential""" in_dict = build_multi_input_dict(dims=[(2, 10), (2, 7), (2, 9)]) self_attn_block = MultiHeadAttentionBlock(in_keys=['in_key_0', 'in_key_1', 'in_key_2'], out_keys='self_attention', in_shapes=[(10,), (7,), (9,)], num_heads=10, dropout=0.0, bias=False, add_input_to_output=True, add_bias_kv=False, add_zero_attn=False, kdim=7, vdim=9, use_key_padding_mask=False) str(self_attn_block) out_dict = self_attn_block(in_dict) assert self_attn_block.get_num_of_parameters() == 361 assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 1 assert out_dict[self_attn_block.out_keys[0]].shape == (2, 10)
def test_concat_block(): """ perception test """ for d0, d1 in [([100, 16], [100, 16]), ([100, 1, 16], [100, 1, 16])]: in_dict = build_multi_input_dict(dims=[d0, d1]) net = ConcatenationBlock(in_keys=["in_key_0", "in_key_1"], out_keys="concat", in_shapes=[d0[1:], d1[1:]], concat_dim=-1) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert out_dict["concat"].shape[-1] == 32 assert net.out_shapes() == [out_dict["concat"].shape[1:]]
def test_attention_1d(): """test_attention_1d""" in_dict = build_multi_input_dict(dims=[(2, 10), (2, 10), (2, 10)]) self_attn_block = MultiHeadAttentionBlock(in_keys=['in_key_0', 'in_key_1', 'in_key_2'], out_keys=['added_attention', 'attention'], in_shapes=[(10,), (10,), (10,)], num_heads=10, dropout=0.0, bias=False, add_input_to_output=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, use_key_padding_mask=False) out_dict = self_attn_block(in_dict) assert self_attn_block.get_num_of_parameters() == 401 assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 2 assert out_dict[self_attn_block.out_keys[0]].shape == (2, 10)
def test_functional_block_multi_arg_lambda(): """ perception test """ in_dict = build_multi_input_dict(dims=[[100, 64, 1], [100, 64, 1]]) net: FunctionalBlock = FunctionalBlock( in_keys=["in_key_0", 'in_key_1'], out_keys="out_key", in_shapes=[(100, 64, 1), (100, 64, 1)], func=lambda in_key_0, in_key_1: torch.cat( (in_key_0, in_key_1), dim=-1)) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (100, 64, 2)
def test_self_attention_sequential_masked(): """test_self_attention_sequential""" in_dict = build_multi_input_dict(dims=[(2, 7, 10), (2, 7, 7)]) in_dict['in_key_1'] = in_dict['in_key_1'] != 0 self_attn_block = SelfAttentionSeqBlock(in_keys=['in_key_0', 'in_key_1'], out_keys='self_attention', in_shapes=[(7, 10), (7, 7)], num_heads=10, dropout=0.0, bias=False, add_input_to_output=False) str(self_attn_block) out_dict = self_attn_block(in_dict) assert self_attn_block.get_num_of_parameters() == 401 assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 1 assert out_dict[self_attn_block.out_keys[0]].shape == (2, 7, 10)
def test_functional_block_multi_arg_order(): """ perception test """ in_dict = build_multi_input_dict(dims=[[100, 64], [100, 64, 1]]) def my_func(in_key_1, in_key_0): squeeze_in_1 = torch.squeeze(in_key_1, dim=-1) return torch.cat((in_key_0, squeeze_in_1), dim=-1) net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'], out_keys="out_key", in_shapes=[(100, 64), (100, 64, 1)], func=my_func) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (100, 128)
def test_functional_block_multi_arg_multi_out(): """ perception test """ in_dict = build_multi_input_dict(dims=[[100, 64, 32, 1], [100, 64, 1]]) def my_func(in_key_1, in_key_0): return torch.squeeze(in_key_0), torch.squeeze(in_key_1) net: FunctionalBlock = FunctionalBlock(in_keys=["in_key_0", 'in_key_1'], out_keys=["out_key_0", 'out_key_1'], in_shapes=[(100, 64, 32, 1), (100, 64, 1)], func=my_func) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (100, 64, 32) assert out_dict[net.out_keys[1]].shape == (100, 64)
def test_self_attention_2d(): """test_self_attention_2d""" in_dict = build_multi_input_dict(dims=[(2, 16, 5, 5), (2, 25, 25)]) in_dict['in_key_1'] = in_dict['in_key_1'] != 0 self_attn_block = SelfAttentionConvBlock( in_keys=['in_key_0', 'in_key_1'], out_keys=['self_attention', 'attention'], in_shapes=[(16, 5, 5), (25, 25)], embed_dim=2, add_input_to_output=True, bias=True, dropout=None) str(self_attn_block) out_dict = self_attn_block(in_dict) assert len(out_dict.keys()) == len(self_attn_block.out_keys) == 2 assert out_dict[ self_attn_block.out_keys[0]].shape == in_dict['in_key_0'].shape assert out_dict[self_attn_block.out_keys[1]].shape == (2, 25, 25)
def test_repeat_block(): """ perception test """ in_dict = build_multi_input_dict(dims=[[4, 1, 2], [4, 100, 2]]) net: RepeatToMatchBlock = RepeatToMatchBlock( in_keys=["in_key_0", 'in_key_1'], out_keys="out_key", in_shapes=[(4, 1, 2), (4, 100, 2)], repeat_at_idx=-2) str(net) out_dict = net(in_dict) assert isinstance(out_dict, Dict) assert set(net.out_keys).issubset(set(out_dict.keys())) assert out_dict[net.out_keys[0]].shape == (4, 100, 2) for i in range(out_dict[net.out_keys[0]].shape[0]): for j in range(out_dict[net.out_keys[0]].shape[1]): assert all( torch.eq(out_dict[net.out_keys[0]][i][j], in_dict['in_key_0'][i][0]))
def build_perception_dict(): """ helper function """ in_dict = build_multi_input_dict(dims=[[100, 1, 16], [100, 1, 8]]) perception_dict = dict() for in_key, in_tensor in in_dict.items(): # compile network block net = DenseBlock(in_keys=in_key, out_keys=f"{in_key}_feat", in_shapes=[in_tensor.shape[-1:]], hidden_units=[32, 32], non_lin=nn.ReLU) perception_dict[f"{in_key}_feat"] = net net = ConcatenationBlock(in_keys=list(perception_dict.keys()), out_keys="concat", in_shapes=[(32, ), (32, )], concat_dim=-1) perception_dict["concat"] = net return in_dict, perception_dict
def test_mlp_and_concat(): """ perception test """ in_dict = build_multi_input_dict(dims=[[100, 1, 16], [100, 1, 8]]) feat_dict = dict() for in_key, in_tensor in in_dict.items(): # compile network block net = DenseBlock(in_keys=in_key, out_keys=f"{in_key}_feat", in_shapes=(in_tensor.shape[-1], ), hidden_units=[32, 32], non_lin=nn.ReLU) # update output dictionary feat_dict.update(net(in_dict)) net = ConcatenationBlock(in_keys=list(feat_dict.keys()), out_keys="concat", in_shapes=[(32, ), (32, )], concat_dim=-1) out_dict = net(feat_dict) assert out_dict["concat"].ndim == 3 assert out_dict["concat"].shape[-1] == 64