def test_reshape_param_data():
    expected_tensor = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
    dev_mat = [2, 2]
    tensor_map = [0, 1]
    input_tensor = Tensor([[1, 2], [5, 6], [3, 4], [7, 8]])
    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
    if expected_tensor.__str__() != tensor.__str__():
        raise AssertionError

    tensor_map = [1, -1]
    input_tensor = Tensor([[1, 2, 3, 4], [1, 2, 3, 4], [5, 6, 7, 8],
                           [5, 6, 7, 8]])
    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
    if expected_tensor.__str__() != tensor.__str__():
        raise AssertionError

    expected_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                              [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])

    input_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\
                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])

    dev_mat = [4]
    tensor_map = [-1, -1, -1, -1]
    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
    if expected_tensor.__str__() != tensor.__str__():
        raise AssertionError
Example #2
0
def _get_merged_param_data(net, param_name, param_data):
    """
    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

    Args:
        net (Cell): MindSpore network.
        param_name(str): The parameter name, which to be combined.
        param_data(Tensor):The parameter data on the local device,
                           It was a slice of the whole parameter data.
    Returns:
        Tensor, the combined tensor which with the whole data value.
    """
    layout = []
    layout = net.parameter_layout_dict[param_name]
    if len(layout) < 2:
        logger.info("layout dict does not contain the key %s", param_name)
        return param_data

    dev_mat = layout[0]
    tensor_map = layout[1]

    from mindspore.parallel._cell_wrapper import get_allgather_cell
    from mindspore.parallel._tensor import _reshape_param_data
    # while any dim is not equal to -1, means param is splited and needs to be merged
    for dim in tensor_map:
        if dim != -1:
            allgather_net = get_allgather_cell()
            param_data = allgather_net(param_data)
            return _reshape_param_data(param_data, dev_mat, tensor_map)

    return param_data
Example #3
0
def _get_merged_param_data(net, param_name, param_data, integrated_save):
    """
    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

    Args:
        net (Cell): MindSpore network.
        param_name (str): The parameter name, which to be combined.
        param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data.
        integrated_save (bool): Whether to integrated save in automatic model parallel scene.
    Returns:
        Tensor, the combined tensor which with the whole data value.
    """
    from mindspore.parallel._cell_wrapper import get_allgather_cell
    from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
    layout = net.parameter_layout_dict[param_name]
    if len(layout) < 6:
        logger.info("layout dict does not contain the key %s", param_name)
        return param_data

    dev_mat = layout[0]
    tensor_map = layout[1]
    field_size = layout[3]
    uniform_split = layout[4]
    opt_shard_group = layout[5]

    allgather_net = None
    if param_name in net.parallel_parameter_merge_net_dict:
        allgather_net = net.parallel_parameter_merge_net_dict[param_name]
    else:
        logger.info("need to create allgather net for %s", param_name)

    if integrated_save:
        if uniform_split == 0:
            raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
        # while any dim is not equal to -1, means param is split and needs to be merged
        # pipeline parallel need to be supported here later
        for dim in tensor_map:
            if dim != -1:
                if allgather_net is None:
                    if opt_shard_group:
                        allgather_net = get_allgather_cell(opt_shard_group, True)
                    else:
                        allgather_net = get_allgather_cell(opt_shard_group, False)
                    net.parallel_parameter_merge_net_dict[param_name] = allgather_net
                param_data = allgather_net(param_data)
                if field_size:
                    return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
                return _reshape_param_data(param_data, dev_mat, tensor_map)
        if opt_shard_group:
            if allgather_net is None:
                allgather_net = get_allgather_cell(opt_shard_group, False)
                net.parallel_parameter_merge_net_dict[param_name] = allgather_net
            param_data = allgather_net(param_data)
    elif opt_shard_group:
        if allgather_net is None:
            allgather_net = get_allgather_cell(opt_shard_group, False)
            net.parallel_parameter_merge_net_dict[param_name] = allgather_net
        param_data = allgather_net(param_data)
    return param_data
