def main():
    global args
    args = parser.parse_args()

    vae_ckpt = './checkpoints/cure-tsr/vae/' \
               'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
    gradient_layer = 'down_6'  # kld
    gradient_layer2 = 'up_0'  # bce
    chall = '07_01'  # Training outlier class
    savedir = 'cure-tsr/d/%s/bce_kld_grad/d_BCE_ShallowLinear_norm_bce-%s_kld-%s_in-00_00_out-%s' \
              % (vae_ckpt.split('/')[-2], gradient_layer2, gradient_layer, chall)

    checkpointdir = os.path.join('./checkpoints', savedir)
    logdir = os.path.join('./logs', savedir)

    seed = random.randint(1, 100000)
    torch.manual_seed(seed)
    dataset_dir = os.path.join(args.dataset_dir,
                               'kld_grad/%s' % vae_ckpt.split('/')[-2])
    dataset_dir2 = os.path.join(args.dataset_dir,
                                'bce_grad/%s' % vae_ckpt.split('/')[-2])

    if args.write_enable:
        os.makedirs(checkpointdir)
        writer = SummaryWriter(log_dir=logdir)
        print('log directory: %s' % logdir)
        print('checkpoints directory: %s' % checkpointdir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    best_score = 1e20
    batch_size = 64

    vae = models.VAECURECNN()
    vae = torch.nn.DataParallel(vae).to(device)
    if os.path.isfile(vae_ckpt):
        print("=> loading checkpoint '{}'".format(vae_ckpt))
        checkpoint = torch.load(vae_ckpt)
        best_loss = checkpoint['best_loss']
        vae.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".format(
            vae_ckpt, checkpoint['epoch'], best_loss))
    else:
        print("=> no checkpoint found at '{}'".format(vae_ckpt))

    grad_dim = vae.module.down[6].weight.view(
        -1).shape[0] + vae.module.up[0].weight.view(-1).shape[0]

    d = models.DisShallowLinear(grad_dim)
    d = torch.nn.DataParallel(d).to(device)
    optimizer = optim.Adam(d.parameters(), lr=1e-3)

    in_train_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '00_00_train_%s.pt' % gradient_layer),
        os.path.join(dataset_dir2, '00_00_train_%s.pt' % gradient_layer2)
    ]),
                                                  batch_size=batch_size,
                                                  shuffle=True)

    out_train_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '%s_train_%s.pt' % (chall, gradient_layer)),
        os.path.join(dataset_dir2, '%s_train_%s.pt' % (chall, gradient_layer2))
    ]),
                                                   batch_size=batch_size,
                                                   shuffle=True)

    in_val_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '00_00_val_%s.pt' % gradient_layer),
        os.path.join(dataset_dir2, '00_00_val_%s.pt' % gradient_layer2)
    ]),
                                                batch_size=batch_size,
                                                shuffle=True)

    out_val_loader = torch.utils.data.DataLoader(datasets.GradDataset([
        os.path.join(dataset_dir, '%s_val_%s.pt' % (chall, gradient_layer)),
        os.path.join(dataset_dir2, '%s_val_%s.pt' % (chall, gradient_layer2))
    ]),
                                                 batch_size=batch_size,
                                                 shuffle=True)

    # Start training
    timestart = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        print('\n*** Start Training *** Epoch: [%d/%d]\n' %
              (epoch + 1, args.epochs))
        d_ext_gradient_train.train(d,
                                   None,
                                   device,
                                   in_train_loader,
                                   optimizer,
                                   epoch + 1,
                                   args.print_freq,
                                   out_iter=iter(out_train_loader))

        print('\n*** Start Testing *** Epoch: [%d/%d]\n' %
              (epoch + 1, args.epochs))
        loss, acc, _ = d_ext_gradient_train.test(d,
                                                 None,
                                                 device,
                                                 in_val_loader,
                                                 epoch + 1,
                                                 args.print_freq,
                                                 out_iter=iter(out_val_loader))

        is_best = loss < best_score
        best_score = min(loss, best_score)

        if is_best:
            best_epoch = epoch + 1

        if args.write_enable:
            if epoch % args.write_freq == 0 or is_best is True:
                writer.add_scalar('loss', loss, epoch + 1)
                writer.add_scalar('accuracy', acc, epoch + 1)
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': d.state_dict(),
                        'best_acc': best_score,
                        'last_loss': loss,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, checkpointdir)

    if args.write_enable:
        writer.close()

    print('Best Testing Acc/Loss: %.3f at epoch %d' % (best_score, best_epoch))
    print('Best epoch: ', best_epoch)
    print('Total processing time: %.4f' % (time.time() - timestart))
