Esempio n. 1
0
def _weight_norm(v, g, dim):
    shape = v.shape
    ndims = len(shape)

    if dim is None:
        v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
    elif dim == 0:
        p_matrix = F.reshape(v, (shape[0], -1))
        v_normalized = F.l2_normalize(p_matrix, axis=1)
        v_normalized = F.reshape(v_normalized, shape)
    elif dim == -1 or dim == ndims - 1:
        p_matrix = F.reshape(v, (-1, shape[-1]))
        v_normalized = F.l2_normalize(p_matrix, axis=0)
        v_normalized = F.reshape(v_normalized, shape)
    else:
        perm = list(range(ndims))
        perm[0] = dim
        perm[dim] = 0
        p_transposed = F.transpose(v, perm)
        transposed_shape = p_transposed.shape
        p_matrix = F.reshape(p_transposed, (p_transposed.shape[0], -1))
        v_normalized = F.l2_normalize(p_matrix, axis=1)
        v_normalized = F.reshape(v_normalized, transposed_shape)
        v_normalized = F.transpose(v_normalized, perm)
    weight = F.elementwise_mul(v_normalized,
                               g,
                               axis=dim if dim is not None else -1)
    return weight
Esempio n. 2
0
 def test_square(self):
     program = Program()
     with program_guard(program):
         input = layers.data(name="input", shape=[16], dtype="float32")
         out = layers.square(input, name='square')
         self.assertIsNotNone(out)
     print(str(program))
def layer_norm(x,
               begin_norm_axis=1,
               epsilon=1e-12,
               param_attr=None,
               bias_attr=None):
    """
    Replace build-in layer_norm op with this function
    """
    helper = LayerHelper('layer_norm', **locals())
    mean = layers.reduce_mean(x, dim=begin_norm_axis, keep_dim=True)
    shift_x = layers.elementwise_sub(x=x, y=mean, axis=0)
    variance = layers.reduce_mean(
        layers.square(shift_x), dim=begin_norm_axis, keep_dim=True)
    r_stdev = layers.rsqrt(variance + epsilon)
    norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0)

    param_shape = [reduce(lambda x, y: x * y, norm_x.shape[begin_norm_axis:])]
    param_dtype = norm_x.dtype
    scale = helper.create_parameter(
        attr=param_attr,
        shape=param_shape,
        dtype=param_dtype,
        default_initializer=fluid.initializer.Constant(1.))
    bias = helper.create_parameter(
        attr=bias_attr,
        shape=param_shape,
        dtype=param_dtype,
        is_bias=True,
        default_initializer=fluid.initializer.Constant(0.))

    out = layers.elementwise_mul(x=norm_x, y=scale, axis=-1)
    out = layers.elementwise_add(x=out, y=bias, axis=-1)

    return out
Esempio n. 4
0
def compute_l2_normalized_weight(v, g, dim):
    shape = v.shape
    ndim = len(shape)

    if dim is None:
        v_normalized = v / (F.reduce_sum(F.square(v)) + 1e-12)
    elif dim == 0:
        param_matrix = F.reshape(v, (shape[0], np.prod(shape[1:])))
        v_normalized = F.l2_normalize(param_matrix, axis=1)
    elif dim == -1 or dim == ndim - 1:
        param_matrix = F.reshape(v, (np.prod(shape[:-1]), shape[-1]))
        v_normalized = F.l2_normalize(param_matrix, axis=0)
    else:
        perm = list(range(ndim))
        perm[0] = dim
        perm[dim] = 0
        transposed_param = F.transpose(v, perm)
        param_matrix = F.reshape(
            transposed_param,
            (transposed_param.shape[0], np.prod(transposed_param.shape[1:])))
        v_normalized = F.l2_normalize(param_matrix, axis=1)
        v_normalized = F.transpose(v_normalized, perm)
    v_normalized = F.reshape(v_normalized, shape)
    weight = F.elementwise_mul(v_normalized, g, axis=dim)
    return weight
Esempio n. 5
0
 def _distance(self, anchor_emb, other_emb):
     """¼ÆËãÁ½Êä³ö¾ØÕóµÄ¾àÀë
     """
     square_out = L.square(anchor_emb - other_emb)
     #logging.info("square_out shape: {}".format(square_out.shape))
     distance = L.reduce_sum(square_out, 1)
     #logging.info("distance shape: {}".format(distance.shape))
     return distance
