예제 #1
0
파일: tvm_vggs.py 프로젝트: zhjpqq/scalenet
def vgg13_bn(pretrained=False, **kwargs):
    """VGG 13-layer model (configuration "B") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
    return model
예제 #2
0
    def __init__(self, vgg_model: torchvision_models.VGG):
        super().__init__()

        self.normalizer = ImageNetNormalizer()
        self.model = vgg_model.eval()

        self.layer1 = nn.Sequential(self.model.features[:4])
        self.layer2 = nn.Sequential(self.model.features[4:9])
        self.layer3 = nn.Sequential(self.model.features[9:16])
        self.layer4 = nn.Sequential(self.model.features[16:23])
        self.layer5 = nn.Sequential(self.model.features[23:30])
예제 #3
0
def get_model_by_name(model_name) -> Module:
    if model_name == "vgg":
        return VGG(vgg11_bn().features, num_classes=2)
    if model_name == "mobilenetv2":
        return MobileNetV2(num_classes=2)
    if model_name == "simple":
        return SimpleNetwork()
    if model_name == "squeezenet":
        return SqueezeNet(version=1.1, num_classes=2)

    raise Exception(f"Invalid model name: {model_name}.")
예제 #4
0
파일: loss.py 프로젝트: qgking/MMT-PSM
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg = VGG(make_layers([
            64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
            512, 512, 512, 'M'
        ]),
                  init_weights=False)
        vgg_weight = torch.load('vgg16-397923af.pth')
        vgg.load_state_dict(vgg_weight)
        self.mean = torch.tensor([0.485, 0.456, 0.406], ).cuda()
        self.mean = self.mean.view(1, 3, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).cuda()
        self.std = self.std.view(1, 3, 1, 1)
        vgg_pretrained_features = vgg.features
        del vgg_weight,
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5

        self.slice1.add_module('0', vgg_pretrained_features[0])
        self.slice2.add_module('1', vgg_pretrained_features[1])
        self.slice3.add_module('2', vgg_pretrained_features[2])
        self.slice4.add_module('3', vgg_pretrained_features[3])
        #
        # for x in range(4):
        #     self.slice1.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(4, 9):
        #     self.slice2.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(9, 16):
        #     self.slice3.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(16, 23):
        #     self.slice4.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(23, 30):
        #     self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
예제 #5
0
파일: tvm_vggs.py 프로젝트: zhjpqq/scalenet
def vgg16(pretrained=False, model_path=None, **kwargs):
    """VGG 16-layer model (configuration "D")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['D']), **kwargs)
    if pretrained:
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
    return model
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_before.pth',
                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    ntk_matrix_before = batch_wise_ntk(net,
                                       trainloader,
                                       samplesize=args.sampling)
    plt.imshow(ntk_matrix_before)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_BEFORE.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_before_norm = np.linalg.norm(ntk_matrix_before.flatten())
    print(
        f'The total norm of the NTK sample before training is {ntk_matrix_before_norm:.2f}'
    )
    param_norm_before = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}')

    if args.pdist:
        pdist_init, cos_init, prod_init = batch_feature_correlations(
            trainloader)
        pdist_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_init])
        cos_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_init])
        prod_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_init])
        print(
            f'The total norm of feature distances before training is {pdist_init_norm:.2f}'
        )
        print(
            f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}'
        )
        print(
            f'The total norm of feature inner product before training is {prod_init_norm:.2f}'
        )

        save_plot(pdist_init, trainloader, name='pdist_before_training')
        save_plot(cos_init, trainloader, name='cosine_before_training')
        save_plot(prod_init, trainloader, name='prod_before_training')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_after.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_after.pth'
        dl.train(net,
                 optimizer,
                 scheduler,
                 loss_fn,
                 trainloader,
                 config,
                 path=None,
                 dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
        else:
            print(f'Would save to {path}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    param_norm_after = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}')

    ntk_matrix_after = batch_wise_ntk(net,
                                      trainloader,
                                      samplesize=args.sampling)
    plt.imshow(ntk_matrix_after)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_AFTER.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_after_norm = np.linalg.norm(ntk_matrix_after.flatten())
    print(
        f'The total norm of the NTK sample after training is {ntk_matrix_after_norm:.2f}'
    )

    ntk_matrix_diff = np.abs(ntk_matrix_before - ntk_matrix_after)
    plt.imshow(ntk_matrix_diff)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_DIFF.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_diff_norm = np.linalg.norm(ntk_matrix_diff.flatten())
    print(
        f'The total norm of the NTK sample diff is {ntk_matrix_diff_norm:.2f}')

    ntk_matrix_rdiff = np.abs(ntk_matrix_before - ntk_matrix_after) / (
        np.abs(ntk_matrix_before) + 1e-4)
    plt.imshow(ntk_matrix_rdiff)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_RDIFF.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_rdiff_norm = np.linalg.norm(ntk_matrix_rdiff.flatten())
    print(
        f'The total norm of the NTK sample relative diff is {ntk_matrix_rdiff_norm:.2f}'
    )

    n1_mean = np.mean(ntk_matrix_before)
    n2_mean = np.mean(ntk_matrix_after)
    matrix_corr = (ntk_matrix_before - n1_mean) * (ntk_matrix_after - n2_mean) / \
        np.std(ntk_matrix_before) / np.std(ntk_matrix_after)
    plt.imshow(matrix_corr)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png',
                bbox_inches='tight',
                dpi=1200)
    corr_coeff = np.mean(matrix_corr)
    print(
        f'The Correlation coefficient of the NTK sample before and after training is {corr_coeff:.2f}'
    )

    matrix_sim = (ntk_matrix_before * ntk_matrix_after) / \
        np.sqrt(np.sum(ntk_matrix_before**2) * np.sum(ntk_matrix_after**2))
    plt.imshow(matrix_corr)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png',
                bbox_inches='tight',
                dpi=1200)
    corr_tom = np.sum(matrix_sim)
    print(
        f'The Similarity coefficient of the NTK sample before and after training is {corr_tom:.2f}'
    )

    save_output(args.table_path,
                name='ntk',
                width=config['width'],
                num_params=num_params,
                before_norm=ntk_matrix_before_norm,
                after_norm=ntk_matrix_after_norm,
                diff_norm=ntk_matrix_diff_norm,
                rdiff_norm=ntk_matrix_rdiff_norm,
                param_norm_before=param_norm_before,
                param_norm_after=param_norm_after,
                corr_coeff=corr_coeff,
                corr_tom=corr_tom)

    if args.pdist:
        # Check feature maps after training
        pdist_after, cos_after, prod_after = batch_feature_correlations(
            trainloader)

        pdist_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_after])
        cos_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_after])
        prod_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_after])
        print(
            f'The total norm of feature distances after training is {pdist_after_norm:.2f}'
        )
        print(
            f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}'
        )
        print(
            f'The total norm of feature inner product after training is {prod_after_norm:.2f}'
        )

        save_plot(pdist_after, trainloader, name='pdist_after_training')
        save_plot(cos_after, trainloader, name='cosine_after_training')
        save_plot(prod_after, trainloader, name='prod_after_training')

        # Check feature map differences
        pdist_ndiff = [
            np.abs(co1 - co2) / pdist_init_norm
            for co1, co2 in zip(pdist_init, pdist_after)
        ]
        cos_ndiff = [
            np.abs(co1 - co2) / cos_init_norm
            for co1, co2 in zip(cos_init, cos_after)
        ]
        prod_ndiff = [
            np.abs(co1 - co2) / prod_init_norm
            for co1, co2 in zip(prod_init, prod_after)
        ]

        pdist_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff])
        cos_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_ndiff])
        prod_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_ndiff])
        print(
            f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}'
        )
        print(
            f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}'
        )
        print(
            f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}'
        )

        save_plot(pdist_ndiff, trainloader, name='pdist_ndiff')
        save_plot(cos_ndiff, trainloader, name='cosine_ndiff')
        save_plot(prod_ndiff, trainloader, name='prod_ndiff')

        # Check feature map differences
        pdist_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(pdist_init, pdist_after)
        ]
        cos_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(cos_init, cos_after)
        ]
        prod_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(prod_init, prod_after)
        ]

        pdist_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff])
        cos_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_rdiff])
        prod_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_rdiff])
        print(
            f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}'
        )
        print(
            f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}'
        )
        print(
            f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}'
        )

        save_plot(pdist_rdiff, trainloader, name='pdist_rdiff')
        save_plot(cos_rdiff, trainloader, name='cosine_rdiff')
        save_plot(prod_rdiff, trainloader, name='prod_rdiff')

        save_output(args.table_path,
                    'pdist',
                    width=config['width'],
                    num_params=num_params,
                    pdist_init_norm=pdist_init_norm,
                    pdist_after_norm=pdist_after_norm,
                    pdist_ndiff_norm=pdist_ndiff_norm,
                    pdist_rdiff_norm=pdist_rdiff_norm,
                    cos_init_norm=pdist_init_norm,
                    cos_after_norm=pdist_after_norm,
                    cos_ndiff_norm=pdist_ndiff_norm,
                    cos_rdiff_norm=cos_rdiff_norm,
                    prod_init_norm=pdist_init_norm,
                    prod_after_norm=pdist_after_norm,
                    prod_ndiff_norm=pdist_ndiff_norm,
                    prod_rdiff_norm=prod_rdiff_norm)

    # Save raw data
    # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init,
    #                pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after,
    #                pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff,
    #                pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff)
    # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth'
    # torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_before.pth',
                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')
    param_norm_before = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}')

    net_init = [p.detach().clone() for p in net.parameters()]

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()
    analyze_model(net, trainloader, testloader, loss_fn, config)
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_after.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        print(repr(e))
        print('Could not find model data ... aborting ...')
        return
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    param_norm_after = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}')

    change_total = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_total += (p1 - p2).detach().pow(2).sum()
    change_total = change_total.sqrt().cpu().numpy()

    change_rel = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_rel += (p1 - p2).detach().pow(2).mean()
    change_rel = change_rel.sqrt().cpu().numpy()

    change_nrmsum = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_nrmsum += (p1 - p2).norm()
    change_nrmsum = change_nrmsum.cpu().numpy()

    # Analyze results
    acc_train, acc_test, loss_train, loss_trainw, grd_train = analyze_model(
        net, trainloader, testloader, loss_fn, config)

    save_output(args.table_path,
                name='ntk_stats',
                width=config['width'],
                num_params=num_params,
                acc_train=acc_train,
                acc_test=acc_test,
                loss_train=loss_train,
                loss_trainw=loss_trainw,
                grd_train=grd_train,
                param_norm_before=param_norm_before,
                param_norm_after=param_norm_after,
                change_total=change_total,
                change_rel=change_rel,
                change_nrmsum=change_nrmsum)

    # Save raw data
    # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init,
    #                pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after,
    #                pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff,
    #                pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff)
    # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth'
    # torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