Example #4
0
def _get_merged_param_data(net, param_name, param_data):
    """
    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

    Args:
        net (Cell): MindSpore network.
        param_name(str): The parameter name, which to be combined.
        param_data(Tensor):The parameter data on the local device,
                           It was a slice of the whole parameter data.
    Returns:
        Tensor, the combined tensor which with the whole data value.
    """
    layout = net.parameter_layout_dict[param_name]
    if len(layout) < 6:
        logger.info("layout dict does not contain the key %s", param_name)
        return param_data

    dev_mat = layout[0]
    tensor_map = layout[1]
    field_size = layout[3]
    uniform_split = layout[4]
    opt_shard_group = layout[5]
    if uniform_split == 0:
        raise RuntimeError(
            "Save checkpoint only support uniform split tensor now.")

    from mindspore.parallel._cell_wrapper import get_allgather_cell
    from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
    # while any dim is not equal to -1, means param is split and needs to be merged
    # pipeline parallel need to be supported here later
    for dim in tensor_map:
        if dim != -1 or opt_shard_group:
            allgather_net = get_allgather_cell(opt_shard_group)
            param_data = allgather_net(param_data)
            if field_size:
                return _reshape_param_data_with_weight(param_data, dev_mat,
                                                       field_size)
            return _reshape_param_data(param_data, dev_mat, tensor_map)

    return param_data
Example #5
0
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
    """
    Merge data slices to one tensor with whole data when strategy is not None.

    Args:
        sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
        parameter_name (str): Name of parameter.
        strategy (dict): Parameter slice strategy.
        is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.

    Returns:
        Tensor, the merged Tensor which has the whole data.

    Raises:
        ValueError: Failed to merge.
    """
    layout = strategy.get(parameter_name)
    try:
        dev_mat = list(layout.dev_matrix[0].dim)
        tensor_map = list(layout.tensor_map[0].dim)
        param_split_shape = list(layout.param_split_shape[0].dim)
        field_size = int(layout.field)
    except BaseException as e:
        raise ValueError(
            f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto."
        )

    device_count = 1
    for dim in dev_mat:
        device_count *= dim

    if len(sliced_data) != device_count:
        raise ValueError(
            f"The sliced_parameters length should be equal to device_count. "
            f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}."
        )

    merged_tensor = None
    if not param_split_shape:
        if not is_even:
            raise ValueError(
                "The shape of every parameter in sliced_parameters should be the same "
                "when slice manner is even.")

        all_gather_tensor = Tensor(np.concatenate(sliced_data))

        if field_size > 0:
            from mindspore.parallel._tensor import _reshape_param_data_with_weight
            merged_tensor = _reshape_param_data_with_weight(
                all_gather_tensor, dev_mat, field_size)

        else:
            from mindspore.parallel._tensor import _reshape_param_data
            merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat,
                                                tensor_map)

    else:
        from mindspore.parallel._tensor import _get_tensor_strategy, _get_tensor_slice_index
        tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)

        slice_count = 1
        for dim in tensor_strategy:
            slice_count *= dim

        if len(param_split_shape) != slice_count:
            raise ValueError(
                f"The param_split_shape length in strategy should be {slice_count}, "
                f"but got {len(param_split_shape)}.")

        tensor_slices_new = list(range(slice_count))
        tensor_slices = sliced_data
        for i in range(device_count):
            slice_index = int(
                _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map,
                                        i))
            if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
                raise ValueError(
                    f"The slice {slice_index} is {param_split_shape[slice_index]} in 0 axis, "
                    f"but got {tensor_slices[i].shape[0]}.")
            tensor_slices_new[slice_index] = np.array(tensor_slices[i])

        dim_len = len(tensor_strategy)
        for i in range(dim_len):
            ele_count = int(
                len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
            tensor_slices_new_inner = []
            for j in range(ele_count):
                new_tensor = tensor_slices_new[j * tensor_strategy[dim_len -
                                                                   1 - i]]
                for l in range(j * tensor_strategy[dim_len - 1 - i] + 1,
                               (j + 1) * tensor_strategy[dim_len - 1 - i]):
                    new_tensor = np.concatenate(
                        (new_tensor, tensor_slices_new[l]),
                        axis=dim_len - 1 - i)
                tensor_slices_new_inner.insert(len(tensor_slices_new_inner),
                                               np.array(new_tensor))
            tensor_slices_new = tensor_slices_new_inner
        merged_tensor = Tensor(tensor_slices_new[0])

    return merged_tensor