Esempio n. 6
0
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        # clip by value first
        for p, g in params_grads:
            if g is None:
                continue
            if self._need_clip_func is not None and not self._need_clip_func(
                    p):
                params_and_grads.append((p, g))
                continue
            new_grad = layers.clip(x=g,
                                   min=-self.clip_value,
                                   max=self.clip_value)
            params_and_grads.append((p, new_grad))
        params_grads = params_and_grads

        # clip by global norm
        params_and_grads = []
        sum_square_list = []
        for p, g in params_grads:
            if g is None:
                continue
            if self._need_clip_func is not None and not self._need_clip_func(
                    p):
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.merge_selected_rows(g)
                merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)
            sum_square_list.append(sum_square)

        # all parameters have been filterd out
        if len(sum_square_list) == 0:
            return params_grads

        global_norm_var = layers.concat(sum_square_list)
        global_norm_var = layers.reduce_sum(global_norm_var)
        global_norm_var = layers.sqrt(global_norm_var)
        max_global_norm = layers.fill_constant(shape=[1],
                                               dtype='float32',
                                               value=self.clip_norm)
        clip_var = layers.elementwise_div(x=max_global_norm,
                                          y=layers.elementwise_max(
                                              x=global_norm_var,
                                              y=max_global_norm))
        for p, g in params_grads:
            if g is None:
                continue
            if self._need_clip_func is not None and not self._need_clip_func(
                    p):
                params_and_grads.append((p, g))
                continue
            new_grad = layers.elementwise_mul(x=g, y=clip_var)
            params_and_grads.append((p, new_grad))

        return params_and_grads
Esempio n. 7
0
 def loss(self, x, label):
     '''
     Args:
         label:原始图像
         x:解码之后的图像
     Return:
         原始图像和解码图像之间的欧式距离
     '''
     return F.square(x - label)
 def forward(self, output1, output2, label):
     """
     :param output1: [n, 128]
     :param output2: [n, 128]
     :param label: [n, 1]
     :return: [1]
     """
     distance = layers.elementwise_sub(output1, output2)
     distance = layers.square(distance)
     euclidean_distance = layers.reduce_sum(distance, dim=1, keep_dim=True)
     euclidean_distance = layers.sqrt(euclidean_distance)
     loss_contrastive = layers.elementwise_mul(
         1 - label, layers.square(euclidean_distance),
         axis=0) + layers.elementwise_mul(
             label,
             layers.square(
                 layers.clamp(self.margin - euclidean_distance, min=0.0)),
             axis=0)
     return loss_contrastive, euclidean_distance.numpy(), label.numpy()
Esempio n. 9
0
def var(x, dim, unbiased=True, keepdim=False):
    # unbiased variance
    shape = x.shape
    if isinstance(dim, int):
        e = shape[dim]
    else:
        e = int(np.prod([shape[d] for d in dim]))
    if unbiased:
        e -= 1
    return L.reduce_sum(L.square(x - L.reduce_mean(x, dim=dim, keep_dim=True)), dim=dim, keep_dim=keepdim) / e
Esempio n. 10
0
 def loss(self, x, label, rote_list):
     '''
     Args:
         label:原始图像
         x:解码之后的图像
     Return:
         原始图像和解码图像之间的欧式距离
     '''
     index = rote_list.index('ed')
     total_kl_loss = getattr(self, 'AE' + str(index)).kl_loss
     return F.square(x - label) + total_kl_loss, total_kl_loss
    def func(self, place):
        # the shape of input variable shoule be clearly specified, not inlcude -1.
        shape = [2, 3, 7, 9]
        eps = 0.005
        dtype = np.float64

        x = layers.data('x', shape, False, dtype)
        x.persistable = True
        y = layers.square(x)
        x_arr = np.random.uniform(-1, 1, shape).astype(dtype)

        gradient_checker.double_grad_check(
            [x], y, x_init=x_arr, place=place, eps=eps)
Esempio n. 12
0
 def forward(self, x):
     """ Forward process of LayerNorm. """
     mean = layers.reduce_mean(x,
                               dim=list(range(self._begin_norm_axis, len(x.shape))),
                               keep_dim=True)
     shift_x = layers.elementwise_sub(x=x, y=mean, axis=0)
     variance = layers.reduce_mean(layers.square(shift_x),
                                   dim=list(range(self._begin_norm_axis, len(x.shape))),
                                   keep_dim=True)
     r_stdev = layers.rsqrt(variance + self._epsilon)
     norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0)
     out = layers.elementwise_mul(x=norm_x, y=self._scale_w, axis=-1)
     out = layers.elementwise_add(x=out, y=self._bias_w, axis=-1)
     return out
