def test_reshape_param_data(): expected_tensor = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) dev_mat = [2, 2] tensor_map = [0, 1] input_tensor = Tensor([[1, 2], [5, 6], [3, 4], [7, 8]]) tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) if expected_tensor.__str__() != tensor.__str__(): raise AssertionError tensor_map = [1, -1] input_tensor = Tensor([[1, 2, 3, 4], [1, 2, 3, 4], [5, 6, 7, 8], [5, 6, 7, 8]]) tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) if expected_tensor.__str__() != tensor.__str__(): raise AssertionError expected_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) input_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\ [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) dev_mat = [4] tensor_map = [-1, -1, -1, -1] tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) if expected_tensor.__str__() != tensor.__str__(): raise AssertionError
def _get_merged_param_data(net, param_name, param_data): """ Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. Args: net (Cell): MindSpore network. param_name(str): The parameter name, which to be combined. param_data(Tensor):The parameter data on the local device, It was a slice of the whole parameter data. Returns: Tensor, the combined tensor which with the whole data value. """ layout = [] layout = net.parameter_layout_dict[param_name] if len(layout) < 2: logger.info("layout dict does not contain the key %s", param_name) return param_data dev_mat = layout[0] tensor_map = layout[1] from mindspore.parallel._cell_wrapper import get_allgather_cell from mindspore.parallel._tensor import _reshape_param_data # while any dim is not equal to -1, means param is splited and needs to be merged for dim in tensor_map: if dim != -1: allgather_net = get_allgather_cell() param_data = allgather_net(param_data) return _reshape_param_data(param_data, dev_mat, tensor_map) return param_data
def _get_merged_param_data(net, param_name, param_data, integrated_save): """ Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. Args: net (Cell): MindSpore network. param_name (str): The parameter name, which to be combined. param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data. integrated_save (bool): Whether to integrated save in automatic model parallel scene. Returns: Tensor, the combined tensor which with the whole data value. """ from mindspore.parallel._cell_wrapper import get_allgather_cell from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight layout = net.parameter_layout_dict[param_name] if len(layout) < 6: logger.info("layout dict does not contain the key %s", param_name) return param_data dev_mat = layout[0] tensor_map = layout[1] field_size = layout[3] uniform_split = layout[4] opt_shard_group = layout[5] allgather_net = None if param_name in net.parallel_parameter_merge_net_dict: allgather_net = net.parallel_parameter_merge_net_dict[param_name] else: logger.info("need to create allgather net for %s", param_name) if integrated_save: if uniform_split == 0: raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.") # while any dim is not equal to -1, means param is split and needs to be merged # pipeline parallel need to be supported here later for dim in tensor_map: if dim != -1: if allgather_net is None: if opt_shard_group: allgather_net = get_allgather_cell(opt_shard_group, True) else: allgather_net = get_allgather_cell(opt_shard_group, False) net.parallel_parameter_merge_net_dict[param_name] = allgather_net param_data = allgather_net(param_data) if field_size: return _reshape_param_data_with_weight(param_data, dev_mat, field_size) return _reshape_param_data(param_data, dev_mat, tensor_map) if opt_shard_group: if allgather_net is None: allgather_net = get_allgather_cell(opt_shard_group, False) net.parallel_parameter_merge_net_dict[param_name] = allgather_net param_data = allgather_net(param_data) elif opt_shard_group: if allgather_net is None: allgather_net = get_allgather_cell(opt_shard_group, False) net.parallel_parameter_merge_net_dict[param_name] = allgather_net param_data = allgather_net(param_data) return param_data
def _get_merged_param_data(net, param_name, param_data): """ Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. Args: net (Cell): MindSpore network. param_name(str): The parameter name, which to be combined. param_data(Tensor):The parameter data on the local device, It was a slice of the whole parameter data. Returns: Tensor, the combined tensor which with the whole data value. """ layout = net.parameter_layout_dict[param_name] if len(layout) < 6: logger.info("layout dict does not contain the key %s", param_name) return param_data dev_mat = layout[0] tensor_map = layout[1] field_size = layout[3] uniform_split = layout[4] opt_shard_group = layout[5] if uniform_split == 0: raise RuntimeError( "Save checkpoint only support uniform split tensor now.") from mindspore.parallel._cell_wrapper import get_allgather_cell from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight # while any dim is not equal to -1, means param is split and needs to be merged # pipeline parallel need to be supported here later for dim in tensor_map: if dim != -1 or opt_shard_group: allgather_net = get_allgather_cell(opt_shard_group) param_data = allgather_net(param_data) if field_size: return _reshape_param_data_with_weight(param_data, dev_mat, field_size) return _reshape_param_data(param_data, dev_mat, tensor_map) return param_data
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): """ Merge data slices to one tensor with whole data when strategy is not None. Args: sliced_data (list[numpy.ndarray]): Data slices in order of rank_id. parameter_name (str): Name of parameter. strategy (dict): Parameter slice strategy. is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly. Returns: Tensor, the merged Tensor which has the whole data. Raises: ValueError: Failed to merge. """ layout = strategy.get(parameter_name) try: dev_mat = list(layout.dev_matrix[0].dim) tensor_map = list(layout.tensor_map[0].dim) param_split_shape = list(layout.param_split_shape[0].dim) field_size = int(layout.field) except BaseException as e: raise ValueError( f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto." ) device_count = 1 for dim in dev_mat: device_count *= dim if len(sliced_data) != device_count: raise ValueError( f"The sliced_parameters length should be equal to device_count. " f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}." ) merged_tensor = None if not param_split_shape: if not is_even: raise ValueError( "The shape of every parameter in sliced_parameters should be the same " "when slice manner is even.") all_gather_tensor = Tensor(np.concatenate(sliced_data)) if field_size > 0: from mindspore.parallel._tensor import _reshape_param_data_with_weight merged_tensor = _reshape_param_data_with_weight( all_gather_tensor, dev_mat, field_size) else: from mindspore.parallel._tensor import _reshape_param_data merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map) else: from mindspore.parallel._tensor import _get_tensor_strategy, _get_tensor_slice_index tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) slice_count = 1 for dim in tensor_strategy: slice_count *= dim if len(param_split_shape) != slice_count: raise ValueError( f"The param_split_shape length in strategy should be {slice_count}, " f"but got {len(param_split_shape)}.") tensor_slices_new = list(range(slice_count)) tensor_slices = sliced_data for i in range(device_count): slice_index = int( _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)) if tensor_slices[i].shape[0] != param_split_shape[slice_index]: raise ValueError( f"The slice {slice_index} is {param_split_shape[slice_index]} in 0 axis, " f"but got {tensor_slices[i].shape[0]}.") tensor_slices_new[slice_index] = np.array(tensor_slices[i]) dim_len = len(tensor_strategy) for i in range(dim_len): ele_count = int( len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i]) tensor_slices_new_inner = [] for j in range(ele_count): new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]] for l in range(j * tensor_strategy[dim_len - 1 - i] + 1, (j + 1) * tensor_strategy[dim_len - 1 - i]): new_tensor = np.concatenate( (new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i) tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor)) tensor_slices_new = tensor_slices_new_inner merged_tensor = Tensor(tensor_slices_new[0]) return merged_tensor