def __init__(self, input_size, hidden_size, num_layers=1, has_bias=True, batch_first=False, dropout=0, bidirectional=False): super(LSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.has_bias = has_bias self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) self.dropout = float(dropout) self.bidirectional = bidirectional if self.batch_first: self.transpose1 = P.Transpose() self.transpose2 = P.Transpose() num_directions = 2 if self.bidirectional else 1 self.cpu_target = False enable_debug = context.get_context("enable_debug_runtime") if context.get_context("device_target") == "CPU" and not enable_debug: self.cpu_target = True if not self.cpu_target: self.lstm = P.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, has_bias=self.has_bias, bidirectional=self.bidirectional, dropout=self.dropout) weight_size = 0 gate_size = 4 * self.hidden_size for layer in range(self.num_layers): input_layer_size = self.input_size if layer == 0 else self.hidden_size * num_directions increment_size = gate_size * input_layer_size increment_size += gate_size * self.hidden_size if self.has_bias: increment_size += 2 * gate_size weight_size += increment_size * num_directions stdv = 1 / math.sqrt(hidden_size) w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') else: input_size_list = [] input_size_list.append(self.input_size) for i in range(self.num_layers - 1): input_size_list.append(self.hidden_size * num_directions) weights = [] layers = [] bias_size = 0 if not self.has_bias else num_directions * self.hidden_size * 4 stdv = 1 / math.sqrt(hidden_size) for i in range(num_layers): weight_size = (input_size_list[i] + self.hidden_size ) * num_directions * self.hidden_size * 4 if has_bias: weight_size = weight_size + bias_size w_np = np.random.uniform( -stdv, stdv, (weight_size, 1, 1)).astype(np.float32) weights.append( Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i))) layers.append( nn.LSTMCell(input_size=input_size_list[i], hidden_size=self.hidden_size, has_bias=self.has_bias, bidirectional=self.bidirectional, dropout=self.dropout)) self.lstms = layers self.weight = ParameterTuple(tuple(weights)) self.fill = P.Fill() self.shape = P.Shape()
def test_check_non_positive_int2(): a = np.random.randint(1, 100) with pytest.raises(ValueError): Validator.check_non_positive_int(a)
def __init__(self, alpha=0.2): super(LeakyReLU, self).__init__() validator.check_value_type('alpha', alpha, [float, int], self.cls_name) self.greater_equal = P.GreaterEqual() self.mul = P.Mul() self.alpha = alpha
def test_check_positive_int3(): with pytest.raises(TypeError): Validator.check_positive_int(3.3)
def test_check_negative_int4(): with pytest.raises(TypeError): Validator.check_negative_int("str")
def test_check_int2(): with pytest.raises(TypeError): Validator.check_is_int(3.3)
def test_check_is_int5(): with pytest.raises(TypeError): Validator.check_is_int(True) with pytest.raises(TypeError): Validator.check_is_int(False)
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
def __init__(self, max_val=1.0): super(PSNR, self).__init__() validator.check_value_type('max_val', max_val, [int, float], self.cls_name) validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val
def _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) validator.check_number_range("start_learning_rate rate", start_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) validator.check_float_positive('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name) validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type( "weight_dacay", weight_decay, [float], prim_name) validator.check_number_range( "beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_number_range( "beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_number_range( "eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) validator.check_number_range( "weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
def __init__(self, seed, dtype, name, param): """ Constructor of distribution class. """ super(Distribution, self).__init__() if seed is None: seed = get_seed() if seed is None: seed = 0 validator.check_value_type('name', name, [str], type(self).__name__) validator.check_integer('seed', seed, 0, Rel.GE, name) self._name = name self._seed = seed self._dtype = cast_type_for_device(dtype) self._parameters = {} # parsing parameters for k in param.keys(): if not(k == 'self' or k.startswith('_')): self._parameters[k] = param[k] # some attributes if 'distribution' in self.parameters.keys(): self.parameter_type = self.parameters['distribution'].parameter_type else: self.parameter_type = set_param_type(self.parameters['param_dict'], dtype) self._broadcast_shape = self._calc_broadcast_shape() self._is_scalar_batch = self._check_is_scalar_batch() # set the function to call according to the derived class's attributes self._set_prob() self._set_log_prob() self._set_sd() self._set_var() self._set_cdf() self._set_survival() self._set_log_cdf() self._set_log_survival() self._set_cross_entropy() self.context_mode = context.get_context('mode') self.device_target = context.get_context('device_target') self.checktuple = CheckTuple() self.checktensor = CheckTensor() self.broadcast = broadcast_to # ops needed for the base class self.cast_base = P.Cast() self.dtype_base = P.DType() self.exp_base = exp_generic self.fill_base = P.Fill() self.log_base = log_generic self.sametypeshape_base = P.SameTypeShape() self.sq_base = P.Square() self.sqrt_base = P.Sqrt() self.shape_base = P.Shape()
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1): """ Training. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be returned and passed to the network. Otherwise, a tuple (data, label) will be returned, and the data and label are passed to the network and loss function respectively. callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. Configure pynative mode, the training process will be performed with dataset not sink. sink_size (int): Control the amount of data each sink. Default: -1. """ epoch = Validator.check_positive_int(epoch) self._train_network.set_train() if self._parameter_broadcast: self._train_network.set_broadcast_flag() cb_params = _InternalCallbackParam() cb_params.train_network = self._train_network cb_params.epoch_num = epoch if dataset_sink_mode and sink_size > 0: cb_params.batch_num = sink_size else: cb_params.batch_num = train_dataset.get_dataset_size() cb_params.mode = "train" cb_params.loss_fn = self._loss_fn cb_params.optimizer = self._optimizer cb_params.parallel_mode = self._parallel_mode cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.train_dataset_element = None cb_params.network = self._network # build callback list with _CallbackManager(callbacks) as list_callback: if not dataset_sink_mode: self._train_process(epoch, train_dataset, list_callback, cb_params) elif context.get_context("mode") == context.PYNATIVE_MODE: logger.warning( "The pynative mode cannot support dataset sink mode currently." "So the training process will be performed with dataset not sink." ) self._train_process(epoch, train_dataset, list_callback, cb_params) else: self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name): """Check the type of inputs.""" validator.check_float_positive('learning_rate', learning_rate, prim_name) validator.check_float_legal_value('learning_rate', learning_rate, prim_name) validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name) validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name) validator.check_float_positive('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
def test_check_bool_5(): with pytest.raises(TypeError): Validator.check_bool(3.5)
def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization): """Initialize learning rate, weight decay or grad centralization in group params.""" self._parse_group_params(parameters, learning_rate) default_lr = self._build_single_lr(learning_rate, 'learning_rate') params_store = [] for group_num, group_param in enumerate(parameters): if 'order_params' in group_param.keys(): ordered_parameters = group_param['order_params'] continue self.group_params += group_param['params'] if 'lr' in group_param.keys(): lr_param_name = 'learning_rate_group_' + str(group_num) lr = self._preprocess_single_lr(group_param['lr']) lr = self._build_single_lr(lr, lr_param_name) else: lr = default_lr if 'weight_decay' in group_param.keys(): cur_weight_decay = self._preprocess_weight_decay( group_param['weight_decay']) weight_decay_ = cur_weight_decay * self.loss_scale else: weight_decay_ = weight_decay * self.loss_scale if 'grad_centralization' in group_param.keys(): self.grad_centralization = self._preprocess_grad_centralization( group_param['grad_centralization']) for param in group_param['params']: validator.check_value_type("parameter", param, [Parameter], self.cls_name) if "conv" not in param.name and self.grad_centralization is True: raise ValueError( "Grad centralization can be perform only on the conv layer. If the parameter" "is not a convolution layer, this parameter cannot be set to True." ) grad_centralization_ = self.grad_centralization else: grad_centralization_ = grad_centralization for key in group_param.keys(): if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'): logger.warning( f"The optimizer cannot parse '{key}' when setting parameter groups." ) for param in group_param['params']: validator.check_value_type("parameter", param, [Parameter], self.cls_name) if param.name in params_store: raise RuntimeError( f"The {param.name} parameter has appeared in parameter groups." ) params_store.append(param.name) self.group_lr.append(lr) self.group_weight_decay.append(weight_decay_) self.group_grad_centralization.append(grad_centralization_) if self.is_group_params_ordered: self._order_and_adjust_group_params(ordered_parameters)
def test_check_int1(): a = np.random.randint(-100, 100) assert Validator.check_is_int(a) == a
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): super(Optimizer, self).__init__(auto_prefix=False) if parameters is not None and not isinstance(parameters, list): parameters = list(parameters) if not parameters: raise ValueError("Optimizer got an empty parameter list.") if not isinstance(parameters[0], (dict, Parameter)): raise TypeError( "Only a list of Parameter or dict can be supported.") if isinstance(loss_scale, int): loss_scale = float(loss_scale) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) validator.check_positive_float(loss_scale, "loss_scale", self.cls_name) self.loss_scale = loss_scale weight_decay = self._preprocess_weight_decay(weight_decay) self.grad_centralization = False self._unique = True self._target = context.get_context("device_target") self.dynamic_lr = False self.assignadd = None self.global_step = None self.is_group = False self.is_group_lr = False self.is_group_params_ordered = False learning_rate = self._preprocess_single_lr(learning_rate) if isinstance(parameters[0], dict): self.is_group = True self.group_params = [] self.group_lr = [] self.group_weight_decay = [] self.group_grad_centralization = [] self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization) # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params if self.dynamic_lr: self.assignadd = P.AssignAdd() self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') if self.is_group_lr: self.learning_rate = CellList( self.group_lr) if self.dynamic_lr else ParameterTuple( self.group_lr) else: self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') if self.is_group: self.parameters = ParameterTuple(self.group_params) self.weight_decay = tuple(self.group_weight_decay) self.weight_decay_tensor_tuple = tuple( Tensor(x, mstype.float32) for x in self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple( decay_filter(x) for x in self.weight_decay) self.exec_weight_decay = any(self.decay_flags) self.grad_centralization_flags = tuple( self.group_grad_centralization) else: self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32) decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.exec_weight_decay = self.weight_decay > 0 # when a parameter has been unique, there is no need do another unique in optimizer. for param in self.parameters: if param.unique: self._unique = False break ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) cache_filter = lambda x: x.cache_enable self.cache_enable = tuple(cache_filter(x) for x in self.parameters) self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) self.need_scale = loss_scale != 1.0 self.global_step_increase_tensor = Tensor(1, mstype.int32) self.param_length = len(self.parameters) self.map_ = C.Map() self._use_parallel_optimizer()
def test_check_int3(): with pytest.raises(TypeError): Validator.check_is_int("str")
def _check_param_value(beta1, beta2, eps, prim_name): validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name) validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) validator.check_positive_float(eps, "eps", prim_name)
def test_check_positive_int1(): a = np.random.randint(1, 100) assert Validator.check_positive_int(a) == a
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False): """ Saves checkpoint info to a specified file. Args: save_obj (nn.Cell or list): The cell object or data list(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of param_name would be string, and the type of param_data would be parameter or tensor). ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False Raises: TypeError: If the parameter save_obj is not nn.Cell or list type.And if the parameter integrated_save and async_save are not bool type. """ if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): raise TypeError( "The parameter save_obj should be nn.Cell or list, but got {}". format(type(save_obj))) integrated_save = Validator.check_bool(integrated_save) async_save = Validator.check_bool(async_save) logger.info("Execute save checkpoint process.") if isinstance(save_obj, nn.Cell): save_obj.init_parameters_data() param_dict = {} for _, param in save_obj.parameters_and_names(): param_dict[param.name] = param param_list = [] for (key, value) in param_dict.items(): each_param = {"name": key} param_data = Tensor(value.data) # in automatic model parallel scenario, some parameters were spliteds to all the devices, # which should be combined before saving if integrated_save and key in save_obj.parameter_layout_dict: param_data = _get_merged_param_data(save_obj, key, param_data) each_param["data"] = param_data param_list.append(each_param) save_obj = param_list data_list = {} with _ckpt_mutex: for param in save_obj: key = param["name"] data_list[key] = [] if isinstance(param["data"], Parameter): param["data"].init_data() dims = [] if param['data'].shape == (): dims.append(0) else: for dim in param['data'].shape: dims.append(dim) data_list[key].append(dims) tensor_type = str(param["data"].dtype) data_list[key].append(tensor_type) data = param["data"].asnumpy().reshape(-1) data_list[key].append(data) if async_save: thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt") thr.start() else: _exec_save(ckpt_file_name, data_list) logger.info("Save checkpoint process finish.")
def test_check_negative_int1(): a = np.random.randint(-100, -1) assert Validator.check_negative_int(a) == a
def test_check_bool_1(): assert Validator.check_bool(True)
def test_check_non_positive_int1(): a = np.random.randint(-100, 0) assert Validator.check_non_positive_int(a) == a
def test_check_bool_2(): assert Validator.check_bool(False) is not True
def conv2d(x, weight, bias=None, stride=1, pad=0, dilation=1, groups=1, padding_mode='zeros'): """Convolution 2D.""" # pylint: disable=unused-argument validator.check_value_type('stride', stride, (int, tuple), None) if isinstance(stride, int): stride = (stride, stride) elif len(stride) == 4: stride = (stride[2], stride[3]) if len(stride) != 2 or (not isinstance(stride[0], int)) or \ (not isinstance(stride[1], int)) or \ stride[0] < 1 or stride[1] < 1: raise ValueError( f"The \'stride\' of \'conv2d\' should be an positive int number or " f"a tuple of two positive int numbers, but got {stride}") stride_h = stride[0] stride_w = stride[1] validator.check_value_type('dilation', dilation, (int, tuple), None) if isinstance(dilation, int): dilation = (dilation, dilation) elif len(dilation) == 4: dilation = (dilation[2], dilation[3]) if len(dilation) != 2 or (not isinstance(dilation[0], int)) or \ (not isinstance(dilation[1], int)) or \ dilation[0] < 1 or dilation[1] < 1: raise ValueError( f"The \'dilation\' of \'conv2d\' should be an positive int number or " f"a tuple of two positive int numbers, but got {dilation}") dilation_h = dilation[0] dilation_w = dilation[1] if isinstance(pad, int): pad_top = pad pad_bottom = pad pad_left = pad pad_right = pad elif isinstance(pad, tuple) and len(pad) == 4: pad_top, pad_bottom, pad_left, pad_right = pad else: raise ValueError(f"The \'pad\' should be an int number or " f"a tuple of two or four int numbers, but got {pad}") batch_num, _, x_h, x_w = x.shape filter_num, _, filter_h, filter_w = weight.shape out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h) out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w) col = im2col(x, filter_h, filter_w, stride, pad, dilation) col_w = np.reshape(weight, (filter_num, -1)).T out = np.dot(col, col_w) out = out.reshape(batch_num, out_h, out_w, -1).transpose(0, 3, 1, 2) if bias is not None: out += bias return out
def test_check_bool_3(): with pytest.raises(TypeError): Validator.check_bool("str")
def _check_input_filter_size(input_shape, param_name, filter_size, func_name): _check_input_4d(input_shape, param_name, func_name) validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name) validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): super(Optimizer, self).__init__(auto_prefix=False) if parameters is not None and not isinstance(parameters, list): parameters = list(parameters) if not parameters: raise ValueError("Optimizer got an empty parameter list.") if not isinstance(parameters[0], (dict, Parameter)): raise TypeError("Only a list of Parameter or dict can be supported.") if isinstance(loss_scale, int): loss_scale = float(loss_scale) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) self.loss_scale = loss_scale weight_decay = self._preprocess_weight_decay(weight_decay) self.dynamic_lr = False self.assignadd = None self.global_step = None self.is_group = False self.is_group_lr = False self.is_group_params_ordered = False learning_rate = self._preprocess_single_lr(learning_rate) if isinstance(parameters[0], dict): self.is_group = True self.group_params = [] self.group_lr = [] self.group_weight_decay = [] self._init_group_params(parameters, learning_rate, weight_decay) # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params if self.dynamic_lr: self.assignadd = P.AssignAdd() self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') if self.is_group_lr: if self.dynamic_lr: self.learning_rate = CellList(self.group_lr) else: self.learning_rate = ParameterTuple(self.group_lr) else: self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') if self.is_group: self.parameters = ParameterTuple(self.group_params) self.weight_decay = tuple(self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) self.exec_weight_decay = any(self.decay_flags) else: self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.exec_weight_decay = self.weight_decay > 0 ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) self.reciprocal_scale = 1.0 / loss_scale self.param_length = len(self.parameters) self.map_ = C.Map() use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer") self.use_parallel = use_parallel if use_parallel: if self.cls_name not in ["Lamb", "AdamWeightDecay"]: raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) if _get_parallel_mode() != ParallelMode.DATA_PARALLEL: raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format (_get_parallel_mode())) self.dev_num = _get_device_num() if self.dev_num > self.param_length: raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" " less than the number of devices {}".format(self.param_length, self.dev_num)) self.param_rank = self._get_parameter_group_id() self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) self.param_names = [] for param in self.parameters: self.param_names.append(param.name) else: self.optim_filter = (True,) * self.param_length