def __init__(self, hparams: Params):
        super().__init__()

        self.hparams = hparams
        self.model = MobileNetV2(
            num_classes=10,
            inverted_residual_setting=[
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                # [6, 160, 3, 2],
                # [6, 320, 1, 1],
            ])
        self.loss = F.cross_entropy
        self.train_dataset = None
        self.val_dataset = None

        self.tfms_common = [
            transforms.ToTensor(),
            transforms.Normalize((0.285, ), (.3523, ))
        ]
        self.tfms_train = [
            transforms.RandomRotation(7., fill=(0, )),
            transforms.RandomCrop(28, padding=3),
        ]
示例#2
0
    def __init__(self, pretrained=False, num_classes=5):
        super().__init__()
        # Pretrained model
        self.datapath = os.path.join(os.getcwd(), 'data')

        if pretrained:
            self.model = mobilenet_v2(pretrained=True)
            for params in self.model.features.parameters(recurse=True):
                params.requires_grad = False
            output_features_nb = 1280
            fc_layer = nn.Linear(output_features_nb, num_classes, bias=True)
            self.model.classifier = nn.Sequential(nn.Dropout(0.2), fc_layer)
            torch.nn.init.xavier_normal_(fc_layer.weight, gain=1.0)
            torch.nn.init.zeros_(fc_layer.bias)

            # self.model = resnet34(pretrained=True)
            # for params in self.model.parameters(recurse=True):
            #     params.requires_grad = False
            # output_features_nb = self.model.fc.weight.size(1)
            # self.model.fc = nn.Linear(output_features_nb, num_classes, bias=True)
            # for params in self.model.fc.parameters(recurse=True):
            #     params.requires_grad = True
            # torch.nn.init.xavier_normal_(self.model.fc.weight, gain=1.0)
            # torch.nn.init.zeros_(self.model.fc.bias)

        # From scratch model
        else:
            features = MobileNetV2(num_classes=num_classes)
            # Freeze the network except the classifier
            for params in features.features.parameters(recurse=True):
                params.requires_grad = False
            self.model = features

        self.softmax = nn.Softmax(dim=1)
示例#3
0
def test_hawq_precision_init(_seed, dataset_dir, tmp_path, mocker,
                             config_creator: Callable, filename_suffix: str):
    num_data_points = 100
    batch_size = 10
    config = config_creator(batch_size, num_data_points)
    model = MobileNetV2(num_classes=10)
    model.eval()

    criterion = nn.CrossEntropyLoss().cuda()
    if not dataset_dir:
        dataset_dir = str(tmp_path)
    train_loader, _ = create_test_dataloaders(config.get("model_size"),
                                              dataset_dir, batch_size)
    config = register_default_init_args(config, criterion, train_loader)

    mocked_trace = mocker.patch(
        'nncf.quantization.hessian_trace.HessianTraceEstimator.get_average_traces'
    )

    mock_avg_traces = get_mock_avg_traces(model)
    mocked_trace.return_value = mock_avg_traces
    from torchvision.models.mobilenet import model_urls
    load_state(model, model_zoo.load_url(model_urls['mobilenet_v2']))
    model, algo_ctrl = create_compressed_model_and_algo_for_test(model, config)
    model = model.cuda()

    all_quantizers_per_full_scope = get_all_quantizers_per_full_scope(model)
    graph = get_bitwidth_graph(algo_ctrl, model, all_quantizers_per_full_scope)

    path_to_dot = 'mobilenet_v2_mixed_bitwidth_graph_{}.dot'.format(
        filename_suffix)
    check_graph(graph,
                path_to_dot,
                os.path.join('quantized', 'hawq'),
                sort_dot_graph=False)
示例#4
0
 def __init__(self, num_labels: int):
     super().__init__(num_labels)
     self.downsample = nn.Sequential(nn.Conv2d(1, 3, 3, padding=(1, 3)),
                                     nn.BatchNorm2d(3), nn.ReLU(),
                                     nn.MaxPool2d((1, 2)))
     self.model = mobilenet_v2(pretrained=True)
     model = MobileNetV2(num_classes=num_labels)
     self.model.classifier = model.classifier
