Ejemplo n.º 1
0
    def test_group_norm_error(self, device):
        # group norm has to call native_group_norm. This checks that it hits the same errors
        # that normal group norm would

        N = 3
        C = 5
        inp = torch.randn(N, C)
        with self.assertRaisesRegex(RuntimeError, r"Expected number of channels in input to be divisible"):
            F.group_norm(inp, 2)  # 5 is not divisible by 2
Ejemplo n.º 2
0
 def forward(self, x):
     skip = []
     res = x
     for block in self.conv_blocks:
         x = block(x)
         res = res + x
         res = F.group_norm(res, 8)
         skip.append(res)
     skip = torch.stack(skip).sum(0)
     skip = F.group_norm(skip, 8)
     return skip
Ejemplo n.º 3
0
 def forward(self, input):
     num_channels = int(self.num_channels * self.sr_in_list[self.sr_idx])
     weight, bias = self.weight[:num_channels], self.bias[:num_channels]
     return F.group_norm(
         input,
         round(num_channels * self.num_groups / float(self.num_channels)),
         weight, bias, self.eps)
Ejemplo n.º 4
0
    def forward(self, input, pad_mask=None, is_encoder=False):
        # input: only reudce over the C dim.
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)
        T, B, C = input.shape
        # Permute the mask_input to (B, T, C)
        # mask_input = input.transpose(0, 1)
        # # Compute the mean, var for LN, size to be BxTx1 -> BxCxT
        # # Split the mask_input into group
        # gn_input = mask_input.view(B, T, self.num_groups, self.group_feature)
        # gn_input = gn_input.permute(1, 2, 3, 0).contiguous().view(T, self.num_groups, self.group_feature * B)
        # # TxGx1 -> TxC -> BxTxC -> BxCxT
        # mean_gn = tile(gn_input.mean(-1, keepdim=True).squeeze(-1), self.group_feature, -1).expand_as(mask_input).transpose(1, 2)
        # var_gn = tile(gn_input.var(-1, keepdim=True).squeeze(-1), self.group_feature, -1).expand_as(mask_input).transpose(1, 2)
        #
        # # Resize the input to (B, C, -1).
        # input = input.permute(1, 2, 0).contiguous()
        # input_shape = input.size()
        # input = input.view(input.size(0), self.num_channels, -1)
        #
        # input = (input - mean_gn) / (var_gn + self.eps).sqrt()
        # input = input * (self.weight).unsqueeze(-1) + (self.bias).unsqueeze(-1)
        # input = input.view(B, C, T)
        # input = input.permute(2, 0, 1).contiguous()
        # return input

        input = input.contiguous().view(T * B, C)
        input = F.group_norm(input, self.num_groups, self.weight, self.bias,
                             self.eps)
        input = input.contiguous().view(T, B, C)
        if shaped_input:
            input = input.squeeze(0)
        return input
Ejemplo n.º 5
0
 def _hook_compute_trace(self, mod, grad_input, grad_output):
     mod_class = mod.__class__.__name__
     gy = grad_output[0]
     x = self.xs[mod]
     if mod_class == 'Linear':
         self._trace += torch.mm(gy.t()**2, x**2).sum()
         if mod.bias is not None:
             self._trace += (gy**2).sum()
     elif mod_class == 'Conv2d':
         indiv_gw = per_example_grad_conv(mod, x, gy)
         self._trace += (indiv_gw**2).sum()
         if mod.bias is not None:
             self._trace += (gy.sum(dim=(2, 3))**2).sum()
     elif mod_class == 'BatchNorm1d':
         self._check_bn_training(mod)
         x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var,
                                     None, None, mod.training)
         self._trace += (gy**2 * x_normalized**2).sum()
         self._trace += (gy**2).sum()
     elif mod_class == 'BatchNorm2d':
         self._check_bn_training(mod)
         x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var,
                                     None, None, mod.training)
         self._trace += ((gy * x_normalized).sum(dim=(2, 3))**2).sum()
         self._trace += (gy.sum(dim=(2, 3))**2).sum()
     elif mod_class == 'GroupNorm':
         x_normalized = F.group_norm(x, mod.num_groups, None, None, mod.eps)
         self._trace += ((gy * x_normalized).sum(dim=(2, 3))**2).sum()
         self._trace += (gy.sum(dim=(2, 3))**2).sum()
     else:
         raise NotImplementedError
