Exemple #1
0
def test_for_xray(opt, model=None, loader=None, plot=False, vae=False):
    if model is None:
        model = models.AE(opt.ls, opt.mp, opt.u, img_size=IMG_SIZE,
                          vae=vae).to(device)
        model.load_state_dict(torch.load('./models/{}.pth'.format(opt.exp)))
    if loader is None:
        loader = xray_data.get_xray_dataloader(1,
                                               WORKERS,
                                               'test',
                                               dataset=DATASET,
                                               img_size=IMG_SIZE)

    model.eval()
    with torch.no_grad():
        y_score, y_true = [], []
        for bid, (x, label) in tqdm(enumerate(loader)):
            x = x.to(device)
            if opt.u:
                out, logvar = model(x)
                rec_err = (out - x)**2
                res = torch.exp(-logvar) * rec_err
            else:
                out = model(x)
                rec_err = (out - x)**2
                res = rec_err

            res = res.mean(dim=(1, 2, 3))

            y_true.append(label.cpu())
            y_score.append(res.cpu().view(-1))

        y_true = np.concatenate(y_true)
        y_score = np.concatenate(y_score)
        auc = metrics.roc_auc_score(y_true, y_score)
        print('AUC', auc)
        if plot:
            metrics_at_eer(y_score, y_true)
            plt.figure()
            plt.hist(y_score[y_true == 0],
                     bins=100,
                     density=True,
                     color='blue',
                     alpha=0.5)
            plt.hist(y_score[y_true == 1],
                     bins=100,
                     density=True,
                     color='red',
                     alpha=0.5)
            plt.figure()
            fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score)
            plt.plot(fpr, tpr)
            plt.plot([0, 1], [0, 1], 'k--')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.show()
        return auc
def train(opt):
    model = models.AE(opt.ls, opt.mp, img_size=IMG_SIZE)
    # model = models.Spatial2DAE()
    model.to(device)
    EPOCHS = 250
    loader = mvtec_data.get_dataloader(DATA_PATH, CLASSES[opt.c], 'train',
                                       BATCH_SIZE, WORKERS, IMG_SIZE)
    test_loader = mvtec_data.get_dataloader(DATA_PATH, CLASSES[opt.c], 'test',
                                            BATCH_SIZE, WORKERS, IMG_SIZE)

    # opt.model_dir = '/media/user/disk/models/'
    opt.model_dir = './models'
    opt.epochs = EPOCHS
    train_loop(model, loader, test_loader, opt)
Exemple #3
0
def train():
    model = models.AE(opt.ls, opt.mp, opt.u, img_size=IMG_SIZE)
    model.to(device)

    EPOCHS = 250
    loader = xray_data.get_xray_dataloader(BATCH_SIZE,
                                           WORKERS,
                                           'train',
                                           img_size=IMG_SIZE,
                                           dataset=DATASET)
    test_loader = xray_data.get_xray_dataloader(BATCH_SIZE,
                                                WORKERS,
                                                'test',
                                                img_size=IMG_SIZE,
                                                dataset=DATASET)

    opt.epochs = EPOCHS
    train_loop(model, loader, test_loader, opt)