示例#5
0
 def __init__(self, num_classes, weight_path, image_size, cuda):
     super(CardRotator, self).__init__()
     self.image_size = image_size
     self.device = torch.device('cuda' if cuda else 'cpu')
     self.model = MobileNetV2(num_classes)
     self.model.load_state_dict(
         torch.load(utils.abs_path(weight_path), map_location='cpu'))
     self.model.to(self.device)
     self.model.eval()
示例#6
0
def test_hawq_hw_vpu_config_e2e(_seed, dataset_dir, tmp_path):
    config = HAWQConfigBuilder().for_vpu().with_ratio(1.01).build()
    model = MobileNetV2(num_classes=10)
    criterion = nn.CrossEntropyLoss()
    if not dataset_dir:
        dataset_dir = str(tmp_path)
    train_loader, _ = create_test_dataloaders(config, dataset_dir)
    config = register_default_init_args(config, train_loader, criterion)

    create_compressed_model_and_algo_for_test(model, config)
示例#7
0
    def __init__(self, device, num_classes=2):
        super(Ensemble, self).__init__()

        self.num_classes = num_classes
        self.models = [
            FeatherNetA().to(device),
            FeatherNetB().to(device),
            FishNet150(num_cls=self.num_classes).to(device),
            MobileNetV2(num_classes=self.num_classes).to(device),
            MobileLiteNet54().to(device),
            MobileLiteNet54_se().to(device)
        ]
        self.device = device
def disable_quantizer_gradients():
    config = get_quantization_config_without_range_init()
    config['input_info'] = {
        "sample_size": [1, 3, 10, 10],
    }
    model = MobileNetV2(num_classes=10)
    model.eval()
    model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)
    original_requires_grad_per_param = get_requires_grad_per_param(model)
    quantization_types = [class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values()]
    all_quantizations = get_all_modules_by_type(model, quantization_types)
    quantizers_switcher = QuantizersSwitcher(list(all_quantizations.values()))
    disabled_parameters = HAWQPrecisionInitializer.disable_all_gradients_except_weights_of_quantized_modules(
        quantizers_switcher,
        compression_ctrl.quantized_weight_modules_registry,
        model,
        get_scopes_of_skipped_weight_quantizers())
    return quantizers_switcher, disabled_parameters, model, original_requires_grad_per_param
def test_disable_quantizer_gradients():
    config = get_basic_quantization_config()
    config['input_info'] = {
        "sample_size": (1, 3, 10, 10),
    }
    model = MobileNetV2(num_classes=10)
    model.eval()
    model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

    quantization_types = [class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values()]
    all_quantizations = get_all_modules_by_type(model, quantization_types)

    HessianAwarePrecisionInitializeRunner.disable_quantizer_gradients(
        all_quantizations,
        compression_ctrl.quantized_weight_modules_registry,
        model)
    actual_state = get_requires_grad_per_param(model)
    path_to_ref = str(TEST_ROOT / 'data/hawq_reference/mobilenet_v2_requires_grad_per_param.json')
    compare_with_ref_if_exists(actual_state, path_to_ref)
def test_enable_quantizer_gradients():
    config = get_basic_quantization_config()
    config['input_info'] = {
        "sample_size": (1, 3, 10, 10),
    }
    model = MobileNetV2(num_classes=10)
    model.eval()
    model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

    quantization_types = [class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values()]
    all_quantizations = get_all_modules_by_type(model, quantization_types)

    original = get_requires_grad_per_param(model)
    disabled = HessianAwarePrecisionInitializeRunner.disable_quantizer_gradients(
        all_quantizations,
        compression_ctrl.quantized_weight_modules_registry,
        model)
    HessianAwarePrecisionInitializeRunner.enable_quantizer_gradients(model, all_quantizations, disabled)
    actual = get_requires_grad_per_param(model)
    assert original == actual
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_before.pth',
                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')
    param_norm_before = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}')

    net_init = [p.detach().clone() for p in net.parameters()]

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()
    analyze_model(net, trainloader, testloader, loss_fn, config)
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_after.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        print(repr(e))
        print('Could not find model data ... aborting ...')
        return
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    param_norm_after = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}')

    change_total = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_total += (p1 - p2).detach().pow(2).sum()
    change_total = change_total.sqrt().cpu().numpy()

    change_rel = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_rel += (p1 - p2).detach().pow(2).mean()
    change_rel = change_rel.sqrt().cpu().numpy()

    change_nrmsum = 0.0
    for p1, p2 in zip(net_init, net.parameters()):
        change_nrmsum += (p1 - p2).norm()
    change_nrmsum = change_nrmsum.cpu().numpy()

    # Analyze results
    acc_train, acc_test, loss_train, loss_trainw, grd_train = analyze_model(
        net, trainloader, testloader, loss_fn, config)

    save_output(args.table_path,
                name='ntk_stats',
                width=config['width'],
                num_params=num_params,
                acc_train=acc_train,
                acc_test=acc_test,
                loss_train=loss_train,
                loss_trainw=loss_trainw,
                grd_train=grd_train,
                param_norm_before=param_norm_before,
                param_norm_after=param_norm_after,
                change_total=change_total,
                change_rel=change_rel,
                change_nrmsum=change_nrmsum)

    # Save raw data
    # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init,
    #                pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after,
    #                pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff,
    #                pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff)
    # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth'
    # torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