Ejemplo n.º 6
0
    def forward(self, instrument, x):
        """
        Arguments:
            instrument {torch.tensor} -- Instrument embedding of shape (4, E_1)
            x {torch.tensor} -- Input of the groupnorm of shape (B, 4, C, T)

        Returns:
            torch.tensor -- Output of the groupnorm of shape (B, 4, C, T)
        """
        batch_size = x.shape[0]

        instrument = self.bottleneck(instrument)  # shape: (4, E_2)
        affine = self.affine(instrument)  # shape: (4, 2*C)

        scale = affine[:, :self.num_channels].contiguous().view(
            -1)  # shape: (4*C)
        bias = affine[:,
                      self.num_channels:].contiguous().view(-1)  # shape: (4*C)

        x = x.view(batch_size, 4 * self.num_channels, -1)  # shape: (B, 4*C, T)
        x = F.group_norm(x, 4 * self.num_groups, scale, bias,
                         self.eps)  # shape: (B, 4*C, T)
        x = x.view(batch_size, 4, self.num_channels, -1)  # shape: (B, 4, C, T)

        return x  # shape: (B, 4, C, T)
Ejemplo n.º 7
0
    def forward(self, x):
        x = functional.group_norm(x, self.num_groups, self.weight, self.bias,
                                  self.eps)

        if self.activation == ACT_RELU:
            return functional.relu(x, inplace=True)
        elif self.activation == ACT_RELU6:
            return functional.relu6(x, inplace=True)
        elif self.activation == ACT_LEAKY_RELU:
            return functional.leaky_relu(x,
                                         negative_slope=self.slope,
                                         inplace=True)
        elif self.activation == ACT_ELU:
            return functional.elu(x, inplace=True)
        elif self.activation == ACT_SELU:
            return functional.selu(x, inplace=True)
        elif self.activation == ACT_SWISH:
            return swish(x)
        elif self.activation == ACT_HARD_SWISH:
            return hard_swish(x, inplace=True)
        elif self.activation == ACT_HARD_SIGMOID:
            return hard_sigmoid(x, inplace=True)
        elif self.activation == ACT_NONE:
            return x
        else:
            raise KeyError(self.activation)
