def test_group_norm_error(self, device): # group norm has to call native_group_norm. This checks that it hits the same errors # that normal group norm would N = 3 C = 5 inp = torch.randn(N, C) with self.assertRaisesRegex(RuntimeError, r"Expected number of channels in input to be divisible"): F.group_norm(inp, 2) # 5 is not divisible by 2
def forward(self, x): skip = [] res = x for block in self.conv_blocks: x = block(x) res = res + x res = F.group_norm(res, 8) skip.append(res) skip = torch.stack(skip).sum(0) skip = F.group_norm(skip, 8) return skip
def forward(self, input): num_channels = int(self.num_channels * self.sr_in_list[self.sr_idx]) weight, bias = self.weight[:num_channels], self.bias[:num_channels] return F.group_norm( input, round(num_channels * self.num_groups / float(self.num_channels)), weight, bias, self.eps)
def forward(self, input, pad_mask=None, is_encoder=False): # input: only reudce over the C dim. shaped_input = (len(input.shape) == 2) if shaped_input: input = input.unsqueeze(0) T, B, C = input.shape # Permute the mask_input to (B, T, C) # mask_input = input.transpose(0, 1) # # Compute the mean, var for LN, size to be BxTx1 -> BxCxT # # Split the mask_input into group # gn_input = mask_input.view(B, T, self.num_groups, self.group_feature) # gn_input = gn_input.permute(1, 2, 3, 0).contiguous().view(T, self.num_groups, self.group_feature * B) # # TxGx1 -> TxC -> BxTxC -> BxCxT # mean_gn = tile(gn_input.mean(-1, keepdim=True).squeeze(-1), self.group_feature, -1).expand_as(mask_input).transpose(1, 2) # var_gn = tile(gn_input.var(-1, keepdim=True).squeeze(-1), self.group_feature, -1).expand_as(mask_input).transpose(1, 2) # # # Resize the input to (B, C, -1). # input = input.permute(1, 2, 0).contiguous() # input_shape = input.size() # input = input.view(input.size(0), self.num_channels, -1) # # input = (input - mean_gn) / (var_gn + self.eps).sqrt() # input = input * (self.weight).unsqueeze(-1) + (self.bias).unsqueeze(-1) # input = input.view(B, C, T) # input = input.permute(2, 0, 1).contiguous() # return input input = input.contiguous().view(T * B, C) input = F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) input = input.contiguous().view(T, B, C) if shaped_input: input = input.squeeze(0) return input
def _hook_compute_trace(self, mod, grad_input, grad_output): mod_class = mod.__class__.__name__ gy = grad_output[0] x = self.xs[mod] if mod_class == 'Linear': self._trace += torch.mm(gy.t()**2, x**2).sum() if mod.bias is not None: self._trace += (gy**2).sum() elif mod_class == 'Conv2d': indiv_gw = per_example_grad_conv(mod, x, gy) self._trace += (indiv_gw**2).sum() if mod.bias is not None: self._trace += (gy.sum(dim=(2, 3))**2).sum() elif mod_class == 'BatchNorm1d': self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) self._trace += (gy**2 * x_normalized**2).sum() self._trace += (gy**2).sum() elif mod_class == 'BatchNorm2d': self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) self._trace += ((gy * x_normalized).sum(dim=(2, 3))**2).sum() self._trace += (gy.sum(dim=(2, 3))**2).sum() elif mod_class == 'GroupNorm': x_normalized = F.group_norm(x, mod.num_groups, None, None, mod.eps) self._trace += ((gy * x_normalized).sum(dim=(2, 3))**2).sum() self._trace += (gy.sum(dim=(2, 3))**2).sum() else: raise NotImplementedError
def forward(self, instrument, x): """ Arguments: instrument {torch.tensor} -- Instrument embedding of shape (4, E_1) x {torch.tensor} -- Input of the groupnorm of shape (B, 4, C, T) Returns: torch.tensor -- Output of the groupnorm of shape (B, 4, C, T) """ batch_size = x.shape[0] instrument = self.bottleneck(instrument) # shape: (4, E_2) affine = self.affine(instrument) # shape: (4, 2*C) scale = affine[:, :self.num_channels].contiguous().view( -1) # shape: (4*C) bias = affine[:, self.num_channels:].contiguous().view(-1) # shape: (4*C) x = x.view(batch_size, 4 * self.num_channels, -1) # shape: (B, 4*C, T) x = F.group_norm(x, 4 * self.num_groups, scale, bias, self.eps) # shape: (B, 4*C, T) x = x.view(batch_size, 4, self.num_channels, -1) # shape: (B, 4, C, T) return x # shape: (B, 4, C, T)
def forward(self, x): x = functional.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) if self.activation == ACT_RELU: return functional.relu(x, inplace=True) elif self.activation == ACT_RELU6: return functional.relu6(x, inplace=True) elif self.activation == ACT_LEAKY_RELU: return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) elif self.activation == ACT_ELU: return functional.elu(x, inplace=True) elif self.activation == ACT_SELU: return functional.selu(x, inplace=True) elif self.activation == ACT_SWISH: return swish(x) elif self.activation == ACT_HARD_SWISH: return hard_swish(x, inplace=True) elif self.activation == ACT_HARD_SIGMOID: return hard_sigmoid(x, inplace=True) elif self.activation == ACT_NONE: return x else: raise KeyError(self.activation)
def forward(self, input): gn_input = F.group_norm(input, self.num_groups, None, None, self.eps) sigmoid_weight = torch.sigmoid(input * self.sigmoid_weight.view(1, -1, 1, 1)) return gn_input * sigmoid_weight * self.weight.view( 1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
def forward(self, x): # in F.batch_norm `training` regulates whether to use batch stats of buffer stats # if `training` is True and buffers are given, they always would be updated! use_batch_stats = self.training and not self.estimated_stats x = F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, use_batch_stats, self.momentum, self.eps, ) if self.training and self.estimated_stats: with torch.no_grad(): # not sure if needed but just in case # PyTorch BN uses biased var by default var, mean = torch.var_mean(x, dim=(0, 2, 3), unbiased=False) self.running_mean = self.running_mean.mul( 1 - self.momentum).add(mean, alpha=self.momentum) self.running_var = self.running_var.mul(1 - self.momentum).add( var, alpha=self.momentum) x = F.group_norm(x, self.num_groups, self.weight_gn, self.bias_gn, self.eps) func = ACT_FUNC_DICT[self.activation] if self.activation == ACT.LEAKY_RELU: return func(x, inplace=True, negative_slope=self.activation_param) elif self.activation == ACT.ELU: return func(x, inplace=True, alpha=self.activation_param) else: return func(x, inplace=True)
def forward(self, x): if self.use_custom: for group_idx in self.group_idxs: # Select the group of channels and normalize together group = torch.index_select(x, dim=1, index=group_idx) group_mean = torch.mean(group) group_var = torch.var(group) # Normalize for i in group_idx: x[:, i] -= group_mean x[:, i] /= torch.sqrt(group_var + self.eps) if self.affine: # Scale and shift if self.is_3d: x = x * self.weight.view(-1, 1, 1, 1) + self.bias.view( -1, 1, 1, 1) else: x = x * self.weight.view(-1, 1, 1) + self.bias.view( -1, 1, 1) return x else: return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
def forward(self, x, y, z, w0, b0, w1, b1, w2, b2): x = F.group_norm(x, 2, w0, b0) x = F.group_norm(x, 1, None, None) x = F.group_norm(x, 4, self.w3, self.b3) y = F.group_norm(y, 3, w1, b1, eps=1e-4) y = F.group_norm(y, 4, None, None) y = F.group_norm(y, 6, self.w4, self.b4) z = F.group_norm(z, 32, w2, b2) z = F.group_norm(z, 4, None, None, eps=1e-2) z = F.group_norm(z, 8, self.w5, self.b5) return x, y, z
def forward(self, x, **kwargs): if hasattr(self, 'in_sequence') and self.in_sequence: x = x.permute(0, 2, 1) x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) if hasattr(self, 'in_sequence') and self.in_sequence: x = x.permute(0, 2, 1) return x
def forward(self, *x): x = enforce_singleton(x) if hasattr(self, 'in_sequence') and self.in_sequence: x = x.permute(0, 2, 1) x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) if hasattr(self, 'in_sequence') and self.in_sequence: x = x.permute(0, 2, 1) return x
def forward(self, x, skip_x1): out_up_conv = self.up_conv(x) out_up_conv = self.activation(F.group_norm(out_up_conv, 8)) concat = torch.cat((out_up_conv, skip_x1), 1) out = self.ops1(concat) out = self.ops2(out) return out
def forward(self, x): x = self.ops1(x) out_before_pool = self.ops2(x) out = self.pool(out_before_pool) out = self.activation(F.group_norm(out, 8)) return out, out_before_pool
def forward(self, input): # training mode: if self.train: input_norm = F.group_norm(input, self.num_groups, None, None, self.eps) if self.weight is None or self.bias is None: return input_norm else: return affine( input_norm, self.weight, self.bias, self.clip, self.std ) # only the affine transform needs private gradients # inference mode: else: return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
def forward(self, x): x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) func = ACT_FUNC_DICT[self.activation] if self.activation == ACT.LEAKY_RELU: return func(x, inplace=True, negative_slope=self.activation_param) elif self.activation == ACT.ELU: return func(x, inplace=True, alpha=self.activation_param) else: return func(x, inplace=True)
def forward(self, input): output = F.group_norm( input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ) return output.type_as(input)
def forward(self, input): input = input.permute(1, 0, 2, 3) input = F.group_norm(input, self.num_groups, None, None, self.eps) input = input.permute(1, 0, 2, 3) if self.affine: weight = self.weight.unsqueeze(-1).unsqueeze(-1) bias = self.bias.unsqueeze(-1).unsqueeze(-1) return input * weight + bias else: return input
def _hook_compute_diag(self, mod, grad_input, grad_output): mod_class = mod.__class__.__name__ gy = grad_output[0] x = self.xs[mod] layer_id = self.m_to_l[mod] start_p = self.layer_collection.p_pos[layer_id] if mod_class == 'Linear': self.diag_m[start_p:start_p+mod.weight.numel()] \ .add_(torch.mm(gy.t()**2, x**2).view(-1)) if self.layer_collection[layer_id].bias is not None: start_p += mod.weight.numel() self.diag_m[start_p: start_p+mod.bias.numel()] \ .add_((gy**2).sum(dim=0)) elif mod_class == 'Conv2d': indiv_gw = per_example_grad_conv(mod, x, gy) self.diag_m[start_p:start_p+mod.weight.numel()] \ .add_((indiv_gw**2).sum(dim=0).view(-1)) if self.layer_collection[layer_id].bias is not None: start_p += mod.weight.numel() self.diag_m[start_p:start_p+mod.bias.numel()] \ .add_((gy.sum(dim=(2, 3))**2).sum(dim=0)) elif mod_class == 'BatchNorm1d': self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) self.diag_m[start_p:start_p+mod.weight.numel()] \ .add_((gy**2 * x_normalized**2).sum(dim=0).view(-1)) start_p += mod.weight.numel() self.diag_m[start_p: start_p+mod.bias.numel()] \ .add_((gy**2).sum(dim=0)) elif mod_class == 'BatchNorm2d': self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) self.diag_m[start_p:start_p+mod.weight.numel()] \ .add_(((gy * x_normalized).sum(dim=(2, 3))**2).sum(dim=0) .view(-1)) start_p += mod.weight.numel() self.diag_m[start_p: start_p+mod.bias.numel()] \ .add_((gy.sum(dim=(2, 3))**2).sum(dim=0)) elif mod_class == 'GroupNorm': x_normalized = F.group_norm(x, mod.num_groups, None, None, eps=mod.eps) self.diag_m[start_p:start_p+mod.weight.numel()] \ .add_(((gy * x_normalized).sum(dim=(2, 3))**2).sum(dim=0) .view(-1)) start_p += mod.weight.numel() self.diag_m[start_p: start_p+mod.bias.numel()] \ .add_((gy.sum(dim=(2, 3))**2).sum(dim=0)) else: raise NotImplementedError
def groupnorm(x, norm_style): # If number of channels specified in norm_style: if 'ch' in norm_style: ch = int(norm_style.split('_')[-1]) groups = max(int(x.shape[1]) // ch, 1) # If number of groups specified in norm style elif 'grp' in norm_style: groups = int(norm_style.split('_')[-1]) # If neither, default to groups = 16 else: groups = 16 return F.group_norm(x, groups)
def forward(self, x): x = functional.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) if self.activation == "relu": return functional.relu(x, inplace=True) elif self.activation == "leaky_relu": return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True) elif self.activation == "elu": return functional.elu(x, alpha=self.activation_param, inplace=True) elif self.activation == "identity": return x else: raise RuntimeError("Unknown activation function {}".format(self.activation))
def forward(self, x): output = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) size = output.size() y = self.attention_weights(x) # TODO: use output as attention input weight = y @ self.weight_ bias = y @ self.bias_ weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) return weight * output + bias
def forward(self, input): if self.training: # update running estimates input_size = list(input.size()) input = torch.stack(torch.chunk(input, self.num_models, dim=0), dim=0) input = input.transpose(1, 2).reshape(self.num_models, self.num_features, -1) batch_mean = torch.mean(input, dim=-1) batch_var = torch.var(input, dim=-1) input = input.view([ self.num_models, self.num_features, input_size[0] // self.num_models, *input_size[2:] ]).transpose(1, 2).reshape(*input_size) self.running_mean -= self.momentum * (self.running_mean - batch_mean) self.running_var -= self.momentum * (self.running_var - batch_var) # Forward pass. input = input.permute(1, 0, 2, 3) input = F.group_norm(input, self.num_models, None, None, self.eps) input = input.permute(1, 0, 2, 3) if self.affine: num_examples_per_model = input.size(0) // self.num_models weight = torch.cat( [self.weight for i in range(num_examples_per_model)], dim=1).view([-1, self.num_features]) weight = weight.unsqueeze(-1).unsqueeze(-1) bias = torch.cat( [self.bias for i in range(num_examples_per_model)], dim=1).view([-1, self.num_features]) bias = bias.unsqueeze(-1).unsqueeze(-1) return input * weight + bias else: return input else: inputs = torch.chunk(input, self.num_models, dim=0) res = torch.cat([ F.batch_norm(inputs[i], self.running_mean[i], self.running_var[i], self.weights[i], self.bias[i], False, 0., self.eps) for i in range(self.num_models) ], dim=0) return res
def forward(self, x, c): # Normalize x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) # Condition it gamma = 1.0 + self.fc_gamma(c) # 1 centered beta = self.fc_beta(c) gamma = gamma.view(gamma.size(0), gamma.size(1), 1, 1) beta = beta.view(beta.size(0), beta.size(1), 1, 1) if self.affine: # learned affine (unnecessary since it can be learned but seems to help) weight = self.weight.view(1, -1, 1, 1).repeat(x.size(0), 1, 1, 1) bias = self.bias.view(1, -1, 1, 1).repeat(x.size(0), 1, 1, 1) gamma = gamma + weight beta = beta + bias return (gamma * x) + beta
def forward(self, inputs): """ inputs: b*c*h*w for each tensor """ b, c = inputs[0].shape[:2] shapes = [x.shape[2:] for x in inputs] inputs_reshaped = [] for x, s in zip(inputs, shapes): input_reshaped = x.contiguous().view(b, c, s[0] * s[1], 1) inputs_reshaped.append(input_reshaped) inputs_reshaped = torch.cat(inputs_reshaped, dim=2) outs = F.group_norm(inputs_reshaped, self.num_groups, self.weight, self.bias, self.eps) outs = torch.split(outs, [s[0] * s[1] for s in shapes], dim=2) outs = [out.view(b, c, s[0], s[1]) for s, out in zip(shapes, outs)] return outs
def compute_group_norm_grad_sample( layer: nn.GroupNorm, A: torch.Tensor, B: torch.Tensor, batch_dim: int = 0, ) -> None: """ Computes per sample gradients for GroupNorm Args: layer: Layer A: Activations B: Backpropagations batch_dim: Batch dimension position """ gs = F.group_norm(A, layer.num_groups, eps=layer.eps) * B create_or_extend_grad_sample(layer.weight, torch.einsum("ni...->ni", gs), batch_dim) if layer.bias is not None: create_or_extend_grad_sample(layer.bias, torch.einsum("ni...->ni", B), batch_dim)
def backward(ctx, grad_output): input, num_groups = ctx.input, ctx.num_groups weight, bias, eps = ctx.weight, ctx.bias, ctx.eps mean, rstd = ctx.mean, ctx.rstd results: List[Optional[torch.Tensor]] = [] results.append(None) # for kwarg names results.append(None) # for op reference if input.requires_grad: weight_c = unpack_expanded_weight_or_tensor( weight, lambda t: t.contiguous()) input_c = input.contiguous() grad_output_c = grad_output.contiguous( ) if grad_output is not None else None N = input.shape[0] C = input.shape[1] HxW = 1 for s in input.shape[2:]: HxW *= s bw_fn = torch.ops.aten.native_group_norm_backward results.append( bw_fn(grad_output_c, input_c, mean, rstd, weight_c, N, C, HxW, num_groups, (True, False, False))[0]) else: results.append(None) # weight and bias don't compute batched gradients; no other arguments are differentiable results = results + [None] * 4 # set grad_sample field for weight and bias with per sample gradients if hasattr(ctx, "weight"): set_grad_sample_if_exists( weight, lambda _: torch.einsum( "ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output)) if hasattr(ctx, "bias"): set_grad_sample_if_exists( bias, lambda _: torch.einsum("ni...->ni", grad_output)) return tuple(results)
def forward(self, input): if self.affine: weight = noise_fn( self.mu_weight, self.sigma_weight, self.eps_weight, self.sigma_0, self.N, self.alpha, ) bias = noise_fn( self.mu_bias, self.sigma_bias, self.eps_bias, self.sigma_0, self.N, self.alpha, ) else: weight = None bias = None return F.group_norm(input, self.num_groups, weight, bias, self.eps)
def _hook_compute_layer_blocks(self, mod, grad_input, grad_output): mod_class = mod.__class__.__name__ gy = grad_output[0] x = self.xs[mod] bs = x.size(0) layer_id = self.m_to_l[mod] block = self._blocks[layer_id] if mod_class == 'Linear': gw = torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1) if self.layer_collection[layer_id].bias is not None: gw = torch.cat([gw.view(bs, -1), gy.view(bs, -1)], dim=1) block.add_(torch.mm(gw.t(), gw)) elif mod_class == 'Conv2d': gw = per_example_grad_conv(mod, x, gy).view(bs, -1) if self.layer_collection[layer_id].bias is not None: gw = torch.cat([gw, gy.sum(dim=(2, 3)).view(bs, -1)], dim=1) block.add_(torch.mm(gw.t(), gw)) elif mod_class == 'BatchNorm1d': self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) gw = gy * x_normalized gw = torch.cat([gw, gy], dim=1) block.add_(torch.mm(gw.t(), gw)) elif mod_class == 'BatchNorm2d': self._check_bn_training(mod) x_normalized = F.batch_norm(x, mod.running_mean, mod.running_var, None, None, mod.training) gw = (gy * x_normalized).sum(dim=(2, 3)) gw = torch.cat([gw, gy.sum(dim=(2, 3))], dim=1) block.add_(torch.mm(gw.t(), gw)) elif mod_class == 'GroupNorm': x_normalized = F.group_norm(x, mod.num_groups, None, None, mod.eps) gw = (gy * x_normalized).sum(dim=(2, 3)) gw = torch.cat([gw, gy.sum(dim=(2, 3))], dim=1) block.add_(torch.mm(gw.t(), gw)) else: raise NotImplementedError