dl_dw = torch.autograd.grad(loss, model.parameters(), create_graph=True)
            d2l_dw2 = torch.autograd.grad(dl_dw, model.parameters(), ones, create_graph=True)
            sum_dl2_dw2 = sum([torch.abs(g).sum() if args.abs else g.sum() for g in d2l_dw2])
            loss = loss + args.sf * sum_dl2_dw2

        loss.backward()
        optimiser.step()
        print(loss)
        loss_lst.append(loss)
        Hess = FullHessian(crit='MSELoss',
                            loader=training_loader,
                            device=device,
                            model=model,
                            double=False,
                            num_classes=10,
                            hessian_type='Hessian',
                            init_poly_deg=64,
                            poly_deg=128,
                            spectrum_margin=0.05,
                            poly_points=1024,
                            SSI_iters=128
                            )

        lmin, lmax = Hess.compute_lb_ub()
        lmax_lst.append(lmax)
        lmin_lst.append(lmin)

if args.num_epochs > args.scaling_epoch:
    f=osp.join(ckpt_dir , 'scaling_lmax_lmin.npz')
    f_loss = osp.join(images_dir,'scaling_loss.png')    
    f_lmax = osp.join(images_dir,'scaling_lmax.png')
                                map_location=lambda storage, loc: storage)
        if 'model' in state_dict.keys():
            state_dict = state_dict['model']
        model.load_state_dict(state_dict, strict=True)
        model = model.to(device)

        model = model.eval()
        logger.info('weights loaded from path: {}'.format(weights_path))
        logger.info('for epoch: {}'.format(epoch_number))

        Hess = FullHessian(crit='CrossEntropyLoss',
                           loader=loader,
                           device=device,
                           model=model,
                           num_classes=C,
                           hessian_type='Hessian',
                           init_poly_deg=64,
                           poly_deg=128,
                           spectrum_margin=0.05,
                           poly_points=1024,
                           SSI_iters=128)

        lmin, lmax = Hess.compute_lb_ub()

        lmax_list.append((epoch_number, lmax))
        lmin_list.append((epoch_number, lmin))

        logger.info(
            'computation finished for epoch number: {}'.format(epoch_number))
        print(lmax)
    lmax_path = osp.join(ckpt_dir, 'full_lambda_max.pkl')
        examples_per_class=args.examples_per_class,
        epc_seed=row.epc_seed,
        root=osp.join(args.dataset_root, args.dataset),
        train=True,
        transform=transform,
        download=True)

    loader = DataLoader(dataset=subset_dataset,
                        drop_last=False,
                        batch_size=args.batch_size)

    logger.info('Starting G decomposition...')
    res = FullHessian(
        crit='CrossEntropyLoss',
        loader=loader,
        device=device,
        model=model,
        num_classes=row.num_classes,
        hessian_type='G',
    ).compute_G_decomp()

    logger.info('G decomposition finished...')
    logger.info('Starting G eigenspectrum computation')
    Hess = FullHessian(
        crit='CrossEntropyLoss',
        loader=loader,
        device=device,
        model=model,
        num_classes=row.num_classes,
        hessian_type='G',
        init_poly_deg=
        64,  # number of iterations used to compute maximal/minimal eigenvalue
Exemplo n.º 4
0
    for i in range(args.num_models):
        model_weights_path = osp.join(ckpt_dir,
                                      f'model_weights_epochs_{i}.pth')
        print(f'Processing model number {i}')
        state_dict = torch.load(model_weights_path, device)
        model.load_state_dict(state_dict, strict=True)
        model = model.to(device)
        C = 10
        Hess = FullHessian(
            crit='CrossEntropyLoss',
            loader=loader,
            device=device,
            model=model,
            num_classes=C,
            hessian_type='Hessian',
            init_poly_deg=
            64,  # number of iterations used to compute maximal/minimal eigenvalue
            poly_deg=
            128,  # the higher the parameter the better the approximation
            spectrum_margin=0.05,
            poly_points=1024,  # number of points in spectrum approximation
            SSI_iters=128,  # iterations of subspace iterations
        )

        # Exact outliers computation using subspace iteration
        eigvecs_cur, eigvals_cur, _ = Hess.SubspaceIteration()

        # Flatten out eigvecs_cur as it is a list
        top_subspace = torch.zeros(0, device=device)
        for _ in range(len(eigvecs_cur)):
            b = torch.zeros(0, device=device)
                                        download=True
                                        )

    loader = DataLoader(dataset=subset_dataset,
                        drop_last=False,
                        batch_size=args.batch_size)

    C = row.num_classes
    logger.info('Starting hessiandecomposition...')

    Hess = FullHessian(crit='CrossEntropyLoss',
                        loader=loader,
                        device=device,
                        model=model,
                        num_classes=C,
                        hessian_type='Hessian',
                        init_poly_deg=64,
                        poly_deg=128,
                        spectrum_margin=0.05,
                        poly_points=1024,
                        SSI_iters=128
                        )

    Hess_eigval, \
    Hess_eigval_density = Hess.LanczosLoop(denormalize=True)     # the higher the parameter the better the approximation

    H = FullHessian(crit='CrossEntropyLoss',
                        loader=loader,
                        device=device,
                        model=model,
                        num_classes=C,
                        hessian_type='H',
Exemplo n.º 6
0
                                        epc_seed=row.epc_seed,
                                        root=osp.join(args.dataset_root, args.dataset),
                                        train=True,
                                        transform=transform,
                                        download=True
                                        )

    loader = DataLoader(dataset=subset_dataset,
                        drop_last=False,
                        batch_size=args.batch_size)
    C = row.num_classes
    logger.info('Starting G decomposition...')
    res = FullHessian(crit='CrossEntropyLoss',
              loader=loader,
              device=device,
              model=model,
              num_classes=C,
              hessian_type='G',
              ).compute_G_decomp()
    logger.info('G decomposed')
    logger.info('Starting TSNE...')
    tsne_embedded = TSNE(n_components=2,
                        metric='precomputed',
                        perplexity=C).fit_transform(res['dist'])
    logger.info('TSNE finished...')
    # t-SNE X
    delta_c_X = tsne_embedded[:C,0]
    delta_ccp_X = tsne_embedded[C:,0]

    # t-SNE Y
    delta_c_Y = tsne_embedded[:C,1]
