def test_random_case(self): for device in self.devices: step1_dim = torch.randint(low=1, high=100, size=(1,)).item() step1_min_val = -1 step1_max_val = torch.randint(low=0, high=100, size=(1,)).item() step2_dim = torch.randint(low=1, high=100, size=(1,)).item() step2_min_val = -1 step2_max_val = step1_dim - 1 step1_arc_map = torch.randint(low=step1_min_val, high=step1_max_val + 1, size=(step1_dim,), dtype=torch.int32, device=device) step2_arc_map = torch.randint(low=step2_min_val, high=step2_max_val + 1, size=(step2_dim,), dtype=torch.int32, device=device) ans_arc_map = k2.compose_arc_maps(step1_arc_map, step2_arc_map) assert ans_arc_map.device == device step1_arc_map = step1_arc_map.tolist() step2_arc_map = step2_arc_map.tolist() ans_arc_map = ans_arc_map.tolist() assert len(step2_arc_map) == len(ans_arc_map) for i in range(step2_dim): if step2_arc_map[i] == -1: assert ans_arc_map[i] == -1 else: assert ans_arc_map[i] == step1_arc_map[step2_arc_map[i]]
def test_simple_case(self): for device in self.devices: # 0 1 2 3 4 5 6 step1_arc_map = torch.tensor([-1, 0, 0, 2, 3, -1, 3], dtype=torch.int32, device=device) # 0 1 2 3 4 5 6 7 8 9 step2_arc_map = torch.tensor([0, 6, 3, -1, 4, 3, -1, 2, 1, 0], dtype=torch.int32, device=device) ans_arc_map = k2.compose_arc_maps(step1_arc_map, step2_arc_map) expected_arc_map = torch.tensor([-1, 3, 2, -1, 3, 2, -1, 0, 0, -1], dtype=torch.int32, device=device) assert torch.all(torch.eq(ans_arc_map, expected_arc_map))