Ejemplo n.º 8
0
    def forward(self, input):
        gn_input = F.group_norm(input, self.num_groups, None, None, self.eps)
        sigmoid_weight = torch.sigmoid(input *
                                       self.sigmoid_weight.view(1, -1, 1, 1))

        return gn_input * sigmoid_weight * self.weight.view(
            1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
 def forward(self, x):
     # in F.batch_norm `training` regulates whether to use batch stats of buffer stats
     # if `training` is True and buffers are given, they always would be updated!
     use_batch_stats = self.training and not self.estimated_stats
     x = F.batch_norm(
         x,
         self.running_mean,
         self.running_var,
         self.weight,
         self.bias,
         use_batch_stats,
         self.momentum,
         self.eps,
     )
     if self.training and self.estimated_stats:
         with torch.no_grad():  # not sure if needed but just in case
             # PyTorch BN uses biased var by default
             var, mean = torch.var_mean(x, dim=(0, 2, 3), unbiased=False)
             self.running_mean = self.running_mean.mul(
                 1 - self.momentum).add(mean, alpha=self.momentum)
             self.running_var = self.running_var.mul(1 - self.momentum).add(
                 var, alpha=self.momentum)
     x = F.group_norm(x, self.num_groups, self.weight_gn, self.bias_gn,
                      self.eps)
     func = ACT_FUNC_DICT[self.activation]
     if self.activation == ACT.LEAKY_RELU:
         return func(x, inplace=True, negative_slope=self.activation_param)
     elif self.activation == ACT.ELU:
         return func(x, inplace=True, alpha=self.activation_param)
     else:
         return func(x, inplace=True)
Ejemplo n.º 10
0
    def forward(self, x):
        if self.use_custom:
            for group_idx in self.group_idxs:
                # Select the group of channels and normalize together
                group = torch.index_select(x, dim=1, index=group_idx)
                group_mean = torch.mean(group)
                group_var = torch.var(group)

                # Normalize
                for i in group_idx:
                    x[:, i] -= group_mean
                    x[:, i] /= torch.sqrt(group_var + self.eps)

            if self.affine:
                # Scale and shift
                if self.is_3d:
                    x = x * self.weight.view(-1, 1, 1, 1) + self.bias.view(
                        -1, 1, 1, 1)
                else:
                    x = x * self.weight.view(-1, 1, 1) + self.bias.view(
                        -1, 1, 1)

            return x
        else:
            return F.group_norm(x, self.num_groups, self.weight, self.bias,
                                self.eps)
Ejemplo n.º 11
0
    def forward(self, x, y, z, w0, b0, w1, b1, w2, b2):
        x = F.group_norm(x, 2, w0, b0)
        x = F.group_norm(x, 1, None, None)
        x = F.group_norm(x, 4, self.w3, self.b3)

        y = F.group_norm(y, 3, w1, b1, eps=1e-4)
        y = F.group_norm(y, 4, None, None)
        y = F.group_norm(y, 6, self.w4, self.b4)

        z = F.group_norm(z, 32, w2, b2)
        z = F.group_norm(z, 4, None, None, eps=1e-2)
        z = F.group_norm(z, 8, self.w5, self.b5)
        return x, y, z
Ejemplo n.º 12
0
    def forward(self, x, **kwargs):

        if hasattr(self, 'in_sequence') and self.in_sequence:
            x = x.permute(0, 2, 1)
        x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
        if hasattr(self, 'in_sequence') and self.in_sequence:
            x = x.permute(0, 2, 1)
        return x
Ejemplo n.º 13
0
 def forward(self, *x):
     x = enforce_singleton(x)
     if hasattr(self, 'in_sequence') and self.in_sequence:
         x = x.permute(0, 2, 1)
     x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
     if hasattr(self, 'in_sequence') and self.in_sequence:
         x = x.permute(0, 2, 1)
     return x
Ejemplo n.º 14
0
    def forward(self, x, skip_x1):
        out_up_conv = self.up_conv(x)
        out_up_conv = self.activation(F.group_norm(out_up_conv, 8))

        concat = torch.cat((out_up_conv, skip_x1), 1)
        out = self.ops1(concat)
        out = self.ops2(out)
        return out
Ejemplo n.º 15
0
    def forward(self, x):

        x = self.ops1(x)
        out_before_pool = self.ops2(x)
        out = self.pool(out_before_pool)
        out = self.activation(F.group_norm(out, 8))

        return out, out_before_pool
    def forward(self, input):

        # training mode:
        if self.train:
            input_norm = F.group_norm(input, self.num_groups, None, None,
                                      self.eps)
            if self.weight is None or self.bias is None:
                return input_norm
            else:
                return affine(
                    input_norm, self.weight, self.bias, self.clip, self.std
                )  # only the affine transform needs private gradients

        # inference mode:
        else:
            return F.group_norm(input, self.num_groups, self.weight, self.bias,
                                self.eps)
 def forward(self, x):
     x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
     func = ACT_FUNC_DICT[self.activation]
     if self.activation == ACT.LEAKY_RELU:
         return func(x, inplace=True, negative_slope=self.activation_param)
     elif self.activation == ACT.ELU:
         return func(x, inplace=True, alpha=self.activation_param)
     else:
         return func(x, inplace=True)
Ejemplo n.º 18
0
 def forward(self, input):
     output = F.group_norm(
         input.float(),
         self.num_groups,
         self.weight.float() if self.weight is not None else None,
         self.bias.float() if self.bias is not None else None,
         self.eps,
     )
     return output.type_as(input)
Ejemplo n.º 19
0
 def forward(self, input):
     input = input.permute(1, 0, 2, 3)
     input = F.group_norm(input, self.num_groups, None, None, self.eps)
     input = input.permute(1, 0, 2, 3)
     if self.affine:
         weight = self.weight.unsqueeze(-1).unsqueeze(-1)
         bias = self.bias.unsqueeze(-1).unsqueeze(-1)
         return input * weight + bias
     else:
         return input
Ejemplo n.º 20
0
 def _hook_compute_diag(self, mod, grad_input, grad_output):
     mod_class = mod.__class__.__name__
     gy = grad_output[0]
     x = self.xs[mod]
     layer_id = self.m_to_l[mod]
     start_p = self.layer_collection.p_pos[layer_id]
     if mod_class == 'Linear':
         self.diag_m[start_p:start_p+mod.weight.numel()] \
             .add_(torch.mm(gy.t()**2, x**2).view(-1))
         if self.layer_collection[layer_id].bias is not None:
             start_p += mod.weight.numel()
             self.diag_m[start_p: start_p+mod.bias.numel()] \
                 .add_((gy**2).sum(dim=0))
     elif mod_class == 'Conv2d':
         indiv_gw = per_example_grad_conv(mod, x, gy)
         self.diag_m[start_p:start_p+mod.weight.numel()] \
             .add_((indiv_gw**2).sum(dim=0).view(-1))
         if self.layer_collection[layer_id].bias is not None:
             start_p += mod.weight.numel()
             self.diag_m[start_p:start_p+mod.bias.numel()] \
                 .add_((gy.sum(dim=(2, 3))**2).sum(dim=0))
     elif mod_class == 'BatchNorm1d':
         self._check_bn_training(mod)
         x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var,
                                     None, None, mod.training)
         self.diag_m[start_p:start_p+mod.weight.numel()] \
             .add_((gy**2 * x_normalized**2).sum(dim=0).view(-1))
         start_p += mod.weight.numel()
         self.diag_m[start_p: start_p+mod.bias.numel()] \
             .add_((gy**2).sum(dim=0))
     elif mod_class == 'BatchNorm2d':
         self._check_bn_training(mod)
         x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var,
                                     None, None, mod.training)
         self.diag_m[start_p:start_p+mod.weight.numel()] \
             .add_(((gy * x_normalized).sum(dim=(2, 3))**2).sum(dim=0)
                   .view(-1))
         start_p += mod.weight.numel()
         self.diag_m[start_p: start_p+mod.bias.numel()] \
             .add_((gy.sum(dim=(2, 3))**2).sum(dim=0))
     elif mod_class == 'GroupNorm':
         x_normalized = F.group_norm(x,
                                     mod.num_groups,
                                     None,
                                     None,
                                     eps=mod.eps)
         self.diag_m[start_p:start_p+mod.weight.numel()] \
             .add_(((gy * x_normalized).sum(dim=(2, 3))**2).sum(dim=0)
                   .view(-1))
         start_p += mod.weight.numel()
         self.diag_m[start_p: start_p+mod.bias.numel()] \
             .add_((gy.sum(dim=(2, 3))**2).sum(dim=0))
     else:
         raise NotImplementedError