예제 #8
0
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(OrderedDict([
                                 ('flatten', torch.nn.Flatten()),
                                 ('linear0', torch.nn.Linear(3072, config['width'])),
                                 ('relu0', torch.nn.ReLU()),
                                 ('linear1', torch.nn.Linear(config['width'], config['width'])),
                                 ('relu1', torch.nn.ReLU()),
                                 ('linear2', torch.nn.Linear(config['width'], config['width'])),
                                 ('relu2', torch.nn.ReLU()),
                                 ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(OrderedDict([
                                 ('flatten', torch.nn.Flatten()),
                                 ('linear0', torch.nn.Linear(3072, config['width'])),
                                 ('relu0', torch.nn.ReLU()),
                                 ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(OrderedDict([
                                  ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)),
                                  ('relu0', torch.nn.ReLU()),
                                  # ('pool0', torch.nn.MaxPool2d(3)),
                                  ('conv1', torch.nn.Conv2d(1 * config['width'],
                                                            2 * config['width'], kernel_size=3, padding=1)),
                                  ('relu1', torch.nn.ReLU()),
                                  #  ('pool1', torch.nn.MaxPool2d(3)),
                                  ('conv2', torch.nn.Conv2d(2 * config['width'],
                                                            2 * config['width'], kernel_size=3, padding=1)),
                                  ('relu2', torch.nn.ReLU()),
                                  # ('pool2', torch.nn.MaxPool2d(3)),
                                  ('conv3', torch.nn.Conv2d(2 * config['width'],
                                                            4 * config['width'], kernel_size=3, padding=1)),
                                  ('relu3', torch.nn.ReLU()),
                                  ('pool3', torch.nn.MaxPool2d(3)),
                                  ('conv4', torch.nn.Conv2d(4 * config['width'],
                                                            4 * config['width'], kernel_size=3, padding=1)),
                                  ('relu4', torch.nn.ReLU()),
                                  ('pool4', torch.nn.MaxPool2d(3)),
                                  ('flatten', torch.nn.Flatten()),
                                  ('linear', torch.nn.Linear(36 * config['width'], 10))
                                  ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth',
                                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
          f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth',
                                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth'
        dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
        else:
            print(f'Would save to {path}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    save_output(args.table_path, name='ntk', width=config['width'], num_params=num_params,
                before_norm=ntk_matrix_before_norm, after_norm=ntk_matrix_after_norm,
                diff_norm=ntk_matrix_diff_norm, rdiff_norm=ntk_matrix_rdiff_norm,
                param_norm_before=param_norm_before, param_norm_after=param_norm_after,
                corr_coeff=corr_coeff, corr_tom=corr_tom)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    def batch_feature_correlations(dataloader, device=torch.device('cpu')):
        net.eval()
        net.to(device)
        dist_maps = list()
        cosine_maps = list()
        prod_maps = list()
        hooks = []

        def batch_wise_feature_correlation(self, input, output):
            feat_vec = input[0].detach().view(dataloader.batch_size, -1)
            dist_maps.append(
                torch.cdist(feat_vec, feat_vec, 2).detach().cpu().numpy())

            cosine_map = np.empty(
                (dataloader.batch_size, dataloader.batch_size))
            prod_map = np.empty((dataloader.batch_size, dataloader.batch_size))
            for row in range(dataloader.batch_size):
                cosine_map[row, :] = torch.nn.functional.cosine_similarity(
                    feat_vec[row:row + 1, :], feat_vec, dim=1,
                    eps=1e-8).detach().cpu().numpy()
                prod_map[row, :] = torch.mean(feat_vec[row:row + 1, :] *
                                              feat_vec,
                                              dim=1).detach().cpu().numpy()
            cosine_maps.append(cosine_map)
            prod_maps.append(prod_map)

        if isinstance(net, torch.nn.DataParallel):
            hooks.append(
                net.module.linear.register_forward_hook(
                    batch_wise_feature_correlation))
        else:
            if args.net in ['MLP', 'TwoLP']:
                hooks.append(
                    net.linear3.register_forward_hook(
                        batch_wise_feature_correlation))
            elif args.net in ['VGG', 'MobileNetV2']:
                hooks.append(
                    net.classifier.register_forward_hook(
                        batch_wise_feature_correlation))
            else:
                hooks.append(
                    net.linear.register_forward_hook(
                        batch_wise_feature_correlation))

        for inputs, _ in dataloader:
            outputs = net(inputs.to(device))
            if args.dryrun:
                break

        for hook in hooks:
            hook.remove()

        return dist_maps, cosine_maps, prod_maps

    pdist_init, cos_init, prod_init = batch_feature_correlations(trainloader)
    pdist_init_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_init])
    cos_init_norm = np.mean([np.linalg.norm(cm.flatten()) for cm in cos_init])
    prod_init_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_init])
    print(
        f'The total norm of feature distances before training is {pdist_init_norm:.2f}'
    )
    print(
        f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}'
    )
    print(
        f'The total norm of feature inner product before training is {prod_init_norm:.2f}'
    )

    save_plot(pdist_init, trainloader, name='pdist_before_training')
    save_plot(cos_init, trainloader, name='cosine_before_training')
    save_plot(prod_init, trainloader, name='prod_before_training')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '.pth'
        dl.train(net,
                 optimizer,
                 scheduler,
                 loss_fn,
                 trainloader,
                 config,
                 path=None,
                 dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    # Check feature maps after training
    pdist_after, cos_after, prod_after = batch_feature_correlations(
        trainloader)

    pdist_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_after])
    cos_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_after])
    prod_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_after])
    print(
        f'The total norm of feature distances after training is {pdist_after_norm:.2f}'
    )
    print(
        f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}'
    )
    print(
        f'The total norm of feature inner product after training is {prod_after_norm:.2f}'
    )

    save_plot(pdist_after, trainloader, name='pdist_after_training')
    save_plot(cos_after, trainloader, name='cosine_after_training')
    save_plot(prod_after, trainloader, name='prod_after_training')

    # Check feature map differences
    pdist_ndiff = [
        np.abs(co1 - co2) / pdist_init_norm
        for co1, co2 in zip(pdist_init, pdist_after)
    ]
    cos_ndiff = [
        np.abs(co1 - co2) / cos_init_norm
        for co1, co2 in zip(cos_init, cos_after)
    ]
    prod_ndiff = [
        np.abs(co1 - co2) / prod_init_norm
        for co1, co2 in zip(prod_init, prod_after)
    ]

    pdist_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff])
    cos_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_ndiff])
    prod_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_ndiff])
    print(
        f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}'
    )
    print(
        f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}'
    )
    print(
        f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}'
    )

    save_plot(pdist_ndiff, trainloader, name='pdist_ndiff')
    save_plot(cos_ndiff, trainloader, name='cosine_ndiff')
    save_plot(prod_ndiff, trainloader, name='prod_ndiff')

    # Check feature map differences
    pdist_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(pdist_init, pdist_after)
    ]
    cos_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(cos_init, cos_after)
    ]
    prod_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(prod_init, prod_after)
    ]

    pdist_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff])
    cos_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_rdiff])
    prod_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_rdiff])
    print(
        f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}'
    )
    print(
        f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}'
    )
    print(
        f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}'
    )

    save_plot(pdist_rdiff, trainloader, name='pdist_rdiff')
    save_plot(cos_rdiff, trainloader, name='cosine_rdiff')
    save_plot(prod_rdiff, trainloader, name='prod_rdiff')

    save_output(args.table_path,
                width=config['width'],
                num_params=num_params,
                pdist_init_norm=pdist_init_norm,
                pdist_after_norm=pdist_after_norm,
                pdist_ndiff_norm=pdist_ndiff_norm,
                pdist_rdiff_norm=pdist_rdiff_norm,
                cos_init_norm=pdist_init_norm,
                cos_after_norm=pdist_after_norm,
                cos_ndiff_norm=pdist_ndiff_norm,
                cos_rdiff_norm=cos_rdiff_norm,
                prod_init_norm=pdist_init_norm,
                prod_after_norm=pdist_after_norm,
                prod_ndiff_norm=pdist_ndiff_norm,
                prod_rdiff_norm=prod_rdiff_norm)

    # Save raw data
    raw_pkg = dict(pdist_init=pdist_init,
                   cos_init=cos_init,
                   prod_init=prod_init,
                   pdist_after=pdist_after,
                   cos_after=cos_after,
                   prod_after=prod_after,
                   pdist_ndiff=pdist_ndiff,
                   cos_ndiff=cos_ndiff,
                   prod_ndiff=prod_ndiff,
                   pdist_rdiff=pdist_rdiff,
                   cos_rdiff=cos_rdiff,
                   prod_rdiff=prod_rdiff)
    path = config['path'] + 'Cifar10_' + args.net + str(
        config["width"]) + '_rawmaps.pth'
    torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')