def main():

    global args
    args = parser.parse_args()

    root_dir = '/media/gukyeong/HardDisk/dataset/CURE-TSR/folds'
    batch_size = 1
    vae_ckpt = './checkpoints/vae/' \
               'vae_BCE_gradient_reducedCnnSeq-4layer_train-14_val-14/model_best.pth.tar'

    sets = ['train', 'val', 'test']
    vae_layers = ['up_6']
    challenges = ['00_00']
    nsamples_fold = [
        325, 71, 17, 10, 235, 26, 295, 59, 20, 101, 16, 31, 179, 735
    ]  # 14 classes
    in_cls = [13]
    out_cls = [i for i in range(13)]
    cls, nsamples_per_cls = utils.cal_nsample_perclass(in_cls, out_cls,
                                                       nsamples_fold)

    # for challID in range(1, 13):
    #     for levelID in range(1, 6):
    #         challenges.append('%02d_%02d' % (challID, levelID))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    vae = models.VAECURECNN()
    vae = torch.nn.DataParallel(vae).to(device)
    if os.path.isfile(vae_ckpt):
        print("=> loading checkpoint '{}'".format(vae_ckpt))
        checkpoint_vae = torch.load(vae_ckpt)
        best_loss = checkpoint_vae['best_loss']
        vae.load_state_dict(checkpoint_vae['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".format(
            vae_ckpt, checkpoint_vae['epoch'], best_loss))
    else:
        print("=> no checkpoint found at '{}'".format(vae_ckpt))

    for chall in challenges:
        if chall == '00_00':
            dataset_dir = os.path.join(root_dir, 'RealChallengeFree')
        else:
            dataset_dir = os.path.join(root_dir, 'RealChallenge/%s' % chall)
        for gradient_layer in vae_layers:
            for set_name in sets:
                print('Extracting %s chall %s layer %s set\n' %
                      (chall, gradient_layer, set_name))

                if gradient_layer == 'latent':
                    grad_dim = vae.module.fc11.weight.shape[0]
                elif gradient_layer == 'input':
                    grad_dim = 28 * 28 * 3
                else:
                    ngradlayer = gradient_layer.split('_')[1]
                    if gradient_layer.split('_')[0] == 'down':
                        grad_dim = vae.module.down[int(
                            ngradlayer)].weight.view(-1).shape[0]
                    elif gradient_layer.split('_')[0] == 'up':
                        grad_dim = vae.module.up[int(ngradlayer)].weight.view(
                            -1).shape[0]

                if set_name == 'train':
                    cls = in_cls
                    data_loader = torch.utils.data.DataLoader(
                        datasets.CURETSRdataset(
                            os.path.join(dataset_dir, 'train'),
                            transform=transforms.Compose([
                                transforms.Resize([28, 28]),
                                transforms.ToTensor()
                            ]),
                            target_transform=transforms.Compose([
                                transforms.Resize([28, 28]),
                                transforms.ToTensor()
                            ]),
                            cls=cls),
                        batch_size=batch_size,
                        shuffle=True)

                elif set_name == 'val':
                    cls = in_cls
                    data_loader = torch.utils.data.DataLoader(
                        datasets.CURETSRdataset(
                            os.path.join(dataset_dir, 'val'),
                            transform=transforms.Compose([
                                transforms.Resize([28, 28]),
                                transforms.ToTensor()
                            ]),
                            target_transform=transforms.Compose([
                                transforms.Resize([28, 28]),
                                transforms.ToTensor()
                            ]),
                            cls=cls),
                        batch_size=batch_size,
                        shuffle=True)

                elif set_name == 'test':
                    cls = in_cls + out_cls
                    data_loader = torch.utils.data.DataLoader(
                        datasets.CURETSRdataset(
                            os.path.join(dataset_dir, 'test'),
                            transform=transforms.Compose([
                                transforms.Resize([28, 28]),
                                transforms.ToTensor()
                            ]),
                            target_transform=transforms.Compose([
                                transforms.Resize([28, 28]),
                                transforms.ToTensor()
                            ]),
                            cls=cls,
                            nsamples_per_cls=nsamples_per_cls),
                        batch_size=batch_size,
                        shuffle=True)

                # Start evaluation
                vbackprop = gradient.VanillaBackprop(vae, loss_type='bce')
                timestart = time.time()
                print('\n*** Start Testing *** \n')

                labels = torch.zeros(len(data_loader), )
                grad_data = torch.zeros(len(data_loader), grad_dim)
                for batch_idx, (input_img, target_img,
                                class_label) in enumerate(data_loader):
                    print('Processing...(%d / %d)' %
                          (batch_idx, len(data_loader)))
                    input_img = input_img.to(device)
                    target_img = target_img.to(device)

                    vbackprop.generate_gradients(input_img, target_img)
                    grad_data[batch_idx, :] = vbackprop.gradients[
                        gradient_layer].view(1, -1)
                    labels[batch_idx] = class_label

                savedir = os.path.join(root_dir, 'in-14_out-others',
                                       vae_ckpt.split('/')[-2])
                try:
                    os.makedirs(savedir)
                except OSError as exception:
                    if exception.errno != errno.EEXIST:
                        raise

                if set_name == 'test':
                    torch.save(
                        (labels, grad_data),
                        os.path.join(
                            savedir, '%s_incls%s_outcls%s_%s.pt' %
                            (set_name, ''.join(str(i) for i in in_cls),
                             ''.join(str(i)
                                     for i in out_cls), gradient_layer)))
                else:
                    torch.save(
                        (labels, grad_data),
                        os.path.join(
                            savedir, '%s_cls%s_%s.pt' %
                            (set_name, ''.join(str(i)
                                               for i in cls), gradient_layer)))

                print('Total processing time: %.4f' %
                      (time.time() - timestart))
