def main():
    global args
    args = parser.parse_args()

    vae_ckpt = './checkpoints/cure-tsr/vae/' \
               'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
    gradient_layer = 'down_6'  # kld
    gradient_layer2 = 'up_0'  # bce
    chall = '07_01'  # Training outlier class
    savedir = 'cure-tsr/d/%s/bce_kld_grad/d_BCE_ShallowLinear_norm_bce-%s_kld-%s_in-00_00_out-%s' \
              % (vae_ckpt.split('/')[-2], gradient_layer2, gradient_layer, chall)

    checkpointdir = os.path.join('./checkpoints', savedir)
    logdir = os.path.join('./logs', savedir)

    seed = random.randint(1, 100000)
    torch.manual_seed(seed)
    dataset_dir = os.path.join(args.dataset_dir,
                               'kld_grad/%s' % vae_ckpt.split('/')[-2])
    dataset_dir2 = os.path.join(args.dataset_dir,
                                'bce_grad/%s' % vae_ckpt.split('/')[-2])

    if args.write_enable:
        os.makedirs(checkpointdir)
        writer = SummaryWriter(log_dir=logdir)
        print('log directory: %s' % logdir)
        print('checkpoints directory: %s' % checkpointdir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    best_score = 1e20
    batch_size = 64

    vae = models.VAECURECNN()
    vae = torch.nn.DataParallel(vae).to(device)
    if os.path.isfile(vae_ckpt):
        print("=> loading checkpoint '{}'".format(vae_ckpt))
        checkpoint = torch.load(vae_ckpt)
        best_loss = checkpoint['best_loss']
        vae.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".format(
            vae_ckpt, checkpoint['epoch'], best_loss))
    else:
        print("=> no checkpoint found at '{}'".format(vae_ckpt))

    grad_dim = vae.module.down[6].weight.view(
        -1).shape[0] + vae.module.up[0].weight.view(-1).shape[0]

    d = models.DisShallowLinear(grad_dim)
    d = torch.nn.DataParallel(d).to(device)
    optimizer = optim.Adam(d.parameters(), lr=1e-3)

    in_train_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '00_00_train_%s.pt' % gradient_layer),
        os.path.join(dataset_dir2, '00_00_train_%s.pt' % gradient_layer2)
    ]),
                                                  batch_size=batch_size,
                                                  shuffle=True)

    out_train_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '%s_train_%s.pt' % (chall, gradient_layer)),
        os.path.join(dataset_dir2, '%s_train_%s.pt' % (chall, gradient_layer2))
    ]),
                                                   batch_size=batch_size,
                                                   shuffle=True)

    in_val_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '00_00_val_%s.pt' % gradient_layer),
        os.path.join(dataset_dir2, '00_00_val_%s.pt' % gradient_layer2)
    ]),
                                                batch_size=batch_size,
                                                shuffle=True)

    out_val_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '%s_val_%s.pt' % (chall, gradient_layer)),
        os.path.join(dataset_dir2, '%s_val_%s.pt' % (chall, gradient_layer2))
    ]),
                                                 batch_size=batch_size,
                                                 shuffle=True)

    # Start training
    timestart = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        print('\n*** Start Training *** Epoch: [%d/%d]\n' %
              (epoch + 1, args.epochs))
        d_ext_gradient_train.train(d,
                                   None,
                                   device,
                                   in_train_loader,
                                   optimizer,
                                   epoch + 1,
                                   args.print_freq,
                                   out_iter=iter(out_train_loader))

        print('\n*** Start Testing *** Epoch: [%d/%d]\n' %
              (epoch + 1, args.epochs))
        loss, acc, _ = d_ext_gradient_train.test(d,
                                                 None,
                                                 device,
                                                 in_val_loader,
                                                 epoch + 1,
                                                 args.print_freq,
                                                 out_iter=iter(out_val_loader))

        is_best = loss < best_score
        best_score = min(loss, best_score)

        if is_best:
            best_epoch = epoch + 1

        if args.write_enable:
            if epoch % args.write_freq == 0 or is_best is True:
                writer.add_scalar('loss', loss, epoch + 1)
                writer.add_scalar('accuracy', acc, epoch + 1)
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': d.state_dict(),
                        'best_acc': best_score,
                        'last_loss': loss,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, checkpointdir)

    if args.write_enable:
        writer.close()

    print('Best Testing Acc/Loss: %.3f at epoch %d' % (best_score, best_epoch))
    print('Best epoch: ', best_epoch)
    print('Total processing time: %.4f' % (time.time() - timestart))