Ejemplo n.º 21
0
def groupnorm(x, norm_style):
    # If number of channels specified in norm_style:
    if 'ch' in norm_style:
        ch = int(norm_style.split('_')[-1])
        groups = max(int(x.shape[1]) // ch, 1)
    # If number of groups specified in norm style
    elif 'grp' in norm_style:
        groups = int(norm_style.split('_')[-1])
    # If neither, default to groups = 16
    else:
        groups = 16
    return F.group_norm(x, groups)
    def forward(self, x):
        x = functional.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)

        if self.activation == "relu":
            return functional.relu(x, inplace=True)
        elif self.activation == "leaky_relu":
            return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True)
        elif self.activation == "elu":
            return functional.elu(x, alpha=self.activation_param, inplace=True)
        elif self.activation == "identity":
            return x
        else:
            raise RuntimeError("Unknown activation function {}".format(self.activation))
Ejemplo n.º 23
0
    def forward(self, x):
        output = F.group_norm(x, self.num_groups, self.weight, self.bias,
                              self.eps)
        size = output.size()

        y = self.attention_weights(x)  # TODO: use output as attention input

        weight = y @ self.weight_
        bias = y @ self.bias_

        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)

        return weight * output + bias
Ejemplo n.º 24
0
    def forward(self, input):
        if self.training:
            # update running estimates
            input_size = list(input.size())
            input = torch.stack(torch.chunk(input, self.num_models, dim=0),
                                dim=0)
            input = input.transpose(1, 2).reshape(self.num_models,
                                                  self.num_features, -1)
            batch_mean = torch.mean(input, dim=-1)
            batch_var = torch.var(input, dim=-1)
            input = input.view([
                self.num_models, self.num_features,
                input_size[0] // self.num_models, *input_size[2:]
            ]).transpose(1, 2).reshape(*input_size)
            self.running_mean -= self.momentum * (self.running_mean -
                                                  batch_mean)
            self.running_var -= self.momentum * (self.running_var - batch_var)

            # Forward pass.
            input = input.permute(1, 0, 2, 3)
            input = F.group_norm(input, self.num_models, None, None, self.eps)
            input = input.permute(1, 0, 2, 3)
            if self.affine:
                num_examples_per_model = input.size(0) // self.num_models
                weight = torch.cat(
                    [self.weight for i in range(num_examples_per_model)],
                    dim=1).view([-1, self.num_features])
                weight = weight.unsqueeze(-1).unsqueeze(-1)
                bias = torch.cat(
                    [self.bias for i in range(num_examples_per_model)],
                    dim=1).view([-1, self.num_features])
                bias = bias.unsqueeze(-1).unsqueeze(-1)
                return input * weight + bias
            else:
                return input
        else:
            inputs = torch.chunk(input, self.num_models, dim=0)
            res = torch.cat([
                F.batch_norm(inputs[i], self.running_mean[i],
                             self.running_var[i], self.weights[i],
                             self.bias[i], False, 0., self.eps)
                for i in range(self.num_models)
            ],
                            dim=0)
            return res
