def test_reshape_param_data():
    expected_tensor = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
    dev_mat = [2, 2]
    tensor_map = [0, 1]
    input_tensor = Tensor([[1, 2], [5, 6], [3, 4], [7, 8]])
    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
    if expected_tensor.__str__() != tensor.__str__():
        raise AssertionError

    tensor_map = [1, -1]
    input_tensor = Tensor([[1, 2, 3, 4], [1, 2, 3, 4], [5, 6, 7, 8],
                           [5, 6, 7, 8]])
    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
    if expected_tensor.__str__() != tensor.__str__():
        raise AssertionError

    expected_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                              [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])

    input_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])

    dev_mat = [4]
    tensor_map = [-1, -1, -1, -1]
    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
    if expected_tensor.__str__() != tensor.__str__():
        raise AssertionError
Esempio n. 2
0
def test_load_tensor():
    hccl = Hccl()
    tensor = Tensor([[1, 2, 3], [4, 5, 6]])
    dev_mat = [2, 3]
    tensor_map = [1, -1]
    hccl.rank_id = 5
    tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
    expected_tensor = Tensor([[4, 5, 6]])
    if expected_tensor.__str__() != tensor_slice.__str__():
        raise AssertionError

    hccl.rank_id = 2
    tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
    expected_tensor = Tensor([[1, 2, 3]])
    if expected_tensor.__str__() != tensor_slice.__str__():
        raise AssertionError

    # set back to the defalt value
    hccl.rank_id = 0