class GradInitFixUpResNet(torch.nn.Module):

    def __init__(self, block, layers, num_classes=10, **kwargs):
        super(GradInitFixUpResNet, self).__init__()
        self.num_layers = sum(layers)
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        # self.bias1 = torch.nn.Parameter(torch.zeros(1))
        self.bias1 = GradInitBias()
        self.relu = torch.nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        # self.bias2 = nn.Parameter(torch.zeros(1))
        self.bias2 = GradInitBias()
        self.fc = GradInitLinear(64, num_classes)

        for m in self.modules():
            if isinstance(m, GradInitConv2d):
                # torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                # torch.nn.init.normal_(m.weight, mean=0, std=np.sqrt(
                #     2 / (m.weight.shape[0] * np.prod(m.weight.shape[2:]))) * self.num_layers ** (-0.5))
                torch.nn.init.normal_(m.weight, mean=0, std=np.sqrt(
                    2 / (m.weight.shape[0] * np.prod(m.weight.shape[2:]))))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, GradInitLinear):
                if m.bias is not None:
                    m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = torch.nn.AvgPool2d(1, stride=stride)

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(planes, planes))

        return torch.nn.Sequential(*layers)

    def gradinit(self, mode=True):
        pass

    def opt_mode(self, mode=True):
        self.conv1.opt_mode(mode)
        self.bias1.opt_mode(mode)
        for layer in itertools.chain(self.layer1, self.layer2, self.layer3):
            layer.opt_mode(mode)
        self.bias2.opt_mode(mode)
        self.fc.opt_mode(mode)

    def forward(self, x):
        x = self.conv1(x)
        # x = self.relu(x + self.bias1)
        x = self.relu(self.bias1(x))

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        # x = self.fc(x + self.bias2)
        x = self.fc(self.bias2(x))

        return x
class GradInitResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10, use_bn=True, use_zero_init=False, init_multip=1, **kwargs):
        super(GradInitResNet, self).__init__()

        self.num_layers = sum(layers)
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16, bias=not use_bn)
        self.use_bn = use_bn
        if use_bn:
            self.bn1 = GradInitBatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = GradInitLinear(64, num_classes)

        if use_zero_init:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
            # Zero-initialize the last BN in each residual branch,
            # so that the residual branch starts with zeros, and each residual block behaves like an identity.
            # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
            for m in self.modules():
                if isinstance(m, GradInitBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
        else:
            for m in self.modules():
                if isinstance(m, GradInitConv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, GradInitBatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, GradInitLinear):
                    if m.bias is not None:
                        m.bias.data.zero_()

        if init_multip != 1:
            for m in self.modules():
                if isinstance(m, GradInitConv2d):
                    m.weight.data *= init_multip
                    if m.bias is not None:
                        m.bias.data *= init_multip
                elif isinstance(m, GradInitBatchNorm2d):
                    m.weight.data *= init_multip
                    m.bias.data *= init_multip
                elif isinstance(m, GradInitLinear):
                    m.weight.data *= init_multip
                    if m.bias is not None:
                        m.bias.data *= init_multip

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            if self.use_bn:
                downsample = nn.Sequential(
                    nn.AvgPool2d(1, stride=stride),
                    GradInitBatchNorm2d(self.inplanes),
                )
            else:
                downsample = nn.Sequential(nn.AvgPool2d(1, stride=stride))

        layers = []
        layers.append(block(
            self.inplanes, planes, stride, downsample, use_bn=self.use_bn))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(planes, planes, use_bn=self.use_bn))

        # return nn.ModuleList(layers)
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        if self.use_bn:
            x = self.bn1(x)

        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        x = self.fc(x)

        return x

    def gradinit(self, mode=True):
        if self.use_bn:
            self.bn1.gradinit(mode)
        for layer in itertools.chain(self.layer1, self.layer2, self.layer3):
            layer.gradinit(mode=mode)

    def opt_mode(self, mode=True):
        self.conv1.opt_mode(mode)
        if self.use_bn:
            self.bn1.opt_mode(mode)
        for layer in itertools.chain(self.layer1, self.layer2, self.layer3):
            layer.opt_mode(mode)
        self.fc.opt_mode(mode)

    def get_plotting_names(self):
        bn_names, conv_names = [], []
        for n, p in self.named_parameters():
            if (('conv' in n and 'layer' in n) or 'fc' in n)and 'weight' in n:
                conv_names.append('module.' + n)
            elif 'bn' in n and 'weight' in n and 'layer' in n:
                bn_names.append('module.' + n)

        if self.use_bn:
            return {'Linear': conv_names, 'BN': bn_names,}
        else:
            return {'Linear': conv_names, }
