Ejemplo n.º 1
0
    def test_random_case(self):
        for device in self.devices:
            step1_dim = torch.randint(low=1, high=100, size=(1,)).item()
            step1_min_val = -1
            step1_max_val = torch.randint(low=0, high=100, size=(1,)).item()

            step2_dim = torch.randint(low=1, high=100, size=(1,)).item()
            step2_min_val = -1
            step2_max_val = step1_dim - 1

            step1_arc_map = torch.randint(low=step1_min_val,
                                          high=step1_max_val + 1,
                                          size=(step1_dim,),
                                          dtype=torch.int32,
                                          device=device)

            step2_arc_map = torch.randint(low=step2_min_val,
                                          high=step2_max_val + 1,
                                          size=(step2_dim,),
                                          dtype=torch.int32,
                                          device=device)

            ans_arc_map = k2.compose_arc_maps(step1_arc_map, step2_arc_map)
            assert ans_arc_map.device == device

            step1_arc_map = step1_arc_map.tolist()
            step2_arc_map = step2_arc_map.tolist()
            ans_arc_map = ans_arc_map.tolist()

            assert len(step2_arc_map) == len(ans_arc_map)
            for i in range(step2_dim):
                if step2_arc_map[i] == -1:
                    assert ans_arc_map[i] == -1
                else:
                    assert ans_arc_map[i] == step1_arc_map[step2_arc_map[i]]
Ejemplo n.º 2
0
    def test_simple_case(self):
        for device in self.devices:
            #                             0   1  2  3  4   5  6
            step1_arc_map = torch.tensor([-1, 0, 0, 2, 3, -1, 3],
                                         dtype=torch.int32,
                                         device=device)

            #                             0  1  2  3   4  5   6  7  8  9
            step2_arc_map = torch.tensor([0, 6, 3, -1, 4, 3, -1, 2, 1, 0],
                                         dtype=torch.int32,
                                         device=device)

            ans_arc_map = k2.compose_arc_maps(step1_arc_map, step2_arc_map)
            expected_arc_map = torch.tensor([-1, 3, 2, -1, 3, 2, -1, 0, 0, -1],
                                            dtype=torch.int32,
                                            device=device)
            assert torch.all(torch.eq(ans_arc_map, expected_arc_map))