コード例 #1
0
ファイル: test_attention.py プロジェクト: rnsandeep/Attention
def load_attention_model(model_path, num_classes):
    model = AttnVGG(num_classes=num_classes,
                    attention=True,
                    normalize_attn=True)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])

    return model.to(device)
コード例 #2
0
        return delta_im

if __name__ == "__main__":
    print("======>load pretrained models")
    net = AttnVGG(num_classes=2,
                  attention=True,
                  normalize_attn=True,
                  vis=False)
    # net = VGG(num_classes=2, gap=False)
    modelFile = "/home/lrh/git/libadver/examples/IPIM-AttnModel/models/checkpoint.pth"
    testCSVFile = "/home/lrh/git/libadver/examples/IPIM-AttnModel/test.csv"
    trainCSVFile = "/home/lrh/git/libadver/examples/IPIM-AttnModel/train.csv"

    checkpoint = torch.load(modelFile)
    net.load_state_dict(checkpoint['state_dict'])
    pretrained_clf = nn.DataParallel(net).cuda()
    pretrained_clf.eval()

    print("=======>load ISIC2016 dataset")
    mean = (0.7012, 0.5517, 0.4875)
    std = (0.0942, 0.1331, 0.1521)
    normalize = Normalize(mean, std)
    transform = torch_transforms.Compose([
        RatioCenterCrop(0.8),
        Resize((256, 256)),
        CenterCrop((224, 224)),
        ToTensor(), normalize
    ])
    testset = ISIC(csv_file=testCSVFile, transform=transform)
    testloader = torch.utils.data.DataLoader(testset,
コード例 #3
0
def main():
    # load data
    print('\nloading the dataset ...')
    assert opt.dataset == "ISIC2016" or opt.dataset == "ISIC2017"
    if opt.dataset == "ISIC2016":
        normalize = Normalize((0.7012, 0.5517, 0.4875),
                              (0.0942, 0.1331, 0.1521))
    elif opt.dataset == "ISIC2017":
        normalize = Normalize((0.6820, 0.5312, 0.4736),
                              (0.0840, 0.1140, 0.1282))
    transform_test = torch_transforms.Compose([
        RatioCenterCrop(0.8),
        Resize((256, 256)),
        CenterCrop((224, 224)),
        ToTensor(), normalize
    ])
    testset = ISIC(csv_file='test.csv', transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=64,
                                             shuffle=False,
                                             num_workers=8)
    print('done\n')

    # load network
    print('\nloading the model ...')
    if not opt.no_attention:
        print('turn on attention ...')
        if opt.normalize_attn:
            print('use softmax for attention map ...')
        else:
            print('use sigmoid for attention map ...')
    else:
        print('turn off attention ...')

    net = AttnVGG(num_classes=2,
                  attention=not opt.no_attention,
                  normalize_attn=opt.normalize_attn)
    # net = VGG(num_classes=2, gap=False)
    checkpoint = torch.load('checkpoint.pth')
    net.load_state_dict(checkpoint['state_dict'])
    model = nn.DataParallel(net, device_ids=device_ids).to(device)
    model.eval()
    print('done\n')

    # testing
    print('\nstart testing ...\n')
    writer = SummaryWriter(opt.outf)
    total = 0
    correct = 0
    with torch.no_grad():
        with open('test_results.csv', 'wt', newline='') as csv_file:
            csv_writer = csv.writer(csv_file, delimiter=',')
            for i, data in enumerate(testloader, 0):
                images_test, labels_test = data['image'], data['label']
                images_test, labels_test = images_test.to(
                    device), labels_test.to(device)
                pred_test, __, __ = model(images_test)
                predict = torch.argmax(pred_test, 1)
                total += labels_test.size(0)
                correct += torch.eq(predict, labels_test).sum().double().item()
                # record test predicted responses
                responses = F.softmax(pred_test, dim=1).squeeze().cpu().numpy()
                responses = [responses[i] for i in range(responses.shape[0])]
                csv_writer.writerows(responses)
                # log images
                if opt.log_images:
                    I_test = utils.make_grid(images_test,
                                             nrow=8,
                                             normalize=True,
                                             scale_each=True)
                    writer.add_image('test/image', I_test, i)
                    # accention maps
                    if not opt.no_attention:
                        __, a1, a2 = model(images_test)
                        if a1 is not None:
                            attn1 = visualize_attn(
                                I_test,
                                a1,
                                up_factor=opt.base_up_factor,
                                nrow=8)
                            writer.add_image('test/attention_map_1', attn1, i)
                        if a2 is not None:
                            attn2 = visualize_attn(I_test,
                                                   a2,
                                                   up_factor=2 *
                                                   opt.base_up_factor,
                                                   nrow=8)
                            writer.add_image('test/attention_map_2', attn2, i)
    AP, AUC, precision_mean, precision_mel, recall_mean, recall_mel = compute_metrics(
        'test_results.csv', 'test.csv')
    print("\ntest result: accuracy %.2f%%" % (100 * correct / total))
    print(
        "\nmean precision %.2f%% mean recall %.2f%% \nprecision for mel %.2f%% recall for mel %.2f%%"
        % (100 * precision_mean, 100 * recall_mean, 100 * precision_mel,
           100 * recall_mel))
    print("\nAP %.4f AUC %.4f\n" % (AP, AUC))
コード例 #4
0
def main():
    modelFile = "/home/lrh/store/modelpath/adversarial_defense/Dermothsis/adv_training.pth"
    testCSVFile = "/home/lrh/git/libadver/examples/IPIM-AttnModel/test.csv"
    print("======>load pretrained models")
    net = AttnVGG(num_classes=2,
                  attention=True,
                  normalize_attn=False,
                  vis=False)
    # net = VGG(num_classes=2, gap=False)
    checkpoint = torch.load(modelFile)
    net.load_state_dict(checkpoint['state_dict'])
    pretrained_clf = nn.DataParallel(net).cuda()
    pretrained_clf.eval()

    print("=======>load ISIC2016 dataset")
    normalize = Normalize((0.7012, 0.5517, 0.4875), (0.0942, 0.1331, 0.1521))
    transform_test = torch_transforms.Compose([
        RatioCenterCrop(0.8),
        Resize((256, 256)),
        CenterCrop((224, 224)),
        ToTensor(), normalize
    ])
    testset = ISIC(csv_file=testCSVFile, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=8,
                                             shuffle=False,
                                             num_workers=4)

    PGDAttack = attack.ProjectGradientDescent(model=pretrained_clf)

    gt = torch.FloatTensor()
    pred = torch.FloatTensor()
    pred_advx = torch.FloatTensor()

    for i, data in enumerate(testloader, 0):
        print(i)
        images_test, labels_test = data['image'], data['label']
        images_test = images_test.cuda()
        labels_test = labels_test.cuda()

        pgd_params['y'] = labels_test
        pgd_params['clip_max'] = torch.max(images_test)
        pgd_params['clip_min'] = torch.min(images_test)
        adv_x = PGDAttack.generate(images_test, **pgd_params)
        #        torchvision.utils.save_image(adv_x, 'pgd_image_show'+'/adv{}.jpg'.format(i), nrow = 50 ,normalize = True)

        outputs_x, _, _ = pretrained_clf(images_test)
        x_pred = torch.argmax(outputs_x, dim=1).float()
        outputs_advx, _, _ = pretrained_clf(adv_x)
        adv_pred = torch.argmax(outputs_advx, dim=1).float()
        labels_test = labels_test.float()
        gt = torch.cat((gt, labels_test.detach().cpu()), 0)
        pred = torch.cat((pred, x_pred.detach().cpu()), 0)
        pred_advx = torch.cat((pred_advx, adv_pred.detach().cpu()), 0)

    fpr, tpr, thresholds = roc_curve(gt, pred_advx)

    AUC_ROC = roc_auc_score(gt, pred_advx)
    # test_integral = np.trapz(tpr,fpr) #trapz is numpy integration
    print("\nArea under the ROC curve: " + str(AUC_ROC))
    # ROC_curve =plt.figure()
    # plt.plot(fpr,tpr,'-',label='Area Under the Curve (AUC = %0.4f)' % AUC_ROC)
    # plt.title('ROC curve')
    # plt.xlabel("FPR (False Positive Rate)")
    # plt.ylabel("TPR (True Positive Rate)")
    # plt.legend(loc="lower right")
    # plt.savefig("pgd_image_show/ROC.png")

    correct_acc = 0
    correct_fr = 0
    correct_acc = correct_acc + gt.eq(pred_advx).sum()
    correct_fr = correct_fr + pred.eq(pred_advx).sum()
    total = len(gt)

    print("adv_ACC: %.8f" % (float(correct_acc) / total))
    print("FR: %.8f" % (1 - float(correct_fr) / total))