def evaluation_procedure(config):
    """Train model and evaluate eigenvalues with given configuration."""
    # Setup data
    augmentations = False
    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=augmentations,
                                             normalize=True,
                                             shuffle=False)

    if config['model'] == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, 2048)),
                         ('relu0', torch.nn.ReLU()),
                         ('linear1', torch.nn.Linear(2048, 2048)),
                         ('relu1', torch.nn.ReLU()),
                         ('linear2', torch.nn.Linear(2048, 1024)),
                         ('relu2', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(1024, 10))]))
    elif config['model'] == 'ResNet':
        net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10)
    else:
        raise NotImplementedError()

    linear_classifier = torch.nn.Sequential(torch.nn.Flatten(),
                                            torch.nn.Linear(3072, 10))

    linear_classifier.to(**config['setup'])
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)
        linear_classifier = torch.nn.DataParallel(linear_classifier)
    net.eval()

    # Optimizer and loss
    optimizer = torch.optim.SGD(linear_classifier.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    config['epochs'] = config['epochs_linear']
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[50, 75, 85, 95], gamma=0.1)
    loss_fn = torch.nn.CrossEntropyLoss()

    # Check initial model
    analyze_model(linear_classifier, trainloader, testloader, loss_fn, config)

    linear_classifier.to(**config['setup'])
    net.to(**config['setup'])
    # Train
    print('Starting training linear classifier ...')
    dl.train(linear_classifier,
             optimizer,
             scheduler,
             loss_fn,
             trainloader,
             config,
             dryrun=args.dryrun)
    # Analyze results
    print('----Results after training linear classifier ------------')
    analyze_model(linear_classifier, trainloader, testloader, loss_fn, config)
    for name, param in linear_classifier.named_parameters():
        dprint(name, param)
        param.requires_grad = False
    # Check full model
    print('----Distill learned classifier onto network ------------')
    config['epochs'] = config['epochs_distill']
    loss_distill = torch.nn.KLDivLoss(reduction='batchmean')
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[120, 180, 240], gamma=0.2)
    dl.distill(linear_classifier,
               net,
               optimizer,
               scheduler,
               loss_distill,
               trainloader,
               config,
               dryrun=args.dryrun)

    # Analyze results
    analyze_model(net, trainloader, testloader, loss_fn, config)
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_before.pth',
                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    ntk_matrix_before = batch_wise_ntk(net,
                                       trainloader,
                                       samplesize=args.sampling)
    plt.imshow(ntk_matrix_before)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_BEFORE.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_before_norm = np.linalg.norm(ntk_matrix_before.flatten())
    print(
        f'The total norm of the NTK sample before training is {ntk_matrix_before_norm:.2f}'
    )
    param_norm_before = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}')

    if args.pdist:
        pdist_init, cos_init, prod_init = batch_feature_correlations(
            trainloader)
        pdist_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_init])
        cos_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_init])
        prod_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_init])
        print(
            f'The total norm of feature distances before training is {pdist_init_norm:.2f}'
        )
        print(
            f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}'
        )
        print(
            f'The total norm of feature inner product before training is {prod_init_norm:.2f}'
        )

        save_plot(pdist_init, trainloader, name='pdist_before_training')
        save_plot(cos_init, trainloader, name='cosine_before_training')
        save_plot(prod_init, trainloader, name='prod_before_training')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_after.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_after.pth'
        dl.train(net,
                 optimizer,
                 scheduler,
                 loss_fn,
                 trainloader,
                 config,
                 path=None,
                 dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
        else:
            print(f'Would save to {path}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    param_norm_after = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}')

    ntk_matrix_after = batch_wise_ntk(net,
                                      trainloader,
                                      samplesize=args.sampling)
    plt.imshow(ntk_matrix_after)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_AFTER.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_after_norm = np.linalg.norm(ntk_matrix_after.flatten())
    print(
        f'The total norm of the NTK sample after training is {ntk_matrix_after_norm:.2f}'
    )

    ntk_matrix_diff = np.abs(ntk_matrix_before - ntk_matrix_after)
    plt.imshow(ntk_matrix_diff)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_DIFF.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_diff_norm = np.linalg.norm(ntk_matrix_diff.flatten())
    print(
        f'The total norm of the NTK sample diff is {ntk_matrix_diff_norm:.2f}')

    ntk_matrix_rdiff = np.abs(ntk_matrix_before - ntk_matrix_after) / (
        np.abs(ntk_matrix_before) + 1e-4)
    plt.imshow(ntk_matrix_rdiff)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_RDIFF.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_rdiff_norm = np.linalg.norm(ntk_matrix_rdiff.flatten())
    print(
        f'The total norm of the NTK sample relative diff is {ntk_matrix_rdiff_norm:.2f}'
    )

    n1_mean = np.mean(ntk_matrix_before)
    n2_mean = np.mean(ntk_matrix_after)
    matrix_corr = (ntk_matrix_before - n1_mean) * (ntk_matrix_after - n2_mean) / \
        np.std(ntk_matrix_before) / np.std(ntk_matrix_after)
    plt.imshow(matrix_corr)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png',
                bbox_inches='tight',
                dpi=1200)
    corr_coeff = np.mean(matrix_corr)
    print(
        f'The Correlation coefficient of the NTK sample before and after training is {corr_coeff:.2f}'
    )

    matrix_sim = (ntk_matrix_before * ntk_matrix_after) / \
        np.sqrt(np.sum(ntk_matrix_before**2) * np.sum(ntk_matrix_after**2))
    plt.imshow(matrix_corr)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png',
                bbox_inches='tight',
                dpi=1200)
    corr_tom = np.sum(matrix_sim)
    print(
        f'The Similarity coefficient of the NTK sample before and after training is {corr_tom:.2f}'
    )

    save_output(args.table_path,
                name='ntk',
                width=config['width'],
                num_params=num_params,
                before_norm=ntk_matrix_before_norm,
                after_norm=ntk_matrix_after_norm,
                diff_norm=ntk_matrix_diff_norm,
                rdiff_norm=ntk_matrix_rdiff_norm,
                param_norm_before=param_norm_before,
                param_norm_after=param_norm_after,
                corr_coeff=corr_coeff,
                corr_tom=corr_tom)

    if args.pdist:
        # Check feature maps after training
        pdist_after, cos_after, prod_after = batch_feature_correlations(
            trainloader)

        pdist_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_after])
        cos_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_after])
        prod_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_after])
        print(
            f'The total norm of feature distances after training is {pdist_after_norm:.2f}'
        )
        print(
            f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}'
        )
        print(
            f'The total norm of feature inner product after training is {prod_after_norm:.2f}'
        )

        save_plot(pdist_after, trainloader, name='pdist_after_training')
        save_plot(cos_after, trainloader, name='cosine_after_training')
        save_plot(prod_after, trainloader, name='prod_after_training')

        # Check feature map differences
        pdist_ndiff = [
            np.abs(co1 - co2) / pdist_init_norm
            for co1, co2 in zip(pdist_init, pdist_after)
        ]
        cos_ndiff = [
            np.abs(co1 - co2) / cos_init_norm
            for co1, co2 in zip(cos_init, cos_after)
        ]
        prod_ndiff = [
            np.abs(co1 - co2) / prod_init_norm
            for co1, co2 in zip(prod_init, prod_after)
        ]

        pdist_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff])
        cos_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_ndiff])
        prod_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_ndiff])
        print(
            f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}'
        )
        print(
            f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}'
        )
        print(
            f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}'
        )

        save_plot(pdist_ndiff, trainloader, name='pdist_ndiff')
        save_plot(cos_ndiff, trainloader, name='cosine_ndiff')
        save_plot(prod_ndiff, trainloader, name='prod_ndiff')

        # Check feature map differences
        pdist_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(pdist_init, pdist_after)
        ]
        cos_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(cos_init, cos_after)
        ]
        prod_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(prod_init, prod_after)
        ]

        pdist_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff])
        cos_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_rdiff])
        prod_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_rdiff])
        print(
            f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}'
        )
        print(
            f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}'
        )
        print(
            f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}'
        )

        save_plot(pdist_rdiff, trainloader, name='pdist_rdiff')
        save_plot(cos_rdiff, trainloader, name='cosine_rdiff')
        save_plot(prod_rdiff, trainloader, name='prod_rdiff')

        save_output(args.table_path,
                    'pdist',
                    width=config['width'],
                    num_params=num_params,
                    pdist_init_norm=pdist_init_norm,
                    pdist_after_norm=pdist_after_norm,
                    pdist_ndiff_norm=pdist_ndiff_norm,
                    pdist_rdiff_norm=pdist_rdiff_norm,
                    cos_init_norm=pdist_init_norm,
                    cos_after_norm=pdist_after_norm,
                    cos_ndiff_norm=pdist_ndiff_norm,
                    cos_rdiff_norm=cos_rdiff_norm,
                    prod_init_norm=pdist_init_norm,
                    prod_after_norm=pdist_after_norm,
                    prod_ndiff_norm=pdist_ndiff_norm,
                    prod_rdiff_norm=prod_rdiff_norm)

    # Save raw data
    # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init,
    #                pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after,
    #                pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff,
    #                pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff)
    # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth'
    # torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