示例#12
0
 def __init__(self, config: MNClassifierConfig):
     super().__init__()
     self.model = mobilenet_v2(pretrained=True)
     model = MobileNetV2(num_classes=config.num_labels)
     self.model.classifier = model.classifier
示例#13
0
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10', config['batch_size'], augmentations=False, shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2], widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16, widen_factor=config['width'], dropout_rate=0.0, num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(OrderedDict([
                                 ('flatten', torch.nn.Flatten()),
                                 ('linear0', torch.nn.Linear(3072, config['width'])),
                                 ('relu0', torch.nn.ReLU()),
                                 ('linear1', torch.nn.Linear(config['width'], config['width'])),
                                 ('relu1', torch.nn.ReLU()),
                                 ('linear2', torch.nn.Linear(config['width'], config['width'])),
                                 ('relu2', torch.nn.ReLU()),
                                 ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(OrderedDict([
                                 ('flatten', torch.nn.Flatten()),
                                 ('linear0', torch.nn.Linear(3072, config['width'])),
                                 ('relu0', torch.nn.ReLU()),
                                 ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10, width_mult=config['width'], round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'], 4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(OrderedDict([
                                  ('conv0', torch.nn.Conv2d(3, 1 * config['width'], kernel_size=3, padding=1)),
                                  ('relu0', torch.nn.ReLU()),
                                  # ('pool0', torch.nn.MaxPool2d(3)),
                                  ('conv1', torch.nn.Conv2d(1 * config['width'],
                                                            2 * config['width'], kernel_size=3, padding=1)),
                                  ('relu1', torch.nn.ReLU()),
                                  #  ('pool1', torch.nn.MaxPool2d(3)),
                                  ('conv2', torch.nn.Conv2d(2 * config['width'],
                                                            2 * config['width'], kernel_size=3, padding=1)),
                                  ('relu2', torch.nn.ReLU()),
                                  # ('pool2', torch.nn.MaxPool2d(3)),
                                  ('conv3', torch.nn.Conv2d(2 * config['width'],
                                                            4 * config['width'], kernel_size=3, padding=1)),
                                  ('relu3', torch.nn.ReLU()),
                                  ('pool3', torch.nn.MaxPool2d(3)),
                                  ('conv4', torch.nn.Conv2d(4 * config['width'],
                                                            4 * config['width'], kernel_size=3, padding=1)),
                                  ('relu4', torch.nn.ReLU()),
                                  ('pool4', torch.nn.MaxPool2d(3)),
                                  ('flatten', torch.nn.Flatten()),
                                  ('linear', torch.nn.Linear(36 * config['width'], 10))
                                  ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth',
                                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
          f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9, weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(torch.load(config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth',
                                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_after.pth'
        dl.train(net, optimizer, scheduler, loss_fn, trainloader, config, path=None, dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
        else:
            print(f'Would save to {path}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    save_output(args.table_path, name='ntk', width=config['width'], num_params=num_params,
                before_norm=ntk_matrix_before_norm, after_norm=ntk_matrix_after_norm,
                diff_norm=ntk_matrix_diff_norm, rdiff_norm=ntk_matrix_rdiff_norm,
                param_norm_before=param_norm_before, param_norm_after=param_norm_after,
                corr_coeff=corr_coeff, corr_tom=corr_tom)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
示例#14
0
S3_BUCKET = os.environ['S3_BUCKET'] if 'S3_BUCKET' in os.environ else 'models-eva'
MODEL_PATH = os.environ['MODEL_PATH'] if 'MODEL_PATH' in os.environ else 'mobilenet_v2-b0353104.pth'

print("Downloading model")
model_full_path = 'https://models-eva.s3.ap-south-1.amazonaws.com/mobilenet_v2-b0353104.pth'
s3 = boto3.client("s3")

try:
    if os.path.isfile(MODEL_PATH) != True:
        #obj = s3.get_object(Bucket=S3_BUCKET, Key=MODEL_PATH)
        #s3.Bucket(S3_BUCKET).download_file(MODEL_PATH, location)
        #print("Creating Bytestream")
        #bytestream = io.BytesIO(obj['Body'].read())
        #decodedd = decoder.b64decode(bytestream)
        model = MobileNetV2()
        state_dict = load_state_dict_from_url(model_full_path, '/tmp', progress=True)
        #print(type(model))

        #mm = mobilenet_v2(False)
        model.load_state_dict(state_dict)
        #m = torch.jit.script(mm)

        # Save to file
        #torch.jit.save(m, '/tmp/mobilenet_v2-b0353104.pth')
     
        # This line is equivalent to the previous
        #m.save("scriptmodule.pt")
        #with open('/tmp/scriptmodule.pt', 'wb') as f:
        #    f.write(bytestream.read())
        
示例#15
0
def _mobilenetv2_train_top_layers_only(model: MobileNetV2):
    assert type(model) is MobileNetV2
    model.requires_grad_(False)
    model.features[18].requires_grad_(True)
    model.features[17].requires_grad_(True)
    return model
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_before.pth',
                       map_location=device))
        print('Initialized net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_before.pth'
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Initialized net saved to file.')
        else:
            print(f'Would save to {path}')

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    ntk_matrix_before = batch_wise_ntk(net,
                                       trainloader,
                                       samplesize=args.sampling)
    plt.imshow(ntk_matrix_before)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_BEFORE.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_before_norm = np.linalg.norm(ntk_matrix_before.flatten())
    print(
        f'The total norm of the NTK sample before training is {ntk_matrix_before_norm:.2f}'
    )
    param_norm_before = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_before:.2f}')

    if args.pdist:
        pdist_init, cos_init, prod_init = batch_feature_correlations(
            trainloader)
        pdist_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_init])
        cos_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_init])
        prod_init_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_init])
        print(
            f'The total norm of feature distances before training is {pdist_init_norm:.2f}'
        )
        print(
            f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}'
        )
        print(
            f'The total norm of feature inner product before training is {prod_init_norm:.2f}'
        )

        save_plot(pdist_init, trainloader, name='pdist_before_training')
        save_plot(cos_init, trainloader, name='cosine_before_training')
        save_plot(prod_init, trainloader, name='prod_before_training')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '_after.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '_after.pth'
        dl.train(net,
                 optimizer,
                 scheduler,
                 loss_fn,
                 trainloader,
                 config,
                 path=None,
                 dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
        else:
            print(f'Would save to {path}')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    param_norm_after = np.sqrt(
        np.sum(
            [p.pow(2).sum().detach().cpu().numpy() for p in net.parameters()]))
    print(f'The L2 norm of the parameter vector is {param_norm_after:.2f}')

    ntk_matrix_after = batch_wise_ntk(net,
                                      trainloader,
                                      samplesize=args.sampling)
    plt.imshow(ntk_matrix_after)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_AFTER.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_after_norm = np.linalg.norm(ntk_matrix_after.flatten())
    print(
        f'The total norm of the NTK sample after training is {ntk_matrix_after_norm:.2f}'
    )

    ntk_matrix_diff = np.abs(ntk_matrix_before - ntk_matrix_after)
    plt.imshow(ntk_matrix_diff)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_DIFF.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_diff_norm = np.linalg.norm(ntk_matrix_diff.flatten())
    print(
        f'The total norm of the NTK sample diff is {ntk_matrix_diff_norm:.2f}')

    ntk_matrix_rdiff = np.abs(ntk_matrix_before - ntk_matrix_after) / (
        np.abs(ntk_matrix_before) + 1e-4)
    plt.imshow(ntk_matrix_rdiff)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_RDIFF.png',
                bbox_inches='tight',
                dpi=1200)
    ntk_matrix_rdiff_norm = np.linalg.norm(ntk_matrix_rdiff.flatten())
    print(
        f'The total norm of the NTK sample relative diff is {ntk_matrix_rdiff_norm:.2f}'
    )

    n1_mean = np.mean(ntk_matrix_before)
    n2_mean = np.mean(ntk_matrix_after)
    matrix_corr = (ntk_matrix_before - n1_mean) * (ntk_matrix_after - n2_mean) / \
        np.std(ntk_matrix_before) / np.std(ntk_matrix_after)
    plt.imshow(matrix_corr)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png',
                bbox_inches='tight',
                dpi=1200)
    corr_coeff = np.mean(matrix_corr)
    print(
        f'The Correlation coefficient of the NTK sample before and after training is {corr_coeff:.2f}'
    )

    matrix_sim = (ntk_matrix_before * ntk_matrix_after) / \
        np.sqrt(np.sum(ntk_matrix_before**2) * np.sum(ntk_matrix_after**2))
    plt.imshow(matrix_corr)
    plt.savefig(config['path'] +
                f'{args.net}{config["width"]}_CIFAR_NTK_CORR.png',
                bbox_inches='tight',
                dpi=1200)
    corr_tom = np.sum(matrix_sim)
    print(
        f'The Similarity coefficient of the NTK sample before and after training is {corr_tom:.2f}'
    )

    save_output(args.table_path,
                name='ntk',
                width=config['width'],
                num_params=num_params,
                before_norm=ntk_matrix_before_norm,
                after_norm=ntk_matrix_after_norm,
                diff_norm=ntk_matrix_diff_norm,
                rdiff_norm=ntk_matrix_rdiff_norm,
                param_norm_before=param_norm_before,
                param_norm_after=param_norm_after,
                corr_coeff=corr_coeff,
                corr_tom=corr_tom)

    if args.pdist:
        # Check feature maps after training
        pdist_after, cos_after, prod_after = batch_feature_correlations(
            trainloader)

        pdist_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_after])
        cos_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_after])
        prod_after_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_after])
        print(
            f'The total norm of feature distances after training is {pdist_after_norm:.2f}'
        )
        print(
            f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}'
        )
        print(
            f'The total norm of feature inner product after training is {prod_after_norm:.2f}'
        )

        save_plot(pdist_after, trainloader, name='pdist_after_training')
        save_plot(cos_after, trainloader, name='cosine_after_training')
        save_plot(prod_after, trainloader, name='prod_after_training')

        # Check feature map differences
        pdist_ndiff = [
            np.abs(co1 - co2) / pdist_init_norm
            for co1, co2 in zip(pdist_init, pdist_after)
        ]
        cos_ndiff = [
            np.abs(co1 - co2) / cos_init_norm
            for co1, co2 in zip(cos_init, cos_after)
        ]
        prod_ndiff = [
            np.abs(co1 - co2) / prod_init_norm
            for co1, co2 in zip(prod_init, prod_after)
        ]

        pdist_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff])
        cos_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_ndiff])
        prod_ndiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_ndiff])
        print(
            f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}'
        )
        print(
            f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}'
        )
        print(
            f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}'
        )

        save_plot(pdist_ndiff, trainloader, name='pdist_ndiff')
        save_plot(cos_ndiff, trainloader, name='cosine_ndiff')
        save_plot(prod_ndiff, trainloader, name='prod_ndiff')

        # Check feature map differences
        pdist_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(pdist_init, pdist_after)
        ]
        cos_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(cos_init, cos_after)
        ]
        prod_rdiff = [
            np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
            for co1, co2 in zip(prod_init, prod_after)
        ]

        pdist_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff])
        cos_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in cos_rdiff])
        prod_rdiff_norm = np.mean(
            [np.linalg.norm(cm.flatten()) for cm in prod_rdiff])
        print(
            f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}'
        )
        print(
            f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}'
        )
        print(
            f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}'
        )

        save_plot(pdist_rdiff, trainloader, name='pdist_rdiff')
        save_plot(cos_rdiff, trainloader, name='cosine_rdiff')
        save_plot(prod_rdiff, trainloader, name='prod_rdiff')

        save_output(args.table_path,
                    'pdist',
                    width=config['width'],
                    num_params=num_params,
                    pdist_init_norm=pdist_init_norm,
                    pdist_after_norm=pdist_after_norm,
                    pdist_ndiff_norm=pdist_ndiff_norm,
                    pdist_rdiff_norm=pdist_rdiff_norm,
                    cos_init_norm=pdist_init_norm,
                    cos_after_norm=pdist_after_norm,
                    cos_ndiff_norm=pdist_ndiff_norm,
                    cos_rdiff_norm=cos_rdiff_norm,
                    prod_init_norm=pdist_init_norm,
                    prod_after_norm=pdist_after_norm,
                    prod_ndiff_norm=pdist_ndiff_norm,
                    prod_rdiff_norm=prod_rdiff_norm)

    # Save raw data
    # raw_pkg = dict(pdist_init=pdist_init, cos_init=cos_init, prod_init=prod_init,
    #                pdist_after=pdist_after, cos_after=cos_after, prod_after=prod_after,
    #                pdist_ndiff=pdist_ndiff, cos_ndiff=cos_ndiff, prod_ndiff=prod_ndiff,
    #                pdist_rdiff=pdist_rdiff, cos_rdiff=cos_rdiff, prod_rdiff=prod_rdiff)
    # path = config['path'] + 'Cifar10_' + args.net + str(config["width"]) + '_rawmaps.pth'
    # torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')
