def convert_type(shapes, types): ms_types = [] for np_shape, np_type in zip(shapes, types): input_np = np.zeros(np_shape, np_type) tensor = Tensor(input_np) ms_types.append(tensor.dtype()) return ms_types
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): """Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution.""" lst = [] if not isinstance(elem, (tuple, list)): elem = [elem] if global_rank >= device_num: raise ValueError( "The global rank must be smaller than device number, the global rank is {}, " "the device num is {}".format(global_rank, device_num)) for data in elem: if isinstance(data, np.ndarray): data = Tensor(data) if not isinstance(data, Tensor): raise ValueError("elements in tensors must be Tensor") shape_ = data.shape() type_ = data.dtype() new_shape = () batchsize_per_device = 1 for i, item in enumerate(shape_): if i == 0: new_shape += (item * device_num, ) batchsize_per_device = item else: new_shape += (item, ) new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_)) start = global_rank * batchsize_per_device new_tensor_numpy[start:start + batchsize_per_device] = data.asnumpy() new_tensor = Tensor(new_tensor_numpy) lst.append(new_tensor) if scaling_sens: lst.append(Tensor(scaling_sens, mstype.float32)) return tuple(lst)