def __init__(self, in_channel, out_channel): super(ResidualBlock, self).__init__() self.in_channel = in_channel self.out_channel = out_channel self.relu = nn.ReLU() if in_channel != out_channel: self.downsample = nn.Conv2d(in_channel, out_channel, 1, bias=False) self.conv = nn.Sequential( nn.Conv2d(in_channel, in_channel, 3, padding=1, bias=False), nn.SyncBatchNorm(in_channel), nn.Conv2d(in_channel, out_channel, 1, bias=False), nn.SyncBatchNorm(out_channel) ) else: self.downsample = nn.Sequential() self.conv = nn.Sequential( nn.Conv2d(in_channel, in_channel // 4, 1, bias=False), nn.SyncBatchNorm(in_channel // 4), nn.Conv2d(in_channel // 4, in_channel // 4, 3, padding=1, bias=False), nn.SyncBatchNorm(in_channel // 4), nn.Conv2d(in_channel // 4, out_channel, 1, bias=False), nn.SyncBatchNorm(out_channel) ) for m in self.modules(): if isinstance(m, nn.SyncBatchNorm): m._specify_ddp_gpu_num(1)
def _make_self_attention(self, in_channels, latent_channels): key_transform = nn.Sequential( nn.Conv2d(in_channels, latent_channels, kernel_size=1, stride=1, padding=0), nn.SyncBatchNorm(latent_channels), nn.ReLU(inplace=True)) query_transform = nn.Sequential( nn.Conv2d(in_channels, latent_channels, kernel_size=1, stride=1, padding=0), nn.SyncBatchNorm(latent_channels), nn.ReLU(inplace=True)) down_transform = nn.Conv2d(in_channels, latent_channels, kernel_size=1, stride=1, padding=0) up_transform = nn.Conv2d(latent_channels, in_channels, kernel_size=1, stride=1, padding=0) return key_transform, query_transform, down_transform, up_transform
def __init__(self, opts): super(ConvNetwork, self).__init__() if not (opts.num_classes == opts.embedding_size): embedding_size = opts.embedding_size opts.__dict__['num_classes'] = embedding_size # opts = {**opts, 'num_classes':embedding_size} self.opts = opts input_channels = opts.embedding_size num_planes = opts.num_planes enc_features = opts.mpi_encoder_features self.input_channels = input_channels self.num_classes = opts.num_classes self.num_planes = num_planes self.out_seg_chans = self.num_classes self.discriptor_net = BaseEncoderDecoder(input_channels) self.base_res_layers = nn.Sequential( *[ResBlock(enc_features, 3) for i in range(2)]) self.blending_alpha_seg_pred = nn.Sequential( ResBlock(enc_features, 3), ResBlock(enc_features, 3), nn.SyncBatchNorm(enc_features), ConvBlock(enc_features, int(num_planes * (self.out_seg_chans + 1)) // 2, 3, down_sample=False), nn.SyncBatchNorm(int(num_planes * (self.out_seg_chans + 1)) // 2), ConvBlock(int(num_planes * (self.out_seg_chans + 1)) // 2, int(num_planes * (self.out_seg_chans + 1)), 3, down_sample=False, use_no_relu=True))
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False, separable=True): super(SeparableConv2d, self).__init__() self.separable = separable if self.separable: self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias) self.depthwise_bn = nn.SyncBatchNorm(inplanes, eps=1e-05, momentum=0.0003) self.depthwise_relu = nn.ReLU(inplace=True) self.pointwise = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=bias) self.pointwise_bn = nn.SyncBatchNorm(planes, eps=1e-05, momentum=0.0003) self.pointwise_relu = nn.ReLU(inplace=True) else: self.conv = nn.Conv2d(inplanes, planes, kernel_size, stride, padding, dilation, bias=bias) self.bn = nn.SyncBatchNorm(planes, eps=1e-05, momentum=0.0003) self.relu = nn.ReLU(inplace=True) self._init_weight()
def __init__(self, opts): super(MulLayerConvNetwork, self).__init__() self.opts = opts input_channels = opts.num_classes num_planes = opts.num_planes enc_features = opts.mpi_encoder_features self.input_channels = input_channels self.num_classes = opts.num_classes self.num_planes = num_planes self.out_seg_chans = self.opts.embedding_size self.discriptor_net = BaseEncoderDecoder(input_channels) self.base_res_layers = nn.Sequential( *[ResBlock(enc_features, 3) for i in range(2)]) # we will re-use the input semantics total_seg_channels = (self.opts.num_layers - 1) * self.out_seg_chans total_alpha_channels = num_planes self.total_seg_channels = total_seg_channels self.total_alpha_channels = total_alpha_channels self.total_beta_channels = num_planes * self.opts.num_layers total_output_channels = total_seg_channels + \ total_alpha_channels + self.total_beta_channels self.blending_alpha_seg_beta_pred = nn.Sequential( ResBlock(enc_features, 3), ResBlock(enc_features, 3), nn.SyncBatchNorm(enc_features), ConvBlock(enc_features, total_output_channels // 2, 3, down_sample=False), nn.SyncBatchNorm(total_output_channels // 2), ConvBlock(total_output_channels // 2, total_output_channels, 3, down_sample=False, use_no_relu=True))
def __init__(self, cfg): super(SimSiam, self).__init__() self.device = torch.device(cfg.MODEL.DEVICE) self.proj_dim = cfg.MODEL.BYOL.PROJ_DIM self.pred_dim = cfg.MODEL.BYOL.PRED_DIM self.out_dim = cfg.MODEL.BYOL.OUT_DIM self.total_steps = cfg.SOLVER.LR_SCHEDULER.MAX_ITER * cfg.SOLVER.BATCH_SUBDIVISIONS # create the encoders # num_classes is the output fc dimension cfg.MODEL.RESNETS.NUM_CLASSES = self.out_dim self.encoder = cfg.build_backbone( cfg, input_shape=ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))) self.encoder.stem = nn.Sequential( Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False, norm=get_norm(cfg.MODEL.RESNETS.NORM, 64)), nn.ReLU(), ) self.size_divisibility = self.encoder.size_divisibility dim_mlp = self.encoder.linear.weight.shape[1] # Projection Head self.encoder.linear = nn.Sequential( nn.Linear(dim_mlp, self.proj_dim), nn.SyncBatchNorm(self.proj_dim), nn.ReLU(), nn.Linear(self.proj_dim, self.proj_dim), nn.SyncBatchNorm(self.proj_dim), ) # Predictor self.predictor = nn.Sequential( nn.Linear(self.proj_dim, self.pred_dim), nn.SyncBatchNorm(self.pred_dim), nn.ReLU(), nn.Linear(self.pred_dim, self.out_dim), ) pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view( 1, 3, 1, 1) pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view( 1, 3, 1, 1) self.normalizer = lambda x: (x / 255.0 - pixel_mean) / pixel_std self.to(self.device)
def test_batchnorm(self): batch = nn.BatchNorm2d(1) sync = nn.SyncBatchNorm(1) instance = nn.SyncBatchNorm(1) sequence = nn.Sequential(batch, sync, instance) not_decay, decay = add_weight_decay(sequence, 9.0) # Both weights and biases of BatchNorm go to weight decay. assert len(not_decay["params"]) == 6 assert len(decay["params"]) == 0
def __init__(self, in_channels, out_channels, mid_channels=None, running_on_gpu=True): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), nn.SyncBatchNorm(mid_channels) if running_on_gpu else nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), nn.SyncBatchNorm(out_channels) if running_on_gpu else nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) )
def __init__(self): super(DMNet, self).__init__() self.conv1 = nn.Sequential(SeparableConv2d(3, 4, kernel_size=3, stride=1, padding=1, bias=True), nn.SyncBatchNorm(4), nn.ReLU(inplace=True)) self.conv2 = nn.Sequential(SeparableConv2d(4, 8, kernel_size=3, stride=1, padding=1, bias=True), nn.SyncBatchNorm(8), nn.ReLU(inplace=True)) self.conv3 = nn.Sequential(SeparableConv2d(8, 16, kernel_size=3, stride=1, padding=1, bias=True), nn.SyncBatchNorm(16), nn.ReLU(inplace=True)) self.conv4 = nn.Conv2d(16, 16, kernel_size=1, stride=1, padding=0, bias=True) self.weight_init()
def __init__(self, n_classes=21): super(deeplabv3plus, self).__init__() # Atrous Conv self.xception_features = Xception() # ASPP rates = [1, 6, 12, 18] self.aspp0 = ASPP_module(2048, 256, rate=rates[0], separable=False) self.aspp1 = ASPP_module(2048, 256, rate=rates[1]) self.aspp2 = ASPP_module(2048, 256, rate=rates[2]) self.aspp3 = ASPP_module(2048, 256, rate=rates[3]) self.image_pooling = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(2048, 256, 1, stride=1, bias=False), nn.SyncBatchNorm(256, eps=1e-05, momentum=0.0003), nn.ReLU(inplace=True)) self.concat_projection = nn.Sequential( nn.Conv2d(1280, 256, 1, stride=1, bias=False), nn.SyncBatchNorm(256, eps=1e-05, momentum=0.0003), nn.ReLU(inplace=True), nn.Dropout2d(p=0.1)) # adopt [1x1, 48] for channel reduction. self.feature_projection0_conv = nn.Conv2d(256, 48, 1, bias=False) self.feature_projection0_bn = nn.SyncBatchNorm(48, eps=1e-03, momentum=0.0003) self.feature_projection0_relu = nn.ReLU(inplace=True) self.decoder = nn.Sequential( SeparableConv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), SeparableConv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)) self.logits = nn.Conv2d(256, n_classes, kernel_size=1, stride=1, padding=0, bias=True)
def __init__(self, C_in, C_out, affine=True, bn=False, **kwargs): super(Normal_Relu_Conv, self).__init__() if not bn: op = nn.Sequential( # nn.ReLU(), nn.Conv2d(C_in, C_in, bias=True, **kwargs), ) else: if cfg['GN']: bn_layer = nn.GroupNorm(32, C_out) elif cfg["syncBN"]: bn_layer = nn.SyncBatchNorm(C_out) else: bn_layer = nn.BatchNorm2d(C_out) op = nn.Sequential( # nn.ReLU(), nn.Conv2d(C_in, C_in, bias=False, **kwargs), bn_layer, ) if RELU_FIRST: self.op = nn.Sequential() self.op.add_module('0', nn.ReLU()) for i in range(1, len(op)+1): self.op.add_module(str(i), op[i-1]) else: self.op = op self.op.add_module(str(len(op)), nn.ReLU())
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, bn=False): super(SepConv, self).__init__() if not bn: op = nn.Sequential( # nn.ReLU(), nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=True,), nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True), ) else: if cfg['GN']: bn_layer = nn.GroupNorm(32, C_out) elif cfg["syncBN"]: bn_layer = nn.SyncBatchNorm(C_out) else: bn_layer = nn.BatchNorm2d(C_out) op = nn.Sequential( # nn.ReLU(), nn.Conv2d( C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False, ), nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), bn_layer, ) if RELU_FIRST: self.op = nn.Sequential(nn.ReLU()) # self.op.add_module('0', nn.ReLU()) for i in range(1, len(op)+1): self.op.add_module(str(i), op[i-1]) else: self.op = op self.op.add_module(str(len(op)), nn.ReLU())
def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, kernel_size=3, ): super().__init__() out_features = out_features or in_features padding = kernel_size // 2 self.conv1 = torch.nn.Conv2d( in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, ) self.act = act_layer() self.bn = nn.SyncBatchNorm(in_features) self.conv2 = torch.nn.Conv2d( in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, )
def __init__(self, inplanes, depth_list, skip_connection_type, stride, unit_rate_list=None, dilation=1, activation_fn_in_separable_conv=True, low_level_features=False): super(xception_module, self).__init__() if len(depth_list) != 3: raise ValueError('Expect three elements in depth_list.') if unit_rate_list: if len(unit_rate_list) != 3: raise ValueError('Expect three elements in unit_rate_list.') else: unit_rate_list = [1, 1, 1] residual = inplanes self.separable_conv1 = SeparableConv2d_same( residual, depth_list[0], kernel_size=3, stride=1, dilation=dilation * unit_rate_list[0], activation_fn_in_separable_conv=activation_fn_in_separable_conv) residual = depth_list[0] self.separable_conv2 = SeparableConv2d_same( residual, depth_list[1], kernel_size=3, stride=1, dilation=dilation * unit_rate_list[1], activation_fn_in_separable_conv=activation_fn_in_separable_conv) residual = depth_list[1] self.separable_conv3 = SeparableConv2d_same( residual, depth_list[2], kernel_size=3, stride=stride, dilation=dilation * unit_rate_list[2], activation_fn_in_separable_conv=activation_fn_in_separable_conv) shortcut_list = [] if skip_connection_type == 'conv': shortcut_list.append( nn.Conv2d(inplanes, depth_list[-1], kernel_size=1, stride=stride, bias=False)) shortcut_list.append( nn.SyncBatchNorm(depth_list[-1], eps=1e-03, momentum=0.0003)) self.shortcut = nn.Sequential(*shortcut_list) self.skip_connection_type = skip_connection_type self.low_level_features = low_level_features self._init_weight()
def add_norm_layer(layer, opt): nonlocal norm_type if norm_type.startswith('spectral'): layer = spectral_norm(layer) subnorm_type = norm_type[len('spectral'):] else: subnorm_type = norm_type if subnorm_type == 'none' or len(subnorm_type) == 0: return layer # remove bias in the previous layer, which is meaningless # since it has no effect after normalization if getattr(layer, 'bias', None) is not None: delattr(layer, 'bias') layer.register_parameter('bias', None) if subnorm_type == 'instance': norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) elif subnorm_type == 'sync_batch' and opt.mpdist: norm_layer = nn.SyncBatchNorm(get_out_channel(layer), affine=True) else: norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) return nn.Sequential(layer, norm_layer)
def norm(channel, eps=1e-5, args=None, keyword=None, feature_stride=None, affine=False): if args is not None and keyword is None: keyword = getattr(args, "keyword", None) if keyword is None: return nn.BatchNorm2d(channel) if "group-norm" in keyword: group = getattr(args, "fm_quant_group", 32) return nn.GroupNorm(group, channel) if "static-bn" in keyword: return StaticBatchNorm2d(channel, args=args) if "freeze-bn" in keyword: return FrozenBatchNorm2d(channel) if "reverse-bn" in keyword: return ReverseBatchNorm2d(channel) if "sync-bn" in keyword: return nn.SyncBatchNorm(channel) if "instance-norm" in keyword: return nn.InstanceNorm2d(channel, affine=affine) return nn.BatchNorm2d(channel)
def add_norm_layer(layer): nonlocal norm_type if norm_type.startswith('spectral'): layer = spectral_norm(layer) subnorm_type = norm_type[len('spectral'):] if subnorm_type == 'none' or len(subnorm_type) == 0: return layer # remove bias in the previous layer, which is meaningless # since it has no effect after normalization if getattr(layer, 'bias', None) is not None: delattr(layer, 'bias') layer.register_parameter('bias', None) # print(subnorm_type) # print(subnorm_type) # exit() if subnorm_type == 'batch': norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) elif subnorm_type == 'sync_batch': # synch batch norm is dropped in favor of pytorch's synch_batch_norm utility norm_layer = nn.SyncBatchNorm(get_out_channel(layer), affine=True) elif subnorm_type == 'instance': norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) else: raise ValueError('normalization layer %s is not recognized' % subnorm_type) return nn.Sequential(layer, norm_layer)
def __init__(self): super().__init__() self.lin = nn.Linear(10, 10, bias=False) self.bn1 = nn.BatchNorm1d(10) self.bn2 = nn.BatchNorm2d(10) self.bn3 = nn.BatchNorm3d(10) self.sync_bn = nn.SyncBatchNorm(10)
def _make_fuse_layers(self): fuse_layers = [] for i in range(self.num_branches): fuse_layer = [] for j in range(self.num_branches): if j > i: fuse_layer.append( nn.Sequential( nn.Conv2d(self.num_inchannels[j], self.num_inchannels[i], 1, 1, 0, bias=False), nn.SyncBatchNorm(self.num_inchannels[i]))) elif j == i: fuse_layer.append(None) else: conv3x3s = [] for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = self.num_inchannels[i] conv3x3s.append( nn.Sequential( nn.Conv2d(self.num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.SyncBatchNorm(num_outchannels_conv3x3))) else: num_outchannels_conv3x3 = self.num_inchannels[j] conv3x3s.append( nn.Sequential( nn.Conv2d(self.num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.SyncBatchNorm(num_outchannels_conv3x3), nn.ReLU(inplace=True))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers)
def __init__(self, num_features, process_group=None): super(GBN, self).__init__() if process_group is None: self.bn = nn.BatchNorm2d(num_features) else: self.bn = nn.SyncBatchNorm(num_features, process_group=process_group)
def __init__(self, in_channels, out_channels, sync=False, **kwargs): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) if sync: # for sync bn print('use sync inception') self.bn = nn.SyncBatchNorm(out_channels, eps=0.001) else: self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def __init__(self, channels, relu=True, affine=True, norm_type='bn'): super().__init__() self.relu = relu if norm_type == 'bn': self.norm = nn.SyncBatchNorm(channels) elif norm_type == 'ln': self.norm = ChannelLayerNorm(channels) else: self.bn = nn.Identity()
def Norm_layer(norm_cfg, inplanes): if norm_cfg == 'BN': out = nn.BatchNorm3d(inplanes) elif norm_cfg == 'SyncBN': out = nn.SyncBatchNorm(inplanes) elif norm_cfg == 'GN': out = nn.GroupNorm(16, inplanes) elif norm_cfg == 'IN': out = nn.InstanceNorm3d(inplanes, affine=True) return out
def get_norm_layer(norm_layer, num_channels, num_groups=None): if norm_layer == nn.BatchNorm2d: return nn.BatchNorm2d(num_channels) elif norm_layer == nn.GroupNorm: return nn.GroupNorm(num_groups, num_channels) elif norm_layer == nn.SyncBatchNorm: return nn.SyncBatchNorm(num_channels) else: NotImplementedError( f"'norm_layer' should be BatchNorm2d, GroupNorm or SyncBatchNorm, {norm_layer} is not supported." )
def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return torch.nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), nn.SyncBatchNorm(out_planes), )
def __init__(self, in_channels, out_channels, **kwargs): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) if cfg['GN']: self.bn = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-5) elif cfg['syncBN']: self.bn = nn.SyncBatchNorm(out_channels, eps=1e-5) else: self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.fuse_output_proj = nn.Sequential( nn.Conv2d(382, 256, kernel_size=1), nn.SyncBatchNorm(256), nn.ReLU(inplace=True)) # self.cross_atten = CrossAttentionLayer(256) self.cross_fusion = CrossFusionLayer(num_branches=4, num_inchannels=[18, 36, 72, 256])
def __init__(self, in_channels, out_channels, nl_layer=nn.ReLU(inplace=True), norm_type='GN'): super(ConvBottleNeck, self).__init__() self.nl_layer = nl_layer self.in_channels = in_channels self.out_channels = out_channels self.conv1 = nn.Conv2d(in_channels, out_channels // 2, kernel_size=1) self.conv2 = nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=1) if norm_type == 'BN': affine = True # affine = False self.norm1 = nn.BatchNorm2d(out_channels // 2, affine=affine) self.norm2 = nn.BatchNorm2d(out_channels // 2, affine=affine) self.norm3 = nn.BatchNorm2d(out_channels, affine=affine) elif norm_type == 'SYBN': affine = True # affine = False self.norm1 = nn.SyncBatchNorm(out_channels // 2, affine=affine) self.norm2 = nn.SyncBatchNorm(out_channels // 2, affine=affine) self.norm3 = nn.SyncBatchNorm(out_channels, affine=affine) else: self.norm1 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) self.norm2 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2)) self.norm3 = nn.GroupNorm(out_channels // 8, out_channels) if in_channels != out_channels: self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, activation_fn_in_separable_conv=True): super(SeparableConv2d_same, self).__init__() self.relu = nn.ReLU(inplace=False) self.activation_fn_in_separable_conv = activation_fn_in_separable_conv self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, groups=inplanes, bias=bias) self.depthwise_bn = nn.SyncBatchNorm(inplanes, eps=1e-03, momentum=0.0003) if activation_fn_in_separable_conv: self.depthwise_relu = nn.ReLU(inplace=True) self.pointwise = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=bias) self.pointwise_bn = nn.SyncBatchNorm(planes, eps=1e-03, momentum=0.0003) if activation_fn_in_separable_conv: self.pointwise_relu = nn.ReLU(inplace=True)
def __init__(self, num_channels, num_groups, momentum, eps=1e-5, with_pbn=True): super(PBCNorm, self).__init__() self.num_channels = num_channels self.num_groups = num_groups self.eps = eps self.weight = Parameter(torch.ones(1, num_groups, 1)) self.bias = Parameter(torch.zeros(1, num_groups, 1)) self.pbn = nn.SyncBatchNorm(num_channels) self.pbn._specify_ddp_gpu_num(1)