def evaluation_procedure(config):
    """Train model and evaluate eigenvalues with given configuration."""
    # Setup data
    augmentations = False
    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=augmentations,
                                             normalize=args.normalize,
                                             shuffle=False)

    class Restrict(torch.nn.Module):
        def __init__(self, subrank):
            super(Restrict, self).__init__()
            self.shape = int(subrank)

        def forward(self, x):
            return x[:, :self.shape]

    if config['model'] == 'MLP':

        fullnet = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, args.width)),
                         ('relu0', torch.nn.ReLU()),
                         ('linear1', torch.nn.Linear(args.width, args.width)),
                         ('relu1', torch.nn.ReLU()),
                         ('linear2', torch.nn.Linear(args.width, args.width)),
                         ('relu2', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(args.width, 10))]))
        # breakpoint()
        subnet = torch.nn.Sequential(
            torch.nn.Flatten(), Restrict(args.width),
            *list(fullnet.children())[-config['subnet_depth']:])
    else:
        raise NotImplementedError()

    subnet.to(**config['setup'])
    fullnet.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        subnet = torch.nn.DataParallel(subnet)
        fullnet = torch.nn.DataParallel(fullnet)
    subnet.eval()
    fullnet.eval()

    # Optimizer and loss
    optimizer = torch.optim.SGD(subnet.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[50, 200, 400, 600, 700], gamma=0.1)
    loss_fn = torch.nn.CrossEntropyLoss()

    # Check initial model
    analyze_model(subnet, trainloader, testloader, loss_fn, config)

    subnet.to(**config['setup'])
    fullnet.to(**config['setup'])
    # Train
    print(
        'Starting training subnet ...........................................')
    dl.train(subnet,
             optimizer,
             scheduler,
             loss_fn,
             trainloader,
             config,
             dryrun=args.dryrun)

    # Analyze results
    print(
        '----Results after training subnet -----------------------------------------------------------'
    )
    analyze_model(subnet, trainloader, testloader, loss_fn, config)
    for name, param in subnet.named_parameters():
        dprint(name, param)
    # Check full model
    print(
        '----Extend to full model and check local optimality -----------------------------------------'
    )
    # assert all([p1 is  p2 for (p1, p2) in zip(fullnet[-1].parameters(), subnet.parameters())])
    bias_first = True
    bias_offset = 2
    for name, param in fullnet.named_parameters():
        if all([param is not p for p in subnet.parameters()]):
            dprint(f'Currently setting {name}')
            if 'weight' in name:
                torch.nn.init.eye_(param)
                dprint(f'{name} set to Id.')
            elif 'bias' in name:
                if bias_first:
                    torch.nn.init.constant_(param, bias_offset)
                    bias_first = False
                    dprint(f'{name} set to 1.')
                else:
                    torch.nn.init.constant_(param, 0)
                    dprint(f'{name} set to 0.')
                    # if normalize=False, input will be in [0,1] so no bias is necessary
            elif 'conv.weight' in name:
                torch.nn.init.dirac_(param)
                dprint(f'{name} set to dirac.')
        else:
            if 'linear3.bias' in name:
                Axb = subnet(
                    bias_offset *
                    torch.ones(1, 3072, **config['setup'])).detach().squeeze()
                param.data -= Axb - param.data
                dprint(f'{name} set to b - Ax')
    print('Model extended to full model.')
    for name, param in fullnet.named_parameters():
        dprint(name, param)
    # Analyze results
    analyze_model(fullnet, trainloader, testloader, loss_fn, config)
    # Finetune
    print(
        'Finetune full net .................................................')
    config['full_batch'] = False
    optimizer = torch.optim.SGD(subnet.parameters(),
                                lr=1e-4,
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[50, 200, 400, 600, 700], gamma=0.1)
    dl.train(fullnet,
             optimizer,
             scheduler,
             loss_fn,
             trainloader,
             config,
             dryrun=args.dryrun)
    analyze_model(fullnet, trainloader, testloader, loss_fn, config)
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_before.pth',
                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')
    param_norm_before = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}')

    net_init = [p.detach().clone() for p in net.parameters()]

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()
    analyze_model(net, trainloader, testloader, loss_fn, config)
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_after.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        print(repr(e))
        print('Could not find model data ... aborting ...')
        return
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    param_norm_after = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}')

    change_total = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_total += (p1 - p2).detach().pow(2).sum()
    change_total = change_total.sqrt().cpu().numpy()

    change_rel = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_rel += (p1 - p2).detach().pow(2).mean()
    change_rel = change_rel.sqrt().cpu().numpy()

    change_nrmsum = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_nrmsum += (p1 - p2).norm()
    change_nrmsum = change_nrmsum.cpu().numpy()

    # Analyze results
    acc_train, acc_test, loss_train, loss_trainw, grd_train = analyze_model(
        net, trainloader, testloader, loss_fn, config)

    save_output(args.table_path,
                name='ntk_stats',
                width=config['width'],
                num_params=num_params,
                acc_train=acc_train,
                acc_test=acc_test,
                loss_train=loss_train,
                loss_trainw=loss_trainw,
                grd_train=grd_train,
                param_norm_before=param_norm_before,
                param_norm_after=param_norm_after,
                change_total=change_total,
                change_rel=change_rel,
                change_nrmsum=change_nrmsum)

    # Save raw data
    # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init,
    #                pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after,
    #                pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff,
    #                pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff)
    # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth'
    # torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