Exemplo n.º 7
0
def train(model,
          optimizer,
          scheduler,
          dataloaders,
          criterion,
          device,
          num_epochs=100,
          args=None,
          dataset_sizes={
              'train': 5e4,
              'test': 1e4
          },
          images_dir=None,
          ckpt_dir=None):

    logger = logging.getLogger('train')
    loss_list = {'train': list(), 'test': list()}
    acc_list = {'train': list(), 'test': list()}

    assert images_dir is not None
    assert ckpt_dir is not None

    loss_image_path = osp.join(images_dir, 'loss.png')
    acc_image_path = osp.join(images_dir, 'acc.png')

    model.train()

    full_eigenspectrums = list()
    epoch_eigenspectrums = list()

    full_eigenspectrums_path = osp.join(ckpt_dir,
                                        'training_eigenspectrum_full.npy')
    C = config.num_classes
    valid_layers = get_valid_layers(model)
    for epoch in range(num_epochs):
        logger.info('epoch: %d' % epoch)
        with torch.enable_grad():
            for batch, truth in dataloaders['train']:

                batch = batch.to(device)
                truth = truth.to(device)
                optimizer.zero_grad()

                output = model(batch)
                loss = criterion(output, truth)

                loss.backward()
                optimizer.step()

        scheduler.step()

        # updates finished for epochs

        mean, std = get_mean_std(args.dataset)
        pad = int((config.padded_im_size - config.im_size) / 2)
        transform = transforms.Compose([
            transforms.Pad(pad),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        if args.dataset in ['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100']:
            full_dataset = getattr(datasets, args.dataset)
            subset_dataset = get_subset_dataset(
                full_dataset=full_dataset,
                examples_per_class=args.examples_per_class,
                epc_seed=config.epc_seed,
                root=osp.join(args.dataset_root, args.dataset),
                train=True,
                transform=transform,
                download=True)
        elif args.dataset in ['STL10', 'SVHN']:
            full_dataset = getattr(datasets, args.dataset)
            subset_dataset = get_subset_dataset(
                full_dataset=full_dataset,
                examples_per_class=args.examples_per_class,
                epc_seed=config.epc_seed,
                root=osp.join(args.dataset_root, args.dataset),
                split='train',
                transform=transform,
                download=True)
        else:
            raise Exception('Unknown dataset: {}'.format(args.dataset))

        loader = data.DataLoader(dataset=subset_dataset,
                                 drop_last=False,
                                 batch_size=args.batch_size)

        Hess = FullHessian(crit='CrossEntropyLoss',
                           loader=loader,
                           device=device,
                           model=model,
                           num_classes=C,
                           hessian_type='Hessian',
                           init_poly_deg=64,
                           poly_deg=128,
                           spectrum_margin=0.05,
                           poly_points=1024,
                           SSI_iters=128)

        Hess_eigval, \
        Hess_eigval_density = Hess.LanczosLoop(denormalize=True)

        full_eigenspectrums.append(Hess_eigval)
        full_eigenspectrums.append(Hess_eigval_density)

        for layer_name, _ in model.named_parameters():
            if layer_name not in valid_layers:
                continue

            Hess = LayerHessian(crit='CrossEntropyLoss',
                                loader=loader,
                                device=device,
                                model=model,
                                num_classes=C,
                                layer_name=layer_name,
                                hessian_type='Hessian',
                                init_poly_deg=64,
                                poly_deg=128,
                                spectrum_margin=0.05,
                                poly_points=1024,
                                SSI_iters=128)

            Hess_eigval, \
            Hess_eigval_density = Hess.LanczosLoop(denormalize=True)

            layerwise_eigenspectrums_path = osp.join(
                ckpt_dir,
                'training_eigenspectrums_epoch_{}_layer_{}.npz'.format(
                    epoch, layer_name))
            np.savez(layerwise_eigenspectrums_path,
                     eigval=Hess_eigval,
                     eigval_density=Hess_eigval_density)

        for phase in ['train', 'test']:

            stats = evaluate_model(model, criterion, dataloaders[phase],
                                   device, dataset_sizes[phase])

            loss_list[phase].append(stats['loss'])
            acc_list[phase].append(stats['acc'])

            logger.info('{}:'.format(phase))
            logger.info('\tloss:{}'.format(stats['loss']))
            logger.info('\tacc :{}'.format(stats['acc']))

            if phase == 'test':
                plt.clf()
                plt.plot(loss_list['test'], label='test_loss')
                plt.plot(loss_list['train'], label='train_loss')
                plt.legend()
                plt.savefig(loss_image_path)

                plt.clf()
                plt.plot(acc_list['test'], label='test_acc')
                plt.plot(acc_list['train'], label='train_acc')
                plt.legend()
                plt.savefig(acc_image_path)
                plt.clf()

    full_eigenspectrums = np.array(full_eigenspectrums)
    assert full_eigenspectrums.shape[0] % 2 == 0
    assert full_eigenspectrums.shape[0] // 2 == num_epochs
    np.save(full_eigenspectrums_path, full_eigenspectrums)
    return full_eigenspectrums