def main():

    global args
    args = parser.parse_args()

    batch_size = 64
    train_chall = '07_01'
    chall_type_list = [1, 2, 3, 5, 8, 9]
    all_results = np.zeros([5, len(chall_type_list) * 5])

    challcnt = 0
    for challID in chall_type_list:
        for levelID in range(1, 6):
            test_chall = '%02d_%02d' % (challID, levelID)
            vae_ckpt = './checkpoints/cure-tsr/vae/' \
                       'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            gradient_layer = 'down_6'  # kld
            gradient_layer2 = 'up_6'  # bce
            d_ckpt = './checkpoints/cure-tsr/d/%s/bce_kld_grad/' \
                     'd_BCE_ShallowLinear_bce-%s_kld-%s_in-00_00_out-%s/model_best.pth.tar' \
                     % (vae_ckpt.split('/')[-2], gradient_layer2, gradient_layer, train_chall)
            dataset_dir = os.path.join(args.dataset_dir,
                                       'kld_grad/%s' % vae_ckpt.split('/')[-2])
            dataset_dir2 = os.path.join(
                args.dataset_dir, 'bce_grad/%s' % vae_ckpt.split('/')[-2])

            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")

            vae = models.VAECURECNN()
            vae = torch.nn.DataParallel(vae).to(device)
            vae.eval()
            if os.path.isfile(vae_ckpt):
                print("=> loading checkpoint '{}'".format(vae_ckpt))
                checkpoint_vae = torch.load(vae_ckpt)
                best_loss = checkpoint_vae['best_loss']
                vae.load_state_dict(checkpoint_vae['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".
                      format(vae_ckpt, checkpoint_vae['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(vae_ckpt))

            grad_dim = vae.module.down[6].weight.view(
                -1).shape[0] + vae.module.up[6].weight.view(-1).shape[0]

            d = models.DisShallowLinear(grad_dim)
            d = torch.nn.DataParallel(d).to(device)
            d.eval()

            if os.path.isfile(d_ckpt):
                print("=> loading checkpoint '{}'".format(d_ckpt))
                checkpoint_d = torch.load(d_ckpt)
                best_acc = checkpoint_d['best_acc']
                d.load_state_dict(checkpoint_d['state_dict'])
                print(
                    "=> loaded checkpoint '{}' (epoch {}, best_acc {})".format(
                        d_ckpt, checkpoint_d['epoch'], best_acc))
            else:
                print("=> no checkpoint found at '{}'".format(d_ckpt))

            in_test_loader = torch.utils.data.DataLoader(datasets.GradDataset([
                os.path.join(dataset_dir, '00_00_test_%s.pt' % gradient_layer),
                os.path.join(dataset_dir2,
                             '00_00_test_%s.pt' % gradient_layer2)
            ]),
                                                         batch_size=batch_size,
                                                         shuffle=True)

            out_test_loader = torch.utils.data.DataLoader(
                datasets.GradDataset([
                    os.path.join(
                        dataset_dir,
                        '%s_test_%s.pt' % (test_chall, gradient_layer)),
                    os.path.join(
                        dataset_dir2,
                        '%s_test_%s.pt' % (test_chall, gradient_layer2))
                ]),
                batch_size=batch_size,
                shuffle=True)

            # Start evaluation
            timestart = time.time()
            print('\n*** Start Testing *** \n')
            loss, acc, result = d_ext_gradient_train.test(
                d,
                None,
                device,
                in_test_loader,
                1,
                args.print_freq,
                out_iter=iter(out_test_loader),
                is_eval=True)
            print('Total processing time: %.4f' % (time.time() - timestart))

            # inliers
            in_pred = [x[1] for x in result if x[0] == 1]

            # outliers
            out_pred = [x[1] for x in result if x[0] == 0]

            all_results[0, challcnt] = acc
            all_results[1::,
                        challcnt] = cal_metric.calMetric(in_pred, out_pred)
            challcnt += 1

    np.savetxt('./results_%s_%s.csv' %
               (vae_ckpt.split('/')[-2], d_ckpt.split('/')[-2]),
               all_results,
               fmt='%.3f',
               delimiter=',')
Ejemplo n.º 3
0
def main():
    
    global args
    args = parser.parse_args()

    dis_train = 0  # 0: vae_train, 1: d_train
    use_vae = 0  # 0: Not use vae for discriminator training 1: Use vae for discriminator training
    std = 1.0
    chall = '07_01'

    if dis_train:
        if use_vae:
            vae_ckpt = './checkpoints/cure-tsr/vae/' \
                       'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            savedir = 'cure-tsr/d/%s/d_BCE_ShallowLinear_in-00_00_out-%s' \
                      % (vae_ckpt.split('/')[-2], chall)

        else:
            savedir = 'cure-tsr/d/d_BCE_ShallowLinear_in-00_00_out-%s' % chall

    else:
        savedir = 'vae/vae_BCE_gradient_reducedCnnSeq-4layer_train-14_val-14'

    checkpointdir = os.path.join('./checkpoints', savedir)
    logdir = os.path.join('./logs', savedir)

    seed = random.randint(1, 100000)
    torch.manual_seed(seed)
    dataset_dir = '/media/gukyeong/HardDisk/dataset/CURE-TSR/folds/RealChallengeFree'

    if args.write_enable:
        os.mkdir(checkpointdir)
        writer = SummaryWriter(log_dir=logdir)
        print('log directory: %s' % logdir)
        print('checkpoints directory: %s' % checkpointdir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if dis_train:
        best_score = 1e20  # Initinalization of a variable for accuracy
        batch_size = 64  # there will be additional 64 fake samples
        d = models.DisShallowLinear(28 * 28 * 3)
        d = torch.nn.DataParallel(d).to(device)
        d.apply(models.weights_init)
        optimizer = optim.Adam(d.parameters(), lr=1e-3)

        if use_vae:
            vae = models.VAECURELinear()
            vae = torch.nn.DataParallel(vae).to(device)
            vae.eval()
            if os.path.isfile(vae_ckpt):
                print("=> loading checkpoint '{}'".format(vae_ckpt))
                checkpoint = torch.load(vae_ckpt)
                best_loss = checkpoint['best_loss']
                vae.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})"
                      .format(vae_ckpt, checkpoint['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(vae_ckpt))

    else:
        vae = models.VAECURECNN()
        vae = torch.nn.DataParallel(vae).to(device)
        best_score = 1e20
        batch_size = 128
        optimizer = optim.Adam(vae.parameters(), lr=1e-3)

        if args.resume:
            vae_resume_ckpt = './checkpoints/cure-tsr/vae/' \
                              'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            if os.path.isfile(vae_resume_ckpt):
                print("=> loading checkpoint '{}'".format(vae_resume_ckpt))
                checkpoint = torch.load(vae_resume_ckpt)
                # best_loss = checkpoint['best_loss']
                vae.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})"
                      .format(vae_resume_ckpt, checkpoint['epoch'], checkpoint['best_loss']))
            else:
                print("=> no checkpoint found at '{}'".format(vae_resume_ckpt))

        # CURE-TSR (Train : Val : Test = 6 : 2 : 2)
        nsamples_fold = [325, 71, 17, 10, 235, 26, 295, 59, 20, 101, 16, 31, 179, 735]  # 14 classes
        in_cls = [13]
        out_cls = [i for i in range(13)]
        cls, nsamples_per_cls = utils.cal_nsample_perclass(in_cls, out_cls, nsamples_fold)

        in_train_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir, 'train'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=in_cls),
            batch_size=batch_size, shuffle=True)

        in_val_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir, 'val'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=in_cls),
            batch_size=batch_size, shuffle=True)

        out_train_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir[:-4], '%s' % chall, 'train'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=out_cls, nsamples_per_cls=[3 * i for i in nsamples_per_cls[len(in_cls)::]]),
            batch_size=batch_size, shuffle=True)

        out_val_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir[:-4], '%s' % chall, 'val'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=out_cls, nsamples_per_cls=[i for i in nsamples_per_cls[len(in_cls)::]]),
            batch_size=batch_size, shuffle=True)

        # Start training
        timestart = time.time()
        for epoch in range(args.start_epoch, args.epochs):
            if dis_train:
                print('\n*** Start Training *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))

                if use_vae:
                    d_train.train(d, device, in_train_loader, optimizer, epoch + 1, args.print_freq, vae=vae,
                                  std=std, out_iter=iter(out_train_loader))
                else:
                    d_train.train(d, device, in_train_loader, optimizer, epoch + 1, args.print_freq,
                                  out_iter=iter(out_train_loader))

                print('\n*** Start Testing *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))

                if use_vae:
                    loss, acc, _ = d_train.test(d, device, in_val_loader, epoch + 1, args.print_freq, vae=vae,
                                                std=std, out_iter=iter(out_val_loader))
                else:
                    loss, acc, _ = d_train.test(d, device, in_val_loader, epoch + 1, args.print_freq,
                                                out_iter=iter(out_val_loader))

                is_best = loss < best_score
                best_score = min(loss, best_score)

            else:
                print('\n*** Start Training *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))
                vae_train.train(vae, device, in_train_loader, optimizer, epoch + 1, args.print_freq)
                print('\n*** Start Testing *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))
                loss, input_img, recon_img, target_img = vae_train.test(vae, device, in_val_loader, epoch + 1,
                                                                        args.print_freq)

                is_best = loss < best_score
                best_score = min(loss, best_score)

            if is_best:
                best_epoch = epoch + 1

            if args.write_enable:
                if dis_train:
                    if epoch % args.write_freq == 0 or is_best is True:
                        writer.add_scalar('loss', loss, epoch + 1)
                        writer.add_scalar('accuracy', acc, epoch + 1)
                        utils.save_checkpoint({
                            'epoch': epoch + 1,
                            'state_dict': d.state_dict(),
                            'best_loss': best_score,
                            'last_loss': loss,
                            'optimizer': optimizer.state_dict(),
                        }, is_best, checkpointdir)

                else:
                    if epoch % args.write_freq == 0:
                        writer.add_scalar('loss', loss, epoch + 1)
                        writer.add_image('input_img', vutils.make_grid(input_img, nrow=3), epoch + 1)
                        writer.add_image('recon_img', vutils.make_grid(recon_img, nrow=3), epoch + 1)
                        writer.add_image('target_img', vutils.make_grid(target_img, nrow=3), epoch + 1)

                        utils.save_checkpoint({
                            'epoch': epoch + 1,
                            'state_dict': vae.state_dict(),
                            'best_loss': best_score,
                            'last_loss': loss,
                            'optimizer': optimizer.state_dict(),
                        }, is_best, checkpointdir)

        if args.write_enable:
            writer.close()

        print('Best Testing Acc/Loss: %.3f at epoch %d' % (best_score, best_epoch))
        print('Best epoch: ', best_epoch)
        print('Total processing time: %.4f' % (time.time() - timestart))
def main():

    global args
    args = parser.parse_args()

    batch_size = 1
    dataset_dir = '/home/gukyeong/dataset/mnist/folds/28_28'
    vae_ckpt = './checkpoints/mnist/vae/gradient/' \
               'vae_BCE_gradient_reducedCnnSeq-4layer_train-fold234-0-5_val-fold1-0-5/model_best.pth.tar'
    gradient_layer = 'down_0'
    d_ckpt = './checkpoints/mnist/d/gradient/d_BCE_ShallowLinear_%s_grad_train-fold234-0-5_val-fold1-0-5_' % gradient_layer + \
              vae_ckpt.split('/')[-2] + '/model_best.pth.tar'

    # seed = random.randint(1, 100000)
    # torch.manual_seed(seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    vae = models.VAEReducedCNN()
    vae = torch.nn.DataParallel(vae).to(device)
    if os.path.isfile(vae_ckpt):
        print("=> loading checkpoint '{}'".format(vae_ckpt))
        checkpoint_vae = torch.load(vae_ckpt)
        best_loss = checkpoint_vae['best_loss']
        vae.load_state_dict(checkpoint_vae['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}, best_loss {})"
              .format(vae_ckpt, checkpoint_vae['epoch'], best_loss))
    else:
        print("=> no checkpoint found at '{}'".format(vae_ckpt))

    ngradlayer = gradient_layer.split('_')[1]
    if gradient_layer.split('_')[0] == 'down':
        grad_dim = vae.module.down[int(ngradlayer)].weight.view(-1).shape[0]
    else:
        grad_dim = vae.module.down[int(ngradlayer)].weight.view(-1).shape[0]
    d = models.DisShallowLinear(grad_dim)
    d = torch.nn.DataParallel(d).to(device)
    d.eval()

    if os.path.isfile(d_ckpt):
        print("=> loading checkpoint '{}'".format(d_ckpt))
        checkpoint_d = torch.load(d_ckpt)
        best_acc = checkpoint_d['best_acc']
        d.load_state_dict(checkpoint_d['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}, best_acc {})"
              .format(d_ckpt, checkpoint_d['epoch'], best_acc))
    else:
        print("=> no checkpoint found at '{}'".format(d_ckpt))

    # Load dataset
    # # Make the number of inlier samples and outlier samples equal
    # inlier_class = [0, 1, 2, 3, 4, 6, 7, 8, 9]
    # class_of_interest = 5
    # # test_nsamples_perclass = utils.cal_nsample_perclass([class_of_interest], 890)
    # # test_nsamples_perclass = [0] * 10
    # test_nsamples_perclass = [445, 445, 445, 445, 445, 445, 0, 890, 890, 890]
    # # test_nsamples_perclass = [149, 149, 148, 148, 148, 148, 890, 0, 0, 0]
    # batch_size = 1
    #
    # test_loader = torch.utils.data.DataLoader(
    #     datasets.SubsetMNISTdataset(dataset_dir, train=False, transform=transforms.ToTensor(),
    #                                 target_transform= transforms.ToTensor(),
    #                                 nsamples_perclass=test_nsamples_perclass),
    #     batch_size=batch_size, shuffle=False)

    nsamples_fold = [1380, 1575, 1398, 1428, 1364, 1262, 1375, 1458, 1365, 1391]
    inlier_class = [0, 1, 2, 3, 4, 5]
    outlier_class = [7, 8, 9]
    cls = inlier_class + outlier_class

    nout = []
    for n in outlier_class:
        nout.append(nsamples_fold[n])

    total_nout = sum(nout)
    nin = [total_nout // len(inlier_class)] * len(inlier_class)
    rem_class = total_nout % len(inlier_class)
    for cls_idx in range(0, rem_class):
        nin[cls_idx] += 1

    test_loader = torch.utils.data.DataLoader(
        datasets.FoldMNISTdataset(dataset_dir, folds=[0],
                                  transform=transforms.ToTensor(),
                                  target_transform=transforms.ToTensor(),
                                  cls=cls, nsamples_per_cls=nin + nout),
        batch_size=batch_size, shuffle=True)

    # Start evaluation
    vbackprop = gradient.VanillaBackprop(vae)
    timestart = time.time()
    print('\n*** Start Testing *** \n')
    _, _, result = d_gradient_train.test(d, vbackprop, [7, 8, 9], gradient_layer, device, test_loader, 1, args.print_freq)
    print('Total processing time: %.4f' % (time.time() - timestart))

    # inliers
    X1 = [x[1] for x in result if x[0] == 1]

    # outliers
    Y1 = [x[1] for x in result if not x[0] == 0]

    min_delta = min([x[1] for x in result]) - 0.05
    max_delta = max([x[1] for x in result]) + 0.05

    ##################################################################
    # FPR at TPR 95
    ##################################################################
    fpr95 = 0.0
    clothest_tpr = 1.0
    dist_tpr = 1.0
    for e in np.arange(min_delta, max_delta, 0.05):
        tpr = np.sum(np.greater_equal(X1, e)) / np.float(len(X1))
        fpr = np.sum(np.greater_equal(Y1, e)) / np.float(len(Y1))
        if abs(tpr - 0.95) < dist_tpr:
            dist_tpr = abs(tpr - 0.95)
            clothest_tpr = tpr
            fpr95 = fpr

    print("tpr: ", clothest_tpr)
    print("fpr95: ", fpr95)

    ##################################################################
    # Detection error
    ##################################################################
    detect_error = 1.0
    for e in np.arange(min_delta, max_delta, 0.05):
        tpr = np.sum(np.less(X1, e)) / np.float(len(X1))
        fpr = np.sum(np.greater_equal(Y1, e)) / np.float(len(Y1))
        detect_error = np.minimum(detect_error, (tpr + fpr) / 2.0)

    print("Detection error: ", detect_error)

    ##################################################################
    # AUPR IN
    ##################################################################
    auprin = 0.0
    recallTemp = 1.0
    for e in np.arange(min_delta, max_delta, 0.05):
        tp = np.sum(np.greater_equal(X1, e))
        fp = np.sum(np.greater_equal(Y1, e))
        if tp + fp == 0:
            continue
        precision = tp / (tp + fp)
        recall = tp / np.float(len(X1))
        auprin += (recallTemp - recall) * precision
        recallTemp = recall
    auprin += recall * precision

    print("auprin: ", auprin)

    ##################################################################
    # AUPR OUT
    ##################################################################
    minp, max_delta = -max_delta, -min_delta
    X1 = [-x for x in X1]
    Y1 = [-x for x in Y1]
    auprout = 0.0
    recallTemp = 1.0
    for e in np.arange(min_delta, max_delta, 0.05):
        tp = np.sum(np.greater_equal(Y1, e))
        fp = np.sum(np.greater_equal(X1, e))
        if tp + fp == 0:
            continue
        precision = tp / (tp + fp)
        recall = tp / np.float(len(Y1))
        auprout += (recallTemp - recall) * precision
        recallTemp = recall
    auprout += recall * precision

    print("auprout: ", auprout)
def main():

    global args
    args = parser.parse_args()

    batch_size = 64
    train_chall = '07_01'
    all_results = np.zeros([5, 12 * 5])

    challcnt = 0
    for challID in range(1, 13):
        for levelID in range(1, 6):
            test_chall = '%02d_%02d' % (challID, levelID)
            vae_ckpt = './checkpoints/cure-tsr/vae/' \
                       'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            d_ckpt = './checkpoints/cure-tsr/d/%s/d_BCE_ShallowLinear_in-00_00_out-%s' \
                     % (vae_ckpt.split('/')[-2], train_chall) + '/model_best.pth.tar'

            print('\n*** Test on Challenge %s ***\n' % test_chall)

            dataset_dir = '/media/gukyeong/HardDisk/CURE-TSR/folds/RealChallengeFree'
            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")

            vae = models.VAECURECNN()
            vae = torch.nn.DataParallel(vae).to(device)
            vae.eval()
            if os.path.isfile(vae_ckpt):
                print("=> loading checkpoint '{}'".format(vae_ckpt))
                checkpoint_vae = torch.load(vae_ckpt)
                best_loss = checkpoint_vae['best_loss']
                vae.load_state_dict(checkpoint_vae['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".
                      format(vae_ckpt, checkpoint_vae['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(vae_ckpt))

            d = models.DisShallowLinear(20)
            d = torch.nn.DataParallel(d).to(device)
            d.eval()

            if os.path.isfile(d_ckpt):
                print("=> loading checkpoint '{}'".format(d_ckpt))
                checkpoint_d = torch.load(d_ckpt)
                best_loss = checkpoint_d['best_loss']
                d.load_state_dict(checkpoint_d['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".
                      format(d_ckpt, checkpoint_d['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(d_ckpt))

            # CURE-TSR (Train : Val : Test = 6 : 2 : 2)
            in_test_loader = torch.utils.data.DataLoader(
                datasets.CURETSRdataset(os.path.join(dataset_dir, 'test'),
                                        transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ]),
                                        target_transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ])),
                batch_size=batch_size,
                shuffle=True)

            out_test_loader = torch.utils.data.DataLoader(
                datasets.CURETSRdataset(os.path.join(dataset_dir[:-4],
                                                     '%s' % test_chall,
                                                     'test'),
                                        transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ]),
                                        target_transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ])),
                batch_size=batch_size,
                shuffle=True)

            # Start evaluation
            timestart = time.time()
            print('\n*** Start Testing *** \n')

            loss, acc, result = d_train.test(d,
                                             device,
                                             in_test_loader,
                                             1,
                                             args.print_freq,
                                             vae=vae,
                                             std=1.0,
                                             out_iter=iter(out_test_loader),
                                             is_eval=True)

            # inliers
            in_pred = [x[1] for x in result if x[0] == 1]

            # outliers
            out_pred = [x[1] for x in result if x[0] == 0]

            all_results[0, challcnt] = acc
            all_results[1::,
                        challcnt] = cal_metric.calMetric(in_pred, out_pred)
            challcnt += 1
            print('Total processing time: %.4f' % (time.time() - timestart))

    np.savetxt('%s_%s.csv' % (vae_ckpt.split('/')[-2], d_ckpt.split('/')[-2]),
               all_results,
               fmt='%.3f',
               delimiter=',')