def __init__(self, block, layers, num_classes=10, **kwargs): super(GradInitFixUpResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = 16 self.conv1 = conv3x3(3, 16) # self.bias1 = torch.nn.Parameter(torch.zeros(1)) self.bias1 = GradInitBias() self.relu = torch.nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 16, layers[0]) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) # self.bias2 = nn.Parameter(torch.zeros(1)) self.bias2 = GradInitBias() self.fc = GradInitLinear(64, num_classes) for m in self.modules(): if isinstance(m, GradInitConv2d): # torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # torch.nn.init.normal_(m.weight, mean=0, std=np.sqrt( # 2 / (m.weight.shape[0] * np.prod(m.weight.shape[2:]))) * self.num_layers ** (-0.5)) torch.nn.init.normal_(m.weight, mean=0, std=np.sqrt( 2 / (m.weight.shape[0] * np.prod(m.weight.shape[2:])))) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, GradInitLinear): if m.bias is not None: m.bias.data.zero_()
def __init__(self, vgg_name, use_bn=True, use_pt_init=False, init_multip=1, **kwargs): super(VGG, self).__init__() self.use_bn = use_bn self.conv_names = [] self.bn_names = [] self._make_layers(cfg[vgg_name]) self.classifier = GradInitLinear(512, 10) self.conv_names.append(f'module.classifier.weight') if not use_pt_init: self._initialize_weights() if init_multip != 1: for m in self.modules(): if isinstance(m, GradInitConv2d): m.weight.data *= init_multip if m.bias is not None: m.bias.data *= init_multip elif isinstance(m, GradInitBatchNorm2d): m.weight.data *= init_multip m.bias.data *= init_multip elif isinstance(m, GradInitLinear): m.weight.data *= init_multip m.bias.data *= init_multip
def __init__(self, block, layers, num_classes=10, use_bn=True, use_zero_init=False, init_multip=1, **kwargs): super(GradInitResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = 16 self.conv1 = conv3x3(3, 16, bias=not use_bn) self.use_bn = use_bn if use_bn: self.bn1 = GradInitBatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 16, layers[0]) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = GradInitLinear(64, num_classes) if use_zero_init: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 for m in self.modules(): if isinstance(m, GradInitBasicBlock): nn.init.constant_(m.bn2.weight, 0) else: for m in self.modules(): if isinstance(m, GradInitConv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() elif isinstance(m, GradInitBatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, GradInitLinear): if m.bias is not None: m.bias.data.zero_() if init_multip != 1: for m in self.modules(): if isinstance(m, GradInitConv2d): m.weight.data *= init_multip if m.bias is not None: m.bias.data *= init_multip elif isinstance(m, GradInitBatchNorm2d): m.weight.data *= init_multip m.bias.data *= init_multip elif isinstance(m, GradInitLinear): m.weight.data *= init_multip if m.bias is not None: m.bias.data *= init_multip
class GradInitFixUpResNet(torch.nn.Module): def __init__(self, block, layers, num_classes=10, **kwargs): super(GradInitFixUpResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = 16 self.conv1 = conv3x3(3, 16) # self.bias1 = torch.nn.Parameter(torch.zeros(1)) self.bias1 = GradInitBias() self.relu = torch.nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 16, layers[0]) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) # self.bias2 = nn.Parameter(torch.zeros(1)) self.bias2 = GradInitBias() self.fc = GradInitLinear(64, num_classes) for m in self.modules(): if isinstance(m, GradInitConv2d): # torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # torch.nn.init.normal_(m.weight, mean=0, std=np.sqrt( # 2 / (m.weight.shape[0] * np.prod(m.weight.shape[2:]))) * self.num_layers ** (-0.5)) torch.nn.init.normal_(m.weight, mean=0, std=np.sqrt( 2 / (m.weight.shape[0] * np.prod(m.weight.shape[2:])))) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, GradInitLinear): if m.bias is not None: m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1: downsample = torch.nn.AvgPool2d(1, stride=stride) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes for _ in range(1, blocks): layers.append(block(planes, planes)) return torch.nn.Sequential(*layers) def gradinit(self, mode=True): pass def opt_mode(self, mode=True): self.conv1.opt_mode(mode) self.bias1.opt_mode(mode) for layer in itertools.chain(self.layer1, self.layer2, self.layer3): layer.opt_mode(mode) self.bias2.opt_mode(mode) self.fc.opt_mode(mode) def forward(self, x): x = self.conv1(x) # x = self.relu(x + self.bias1) x = self.relu(self.bias1(x)) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.avgpool(x) x = x.view(x.size(0), -1) # x = self.fc(x + self.bias2) x = self.fc(self.bias2(x)) return x
def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5, num_init_features=24, bn_size=4, drop_rate=0, num_classes=10, small_inputs=True, efficient=False, use_bn=True, use_pt_init=False, init_multip=1., **kwargs): super(GradInitDenseNet, self).__init__() assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' no_bn = not use_bn self.use_bn = use_bn # First convolution if small_inputs: self.features = torch.nn.Sequential(OrderedDict([ ('conv0', GradInitConv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=not use_bn)), ])) else: self.features = torch.nn.Sequential(OrderedDict([ ('conv0', GradInitConv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=not use_bn)), ])) if not no_bn: self.features.add_module('norm0', GradInitBatchNorm2d(num_init_features)) self.features.add_module('relu0', torch.nn.ReLU(inplace=True)) self.features.add_module('pool0', torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock( num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, efficient=efficient, use_bn=use_bn ) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = _Transition(num_input_features=num_features, num_output_features=int(num_features * compression), use_bn=use_bn) self.features.add_module('transition%d' % (i + 1), trans) num_features = int(num_features * compression) # Final batch norm if not no_bn: self.features.add_module('norm_final', GradInitBatchNorm2d(num_features)) # Linear layer self.classifier = GradInitLinear(num_features, num_classes) # Initialization if not use_pt_init: for name, param in self.named_parameters(): if 'conv' in name and 'weight' in name: n = param.size(0) * param.size(2) * param.size(3) param.data.normal_().mul_(math.sqrt(2. / n)) elif 'conv' in name and 'bias' in name: param.data.zero_() elif 'norm' in name and 'weight' in name: param.data.fill_(1) elif 'norm' in name and 'bias' in name: param.data.fill_(0) elif 'classifier' in name and 'bias' in name: param.data.fill_(0) self.gradinit_ = False if init_multip != 1: for param in self.parameters(): param.data *= init_multip
class GradInitDenseNet(torch.nn.Module): r"""Densenet-BC model class, based on `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 3 or 4 ints) - how many layers in each pooling block num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. """ def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5, num_init_features=24, bn_size=4, drop_rate=0, num_classes=10, small_inputs=True, efficient=False, use_bn=True, use_pt_init=False, init_multip=1., **kwargs): super(GradInitDenseNet, self).__init__() assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' no_bn = not use_bn self.use_bn = use_bn # First convolution if small_inputs: self.features = torch.nn.Sequential(OrderedDict([ ('conv0', GradInitConv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=not use_bn)), ])) else: self.features = torch.nn.Sequential(OrderedDict([ ('conv0', GradInitConv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=not use_bn)), ])) if not no_bn: self.features.add_module('norm0', GradInitBatchNorm2d(num_init_features)) self.features.add_module('relu0', torch.nn.ReLU(inplace=True)) self.features.add_module('pool0', torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock( num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, efficient=efficient, use_bn=use_bn ) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = _Transition(num_input_features=num_features, num_output_features=int(num_features * compression), use_bn=use_bn) self.features.add_module('transition%d' % (i + 1), trans) num_features = int(num_features * compression) # Final batch norm if not no_bn: self.features.add_module('norm_final', GradInitBatchNorm2d(num_features)) # Linear layer self.classifier = GradInitLinear(num_features, num_classes) # Initialization if not use_pt_init: for name, param in self.named_parameters(): if 'conv' in name and 'weight' in name: n = param.size(0) * param.size(2) * param.size(3) param.data.normal_().mul_(math.sqrt(2. / n)) elif 'conv' in name and 'bias' in name: param.data.zero_() elif 'norm' in name and 'weight' in name: param.data.fill_(1) elif 'norm' in name and 'bias' in name: param.data.fill_(0) elif 'classifier' in name and 'bias' in name: param.data.fill_(0) self.gradinit_ = False if init_multip != 1: for param in self.parameters(): param.data *= init_multip def gradinit(self, mode): for name, layer in self.features.named_children(): if 'norm' in name or 'denseblock' in name or 'transition' in name: layer.gradinit(mode) def opt_mode(self, mode=True): captured_names = [] for name, layer in self.features.named_children(): if 'norm' in name or 'conv' in name or 'denseblock' in name or 'transition' in name or 'classifier' in name: layer.opt_mode(mode) captured_names.append(name) self.classifier.opt_mode(mode) def get_plotting_names(self): bn_names, conv_names = [], [] for n, p in self.named_parameters(): if (('conv' in n and 'layer' in n) or 'classifier' in n)and 'weight' in n: conv_names.append('module.' + n) elif 'norm' in n and 'weight' in n and 'layer' in n: bn_names.append('module.' + n) # bn_names = sorted(bn_names) # conv_names = sorted(conv_names) if self.use_bn: return {'Linear': conv_names, 'BN': bn_names,} else: return {'Linear': conv_names, } def forward(self, x): features = self.features(x) out = F.relu(features, inplace=True) out = F.adaptive_avg_pool2d(out, (1, 1)) out = torch.flatten(out, 1) out = self.classifier(out) return out
class GradInitResNet(nn.Module): def __init__(self, block, layers, num_classes=10, use_bn=True, use_zero_init=False, init_multip=1, **kwargs): super(GradInitResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = 16 self.conv1 = conv3x3(3, 16, bias=not use_bn) self.use_bn = use_bn if use_bn: self.bn1 = GradInitBatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 16, layers[0]) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = GradInitLinear(64, num_classes) if use_zero_init: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 for m in self.modules(): if isinstance(m, GradInitBasicBlock): nn.init.constant_(m.bn2.weight, 0) else: for m in self.modules(): if isinstance(m, GradInitConv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() elif isinstance(m, GradInitBatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, GradInitLinear): if m.bias is not None: m.bias.data.zero_() if init_multip != 1: for m in self.modules(): if isinstance(m, GradInitConv2d): m.weight.data *= init_multip if m.bias is not None: m.bias.data *= init_multip elif isinstance(m, GradInitBatchNorm2d): m.weight.data *= init_multip m.bias.data *= init_multip elif isinstance(m, GradInitLinear): m.weight.data *= init_multip if m.bias is not None: m.bias.data *= init_multip def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1: if self.use_bn: downsample = nn.Sequential( nn.AvgPool2d(1, stride=stride), GradInitBatchNorm2d(self.inplanes), ) else: downsample = nn.Sequential(nn.AvgPool2d(1, stride=stride)) layers = [] layers.append(block( self.inplanes, planes, stride, downsample, use_bn=self.use_bn)) self.inplanes = planes for _ in range(1, blocks): layers.append(block(planes, planes, use_bn=self.use_bn)) # return nn.ModuleList(layers) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) if self.use_bn: x = self.bn1(x) x = self.relu(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x def gradinit(self, mode=True): if self.use_bn: self.bn1.gradinit(mode) for layer in itertools.chain(self.layer1, self.layer2, self.layer3): layer.gradinit(mode=mode) def opt_mode(self, mode=True): self.conv1.opt_mode(mode) if self.use_bn: self.bn1.opt_mode(mode) for layer in itertools.chain(self.layer1, self.layer2, self.layer3): layer.opt_mode(mode) self.fc.opt_mode(mode) def get_plotting_names(self): bn_names, conv_names = [], [] for n, p in self.named_parameters(): if (('conv' in n and 'layer' in n) or 'fc' in n)and 'weight' in n: conv_names.append('module.' + n) elif 'bn' in n and 'weight' in n and 'layer' in n: bn_names.append('module.' + n) if self.use_bn: return {'Linear': conv_names, 'BN': bn_names,} else: return {'Linear': conv_names, }
class VGG(torch.nn.Module): def __init__(self, vgg_name, use_bn=True, use_pt_init=False, init_multip=1, **kwargs): super(VGG, self).__init__() self.use_bn = use_bn self.conv_names = [] self.bn_names = [] self._make_layers(cfg[vgg_name]) self.classifier = GradInitLinear(512, 10) self.conv_names.append(f'module.classifier.weight') if not use_pt_init: self._initialize_weights() if init_multip != 1: for m in self.modules(): if isinstance(m, GradInitConv2d): m.weight.data *= init_multip if m.bias is not None: m.bias.data *= init_multip elif isinstance(m, GradInitBatchNorm2d): m.weight.data *= init_multip m.bias.data *= init_multip elif isinstance(m, GradInitLinear): m.weight.data *= init_multip m.bias.data *= init_multip def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) out = self.classifier(out) return out def _make_layers(self, cfg): # layers = [] in_channels = 3 pool_num, block_num = 0, 0 self.features = torch.nn.Sequential(OrderedDict([])) for x in cfg: if x == 'M': self.features.add_module( f'pool{pool_num}', torch.nn.MaxPool2d(kernel_size=2, stride=2)) pool_num += 1 else: self.features.add_module( f'conv{block_num}', GradInitConv2d(in_channels, x, kernel_size=3, padding=1)) if self.use_bn: self.features.add_module(f'bn{block_num}', GradInitBatchNorm2d(x)) self.features.add_module(f'relu{block_num}', torch.nn.ReLU(inplace=True)) in_channels = x self.conv_names.append( f'module.features.conv{block_num}.weight') self.bn_names.append(f'module.features.bn{block_num}.weight') block_num += 1 self.add_module('global_pool', torch.nn.AvgPool2d(kernel_size=1, stride=1)) def gradinit(self, mode=True): for name, layer in self.features.named_children(): if 'bn' in name: layer.gradinit(mode) def opt_mode(self, mode=True): for name, layer in self.features.named_children(): if 'norm' in name or 'conv' in name: layer.opt_mode(mode) self.classifier.opt_mode(mode) def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, GradInitConv2d): torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: torch.nn.init.constant_(m.bias, 0) elif isinstance(m, GradInitBatchNorm2d): torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0) elif isinstance(m, GradInitLinear): torch.nn.init.normal_(m.weight, 0, 0.01) torch.nn.init.constant_(m.bias, 0) def get_plotting_names(self): if self.use_bn: return { 'Linear': self.conv_names, 'BN': self.bn_names, } else: return { 'Linear': self.conv_names, }