Ejemplo n.º 25
0
    def forward(self, x, c):
        # Normalize
        x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)

        # Condition it
        gamma = 1.0 + self.fc_gamma(c)  # 1 centered
        beta = self.fc_beta(c)
        gamma = gamma.view(gamma.size(0), gamma.size(1), 1, 1)
        beta = beta.view(beta.size(0), beta.size(1), 1, 1)

        if self.affine:
            # learned affine (unnecessary since it can be learned but seems to help)
            weight = self.weight.view(1, -1, 1, 1).repeat(x.size(0), 1, 1, 1)
            bias = self.bias.view(1, -1, 1, 1).repeat(x.size(0), 1, 1, 1)
            gamma = gamma + weight
            beta = beta + bias

        return (gamma * x) + beta
Ejemplo n.º 26
0
    def forward(self, inputs):
        """
        inputs: b*c*h*w for each tensor
        """
        b, c = inputs[0].shape[:2]
        shapes = [x.shape[2:] for x in inputs]
        inputs_reshaped = []
        for x, s in zip(inputs, shapes):
            input_reshaped = x.contiguous().view(b, c, s[0] * s[1], 1)
            inputs_reshaped.append(input_reshaped)
        inputs_reshaped = torch.cat(inputs_reshaped, dim=2)

        outs = F.group_norm(inputs_reshaped, self.num_groups, self.weight,
                            self.bias, self.eps)

        outs = torch.split(outs, [s[0] * s[1] for s in shapes], dim=2)
        outs = [out.view(b, c, s[0], s[1]) for s, out in zip(shapes, outs)]

        return outs
def compute_group_norm_grad_sample(
    layer: nn.GroupNorm,
    A: torch.Tensor,
    B: torch.Tensor,
    batch_dim: int = 0,
) -> None:
    """
    Computes per sample gradients for GroupNorm

    Args:
        layer: Layer
        A: Activations
        B: Backpropagations
        batch_dim: Batch dimension position
    """
    gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B
    create_or_extend_grad_sample(layer.weight, torch.einsum("ni...->ni", gs),
                                 batch_dim)
    if layer.bias is not None:
        create_or_extend_grad_sample(layer.bias, torch.einsum("ni...->ni", B),
                                     batch_dim)