def main():

    global args
    args = parser.parse_args()

    batch_size = 64
    train_chall = '07_01'
    chall_type_list = [1, 2, 3, 5, 8, 9]
    all_results = np.zeros([5, len(chall_type_list) * 5])

    challcnt = 0
    for challID in chall_type_list:
        for levelID in range(1, 6):
            test_chall = '%02d_%02d' % (challID, levelID)
            vae_ckpt = './checkpoints/cure-tsr/vae/' \
                       'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            gradient_layer = 'down_6'  # kld
            gradient_layer2 = 'up_6'  # bce
            d_ckpt = './checkpoints/cure-tsr/d/%s/bce_kld_grad/' \
                     'd_BCE_ShallowLinear_bce-%s_kld-%s_in-00_00_out-%s/model_best.pth.tar' \
                     % (vae_ckpt.split('/')[-2], gradient_layer2, gradient_layer, train_chall)
            dataset_dir = os.path.join(args.dataset_dir,
                                       'kld_grad/%s' % vae_ckpt.split('/')[-2])
            dataset_dir2 = os.path.join(
                args.dataset_dir, 'bce_grad/%s' % vae_ckpt.split('/')[-2])

            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")

            vae = models.VAECURECNN()
            vae = torch.nn.DataParallel(vae).to(device)
            vae.eval()
            if os.path.isfile(vae_ckpt):
                print("=> loading checkpoint '{}'".format(vae_ckpt))
                checkpoint_vae = torch.load(vae_ckpt)
                best_loss = checkpoint_vae['best_loss']
                vae.load_state_dict(checkpoint_vae['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".
                      format(vae_ckpt, checkpoint_vae['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(vae_ckpt))

            grad_dim = vae.module.down[6].weight.view(
                -1).shape[0] + vae.module.up[6].weight.view(-1).shape[0]

            d = models.DisShallowLinear(grad_dim)
            d = torch.nn.DataParallel(d).to(device)
            d.eval()

            if os.path.isfile(d_ckpt):
                print("=> loading checkpoint '{}'".format(d_ckpt))
                checkpoint_d = torch.load(d_ckpt)
                best_acc = checkpoint_d['best_acc']
                d.load_state_dict(checkpoint_d['state_dict'])
                print(
                    "=> loaded checkpoint '{}' (epoch {}, best_acc {})".format(
                        d_ckpt, checkpoint_d['epoch'], best_acc))
            else:
                print("=> no checkpoint found at '{}'".format(d_ckpt))

            in_test_loader = torch.utils.data.DataLoader(datasets.GradDataset([
                os.path.join(dataset_dir, '00_00_test_%s.pt' % gradient_layer),
                os.path.join(dataset_dir2,
                             '00_00_test_%s.pt' % gradient_layer2)
            ]),
                                                         batch_size=batch_size,
                                                         shuffle=True)

            out_test_loader = torch.utils.data.DataLoader(
                datasets.GradDataset([
                    os.path.join(
                        dataset_dir,
                        '%s_test_%s.pt' % (test_chall, gradient_layer)),
                    os.path.join(
                        dataset_dir2,
                        '%s_test_%s.pt' % (test_chall, gradient_layer2))
                ]),
                batch_size=batch_size,
                shuffle=True)

            # Start evaluation
            timestart = time.time()
            print('\n*** Start Testing *** \n')
            loss, acc, result = d_ext_gradient_train.test(
                d,
                None,
                device,
                in_test_loader,
                1,
                args.print_freq,
                out_iter=iter(out_test_loader),
                is_eval=True)
            print('Total processing time: %.4f' % (time.time() - timestart))

            # inliers
            in_pred = [x[1] for x in result if x[0] == 1]

            # outliers
            out_pred = [x[1] for x in result if x[0] == 0]

            all_results[0, challcnt] = acc
            all_results[1::,
                        challcnt] = cal_metric.calMetric(in_pred, out_pred)
            challcnt += 1

    np.savetxt('./results_%s_%s.csv' %
               (vae_ckpt.split('/')[-2], d_ckpt.split('/')[-2]),
               all_results,
               fmt='%.3f',
               delimiter=',')
示例#4
0
def main():

    global args
    args = parser.parse_args()

    root_dir = args.dataset_dir
    batch_size = 1
    vae_ckpt = './checkpoints/cure-tsr/vae/' \
               'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'

    # Extract gradients for training and validation
    # sets = ['train', 'val']
    # vae_layers = ['down_6']
    # challenges = ['00_00', '07_01']

    # Extract gradients for testing
    sets = ['test']
    vae_layers = ['down_6']
    challenges = ['00_00']

    for challID in range(1, 13):
        for levelID in range(1, 6):
            challenges.append('%02d_%02d' % (challID, levelID))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    vae = models.VAECURECNN()
    vae = torch.nn.DataParallel(vae).to(device)
    if os.path.isfile(vae_ckpt):
        print("=> loading checkpoint '{}'".format(vae_ckpt))
        checkpoint_vae = torch.load(vae_ckpt)
        best_loss = checkpoint_vae['best_loss']
        vae.load_state_dict(checkpoint_vae['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".format(
            vae_ckpt, checkpoint_vae['epoch'], best_loss))
    else:
        print("=> no checkpoint found at '{}'".format(vae_ckpt))

    for chall in challenges:
        if chall == '00_00':
            dataset_dir = os.path.join(root_dir, 'RealChallengeFree')
        else:
            dataset_dir = os.path.join(root_dir, 'RealChallenge/%s' % chall)

        for gradient_layer in vae_layers:
            for set_name in sets:
                print('Extracting %s chall %s layer %s set\n' %
                      (chall, gradient_layer, set_name))

                if gradient_layer == 'latent':
                    grad_dim = vae.module.fc11.weight.shape[0]
                elif gradient_layer == 'input':
                    grad_dim = 28 * 28 * 3
                else:
                    ngradlayer = gradient_layer.split('_')[1]
                    if gradient_layer.split('_')[0] == 'down':
                        grad_dim = vae.module.down[int(
                            ngradlayer)].weight.view(-1).shape[0]
                    elif gradient_layer.split('_')[0] == 'up':
                        grad_dim = vae.module.up[int(ngradlayer)].weight.view(
                            -1).shape[0]

                data_loader = torch.utils.data.DataLoader(
                    datasets.CURETSRdataset(
                        os.path.join(dataset_dir, '%s' % set_name),
                        transform=transforms.Compose([
                            transforms.Resize([28, 28]),
                            transforms.ToTensor()
                        ]),
                        target_transform=transforms.Compose([
                            transforms.Resize([28, 28]),
                            transforms.ToTensor()
                        ])),
                    batch_size=batch_size,
                    shuffle=True)

                # Start evaluation
                vbackprop = gradient.VanillaBackprop(vae, loss_type='bce')
                timestart = time.time()
                print('\n*** Start Testing *** \n')

                labels = torch.zeros(len(data_loader), )
                grad_data = torch.zeros(len(data_loader), grad_dim)
                for batch_idx, (input_img, target_img,
                                class_label) in enumerate(data_loader):
                    print('Processing...(%d / %d)' %
                          (batch_idx, len(data_loader)))
                    input_img = input_img.to(device)
                    target_img = target_img.to(device)

                    vbackprop.generate_gradients(input_img, target_img)
                    grad_data[batch_idx, :] = vbackprop.gradients[
                        gradient_layer].view(1, -1)
                    labels[batch_idx] = class_label

                savedir = os.path.join(root_dir, 'bce_grad',
                                       vae_ckpt.split('/')[-2])
                try:
                    os.makedirs(savedir)
                except OSError as exception:
                    if exception.errno != errno.EEXIST:
                        raise

                torch.save(
                    (labels, grad_data),
                    os.path.join(
                        savedir,
                        '%s_%s_%s.pt' % (chall, set_name, gradient_layer)))
                print('Total processing time: %.4f' %
                      (time.time() - timestart))
示例#5
0
def main():
    
    global args
    args = parser.parse_args()

    dis_train = 0  # 0: vae_train, 1: d_train
    use_vae = 0  # 0: Not use vae for discriminator training 1: Use vae for discriminator training
    std = 1.0
    chall = '07_01'

    if dis_train:
        if use_vae:
            vae_ckpt = './checkpoints/cure-tsr/vae/' \
                       'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            savedir = 'cure-tsr/d/%s/d_BCE_ShallowLinear_in-00_00_out-%s' \
                      % (vae_ckpt.split('/')[-2], chall)

        else:
            savedir = 'cure-tsr/d/d_BCE_ShallowLinear_in-00_00_out-%s' % chall

    else:
        savedir = 'vae/vae_BCE_gradient_reducedCnnSeq-4layer_train-14_val-14'

    checkpointdir = os.path.join('./checkpoints', savedir)
    logdir = os.path.join('./logs', savedir)

    seed = random.randint(1, 100000)
    torch.manual_seed(seed)
    dataset_dir = '/media/gukyeong/HardDisk/dataset/CURE-TSR/folds/RealChallengeFree'

    if args.write_enable:
        os.mkdir(checkpointdir)
        writer = SummaryWriter(log_dir=logdir)
        print('log directory: %s' % logdir)
        print('checkpoints directory: %s' % checkpointdir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if dis_train:
        best_score = 1e20  # Initinalization of a variable for accuracy
        batch_size = 64  # there will be additional 64 fake samples
        d = models.DisShallowLinear(28 * 28 * 3)
        d = torch.nn.DataParallel(d).to(device)
        d.apply(models.weights_init)
        optimizer = optim.Adam(d.parameters(), lr=1e-3)

        if use_vae:
            vae = models.VAECURELinear()
            vae = torch.nn.DataParallel(vae).to(device)
            vae.eval()
            if os.path.isfile(vae_ckpt):
                print("=> loading checkpoint '{}'".format(vae_ckpt))
                checkpoint = torch.load(vae_ckpt)
                best_loss = checkpoint['best_loss']
                vae.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})"
                      .format(vae_ckpt, checkpoint['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(vae_ckpt))

    else:
        vae = models.VAECURECNN()
        vae = torch.nn.DataParallel(vae).to(device)
        best_score = 1e20
        batch_size = 128
        optimizer = optim.Adam(vae.parameters(), lr=1e-3)

        if args.resume:
            vae_resume_ckpt = './checkpoints/cure-tsr/vae/' \
                              'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            if os.path.isfile(vae_resume_ckpt):
                print("=> loading checkpoint '{}'".format(vae_resume_ckpt))
                checkpoint = torch.load(vae_resume_ckpt)
                # best_loss = checkpoint['best_loss']
                vae.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})"
                      .format(vae_resume_ckpt, checkpoint['epoch'], checkpoint['best_loss']))
            else:
                print("=> no checkpoint found at '{}'".format(vae_resume_ckpt))

        # CURE-TSR (Train : Val : Test = 6 : 2 : 2)
        nsamples_fold = [325, 71, 17, 10, 235, 26, 295, 59, 20, 101, 16, 31, 179, 735]  # 14 classes
        in_cls = [13]
        out_cls = [i for i in range(13)]
        cls, nsamples_per_cls = utils.cal_nsample_perclass(in_cls, out_cls, nsamples_fold)

        in_train_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir, 'train'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=in_cls),
            batch_size=batch_size, shuffle=True)

        in_val_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir, 'val'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=in_cls),
            batch_size=batch_size, shuffle=True)

        out_train_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir[:-4], '%s' % chall, 'train'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=out_cls, nsamples_per_cls=[3 * i for i in nsamples_per_cls[len(in_cls)::]]),
            batch_size=batch_size, shuffle=True)

        out_val_loader = torch.utils.data.DataLoader(
            datasets.CURETSRdataset(os.path.join(dataset_dir[:-4], '%s' % chall, 'val'),
                                    transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                  transforms.ToTensor()]),
                                    target_transform=transforms.Compose([transforms.Resize([28, 28]),
                                                                         transforms.ToTensor()]),
                                    cls=out_cls, nsamples_per_cls=[i for i in nsamples_per_cls[len(in_cls)::]]),
            batch_size=batch_size, shuffle=True)

        # Start training
        timestart = time.time()
        for epoch in range(args.start_epoch, args.epochs):
            if dis_train:
                print('\n*** Start Training *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))

                if use_vae:
                    d_train.train(d, device, in_train_loader, optimizer, epoch + 1, args.print_freq, vae=vae,
                                  std=std, out_iter=iter(out_train_loader))
                else:
                    d_train.train(d, device, in_train_loader, optimizer, epoch + 1, args.print_freq,
                                  out_iter=iter(out_train_loader))

                print('\n*** Start Testing *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))

                if use_vae:
                    loss, acc, _ = d_train.test(d, device, in_val_loader, epoch + 1, args.print_freq, vae=vae,
                                                std=std, out_iter=iter(out_val_loader))
                else:
                    loss, acc, _ = d_train.test(d, device, in_val_loader, epoch + 1, args.print_freq,
                                                out_iter=iter(out_val_loader))

                is_best = loss < best_score
                best_score = min(loss, best_score)

            else:
                print('\n*** Start Training *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))
                vae_train.train(vae, device, in_train_loader, optimizer, epoch + 1, args.print_freq)
                print('\n*** Start Testing *** Epoch: [%d/%d]\n' % (epoch + 1, args.epochs))
                loss, input_img, recon_img, target_img = vae_train.test(vae, device, in_val_loader, epoch + 1,
                                                                        args.print_freq)

                is_best = loss < best_score
                best_score = min(loss, best_score)

            if is_best:
                best_epoch = epoch + 1

            if args.write_enable:
                if dis_train:
                    if epoch % args.write_freq == 0 or is_best is True:
                        writer.add_scalar('loss', loss, epoch + 1)
                        writer.add_scalar('accuracy', acc, epoch + 1)
                        utils.save_checkpoint({
                            'epoch': epoch + 1,
                            'state_dict': d.state_dict(),
                            'best_loss': best_score,
                            'last_loss': loss,
                            'optimizer': optimizer.state_dict(),
                        }, is_best, checkpointdir)

                else:
                    if epoch % args.write_freq == 0:
                        writer.add_scalar('loss', loss, epoch + 1)
                        writer.add_image('input_img', vutils.make_grid(input_img, nrow=3), epoch + 1)
                        writer.add_image('recon_img', vutils.make_grid(recon_img, nrow=3), epoch + 1)
                        writer.add_image('target_img', vutils.make_grid(target_img, nrow=3), epoch + 1)

                        utils.save_checkpoint({
                            'epoch': epoch + 1,
                            'state_dict': vae.state_dict(),
                            'best_loss': best_score,
                            'last_loss': loss,
                            'optimizer': optimizer.state_dict(),
                        }, is_best, checkpointdir)

        if args.write_enable:
            writer.close()

        print('Best Testing Acc/Loss: %.3f at epoch %d' % (best_score, best_epoch))
        print('Best epoch: ', best_epoch)
        print('Total processing time: %.4f' % (time.time() - timestart))
def main():

    global args
    args = parser.parse_args()

    savedir = 'vae/vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00'
    checkpointdir = os.path.join('./checkpoints', savedir)
    logdir = os.path.join('./logs', savedir)

    seed = random.randint(1, 100000)
    torch.manual_seed(seed)
    dataset_dir = args.dataset_dir

    if args.write_enable:
        os.makedirs(checkpointdir)
        writer = SummaryWriter(log_dir=logdir)
        print('log directory: %s' % logdir)
        print('checkpoints directory: %s' % checkpointdir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    vae = models.VAECURECNN()
    vae = torch.nn.DataParallel(vae).to(device)
    best_score = 1e20
    batch_size = 128
    optimizer = optim.Adam(vae.parameters(), lr=1e-3)

    if args.resume:
        vae_resume_ckpt = './checkpoints/cure-tsr/vae/' \
                          'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
        if os.path.isfile(vae_resume_ckpt):
            print("=> loading checkpoint '{}'".format(vae_resume_ckpt))
            checkpoint = torch.load(vae_resume_ckpt)
            vae.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".format(
                vae_resume_ckpt, checkpoint['epoch'], checkpoint['best_loss']))
        else:
            print("=> no checkpoint found at '{}'".format(vae_resume_ckpt))

    in_train_loader = torch.utils.data.DataLoader(datasets.CURETSRdataset(
        os.path.join(dataset_dir, 'train'),
        transform=transforms.Compose(
            [transforms.Resize([28, 28]),
             transforms.ToTensor()]),
        target_transform=transforms.Compose(
            [transforms.Resize([28, 28]),
             transforms.ToTensor()])),
                                                  batch_size=batch_size,
                                                  shuffle=True)

    in_val_loader = torch.utils.data.DataLoader(datasets.CURETSRdataset(
        os.path.join(dataset_dir, 'val'),
        transform=transforms.Compose(
            [transforms.Resize([28, 28]),
             transforms.ToTensor()]),
        target_transform=transforms.Compose(
            [transforms.Resize([28, 28]),
             transforms.ToTensor()])),
                                                batch_size=batch_size,
                                                shuffle=True)

    # Start training
    timestart = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        print('\n*** Start Training *** Epoch: [%d/%d]\n' %
              (epoch + 1, args.epochs))
        vae_train.train(vae, device, in_train_loader, optimizer, epoch + 1,
                        args.print_freq)
        print('\n*** Start Testing *** Epoch: [%d/%d]\n' %
              (epoch + 1, args.epochs))
        loss, input_img, recon_img, target_img = vae_train.test(
            vae, device, in_val_loader, epoch + 1, args.print_freq)

        is_best = loss < best_score
        best_score = min(loss, best_score)

        if is_best:
            best_epoch = epoch + 1

        if args.write_enable:
            if epoch % args.write_freq == 0:
                writer.add_scalar('loss', loss, epoch + 1)
                writer.add_image('input_img',
                                 vutils.make_grid(input_img, nrow=3),
                                 epoch + 1)
                writer.add_image('recon_img',
                                 vutils.make_grid(recon_img, nrow=3),
                                 epoch + 1)
                writer.add_image('target_img',
                                 vutils.make_grid(target_img, nrow=3),
                                 epoch + 1)

                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': vae.state_dict(),
                        'best_loss': best_score,
                        'last_loss': loss,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, checkpointdir)

    if args.write_enable:
        writer.close()

    print('Best Testing Acc/Loss: %.3f at epoch %d' % (best_score, best_epoch))
    print('Best epoch: ', best_epoch)
    print('Total processing time: %.4f' % (time.time() - timestart))
def main():

    global args
    args = parser.parse_args()

    batch_size = 64
    train_chall = '07_01'
    all_results = np.zeros([5, 12 * 5])

    challcnt = 0
    for challID in range(1, 13):
        for levelID in range(1, 6):
            test_chall = '%02d_%02d' % (challID, levelID)
            vae_ckpt = './checkpoints/cure-tsr/vae/' \
                       'vae_BCE_gradient_reducedCnnSeq-4layer_train-00_00_val-00_00/model_best.pth.tar'
            d_ckpt = './checkpoints/cure-tsr/d/%s/d_BCE_ShallowLinear_in-00_00_out-%s' \
                     % (vae_ckpt.split('/')[-2], train_chall) + '/model_best.pth.tar'

            print('\n*** Test on Challenge %s ***\n' % test_chall)

            dataset_dir = '/media/gukyeong/HardDisk/CURE-TSR/folds/RealChallengeFree'
            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")

            vae = models.VAECURECNN()
            vae = torch.nn.DataParallel(vae).to(device)
            vae.eval()
            if os.path.isfile(vae_ckpt):
                print("=> loading checkpoint '{}'".format(vae_ckpt))
                checkpoint_vae = torch.load(vae_ckpt)
                best_loss = checkpoint_vae['best_loss']
                vae.load_state_dict(checkpoint_vae['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".
                      format(vae_ckpt, checkpoint_vae['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(vae_ckpt))

            d = models.DisShallowLinear(20)
            d = torch.nn.DataParallel(d).to(device)
            d.eval()

            if os.path.isfile(d_ckpt):
                print("=> loading checkpoint '{}'".format(d_ckpt))
                checkpoint_d = torch.load(d_ckpt)
                best_loss = checkpoint_d['best_loss']
                d.load_state_dict(checkpoint_d['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {}, best_loss {})".
                      format(d_ckpt, checkpoint_d['epoch'], best_loss))
            else:
                print("=> no checkpoint found at '{}'".format(d_ckpt))

            # CURE-TSR (Train : Val : Test = 6 : 2 : 2)
            in_test_loader = torch.utils.data.DataLoader(
                datasets.CURETSRdataset(os.path.join(dataset_dir, 'test'),
                                        transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ]),
                                        target_transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ])),
                batch_size=batch_size,
                shuffle=True)

            out_test_loader = torch.utils.data.DataLoader(
                datasets.CURETSRdataset(os.path.join(dataset_dir[:-4],
                                                     '%s' % test_chall,
                                                     'test'),
                                        transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ]),
                                        target_transform=transforms.Compose([
                                            transforms.Resize([28, 28]),
                                            transforms.ToTensor()
                                        ])),
                batch_size=batch_size,
                shuffle=True)

            # Start evaluation
            timestart = time.time()
            print('\n*** Start Testing *** \n')

            loss, acc, result = d_train.test(d,
                                             device,
                                             in_test_loader,
                                             1,
                                             args.print_freq,
                                             vae=vae,
                                             std=1.0,
                                             out_iter=iter(out_test_loader),
                                             is_eval=True)

            # inliers
            in_pred = [x[1] for x in result if x[0] == 1]

            # outliers
            out_pred = [x[1] for x in result if x[0] == 0]

            all_results[0, challcnt] = acc
            all_results[1::,
                        challcnt] = cal_metric.calMetric(in_pred, out_pred)
            challcnt += 1
            print('Total processing time: %.4f' % (time.time() - timestart))

    np.savetxt('%s_%s.csv' % (vae_ckpt.split('/')[-2], d_ckpt.split('/')[-2]),
               all_results,
               fmt='%.3f',
               delimiter=',')