def _load_tensor_for_layerwise(new_param, old_param): """ Replaces parameters with sliced tensors by layerwise parallel strategies. Args: new_param (Parameter): The new layerwise parallel parameter, will be loaded into net. old_param(Parameter): The current parameter in the net. """ if not isinstance(new_param.data, Tensor) or not isinstance( old_param.data, Tensor): logger.error("Failed to combine the net and the parameters.") msg = ("layerwise parallel parameter should be a Tensor, but got {}.". format(type(new_param.data))) raise TypeError(msg) if old_param.data.shape() == new_param.data.shape(): return from mindspore.parallel._tensor import _load_tensor from mindspore.communication.management import get_group_size dev_mat = [get_group_size()] shape = new_param.data.shape() for x in range(len(shape)): # dim 0 set 0, others set -1 if x: tensor_map.append(-1) new_tensor = _load_tensor(new_param.data, dev_mat, tensor_map) new_param.set_parameter_data(new_tensor)
def test_load_tensor(): hccl = Hccl() tensor = Tensor([[1, 2, 3], [4, 5, 6]]) dev_mat = [2, 3] tensor_map = [1, -1] hccl.rank_id = 5 tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) expected_tensor = Tensor([[4, 5, 6]]) if expected_tensor.__str__() != tensor_slice.__str__(): raise AssertionError hccl.rank_id = 2 tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) expected_tensor = Tensor([[1, 2, 3]]) if expected_tensor.__str__() != tensor_slice.__str__(): raise AssertionError # set back to the defalt value hccl.rank_id = 0
def _merge_and_split(sliced_params, train_strategy, predict_strategy): """Merge sliced parameter and split it according to the predict strategy.""" merged_param = merge_sliced_parameter(sliced_params, train_strategy) if predict_strategy is None: return merged_param param_name = merged_param.name tensor_layout = predict_strategy[param_name] split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1]) requires_grad = merged_param.requires_grad layerwise_parallel = merged_param.layerwise_parallel split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) return split_param
def infer_value(self, x, dev_mat, tensor_map): from mindspore.parallel._tensor import _load_tensor validator.check_value_type("dev_mat", dev_mat, [tuple], self.name) validator.check_value_type("tensor_map", tensor_map, [tuple], self.name) return Tensor(_load_tensor(x, dev_mat, tensor_map))