Esempio n. 13
0
    def forward(self, pred, target):
        target = 1 - target[:, 0]
        batch_size, vector_size = pred.shape[0], pred.shape[1]

        pred = L.l2_normalize(pred, axis=1, epsilon=1e-10)

        square_norm = L.reduce_sum(L.square(pred), dim=1)
        dist = L.elementwise_add(-2.0 * L.matmul(pred, pred, transpose_y=True),
                                 square_norm,
                                 axis=0)
        dist = L.elementwise_add(dist, square_norm, axis=1)
        dist = L.elementwise_max(dist, L.zeros_like(dist))
        dist = L.sqrt(dist)

        ap_dist = L.reshape(dist, (0, 0, 1))
        an_dist = L.reshape(dist, (0, 1, -1))

        loss = L.expand(ap_dist, (1, 1, batch_size)) - L.expand(
            an_dist, (1, batch_size, 1)) + self.magin

        indice_equal = L.diag(
            L.fill_constant((batch_size, ), dtype='float32', value=1.0))
        indice_not_equal = 1.0 - indice_equal

        broad_matrix = L.expand(L.reshape(target, (-1, 1)),
                                (1, batch_size)) + L.expand(
                                    L.reshape(target, (1, -1)),
                                    (batch_size, 1))

        pp = L.cast(L.equal(broad_matrix, L.zeros_like(broad_matrix)),
                    dtype='float32')
        pp = L.reshape(indice_not_equal * pp, (0, 0, 1))

        pn = L.cast(L.equal(broad_matrix,
                            L.zeros_like(broad_matrix) + 1),
                    dtype='float32')
        pn = L.reshape(indice_not_equal * pn, (1, 0, -1))

        apn = L.expand(pp,
                       (1, 1, batch_size)) * L.expand(pn, (batch_size, 1, 1))

        loss = loss * L.cast(apn, dtype='float32')
        loss = L.elementwise_max(loss, L.zeros_like(loss))

        num_tri = L.reduce_sum(
            L.cast(L.greater_than(loss, L.zeros_like(loss)), dtype='float32'))

        loss = L.reduce_sum(loss) * self.loss_weight / (num_tri + 1e-16)

        return loss
Esempio n. 14
0
def norm_except_dim(p, dim):
    shape = p.shape
    ndims = len(shape)
    if dim is None:
        return F.sqrt(F.reduce_sum(F.square(p)))
    elif dim == 0:
        p_matrix = F.reshape(p, (shape[0], -1))
        return l2_norm(p_matrix, axis=1)
    elif dim == -1 or dim == ndims - 1:
        p_matrix = F.reshape(p, (-1, shape[-1]))
        return l2_norm(p_matrix, axis=0)
    else:
        perm = list(range(ndims))
        perm[0] = dim
        perm[dim] = 0
        p_transposed = F.transpose(p, perm)
        return norm_except_dim(p_transposed, 0)
Esempio n. 15
0
    def func(self, place):
        # the shape of input variable should be clearly specified, not inlcude -1.
        shape = [2, 3, 7, 9]
        eps = 0.005
        dtype = np.float64

        x = layers.data('x', shape, False, dtype)
        x.persistable = True
        y = layers.square(x)
        x_arr = np.random.uniform(-1, 1, shape).astype(dtype)

        gradient_checker.double_grad_check(
            [x], y, x_init=x_arr, place=place, eps=eps)
        fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
        gradient_checker.double_grad_check_for_dygraph(
            self.square_wrapper, [x], y, x_init=x_arr, place=place)
        fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
    def build_program(self, dtype):
        with fluid.program_guard(self.main_program, self.startup_program):
            self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3)
            self.feed_vars.append(
                fluid.data(name="data3", shape=[128, 128], dtype=dtype))

            # subgraph with 2 op nodes
            tmp_0 = layers.sum(
                [self.feed_vars[0], self.feed_vars[1], self.feed_vars[2]])
            tmp_1 = layers.sqrt(tmp_0)
            tmp_2 = layers.mul(tmp_0, self.feed_vars[3])
            # subgraph with 2 op nodes
            tmp_3 = layers.square(layers.sum([tmp_1, tmp_2]))

        self.append_gradients(tmp_3)

        self.num_fused_ops = 4
        self.fetch_list = [tmp_3, self.grad(tmp_0)]