Beispiel #3
0
class GradInitDenseNet(torch.nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 3 or 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
            (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger.
        efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower.
    """
    def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5,
                 num_init_features=24, bn_size=4, drop_rate=0,
                 num_classes=10, small_inputs=True, efficient=False, use_bn=True, use_pt_init=False, init_multip=1., **kwargs):

        super(GradInitDenseNet, self).__init__()
        assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1'
        no_bn = not use_bn
        self.use_bn = use_bn

        # First convolution
        if small_inputs:
            self.features = torch.nn.Sequential(OrderedDict([
                ('conv0', GradInitConv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=not use_bn)),
            ]))
        else:
            self.features = torch.nn.Sequential(OrderedDict([
                ('conv0', GradInitConv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=not use_bn)),
            ]))
            if not no_bn:
                self.features.add_module('norm0', GradInitBatchNorm2d(num_init_features))
            self.features.add_module('relu0', torch.nn.ReLU(inplace=True))
            self.features.add_module('pool0', torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
                                                           ceil_mode=False))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                efficient=efficient,
                use_bn=use_bn
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=int(num_features * compression),
                                    use_bn=use_bn)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = int(num_features * compression)

        # Final batch norm
        if not no_bn:
            self.features.add_module('norm_final', GradInitBatchNorm2d(num_features))

        # Linear layer
        self.classifier = GradInitLinear(num_features, num_classes)

        # Initialization
        if not use_pt_init:
            for name, param in self.named_parameters():
                if 'conv' in name and 'weight' in name:
                    n = param.size(0) * param.size(2) * param.size(3)
                    param.data.normal_().mul_(math.sqrt(2. / n))
                elif 'conv' in name and 'bias' in name:
                    param.data.zero_()
                elif 'norm' in name and 'weight' in name:
                    param.data.fill_(1)
                elif 'norm' in name and 'bias' in name:
                    param.data.fill_(0)
                elif 'classifier' in name and 'bias' in name:
                    param.data.fill_(0)

        self.gradinit_ = False

        if init_multip != 1:
            for param in self.parameters():
                param.data *= init_multip

    def gradinit(self, mode):
        for name, layer in self.features.named_children():
            if 'norm' in name or 'denseblock' in name or 'transition' in name:
                layer.gradinit(mode)

    def opt_mode(self, mode=True):
        captured_names = []
        for name, layer in self.features.named_children():
            if 'norm' in name or 'conv' in name or 'denseblock' in name or 'transition' in name or 'classifier' in name:
                layer.opt_mode(mode)
                captured_names.append(name)
        self.classifier.opt_mode(mode)

    def get_plotting_names(self):
        bn_names, conv_names = [], []
        for n, p in self.named_parameters():
            if (('conv' in n and 'layer' in n) or 'classifier' in n)and 'weight' in n:
                conv_names.append('module.' + n)
            elif 'norm' in n and 'weight' in n and 'layer' in n:
                bn_names.append('module.' + n)

        # bn_names = sorted(bn_names)
        # conv_names = sorted(conv_names)
        if self.use_bn:
            return {'Linear': conv_names, 'BN': bn_names,}
        else:
            return {'Linear': conv_names, }

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out
Beispiel #4
0
class VGG(torch.nn.Module):
    def __init__(self,
                 vgg_name,
                 use_bn=True,
                 use_pt_init=False,
                 init_multip=1,
                 **kwargs):
        super(VGG, self).__init__()
        self.use_bn = use_bn
        self.conv_names = []
        self.bn_names = []
        self._make_layers(cfg[vgg_name])
        self.classifier = GradInitLinear(512, 10)
        self.conv_names.append(f'module.classifier.weight')
        if not use_pt_init:
            self._initialize_weights()

        if init_multip != 1:
            for m in self.modules():
                if isinstance(m, GradInitConv2d):
                    m.weight.data *= init_multip
                    if m.bias is not None:
                        m.bias.data *= init_multip
                elif isinstance(m, GradInitBatchNorm2d):
                    m.weight.data *= init_multip
                    m.bias.data *= init_multip
                elif isinstance(m, GradInitLinear):
                    m.weight.data *= init_multip
                    m.bias.data *= init_multip

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        # layers = []
        in_channels = 3
        pool_num, block_num = 0, 0
        self.features = torch.nn.Sequential(OrderedDict([]))
        for x in cfg:
            if x == 'M':
                self.features.add_module(
                    f'pool{pool_num}',
                    torch.nn.MaxPool2d(kernel_size=2, stride=2))
                pool_num += 1
            else:
                self.features.add_module(
                    f'conv{block_num}',
                    GradInitConv2d(in_channels, x, kernel_size=3, padding=1))
                if self.use_bn:
                    self.features.add_module(f'bn{block_num}',
                                             GradInitBatchNorm2d(x))
                self.features.add_module(f'relu{block_num}',
                                         torch.nn.ReLU(inplace=True))
                in_channels = x
                self.conv_names.append(
                    f'module.features.conv{block_num}.weight')
                self.bn_names.append(f'module.features.bn{block_num}.weight')
                block_num += 1

        self.add_module('global_pool',
                        torch.nn.AvgPool2d(kernel_size=1, stride=1))

    def gradinit(self, mode=True):
        for name, layer in self.features.named_children():
            if 'bn' in name:
                layer.gradinit(mode)

    def opt_mode(self, mode=True):
        for name, layer in self.features.named_children():
            if 'norm' in name or 'conv' in name:
                layer.opt_mode(mode)
        self.classifier.opt_mode(mode)

    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, GradInitConv2d):
                torch.nn.init.kaiming_normal_(m.weight,
                                              mode='fan_out',
                                              nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, GradInitBatchNorm2d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, GradInitLinear):
                torch.nn.init.normal_(m.weight, 0, 0.01)
                torch.nn.init.constant_(m.bias, 0)

    def get_plotting_names(self):
        if self.use_bn:
            return {
                'Linear': self.conv_names,
                'BN': self.bn_names,
            }
        else:
            return {
                'Linear': self.conv_names,
            }