def main():
    """Check ntks in a single call."""
    print(f'RUNNING NTK EXPERIMENT WITH NET {args.net} and WIDTH {args.width}')
    print(
        f'CPUs: {torch.get_num_threads()}, GPUs: {torch.torch.cuda.device_count()}'
    )
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))

    trainloader, testloader = dl.get_loaders('CIFAR10',
                                             config['batch_size'],
                                             augmentations=False,
                                             shuffle=False)

    if args.net == 'ResNet':
        net = WideResNet(BasicBlock, [2, 2, 2, 2],
                         widen_factor=config['width'])
    elif args.net == 'WideResNet':  # meliketoy wideresnet variant
        net = Wide_ResNet(depth=16,
                          widen_factor=config['width'],
                          dropout_rate=0.0,
                          num_classes=10)
    elif args.net == 'MLP':
        net = torch.nn.Sequential(
            OrderedDict([
                ('flatten', torch.nn.Flatten()),
                ('linear0', torch.nn.Linear(3072, config['width'])),
                ('relu0', torch.nn.ReLU()),
                ('linear1', torch.nn.Linear(config['width'], config['width'])),
                ('relu1', torch.nn.ReLU()),
                ('linear2', torch.nn.Linear(config['width'], config['width'])),
                ('relu2', torch.nn.ReLU()),
                ('linear3', torch.nn.Linear(config['width'], 10))
            ]))
    elif args.net == 'TwoLP':
        net = torch.nn.Sequential(
            OrderedDict([('flatten', torch.nn.Flatten()),
                         ('linear0', torch.nn.Linear(3072, config['width'])),
                         ('relu0', torch.nn.ReLU()),
                         ('linear3', torch.nn.Linear(config['width'], 10))]))
    elif args.net == 'MobileNetV2':
        net = MobileNetV2(num_classes=10,
                          width_mult=config['width'],
                          round_nearest=4)
    elif args.net == 'VGG':
        cfg_base = [
            64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
        ]
        cfg = [c * config['width'] for c in cfg_base if isinstance(c, int)]
        print(cfg)
        net = VGG(make_layers(cfg), num_classes=10)
        net.classifier[0] = torch.nn.Linear(512 * 7 * 7 * config['width'],
                                            4096)
    elif args.net == 'ConvNet':
        net = torch.nn.Sequential(
            OrderedDict([
                ('conv0',
                 torch.nn.Conv2d(3,
                                 1 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu0', torch.nn.ReLU()),
                # ('pool0', torch.nn.MaxPool2d(3)),
                ('conv1',
                 torch.nn.Conv2d(1 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu1', torch.nn.ReLU()),
                #  ('pool1', torch.nn.MaxPool2d(3)),
                ('conv2',
                 torch.nn.Conv2d(2 * config['width'],
                                 2 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu2', torch.nn.ReLU()),
                # ('pool2', torch.nn.MaxPool2d(3)),
                ('conv3',
                 torch.nn.Conv2d(2 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu3', torch.nn.ReLU()),
                ('pool3', torch.nn.MaxPool2d(3)),
                ('conv4',
                 torch.nn.Conv2d(4 * config['width'],
                                 4 * config['width'],
                                 kernel_size=3,
                                 padding=1)),
                ('relu4', torch.nn.ReLU()),
                ('pool4', torch.nn.MaxPool2d(3)),
                ('flatten', torch.nn.Flatten()),
                ('linear', torch.nn.Linear(36 * config['width'], 10))
            ]))
    else:
        raise ValueError('Invalid network specified.')
    net.to(**config['setup'])

    num_params = sum([p.numel() for p in net.parameters()])
    print(
        f'Number of params: {num_params} - number of data points: {len(trainloader.dataset)} '
        f'- ratio : {len(trainloader.dataset) / num_params * 100:.2f}%')

    def batch_feature_correlations(dataloader, device=torch.device('cpu')):
        net.eval()
        net.to(device)
        dist_maps = list()
        cosine_maps = list()
        prod_maps = list()
        hooks = []

        def batch_wise_feature_correlation(self, input, output):
            feat_vec = input[0].detach().view(dataloader.batch_size, -1)
            dist_maps.append(
                torch.cdist(feat_vec, feat_vec, 2).detach().cpu().numpy())

            cosine_map = np.empty(
                (dataloader.batch_size, dataloader.batch_size))
            prod_map = np.empty((dataloader.batch_size, dataloader.batch_size))
            for row in range(dataloader.batch_size):
                cosine_map[row, :] = torch.nn.functional.cosine_similarity(
                    feat_vec[row:row + 1, :], feat_vec, dim=1,
                    eps=1e-8).detach().cpu().numpy()
                prod_map[row, :] = torch.mean(feat_vec[row:row + 1, :] *
                                              feat_vec,
                                              dim=1).detach().cpu().numpy()
            cosine_maps.append(cosine_map)
            prod_maps.append(prod_map)

        if isinstance(net, torch.nn.DataParallel):
            hooks.append(
                net.module.linear.register_forward_hook(
                    batch_wise_feature_correlation))
        else:
            if args.net in ['MLP', 'TwoLP']:
                hooks.append(
                    net.linear3.register_forward_hook(
                        batch_wise_feature_correlation))
            elif args.net in ['VGG', 'MobileNetV2']:
                hooks.append(
                    net.classifier.register_forward_hook(
                        batch_wise_feature_correlation))
            else:
                hooks.append(
                    net.linear.register_forward_hook(
                        batch_wise_feature_correlation))

        for inputs, _ in dataloader:
            outputs = net(inputs.to(device))
            if args.dryrun:
                break

        for hook in hooks:
            hook.remove()

        return dist_maps, cosine_maps, prod_maps

    pdist_init, cos_init, prod_init = batch_feature_correlations(trainloader)
    pdist_init_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_init])
    cos_init_norm = np.mean([np.linalg.norm(cm.flatten()) for cm in cos_init])
    prod_init_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_init])
    print(
        f'The total norm of feature distances before training is {pdist_init_norm:.2f}'
    )
    print(
        f'The total norm of feature cosine similarity before training is {cos_init_norm:.2f}'
    )
    print(
        f'The total norm of feature inner product before training is {prod_init_norm:.2f}'
    )

    save_plot(pdist_init, trainloader, name='pdist_before_training')
    save_plot(cos_init, trainloader, name='cosine_before_training')
    save_plot(prod_init, trainloader, name='prod_before_training')

    # Start training
    net.to(**config['setup'])
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=config['lr'],
                                momentum=0.9,
                                weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[60, 120, 160],
                                                     gamma=0.2)
    loss_fn = torch.nn.CrossEntropyLoss()

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    try:
        net.load_state_dict(
            torch.load(config['path'] + 'Cifar10_' + args.net +
                       str(config["width"]) + '.pth',
                       map_location=device))
        print('Net loaded from file.')
    except Exception as e:  # :>
        path = config['path'] + 'Cifar10_' + args.net + str(
            config["width"]) + '.pth'
        dl.train(net,
                 optimizer,
                 scheduler,
                 loss_fn,
                 trainloader,
                 config,
                 path=None,
                 dryrun=args.dryrun)
        if not args.dryrun:
            torch.save(net.state_dict(), path)
            print('Net saved to file.')
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    if isinstance(net, torch.nn.DataParallel):
        net = net.module

    # Check feature maps after training
    pdist_after, cos_after, prod_after = batch_feature_correlations(
        trainloader)

    pdist_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_after])
    cos_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_after])
    prod_after_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_after])
    print(
        f'The total norm of feature distances after training is {pdist_after_norm:.2f}'
    )
    print(
        f'The total norm of feature cosine similarity after training is {cos_after_norm:.2f}'
    )
    print(
        f'The total norm of feature inner product after training is {prod_after_norm:.2f}'
    )

    save_plot(pdist_after, trainloader, name='pdist_after_training')
    save_plot(cos_after, trainloader, name='cosine_after_training')
    save_plot(prod_after, trainloader, name='prod_after_training')

    # Check feature map differences
    pdist_ndiff = [
        np.abs(co1 - co2) / pdist_init_norm
        for co1, co2 in zip(pdist_init, pdist_after)
    ]
    cos_ndiff = [
        np.abs(co1 - co2) / cos_init_norm
        for co1, co2 in zip(cos_init, cos_after)
    ]
    prod_ndiff = [
        np.abs(co1 - co2) / prod_init_norm
        for co1, co2 in zip(prod_init, prod_after)
    ]

    pdist_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_ndiff])
    cos_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_ndiff])
    prod_ndiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_ndiff])
    print(
        f'The total norm normalized diff of feature distances after training is {pdist_ndiff_norm:.2f}'
    )
    print(
        f'The total norm normalized diff of feature cosine similarity after training is {cos_ndiff_norm:.2f}'
    )
    print(
        f'The total norm normalized diff of feature inner product after training is {prod_ndiff_norm:.2f}'
    )

    save_plot(pdist_ndiff, trainloader, name='pdist_ndiff')
    save_plot(cos_ndiff, trainloader, name='cosine_ndiff')
    save_plot(prod_ndiff, trainloader, name='prod_ndiff')

    # Check feature map differences
    pdist_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(pdist_init, pdist_after)
    ]
    cos_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(cos_init, cos_after)
    ]
    prod_rdiff = [
        np.abs(co1 - co2) / (np.abs(co1) + 1e-6)
        for co1, co2 in zip(prod_init, prod_after)
    ]

    pdist_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in pdist_rdiff])
    cos_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in cos_rdiff])
    prod_rdiff_norm = np.mean(
        [np.linalg.norm(cm.flatten()) for cm in prod_rdiff])
    print(
        f'The total norm relative diff of feature distances after training is {pdist_rdiff_norm:.2f}'
    )
    print(
        f'The total norm relative diff of feature cosine similarity after training is {cos_rdiff_norm:.2f}'
    )
    print(
        f'The total norm relative diff of feature inner product after training is {prod_rdiff_norm:.2f}'
    )

    save_plot(pdist_rdiff, trainloader, name='pdist_rdiff')
    save_plot(cos_rdiff, trainloader, name='cosine_rdiff')
    save_plot(prod_rdiff, trainloader, name='prod_rdiff')

    save_output(args.table_path,
                width=config['width'],
                num_params=num_params,
                pdist_init_norm=pdist_init_norm,
                pdist_after_norm=pdist_after_norm,
                pdist_ndiff_norm=pdist_ndiff_norm,
                pdist_rdiff_norm=pdist_rdiff_norm,
                cos_init_norm=pdist_init_norm,
                cos_after_norm=pdist_after_norm,
                cos_ndiff_norm=pdist_ndiff_norm,
                cos_rdiff_norm=cos_rdiff_norm,
                prod_init_norm=pdist_init_norm,
                prod_after_norm=pdist_after_norm,
                prod_ndiff_norm=pdist_ndiff_norm,
                prod_rdiff_norm=prod_rdiff_norm)

    # Save raw data
    raw_pkg = dict(pdist_init=pdist_init,
                   cos_init=cos_init,
                   prod_init=prod_init,
                   pdist_after=pdist_after,
                   cos_after=cos_after,
                   prod_after=prod_after,
                   pdist_ndiff=pdist_ndiff,
                   cos_ndiff=cos_ndiff,
                   prod_ndiff=prod_ndiff,
                   pdist_rdiff=pdist_rdiff,
                   cos_rdiff=cos_rdiff,
                   prod_rdiff=prod_rdiff)
    path = config['path'] + 'Cifar10_' + args.net + str(
        config["width"]) + '_rawmaps.pth'
    torch.save(raw_pkg, path)

    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('-----------------------------------------------------')
    print('Job finished.----------------------------------------')
    print('-----------------------------------------------------')