Ejemplo n.º 28
0
    def backward(ctx, grad_output):
        input, num_groups = ctx.input, ctx.num_groups
        weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
        mean, rstd = ctx.mean, ctx.rstd

        results: List[Optional[torch.Tensor]] = []
        results.append(None)  # for kwarg names
        results.append(None)  # for op reference

        if input.requires_grad:
            weight_c = unpack_expanded_weight_or_tensor(
                weight, lambda t: t.contiguous())
            input_c = input.contiguous()
            grad_output_c = grad_output.contiguous(
            ) if grad_output is not None else None
            N = input.shape[0]
            C = input.shape[1]
            HxW = 1
            for s in input.shape[2:]:
                HxW *= s
            bw_fn = torch.ops.aten.native_group_norm_backward
            results.append(
                bw_fn(grad_output_c, input_c, mean, rstd, weight_c, N, C, HxW,
                      num_groups, (True, False, False))[0])
        else:
            results.append(None)

        # weight and bias don't compute batched gradients; no other arguments are differentiable
        results = results + [None] * 4

        # set grad_sample field for weight and bias with per sample gradients
        if hasattr(ctx, "weight"):
            set_grad_sample_if_exists(
                weight, lambda _: torch.einsum(
                    "ni...->ni",
                    F.group_norm(input, num_groups, eps=eps) * grad_output))
        if hasattr(ctx, "bias"):
            set_grad_sample_if_exists(
                bias, lambda _: torch.einsum("ni...->ni", grad_output))
        return tuple(results)
Ejemplo n.º 29
0
 def forward(self, input):
     if self.affine:
         weight = noise_fn(
             self.mu_weight,
             self.sigma_weight,
             self.eps_weight,
             self.sigma_0,
             self.N,
             self.alpha,
         )
         bias = noise_fn(
             self.mu_bias,
             self.sigma_bias,
             self.eps_bias,
             self.sigma_0,
             self.N,
             self.alpha,
         )
     else:
         weight = None
         bias = None
     return F.group_norm(input, self.num_groups, weight, bias, self.eps)
Ejemplo n.º 30
0
 def _hook_compute_layer_blocks(self, mod, grad_input, grad_output):
     mod_class = mod.__class__.__name__
     gy = grad_output[0]
     x = self.xs[mod]
     bs = x.size(0)
     layer_id = self.m_to_l[mod]
     block = self._blocks[layer_id]
     if mod_class == 'Linear':
         gw = torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1)
         if self.layer_collection[layer_id].bias is not None:
             gw = torch.cat([gw.view(bs, -1), gy.view(bs, -1)], dim=1)
         block.add_(torch.mm(gw.t(), gw))
     elif mod_class == 'Conv2d':
         gw = per_example_grad_conv(mod, x, gy).view(bs, -1)
         if self.layer_collection[layer_id].bias is not None:
             gw = torch.cat([gw, gy.sum(dim=(2, 3)).view(bs, -1)], dim=1)
         block.add_(torch.mm(gw.t(), gw))
     elif mod_class == 'BatchNorm1d':
         self._check_bn_training(mod)
         x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var,
                                     None, None, mod.training)
         gw = gy * x_normalized
         gw = torch.cat([gw, gy], dim=1)
         block.add_(torch.mm(gw.t(), gw))
     elif mod_class == 'BatchNorm2d':
         self._check_bn_training(mod)
         x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var,
                                     None, None, mod.training)
         gw = (gy * x_normalized).sum(dim=(2, 3))
         gw = torch.cat([gw, gy.sum(dim=(2, 3))], dim=1)
         block.add_(torch.mm(gw.t(), gw))
     elif mod_class == 'GroupNorm':
         x_normalized = F.group_norm(x, mod.num_groups, None, None, mod.eps)
         gw = (gy * x_normalized).sum(dim=(2, 3))
         gw = torch.cat([gw, gy.sum(dim=(2, 3))], dim=1)
         block.add_(torch.mm(gw.t(), gw))
     else:
         raise NotImplementedError