Exemple #4
0
def full_training(args):
    if not os.path.isdir(args.expdir):
        os.makedirs(args.expdir)
    elif os.path.exists(args.expdir + '/results.npy'):
        return

    if 'ae' in args.task:
        os.mkdir(args.expdir + '/figs/')

    train_batch_size = args.train_batch_size // 4 if args.task == 'rot' else args.train_batch_size
    test_batch_size = args.test_batch_size // 4 if args.task == 'rot' else args.test_batch_size
    yield_indices = (args.task == 'inst_disc')
    datadir = args.datadir + args.dataset
    trainloader, valloader, num_classes = general_dataset_loader.prepare_data_loaders(
        datadir,
        image_dim=args.image_dim,
        yield_indices=yield_indices,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
        train_on_10_percent=args.train_on_10,
        train_on_half_classes=args.train_on_half)
    _, testloader, _ = general_dataset_loader.prepare_data_loaders(
        datadir,
        image_dim=args.image_dim,
        yield_indices=yield_indices,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
    )

    args.num_classes = num_classes
    if args.task == 'rot':
        num_classes = 4
    elif args.task == 'inst_disc':
        num_classes = args.low_dim

    if args.task == 'ae':
        net = models.AE([args.code_dim], image_dim=args.image_dim)
    elif args.task == 'jigsaw':
        net = JigsawModel(num_perms=args.num_perms,
                          code_dim=args.code_dim,
                          gray_prob=args.gray_prob,
                          image_dim=args.image_dim)
    else:
        net = models.resnet26(num_classes,
                              mlp_depth=args.mlp_depth,
                              normalize=(args.task == 'inst_disc'))
    if args.task == 'inst_disc':
        train_lemniscate = LinearAverage(args.low_dim,
                                         trainloader.dataset.__len__(),
                                         args.nce_t, args.nce_m)
        train_lemniscate.cuda()
        args.train_lemniscate = train_lemniscate
        test_lemniscate = LinearAverage(args.low_dim,
                                        valloader.dataset.__len__(),
                                        args.nce_t, args.nce_m)
        test_lemniscate.cuda()
        args.test_lemniscate = test_lemniscate
    if args.source:
        try:
            old_net = torch.load(args.source)
        except:
            print("Falling back encoding")
            from functools import partial
            import pickle
            pickle.load = partial(pickle.load, encoding="latin1")
            pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
            old_net = torch.load(args.source,
                                 map_location=lambda storage, loc: storage,
                                 pickle_module=pickle)

        # net.load_state_dict(old_net['net'].state_dict())
        old_net = old_net['net']
        if hasattr(old_net, "module"):
            old_net = old_net.module
        old_state_dict = old_net.state_dict()
        new_state_dict = net.state_dict()
        for key, weight in old_state_dict.items():
            if 'linear' not in key:
                new_state_dict[key] = weight
            elif key == 'linears.0.weight' and weight.shape[0] == num_classes:
                new_state_dict['linears.0.0.weight'] = weight
            elif key == 'linears.0.bias' and weight.shape[0] == num_classes:
                new_state_dict['linears.0.0.bias'] = weight
        net.load_state_dict(new_state_dict)

        del old_net
    net = torch.nn.DataParallel(net).cuda()
    start_epoch = 0
    if args.task in ['ae', 'inst_disc']:
        best_acc = np.inf
    else:
        best_acc = -1
    results = np.zeros((4, start_epoch + args.nb_epochs))

    net.cuda()
    cudnn.benchmark = True

    if args.task in ['ae']:
        args.criterion = nn.MSELoss()
    else:
        args.criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       net.parameters()),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.wd)

    print("Start training")
    train_func = eval('utils_pytorch.train_' + args.task)
    test_func = eval('utils_pytorch.test_' + args.task)
    if args.test_first:
        with torch.no_grad():
            test_func(0, valloader, net, best_acc, args, optimizer)
    for epoch in range(start_epoch, start_epoch + args.nb_epochs):
        utils_pytorch.adjust_learning_rate(optimizer, epoch, args)
        st_time = time.time()

        # Training and validation
        train_acc, train_loss = train_func(epoch, trainloader, net, args,
                                           optimizer)
        test_acc, test_loss, best_acc = test_func(epoch, valloader, net,
                                                  best_acc, args, optimizer)

        # Record statistics
        results[0:2, epoch] = [train_loss, train_acc]
        results[2:4, epoch] = [test_loss, test_acc]
        np.save(args.expdir + '/results.npy', results)
        print('Epoch lasted {0}'.format(time.time() - st_time))
        sys.stdout.flush()
        if (args.task == 'rot') and (train_acc >= 98) and args.early_stopping:
            break
    if args.task == 'inst_disc':
        args.train_lemniscate = None
        args.test_lemniscate = None
    else:
        best_net = torch.load(args.expdir + 'checkpoint.t7')['net']
        if args.task in ['ae', 'inst_disc']:
            best_acc = np.inf
        else:
            best_acc = -1
        final_acc, final_loss, _ = test_func(0, testloader, best_net, best_acc,
                                             args, None)
Exemple #5
0
                                      directed=True)