Esempio n. 17
0
    def _dygraph_clip(self, params_grads):
        params_and_grads = []

        sum_square_dist_fp16 = []
        sum_square_dist_fp32 = []
        sum_square_not_dist_fp16 = []
        sum_square_not_dist_fp32 = []

        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.merge_selected_rows(g)
                merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)

            not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
                hasattr(p, 'is_firstly_shared')
                and getattr(p, 'is_firstly_shared', True))

            if not_shared_enable:
                if p.is_distributed:
                    if p.dtype == paddle.float16:
                        sum_square_dist_fp16.append(sum_square)
                    elif p.dtype == paddle.float32:
                        sum_square_dist_fp32.append(sum_square)
                else:
                    if p.dtype == paddle.float16:
                        sum_square_not_dist_fp16.append(sum_square)
                    elif p.dtype == paddle.float32:
                        sum_square_not_dist_fp32.append(sum_square)

        # global norm of distributed FP16 params_and_grads
        if len(sum_square_dist_fp16) == 0:
            global_norm_dist_fp16 = paddle.to_tensor([0.],
                                                     dtype=paddle.float32)
        else:
            global_norm_dist_fp16 = layers.concat(sum_square_dist_fp16)
            global_norm_dist_fp16 = layers.reduce_sum(global_norm_dist_fp16)
            global_norm_dist_fp16 = paddle.cast(global_norm_dist_fp16,
                                                dtype=paddle.float32)

        # global norm of non-distributed FP16 params_and_grads
        if len(sum_square_not_dist_fp16) == 0:
            global_norm_not_dist_fp16 = paddle.to_tensor([0.],
                                                         dtype=paddle.float32)
        else:
            global_norm_not_dist_fp16 = layers.concat(sum_square_not_dist_fp16)
            global_norm_not_dist_fp16 = layers.reduce_sum(
                global_norm_not_dist_fp16)
            global_norm_not_dist_fp16 = paddle.cast(global_norm_not_dist_fp16,
                                                    dtype=paddle.float32)

        # global norm of distributed FP32 params_and_grads
        global_norm_dist_fp32 = layers.concat(sum_square_dist_fp32) if len(
            sum_square_dist_fp32) != 0 else paddle.to_tensor(
                [0.], dtype=paddle.float32)
        global_norm_dist_fp32 = layers.reduce_sum(global_norm_dist_fp32)

        # global norm of non-distributed FP32 params_and_grads
        global_norm_not_dist_fp32 = layers.concat(
            sum_square_not_dist_fp32
        ) if len(sum_square_not_dist_fp32) != 0 else paddle.to_tensor(
            [0.], dtype=paddle.float32)
        global_norm_not_dist_fp32 = layers.reduce_sum(
            global_norm_not_dist_fp32)

        global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_fp32
        global_norm_var_not_dist = global_norm_not_dist_fp16 + global_norm_not_dist_fp32

        # add all reduce to get global norm of distributed params_and_grads
        if self._hcg.get_model_parallel_world_size() > 1:
            paddle.distributed.all_reduce(
                global_norm_var_dist,
                group=self._hcg.get_check_parallel_group())

        # add all reduce to get global norm of non-distributed params_and_grads in groups of pp
        if self._hcg.get_pipe_parallel_world_size() > 1:
            paddle.distributed.all_reduce(
                global_norm_var_not_dist,
                group=self._hcg.get_pipe_parallel_group())

        # In Sharding mode, param and grad is mapping different rank in optimizer.
        # ClipGradByGlobalNorm need allreduce to get globol norm
        if self._hcg.get_sharding_parallel_world_size() > 1:
            paddle.distributed.all_reduce(
                global_norm_var_not_dist,
                group=self._hcg.get_sharding_parallel_group())

        global_norm_var_fp32 = layers.sqrt(global_norm_var_dist +
                                           global_norm_var_not_dist)

        max_global_norm = layers.fill_constant(
            shape=[1], dtype=global_norm_var_fp32.dtype, value=self.clip_norm)
        clip_var = layers.elementwise_div(x=max_global_norm,
                                          y=layers.elementwise_max(
                                              x=global_norm_var_fp32,
                                              y=max_global_norm))
        clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                params_and_grads.append((p, g))
                continue
            if p.dtype == paddle.float16:
                new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
            else:
                new_grad = layers.elementwise_mul(x=g, y=clip_var)
            params_and_grads.append((p, new_grad))

        return params_and_grads