Beispiel #5
0
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(OrderedDict([
                                 ('flatten', torch.nn.Flatten()),
                                 ('linear0', torch.nn.Linear(3072, config['width'])),
                                 ('relu0', torch.nn.ReLU()),
                                 ('linear1', torch.nn.Linear(config['width'], config['width'])),
                                 ('relu1', torch.nn.ReLU()),
                                 ('linear2', torch.nn.Linear(config['width'], config['width'])),
                                 ('relu2', torch.nn.ReLU()),
                                 ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(OrderedDict([
                                 ('flatten', torch.nn.Flatten()),
                                 ('linear0', torch.nn.Linear(3072, config['width'])),
                                 ('relu0', torch.nn.ReLU()),
                                 ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(OrderedDict([
                                  ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)),
                                  ('relu0', torch.nn.ReLU()),
                                  # ('pool0', torch.nn.MaxPool2d(3)),
                                  ('conv1', torch.nn.Conv2d(1 * config['width'],
                                                            2 * config['width'], kernel_size=3, padding=1)),
                                  ('relu1', torch.nn.ReLU()),
                                  #  ('pool1', torch.nn.MaxPool2d(3)),
                                  ('conv2', torch.nn.Conv2d(2 * config['width'],
                                                            2 * config['width'], kernel_size=3, padding=1)),
                                  ('relu2', torch.nn.ReLU()),
                                  # ('pool2', torch.nn.MaxPool2d(3)),
                                  ('conv3', torch.nn.Conv2d(2 * config['width'],
                                                            4 * config['width'], kernel_size=3, padding=1)),
                                  ('relu3', torch.nn.ReLU()),
                                  ('pool3', torch.nn.MaxPool2d(3)),
                                  ('conv4', torch.nn.Conv2d(4 * config['width'],
                                                            4 * config['width'], kernel_size=3, padding=1)),
                                  ('relu4', torch.nn.ReLU()),
                                  ('pool4', torch.nn.MaxPool2d(3)),
                                  ('flatten', torch.nn.Flatten()),
                                  ('linear', torch.nn.Linear(36 * config['width'], 10))
                                  ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth',
                                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
          f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth',
                                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth'
        dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
        else:
            print(f'Would save to {path}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    save_output(args.table_path, name='ntk', width=config['width'], num_params=num_params,
                before_norm=ntk_matrix_before_norm, after_norm=ntk_matrix_after_norm,
                diff_norm=ntk_matrix_diff_norm, rdiff_norm=ntk_matrix_rdiff_norm,
                param_norm_before=param_norm_before, param_norm_after=param_norm_after,
                corr_coeff=corr_coeff, corr_tom=corr_tom)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
def evaluation_procedure(config):
    """Train model and evaluate eigenvalues with given configuration."""
    # Setup data
    augmentations = False
    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=augmentations)

    # Setup Network
    if config['model'] == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, 2048)),
                         ('relu0', torch.nn.ReLU()),
                         ('linear1', torch.nn.Linear(2048, 2048)),
                         ('relu1', torch.nn.ReLU()),
                         ('linear2', torch.nn.Linear(2048, 1024)),
                         ('relu2', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(1024, 10))]))
    elif config['model'] == 'MLPsmall':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, 256)),
                         ('relu0', torch.nn.ReLU()),
                         ('linear1', torch.nn.Linear(256, 256)),
                         ('relu1', torch.nn.ReLU()),
                         ('linear2', torch.nn.Linear(256, 256)),
                         ('relu2', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(256, 10))]))
    elif config['model'] == 'MLPsmallB':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, 256)),
                         ('relu0', torch.nn.ReLU()),
                         ('bn0', torch.nn.BatchNorm2d(256)),
                         ('linear1', torch.nn.Linear(256, 256)),
                         ('relu1', torch.nn.ReLU()),
                         ('bn0', torch.nn.BatchNorm2d(256)),
                         ('linear2', torch.nn.Linear(256, 256)),
                         ('relu2', torch.nn.ReLU()),
                         ('bn0', torch.nn.BatchNorm2d(256)),
                         ('linear3', torch.nn.Linear(256, 10))]))
    elif config['model'] == 'ResNet':
        net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10)
    elif config['model'] == 'L-MLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, 2048)),
                         ('linear1', torch.nn.Linear(2048, 2048)),
                         ('linear2', torch.nn.Linear(2048, 1024)),
                         ('linear3', torch.nn.Linear(1024, 10))]))
    elif config['model'] == 'L-ResNet':
        net = ResNetLinear(BasicBlockLinear, [2, 2, 2, 2], num_classes=10)

    net.to(**config['setup'])
    net = torch.nn.DataParallel(net)
    net.eval()

    def initialize_net(net, init):
        for name, param in net.named_parameters():
            with torch.no_grad():
                if init == 'default':
                    pass
                elif init == 'zero':
                    param.zero_()
                elif init == 'low_bias':
                    if 'bias' in name:
                        param -= 20
                elif init == 'high_bias':
                    if 'bias' in name:
                        param += 20
                elif init == 'equal':
                    torch.nn.init.constant_(param, 0.001)
                elif init == 'variant_bias':
                    if 'bias' in name:
                        torch.nn.init.uniform_(param, -args.var, args.var)

    initialize_net(net, config['init'])

    # Optimizer and loss
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=args.mom,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[150, 250, 350], gamma=0.1)
    loss_fn = torch.nn.CrossEntropyLoss()

    # Analyze model before training
    analyze_model(net, trainloader, testloader, loss_fn, config)

    # Train
    print('Starting training ...')
    dl.train(net,
             optimizer,
             scheduler,
             loss_fn,
             trainloader,
             config,
             dryrun=args.dryrun)

    # Analyze results
    acc_train, acc_test, loss_train, loss_trainw, grd_train, maxeig, mineig = analyze_model(
        net, trainloader, testloader, loss_fn, config)

    save_output(args.table_path,
                init=config['init'],
                var=args.var,
                acc_train=acc_train,
                acc_test=acc_test,
                loss_train=loss_train,
                loss_trainw=loss_trainw,
                grd_train=grd_train,
                maxeig=maxeig,
                mineig=mineig)
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    def batch_feature_correlations(dataloader, device=torch.device('cpu')):
        net.eval()
        net.to(device)
        dist_maps = list()
        cosine_maps = list()
        prod_maps = list()
        hooks = []

        def batch_wise_feature_correlation(self, input, output):
            feat_vec = input[0].detach().view(dataloader.batch_size, -1)
            dist_maps.append(
                torch.cdist(feat_vec, feat_vec, 2).detach().cpu().numpy())

            cosine_map = np.empty(
                (dataloader.batch_size, dataloader.batch_size))
            prod_map = np.empty((dataloader.batch_size, dataloader.batch_size))
            for row in range(dataloader.batch_size):
                cosine_map[row, :] = torch.nn.functional.cosine_similarity(
                    feat_vec[row:row + 1, :], feat_vec, dim=1,
                    eps=1e-8).detach().cpu().numpy()
                prod_map[row, :] = torch.mean(feat_vec[row:row + 1, :] *
                                              feat_vec,
                                              dim=1).detach().cpu().numpy()
            cosine_maps.append(cosine_map)
            prod_maps.append(prod_map)

        if isinstance(net, torch.nn.DataParallel):
            hooks.append(
                net.module.linear.register_forward_hook(
                    batch_wise_feature_correlation))
        else:
            if args.net in ['MLP', 'TwoLP']:
                hooks.append(
                    net.linear3.register_forward_hook(
                        batch_wise_feature_correlation))
            elif args.net in ['VGG', 'MobileNetV2']:
                hooks.append(
                    net.classifier.register_forward_hook(
                        batch_wise_feature_correlation))
            else:
                hooks.append(
                    net.linear.register_forward_hook(
                        batch_wise_feature_correlation))

        for inputs, _ in dataloader:
            outputs = net(inputs.to(device))
            if args.dryrun:
                break

        for hook in hooks:
            hook.remove()

        return dist_maps, cosine_maps, prod_maps

    pdist_init, cos_init, prod_init = batch_feature_correlations(trainloader)
    pdist_init_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_init])
    cos_init_norm = np.mean([np.linalg.norm(cm.flatten()) for cm in cos_init])
    prod_init_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_init])
    print(
        f'The total norm of feature distances before training is {pdist_init_norm:.2f}'
    )
    print(
        f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}'
    )
    print(
        f'The total norm of feature inner product before training is {prod_init_norm:.2f}'
    )

    save_plot(pdist_init, trainloader, name='pdist_before_training')
    save_plot(cos_init, trainloader, name='cosine_before_training')
    save_plot(prod_init, trainloader, name='prod_before_training')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '.pth'
        dl.train(net,
                 optimizer,
                 scheduler,
                 loss_fn,
                 trainloader,
                 config,
                 path=None,
                 dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    # Check feature maps after training
    pdist_after, cos_after, prod_after = batch_feature_correlations(
        trainloader)

    pdist_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_after])
    cos_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_after])
    prod_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_after])
    print(
        f'The total norm of feature distances after training is {pdist_after_norm:.2f}'
    )
    print(
        f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}'
    )
    print(
        f'The total norm of feature inner product after training is {prod_after_norm:.2f}'
    )

    save_plot(pdist_after, trainloader, name='pdist_after_training')
    save_plot(cos_after, trainloader, name='cosine_after_training')
    save_plot(prod_after, trainloader, name='prod_after_training')

    # Check feature map differences
    pdist_ndiff = [
        np.abs(co1 - co2) / pdist_init_norm
        for co1, co2 in zip(pdist_init, pdist_after)
    ]
    cos_ndiff = [
        np.abs(co1 - co2) / cos_init_norm
        for co1, co2 in zip(cos_init, cos_after)
    ]
    prod_ndiff = [
        np.abs(co1 - co2) / prod_init_norm
        for co1, co2 in zip(prod_init, prod_after)
    ]

    pdist_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff])
    cos_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_ndiff])
    prod_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_ndiff])
    print(
        f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}'
    )
    print(
        f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}'
    )
    print(
        f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}'
    )

    save_plot(pdist_ndiff, trainloader, name='pdist_ndiff')
    save_plot(cos_ndiff, trainloader, name='cosine_ndiff')
    save_plot(prod_ndiff, trainloader, name='prod_ndiff')

    # Check feature map differences
    pdist_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(pdist_init, pdist_after)
    ]
    cos_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(cos_init, cos_after)
    ]
    prod_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(prod_init, prod_after)
    ]

    pdist_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff])
    cos_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_rdiff])
    prod_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_rdiff])
    print(
        f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}'
    )
    print(
        f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}'
    )
    print(
        f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}'
    )

    save_plot(pdist_rdiff, trainloader, name='pdist_rdiff')
    save_plot(cos_rdiff, trainloader, name='cosine_rdiff')
    save_plot(prod_rdiff, trainloader, name='prod_rdiff')

    save_output(args.table_path,
                width=config['width'],
                num_params=num_params,
                pdist_init_norm=pdist_init_norm,
                pdist_after_norm=pdist_after_norm,
                pdist_ndiff_norm=pdist_ndiff_norm,
                pdist_rdiff_norm=pdist_rdiff_norm,
                cos_init_norm=pdist_init_norm,
                cos_after_norm=pdist_after_norm,
                cos_ndiff_norm=pdist_ndiff_norm,
                cos_rdiff_norm=cos_rdiff_norm,
                prod_init_norm=pdist_init_norm,
                prod_after_norm=pdist_after_norm,
                prod_ndiff_norm=pdist_ndiff_norm,
                prod_rdiff_norm=prod_rdiff_norm)

    # Save raw data
    raw_pkg = dict(pdist_init=pdist_init,
                   cos_init=cos_init,
                   prod_init=prod_init,
                   pdist_after=pdist_after,
                   cos_after=cos_after,
                   prod_after=prod_after,
                   pdist_ndiff=pdist_ndiff,
                   cos_ndiff=cos_ndiff,
                   prod_ndiff=prod_ndiff,
                   pdist_rdiff=pdist_rdiff,
                   cos_rdiff=cos_rdiff,
                   prod_rdiff=prod_rdiff)
    path = config['path'] + 'Cifar10_' + args.net + str(
        config["width"]) + '_rawmaps.pth'
    torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')