def test_reshape_param_data(): expected_tensor = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) dev_mat = [2, 2] tensor_map = [0, 1] input_tensor = Tensor([[1, 2], [5, 6], [3, 4], [7, 8]]) tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) if expected_tensor.__str__() != tensor.__str__(): raise AssertionError tensor_map = [1, -1] input_tensor = Tensor([[1, 2, 3, 4], [1, 2, 3, 4], [5, 6, 7, 8], [5, 6, 7, 8]]) tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) if expected_tensor.__str__() != tensor.__str__(): raise AssertionError expected_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) input_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) dev_mat = [4] tensor_map = [-1, -1, -1, -1] tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) if expected_tensor.__str__() != tensor.__str__(): raise AssertionError
def test_load_tensor(): hccl = Hccl() tensor = Tensor([[1, 2, 3], [4, 5, 6]]) dev_mat = [2, 3] tensor_map = [1, -1] hccl.rank_id = 5 tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) expected_tensor = Tensor([[4, 5, 6]]) if expected_tensor.__str__() != tensor_slice.__str__(): raise AssertionError hccl.rank_id = 2 tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) expected_tensor = Tensor([[1, 2, 3]]) if expected_tensor.__str__() != tensor_slice.__str__(): raise AssertionError # set back to the defalt value hccl.rank_id = 0