Esempio n. 18
0
    def _dygraph_clip(self, params_grads):
        sum_square_fp32, sum_square_fp16 = [], []
        unslice_params_fp32, unslice_params_fp16 = [], []

        for p, g in params_grads:
            p_slice = True  # using for slice parameter in sharding stage3
            if g is None or getattr(p, 'need_clip', True) is False:
                continue
            if hasattr(p, "unslice"):
                p_slice = False

            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.get_tensor_from_selected_rows(
                    layers.merge_selected_rows(g))
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)

            if p.dtype == paddle.float16:
                if p_slice: sum_square_fp16.append(sum_square)
                else: unslice_params_fp16.append(sum_square)
            elif p.dtype == paddle.float32:
                if p_slice: sum_square_fp32.append(sum_square)
                else: unslice_params_fp32.append(sum_square)

        # global norm of non-distributed FP16 params_and_grads
        if len(sum_square_fp16) == 0:
            global_norm_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_norm_fp16 = layers.concat(sum_square_fp16)
            global_norm_fp16 = layers.reduce_sum(global_norm_fp16)
            global_norm_fp16 = paddle.cast(
                global_norm_fp16, dtype=paddle.float32)

        # global norm of non-distributed FP16 params_and_grads for unslice parameters
        if len(unslice_params_fp16) == 0:
            global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_unslice_fp16 = layers.concat(unslice_params_fp16)
            global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16)
            global_unslice_fp16 = paddle.cast(
                global_unslice_fp16, dtype=paddle.float32)

        # global norm of non-distributed FP32 params_and_grads
        global_norm_fp32 = layers.concat(sum_square_fp32) if len(
            sum_square_fp32) != 0 else paddle.to_tensor(
                [0.], dtype=paddle.float32)
        global_norm_fp32 = layers.reduce_sum(global_norm_fp32)

        # global norm of non-distributed FP32 params_and_grads for unslice parameters
        global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
            unslice_params_fp32) != 0 else paddle.to_tensor(
                [0.], dtype=paddle.float32)
        global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
        global_unslice_var = global_unslice_fp16 + global_unslice_fp32

        global_norm_var = global_norm_fp16 + global_norm_fp32 + 1.0 / self._group.nranks * global_unslice_var

        # add all reduce to get global norm of distributed params_and_grads
        dev_id = int(self._device.split(":")[1])
        if paddle.device.get_device() == "cpu":
            global_norm_var = global_norm_var.cuda(dev_id)

        with device_guard(dev_id, "gpu"):
            paddle.distributed.all_reduce(global_norm_var, group=self._group)

        global_norm_var = layers.sqrt(global_norm_var)
        max_global_norm = layers.fill_constant(
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)

        clip_var = layers.elementwise_div(
            x=max_global_norm,
            y=layers.elementwise_max(
                x=global_norm_var, y=max_global_norm))
        clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

        for p, g in params_grads:
            if getattr(p, 'need_clip', True) is False or g is None:
                continue
            origin_state = g.stop_gradient
            g.stop_gradient = True
            if p.dtype == paddle.float16:
                g.scale_(clip_var_fp16.item())
            else:
                g.scale_(clip_var.item())
            g.stop_gradient = origin_state
            # p._reset_grad_inplace_version(True)

        return params_grads
Esempio n. 19
0
                           dtype="float32",
                           append_batch_size=False)
lstm1_init_c = layers.data(name="lstm1_c",
                           shape=[1, batch_size, 50],
                           dtype="float32",
                           append_batch_size=False)

lstm1, lstm1_h, lstm1_c = basic_lstm(x,
                                     lstm1_init_h,
                                     lstm1_init_c,
                                     50,
                                     num_layers=1)
_, lstm2_h, lstm2_c = basic_lstm(lstm1, lstm1_h, lstm1_c, 50, num_layers=1)
lstm2_c_batch_first = layers.transpose(lstm2_c, [1, 0, 2])
pred = layers.fc(lstm2_c_batch_first, 1)
loss = layers.reduce_mean(layers.square(pred - y))

test_program = fluid.default_main_program().clone(for_test=True)

optimizer = fluid.optimizer.RMSPropOptimizer(learning_rate=0.001)
optimizer.minimize(loss)

exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())


