def __init__(self, strategy1, strategy2): super().__init__() self.matmul1 = P.MatMul().shard(strategy1) self.norm = P.FusedBatchNormEx() self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma") self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta") self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean") self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var") self.matmul2 = P.MatMul().shard(strategy2)
def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None): super(NetFusedBatchNormExDynamic, self).__init__() self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1) self.moving_mean = Parameter(initializer(mean_init, num_features), name="mean", requires_grad=False) self.moving_variance = Parameter(initializer(var_init, num_features), name="variance", requires_grad=False) self.gamma = Parameter(initializer(gamma_init, num_features), name="gamma", requires_grad=True) self.beta = Parameter(initializer(beta_init, num_features), name="beta", requires_grad=True) self.dynshape = inner.GpuConvertToDynamicShape()
def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True, gamma_init='ones', beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=None, device_num_each_group=1, input_dims='2d', data_format='NCHW'): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") if momentum < 0 or momentum > 1: raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = use_batch_statistics self.num_features = num_features self.eps = eps self.input_dims = input_dims self.moving_mean = Parameter(initializer( moving_mean_init, num_features), name="mean", requires_grad=False) self.moving_variance = Parameter(initializer( moving_var_init, num_features), name="variance", requires_grad=False) self.gamma = Parameter(initializer( gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) self.group = validator.check_positive_int(device_num_each_group) self.is_global = False if self.group != 1: self.rank_id = get_rank() self.rank_size = get_group_size() self.device_list = [i for i in range(0, self.rank_size)] self.rank_list = self.list_group(self.device_list, self.group) self.rank_list_idx = len(self.rank_list) for i in range(self.rank_list_idx): if self.rank_id in self.rank_list[i] and self.group != 1: self.is_global = True management.create_group('group' + str(i), self.rank_list[i]) self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) self.shape = P.Shape() self.reduce_mean = P.ReduceMean(keep_dims=True) self.square = P.Square() self.sqrt = P.Sqrt() self.cast = P.Cast() self.dtype = P.DType() self.reshape = P.Reshape() self.is_ascend = context.get_context("device_target") == "Ascend" self.is_gpu = context.get_context("device_target") == "GPU" self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE self.momentum = 1.0 - momentum if context.get_context("enable_ge"): self.is_ge_backend = True else: self.is_ge_backend = False if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps) elif self.is_gpu: self.bn_train = P.FusedBatchNormEx(mode=1, epsilon=self.eps, momentum=self.momentum, data_format=self.format) else: self.bn_train = P.FusedBatchNorm(mode=1, epsilon=self.eps, momentum=self.momentum) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend)) self.enable_default_train = self.is_graph_mode and not self.is_global and \ (self.is_ge_backend or self.is_ascend) data_parallel_strategy = ((1,), (1,)) data_parallel_strategy_one = ((1,), ()) self.sub_mean = P.Sub().shard(data_parallel_strategy) self.sub_var = P.Sub().shard(data_parallel_strategy) self.mul_mean = P.Mul().shard(data_parallel_strategy_one) self.mul_var = P.Mul().shard(data_parallel_strategy_one) self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)