train_loader = Dataloader(dataset, batch_size=1)
dataset = d_selector.retrieve_dataset(args.dataset,
                                      partition="val",
                                      directed=True)
val_loader = Dataloader(dataset, batch_size=1, shuffle=False)
dataset = d_selector.retrieve_dataset(args.dataset,
                                      partition="test",
                                      directed=True)
test_loader = Dataloader(dataset, batch_size=1, shuffle=False)

if args.model == 'ae':
    model = models.AE(hidden_nf=args.nf,
                      embedding_nf=args.emb_nf,
                      noise_dim=args.noise_dim,
                      act_fn=nn.SiLU(),
                      learnable_dec=1,
                      device=device,
                      attention=args.attention,
                      n_layers=args.n_layers)
elif args.model == 'ae_rf':
    model = models.AE_rf(embedding_nf=args.K,
                         nf=args.nf,
                         device=device,
                         n_layers=args.n_layers,
                         reg=args.reg,
                         act_fn=nn.SiLU(),
                         clamp=args.clamp)
elif args.model == 'ae_egnn':
    model = models.AE_EGNN(hidden_nf=args.nf,
                           K=args.K,
                           act_fn=nn.SiLU(),
Exemple #6
0
def visualize(opt):
    def optimize(x, model):
        x_0 = x.detach().clone()
        x.requires_grad_()
        for i in range(ITERS):
            out = model(x).detach()
            loss = torch.sum((x - out)**2) + ALPHA * torch.abs(x - x_0).sum()
            loss.backward()
            with torch.no_grad():
                x_grad = x.grad.data
                x = x - LR * x_grad * (x - out)**2
            x.requires_grad_()
        return x
    device = torch.device('cuda:{}'.format(opt.cuda))
    model = models.AE(opt.ls, opt.mp, img_size=IMG_SIZE)
    # model = models.Spatial2DAE()
    model.load_state_dict(torch.load(
        './models/%s.pth' % opt.exp, map_location='cpu'))
    model.to(device)
    model.eval()
    path = os.path.join(DATA_PATH, CLASSES[opt.c], 'test')
    mask_path = os.path.join(DATA_PATH, CLASSES[opt.c], 'ground_truth')


    y_score, y_true = [], []

    try:
        os.makedirs('./visual_grad/{}'.format(opt.exp))
    except:
        pass

    for type_ in os.listdir(path):
        for img_name in tqdm(os.listdir(os.path.join(path, type_))):
            img = Image.open(os.path.join(path, type_, img_name)).resize((IMG_SIZE, IMG_SIZE), resample=Image.BILINEAR)
            img = TF.to_tensor(img).unsqueeze(0)
            if img.size(1) == 1:
                img = img.repeat(1, 3, 1, 1)
            if type_ != 'good':
                mask = Image.open(os.path.join(mask_path, type_, img_name.split('.')[0]+'_mask.png')).resize((IMG_SIZE, IMG_SIZE), resample=Image.BILINEAR)
                mask = TF.to_tensor(mask)
                mask = mask.unsqueeze(0).repeat(1, 3, 1, 1)
                y_true.append(1)
            else:
                mask = torch.zeros_like(img)
                y_true.append(0)
            img = img.to(device)

            img = (img - 0.5) / 0.5

            result = optimize(img, model)
            rec_err = (result - img) ** 2
            rec_err = rec_err.mean(dim=1, keepdim=True)

            img = torch.clamp(img.cpu() * 0.5 + 0.5, 0, 1)
            rec = torch.clamp(result.cpu() * 0.5 + 0.5, 0, 1)
            rec_err = torch.clamp(rec_err.cpu() / 0.5, 0, 1).repeat(1, 3, 1, 1)
            y_score.append(rec_err.mean().item())

            cat = torch.cat((mask, img, rec, rec_err))

            try:
                os.mkdir('./visual_grad/{}/{}'.format(opt.exp, type_))
            except:
                pass
            
            save_image(cat, './visual_grad/{}/{}/{}'.format(opt.exp, type_, img_name))

    y_score, y_true = np.array(y_score), np.array(y_true)

    print(metrics.roc_auc_score(y_true, y_score))