예제 #1
0
def _get_merged_param_data(net, param_name, param_data, integrated_save):
    """
    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

    Args:
        net (Cell): MindSpore network.
        param_name (str): The parameter name, which to be combined.
        param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data.
        integrated_save (bool): Whether to integrated save in automatic model parallel scene.
    Returns:
        Tensor, the combined tensor which with the whole data value.
    """
    from mindspore.parallel._cell_wrapper import get_allgather_cell
    from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
    layout = net.parameter_layout_dict[param_name]
    if len(layout) < 6:
        logger.info("layout dict does not contain the key %s", param_name)
        return param_data

    dev_mat = layout[0]
    tensor_map = layout[1]
    field_size = layout[3]
    uniform_split = layout[4]
    opt_shard_group = layout[5]

    allgather_net = None
    if param_name in net.parallel_parameter_merge_net_dict:
        allgather_net = net.parallel_parameter_merge_net_dict[param_name]
    else:
        logger.info("need to create allgather net for %s", param_name)

    if integrated_save:
        if uniform_split == 0:
            raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
        # while any dim is not equal to -1, means param is split and needs to be merged
        # pipeline parallel need to be supported here later
        for dim in tensor_map:
            if dim != -1:
                if allgather_net is None:
                    if opt_shard_group:
                        allgather_net = get_allgather_cell(opt_shard_group, True)
                    else:
                        allgather_net = get_allgather_cell(opt_shard_group, False)
                    net.parallel_parameter_merge_net_dict[param_name] = allgather_net
                param_data = allgather_net(param_data)
                if field_size:
                    return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
                return _reshape_param_data(param_data, dev_mat, tensor_map)
        if opt_shard_group:
            if allgather_net is None:
                allgather_net = get_allgather_cell(opt_shard_group, False)
                net.parallel_parameter_merge_net_dict[param_name] = allgather_net
            param_data = allgather_net(param_data)
    elif opt_shard_group:
        if allgather_net is None:
            allgather_net = get_allgather_cell(opt_shard_group, False)
            net.parallel_parameter_merge_net_dict[param_name] = allgather_net
        param_data = allgather_net(param_data)
    return param_data
예제 #2
0
def _get_merged_param_data(net, param_name, param_data):
    """
    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

    Args:
        net (Cell): MindSpore network.
        param_name(str): The parameter name, which to be combined.
        param_data(Tensor):The parameter data on the local device,
                           It was a slice of the whole parameter data.
    Returns:
        Tensor, the combined tensor which with the whole data value.
    """
    layout = []
    layout = net.parameter_layout_dict[param_name]
    if len(layout) < 2:
        logger.info("layout dict does not contain the key %s", param_name)
        return param_data

    dev_mat = layout[0]
    tensor_map = layout[1]

    from mindspore.parallel._cell_wrapper import get_allgather_cell
    from mindspore.parallel._tensor import _reshape_param_data
    # while any dim is not equal to -1, means param is splited and needs to be merged
    for dim in tensor_map:
        if dim != -1:
            allgather_net = get_allgather_cell()
            param_data = allgather_net(param_data)
            return _reshape_param_data(param_data, dev_mat, tensor_map)

    return param_data
예제 #3
0
def _get_merged_param_data(net, param_name, param_data):
    """
    Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

    Args:
        net (Cell): MindSpore network.
        param_name(str): The parameter name, which to be combined.
        param_data(Tensor):The parameter data on the local device,
                           It was a slice of the whole parameter data.
    Returns:
        Tensor, the combined tensor which with the whole data value.
    """
    layout = net.parameter_layout_dict[param_name]
    if len(layout) < 6:
        logger.info("layout dict does not contain the key %s", param_name)
        return param_data

    dev_mat = layout[0]
    tensor_map = layout[1]
    field_size = layout[3]
    uniform_split = layout[4]
    opt_shard_group = layout[5]
    if uniform_split == 0:
        raise RuntimeError(
            "Save checkpoint only support uniform split tensor now.")

    from mindspore.parallel._cell_wrapper import get_allgather_cell
    from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
    # while any dim is not equal to -1, means param is split and needs to be merged
    # pipeline parallel need to be supported here later
    for dim in tensor_map:
        if dim != -1 or opt_shard_group:
            allgather_net = get_allgather_cell(opt_shard_group)
            param_data = allgather_net(param_data)
            if field_size:
                return _reshape_param_data_with_weight(param_data, dev_mat,
                                                       field_size)
            return _reshape_param_data(param_data, dev_mat, tensor_map)

    return param_data