def batch_generator(x_data, y_data, batch_size):
    batch_x, batch_y = [], []
    for sample_x, sample_y in zip(x_data, y_data):
        batch_x.append(sample_x)
        batch_y.append(sample_y)
    def forward(self, input, conv, conv_g):
        # deal with wight and grad of self.pre_dxdw!
        self._check_input_dim(input)
        N, C, H, W = input.shape
        NHW = N * H * W
        y = input  # [N, C, H, W]
        weight = conv.weight

        # burnin
        if self.training and self.burnin > 0:
            self.iter_count += 1
            self._update_buffer_num()

        if self.buffer_num > 0 and self.training and (
                not input.stop_gradient):  # some layers are frozen!
            # cal current batch mu and sigma
            cur_mu = L.reduce_mean(y, dim=[0, 2, 3], keep_dim=False)  # [C, ]
            if self.special_kernel is None:  # 为了快速求(x - cur_mu)
                special_kernel = np.ones((self.num_features, 1, 1, 1),
                                         np.float32)
                self.special_kernel = paddle.to_tensor(special_kernel)
                self.special_kernel.stop_gradient = True
            cur_sigma2 = F.conv2d(
                y, self.special_kernel, -cur_mu,
                groups=self.num_features)  # 为了快速求(x - cur_mu)
            cur_sigma2 = L.reduce_sum(
                L.square(cur_sigma2), dim=[0, 2, 3], keep_dim=False) / (
                    NHW - 1)  # [C, ]  作者原版实现中使用的是样本方差,所以分母-1

            y2 = L.square(y)
            cur_meanx2 = L.reduce_mean(y2, dim=[0, 2, 3],
                                       keep_dim=False)  # [C, ]

            # cal dmu/dw dsigma2/dw
            # dmudw = paddle.grad(outputs=[cur_mu], inputs=[weight], create_graph=False, retain_graph=True)[0]
            # dmeanx2dw = paddle.grad(outputs=[cur_meanx2], inputs=[weight], create_graph=False, retain_graph=True)[0]

            # 自己的求法
            dmudinput = np.zeros(input.shape, np.float32) + 1.0 / NHW
            dmudinput = paddle.to_tensor(dmudinput)
            dmeanx2dinput = input.numpy()
            dmeanx2dinput = paddle.to_tensor(dmeanx2dinput)
            dmeanx2dinput *= 2.0 / NHW
            dmudw = conv_g.get_grad_w(conv.weight, conv.bias, dmudinput)
            dmeanx2dw = conv_g.get_grad_w(conv.weight, conv.bias,
                                          dmeanx2dinput)

            # update cur_mu and cur_sigma2 with pres
            weight_data = weight.numpy()
            weight_data = paddle.to_tensor(weight_data)
            weight_data.stop_gradient = True
            # 如果用L.stack()会报错,所以用L.concat()代替。
            mu_all = [
                cur_mu,
            ] + [
                tmp_mu + L.reduce_sum(self.rho * tmp_d * (weight_data - tmp_w),
                                      dim=[1, 2, 3]) for tmp_mu, tmp_d, tmp_w
                in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)
            ]
            meanx2_all = [
                cur_meanx2,
            ] + [
                tmp_meanx2 + L.reduce_sum(
                    self.rho * tmp_d * (weight_data - tmp_w), dim=[1, 2, 3])
                for tmp_meanx2, tmp_d, tmp_w in zip(
                    self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)
            ]
            mu_all = [L.unsqueeze(mu_, 0) for mu_ in mu_all]
            meanx2_all = [L.unsqueeze(meanx2_, 0) for meanx2_ in meanx2_all]
            mu_all = L.concat(mu_all, 0)
            meanx2_all = L.concat(meanx2_all, 0)

            sigma2_all = meanx2_all - L.square(mu_all)

            # with considering count
            re_mu_all = mu_all.clone()
            re_meanx2_all = meanx2_all.clone()
            mask1 = L.cast(sigma2_all >= 0., dtype="float32")
            mask1.stop_gradient = True
            re_mu_all *= mask1
            re_meanx2_all *= mask1
            count = L.reduce_sum(L.cast(sigma2_all >= 0., dtype="float32"),
                                 dim=[
                                     0,
                                 ])
            mu = L.reduce_sum(re_mu_all, dim=[
                0,
            ]) / count
            sigma2 = L.reduce_sum(re_meanx2_all, dim=[
                0,
            ]) / count - L.square(mu)

            cur_mu_ = cur_mu.numpy()
            cur_mu_ = paddle.to_tensor(cur_mu_)
            cur_mu_.stop_gradient = True
            self.pre_mu = [
                cur_mu_,
            ] + self.pre_mu[:(self.buffer_num - 1)]
            cur_meanx2_ = cur_meanx2.numpy()
            cur_meanx2_ = paddle.to_tensor(cur_meanx2_)
            cur_meanx2_.stop_gradient = True
            self.pre_meanx2 = [
                cur_meanx2_,
            ] + self.pre_meanx2[:(self.buffer_num - 1)]
            dmudw_ = dmudw.numpy()
            dmudw_ = paddle.to_tensor(dmudw_)
            dmudw_.stop_gradient = True
            self.pre_dmudw = [
                dmudw_,
            ] + self.pre_dmudw[:(self.buffer_num - 1)]
            dmeanx2dw_ = dmeanx2dw.numpy()
            dmeanx2dw_ = paddle.to_tensor(dmeanx2dw_)
            dmeanx2dw_.stop_gradient = True
            self.pre_dmeanx2dw = [
                dmeanx2dw_,
            ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]

            tmp_weight = weight.numpy()
            tmp_weight = paddle.to_tensor(tmp_weight)
            tmp_weight.stop_gradient = True
            self.pre_weight = [
                tmp_weight,
            ] + self.pre_weight[:(self.buffer_num - 1)]

        else:
            mu = L.reduce_mean(y, dim=[0, 2, 3], keep_dim=False)  # [C, ]
            if self.special_kernel is None:  # 为了快速求(x - mu)
                special_kernel = np.ones((self.num_features, 1, 1, 1),
                                         np.float32)
                self.special_kernel = paddle.to_tensor(special_kernel)
                self.special_kernel.stop_gradient = True
            sigma2 = F.conv2d(y,
                              self.special_kernel,
                              -mu,
                              groups=self.num_features)  # 为了快速求(x - mu)
            sigma2 = L.reduce_sum(L.square(sigma2),
                                  dim=[0, 2, 3],
                                  keep_dim=False) / (NHW - 1)  # [C, ]
            cur_mu = mu
            cur_sigma2 = sigma2

        if not self.training or self.FROZEN:  # eval()状态
            U = self._mean
            # TODO: outside **0.5?
            if self.out_p:
                std = L.sqrt(self._variance + self.eps)
            else:
                std = L.sqrt(self._variance) + self.eps

        else:  # train()状态
            if self.track_running_stats is True:
                state_dict = self.state_dict()
                momentum = self.momentum
                _mean = self._mean.numpy() * momentum + cur_mu.numpy() * (
                    1. - momentum)
                _variance = self._variance.numpy(
                ) * momentum + cur_sigma2.numpy() * (1. - momentum)
                state_dict['_mean'] = _mean.astype(np.float32)
                state_dict['_variance'] = _variance.astype(np.float32)
                self.set_state_dict(state_dict)
            U = mu
            # TODO: outside **0.5?
            if self.out_p:
                std = L.sqrt(sigma2 + self.eps)
            else:
                std = L.sqrt(sigma2) + self.eps

        A = self.weight / std  # [C, ]
        B = self.bias - U * A  # [C, ]
        A = L.unsqueeze(A, [1, 2, 3])  # [C, 1, 1, 1]
        y = F.conv2d(y, A, B, groups=self.num_features)
        return y
    def forward2(self, input, weight):
        # deal with wight and grad of self.pre_dxdw!
        self._check_input_dim(input)
        y = L.transpose(input, [1, 0, 2, 3])  # [C, N, H, W]
        return_shape = y.shape  # [C, N, H, W]
        C, N, H, W = return_shape
        NHW = N * H * W
        y = L.reshape(y, (return_shape[0], -1))  # [C, N*H*W]

        # burnin
        if self.training and self.burnin > 0:
            self.iter_count += 1
            self._update_buffer_num()

        if self.buffer_num > 0 and self.training and (
                not input.stop_gradient):  # some layers are frozen!
            # cal current batch mu and sigma
            _cur_mu = L.reduce_mean(y, dim=[
                1,
            ], keep_dim=True)  # [C, 1]
            _cur_sigma2 = L.reduce_sum(
                L.square(y - _cur_mu), dim=[
                    1,
                ], keep_dim=True) / (NHW - 1)  # [C, 1]  作者原版实现中使用的是样本方差,所以分母-1
            cur_mu = L.reshape(_cur_mu, (-1, ))  # [C, ]
            cur_sigma2 = L.reshape(_cur_sigma2, (-1, ))  # [C, ]
            y2 = L.square(y)
            cur_meanx2 = L.reduce_mean(y2, dim=[
                1,
            ], keep_dim=False)  # [C, ]
            # cal dmu/dw dsigma2/dw
            dmudw = paddle.grad(outputs=[cur_mu],
                                inputs=[weight],
                                create_graph=False,
                                retain_graph=True)[0]
            dmeanx2dw = paddle.grad(outputs=[cur_meanx2],
                                    inputs=[weight],
                                    create_graph=False,
                                    retain_graph=True)[0]

            # update cur_mu and cur_sigma2 with pres
            weight_data = weight.numpy()
            weight_data = paddle.to_tensor(weight_data)
            weight_data.stop_gradient = True
            # 如果用L.stack()会报错,所以用L.concat()代替。
            mu_all = [
                cur_mu,
            ] + [
                tmp_mu + L.reduce_sum(self.rho * tmp_d * (weight_data - tmp_w),
                                      dim=[1, 2, 3]) for tmp_mu, tmp_d, tmp_w
                in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)
            ]
            meanx2_all = [
                cur_meanx2,
            ] + [
                tmp_meanx2 + L.reduce_sum(
                    self.rho * tmp_d * (weight_data - tmp_w), dim=[1, 2, 3])
                for tmp_meanx2, tmp_d, tmp_w in zip(
                    self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)
            ]
            mu_all = [L.unsqueeze(mu_, 0) for mu_ in mu_all]
            meanx2_all = [L.unsqueeze(meanx2_, 0) for meanx2_ in meanx2_all]
            mu_all = L.concat(mu_all, 0)
            meanx2_all = L.concat(meanx2_all, 0)

            sigma2_all = meanx2_all - L.square(mu_all)

            # with considering count
            re_mu_all = mu_all.clone()
            re_meanx2_all = meanx2_all.clone()
            mask1 = L.cast(sigma2_all >= 0., dtype="float32")
            mask1.stop_gradient = True
            re_mu_all *= mask1
            re_meanx2_all *= mask1
            count = L.reduce_sum(L.cast(sigma2_all >= 0., dtype="float32"),
                                 dim=[
                                     0,
                                 ])
            mu = L.reduce_sum(re_mu_all, dim=[
                0,
            ]) / count
            sigma2 = L.reduce_sum(re_meanx2_all, dim=[
                0,
            ]) / count - L.square(mu)

            cur_mu_ = cur_mu.numpy()
            cur_mu_ = paddle.to_tensor(cur_mu_)
            cur_mu_.stop_gradient = True
            self.pre_mu = [
                cur_mu_,
            ] + self.pre_mu[:(self.buffer_num - 1)]
            cur_meanx2_ = cur_meanx2.numpy()
            cur_meanx2_ = paddle.to_tensor(cur_meanx2_)
            cur_meanx2_.stop_gradient = True
            self.pre_meanx2 = [
                cur_meanx2_,
            ] + self.pre_meanx2[:(self.buffer_num - 1)]
            dmudw_ = dmudw.numpy()
            dmudw_ = paddle.to_tensor(dmudw_)
            dmudw_.stop_gradient = True
            self.pre_dmudw = [
                dmudw_,
            ] + self.pre_dmudw[:(self.buffer_num - 1)]
            dmeanx2dw_ = dmeanx2dw.numpy()
            dmeanx2dw_ = paddle.to_tensor(dmeanx2dw_)
            dmeanx2dw_.stop_gradient = True
            self.pre_dmeanx2dw = [
                dmeanx2dw_,
            ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]

            tmp_weight = weight.numpy()
            tmp_weight = paddle.to_tensor(tmp_weight)
            tmp_weight.stop_gradient = True
            self.pre_weight = [
                tmp_weight,
            ] + self.pre_weight[:(self.buffer_num - 1)]

        else:
            x = y  # [C, N*H*W]
            mu = L.reduce_mean(x, dim=[
                1,
            ], keep_dim=True)  # [C, 1]
            sigma2 = L.reduce_sum(L.square(x - mu), dim=[
                1,
            ], keep_dim=True) / (NHW - 1)  # [C, 1]  作者原版实现中使用的是样本方差,所以分母-1
            mu = L.reshape(mu, (-1, ))  # [C, ]
            sigma2 = L.reshape(sigma2, (-1, ))  # [C, ]
            cur_mu = mu
            cur_sigma2 = sigma2

        if not self.training or self.FROZEN:
            y = y - L.reshape(self._mean, (-1, 1))
            # TODO: outside **0.5?
            if self.out_p:
                y = y / (L.reshape(self._variance, (-1, 1)) + self.eps)**.5
            else:
                y = y / (L.reshape(self._variance, (-1, 1))**.5 + self.eps)

        else:
            if self.track_running_stats is True:
                state_dict = self.state_dict()
                momentum = self.momentum
                _mean = self._mean.numpy() * momentum + cur_mu.numpy() * (
                    1. - momentum)
                _variance = self._variance.numpy(
                ) * momentum + cur_sigma2.numpy() * (1. - momentum)
                state_dict['_mean'] = _mean.astype(np.float32)
                state_dict['_variance'] = _variance.astype(np.float32)
                self.set_state_dict(state_dict)
            y = y - L.reshape(mu, (-1, 1))  # [C, N*H*W]
            # TODO: outside **0.5?
            if self.out_p:
                y = y / (L.reshape(sigma2, (-1, 1)) + self.eps)**.5
            else:
                y = y / (L.reshape(sigma2, (-1, 1))**.5 + self.eps)

        y = L.reshape(self.weight, (-1, 1)) * y + L.reshape(self.bias, (-1, 1))
        y = L.reshape(y, return_shape)
        y = L.transpose(y, [1, 0, 2, 3])  # [N, C, H, W]
        return y