Beispiel #1
0
def main():
    start_epoch = 0

    save_model = "./save_model"
    tensorboard_dir = "./tensorboard/OOD"
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Hyper-parameters
    eps = 1e-8

    ### data config
    train_dataset = load_data.Dog_dataloader(image_dir = image_dir,
                                                    num_class = args.num_classes,
                                                    mode = "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2)

    test_dataset = load_data.Dog_dataloader(image_dir = image_dir,
                                         num_class = args.num_classes,
                                         mode = "test")
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=8,
                                              shuffle=False,
                                              num_workers=2)

    ##### model, optimizer config
    if args.net_type == "resnet50":
        model = models.resnet50(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "resnet34":
        model = models.resnet34(num_c=args.num_classes, pretrained=True)


    # optimizer = optim.Adam(model.parameters(), lr=args.init_lr, weight_decay=1e-5)
    optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9, nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           args.num_epochs * len(train_loader),
                                                           )



    if args.resume == True:
        print("load checkpoint_last")
        checkpoint = torch.load(os.path.join(save_model, "checkpoint_last.pth.tar"))

        ##### load model
        model.load_state_dict(checkpoint["model"])
        start_epoch = checkpoint["epoch"]
        optimizer = optim.Adam(model.parameters(), lr = checkpoint["init_lr"])

    #### loss config
    criterion = nn.BCEWithLogitsLoss()

    #### create folder
    Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True, parents=True)

    if args.board_clear == True:
        files = glob.glob(tensorboard_dir+"/*")
        for f in files:
            shutil.rmtree(f)
    i = 0
    while True:
        if Path(os.path.join(tensorboard_dir, str(i))).exists() == True:
            i += 1
        else:
            Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True, parents=True)
            break
    summary = SummaryWriter(os.path.join(tensorboard_dir, str(i)))


    # Start training
    j=0
    best_score=0
    score = 0

    for epoch in range(start_epoch, args.num_epochs):
        for i in range(args.num_classes):
            locals()["train_label{}".format(i)] = 0
            locals()["test_label{}".format(i)] = 0
        total_loss = 0
        train_acc = 0
        test_acc = 0
        stime = time.time()

        for i, train_data in enumerate(train_loader):
            #### initialized
            org_image = train_data['input'].to(device)
            gt = train_data['label'].type(torch.FloatTensor).to(device)

            model = model.to(device).train()
            optimizer.zero_grad()

            #### forward path
            output = model(org_image)

            #### calc loss
            class_loss = criterion(output, gt)

            #### calc accuracy
            train_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()

            gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist()
            output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist()
            for idx, label in enumerate(gt_label):
                if label == output_label[idx]:
                    locals()["train_label{}".format(label)] += 1

            with autograd.detect_anomaly():
                class_loss.backward()
                optimizer.step()
            scheduler.step()

            total_loss += class_loss.item()


        with torch.no_grad():
            for i, test_data in enumerate(test_loader):
                org_image = test_data['input'].to(device)
                gt = test_data['label'].type(torch.FloatTensor).to(device)

                model = model.to(device).eval()
                #### forward path
                output = model(org_image)

                gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist()
                output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist()
                for idx, label in enumerate(gt_label):
                    if label == output_label[idx]:
                        locals()["test_label{}".format(label)] += 1

                test_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()


        print('Epoch [{}/{}], Step {}, loss = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4'
                  .format(epoch, args.num_epochs, i+1,
                          total_loss/len(train_loader),
                          time.time() - stime,
                          scheduler.get_last_lr()[0] * 10 ** 4))

        print("train accuracy total : {:.4f}".format(train_acc/train_data.num_image))
        for num in range(args.num_classes):
            print("label{} : {:.4f}"
                  .format(num, locals()["train_label{}".format(num)]/train_data.len_list[num])
                  , end=" ")
        print()
        print("test accuracy total : {:.4f}".format(test_acc/test_data.num_image))
        for num in range(args.num_classes):
            print("label{} : {:.4f}"
                  .format(num, locals()["test_label{}".format(num)]/test_data.len_list[num])
                  , end=" ")
        print("\n")



        summary.add_scalar('loss/loss', total_loss/len(train_loader), epoch)
        summary.add_scalar('acc/train_acc', train_acc/train_data.num_image, epoch)
        summary.add_scalar('acc/test_acc', test_acc/test_data.num_image, epoch)
        summary.add_scalar("learning_rate/lr", scheduler.get_last_lr()[0], epoch)
        time.sleep(0.001)
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'init_lr' : scheduler.get_last_lr()[0]
            }, os.path.join(save_model, env,args.net_type, 'checkpoint_last.pth.tar'))
Beispiel #2
0
def main():
    start_epoch = 0

    pretrained_model = os.path.join("./pre_trained", args.dataset,
                                    args.net_type + ".pth.tar")
    save_model = "./save_model_dis/pre_training"
    tensorboard_dir = "./tensorboard/OOD_dis/pre_training" + args.dataset
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Hyper-parameters
    eps = 1e-8

    ### data config
    train_dataset = load_data.Dog_metric_dataloader(image_dir=image_dir,
                                                    num_class=args.num_classes,
                                                    mode="train",
                                                    soft_label=args.soft_label)
    if args.custom_sampler:
        MySampler = load_data.customSampler(train_dataset, args.batch_size,
                                            args.num_instances)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_sampler=MySampler,
                                                   num_workers=2)
    else:
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

    test_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                            num_class=args.num_classes,
                                            mode="test")
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=8,
                                              shuffle=False,
                                              num_workers=2)

    out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                                num_class=args.num_classes,
                                                mode="OOD")
    out_test_loader = torch.utils.data.DataLoader(out_test_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=2)

    if args.transfer:
        ### perfectly OOD data
        OOD_dataset = load_data.Dog_dataloader(image_dir=OOD_dir,
                                               num_class=args.OOD_num_classes,
                                               mode="OOD")
        OOD_loader = torch.utils.data.DataLoader(OOD_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=2)

    ##### model, optimizer config
    if args.net_type == "resnet50":
        model = models.resnet50(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "resnet34":
        model = models.resnet34(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "vgg19":
        model = models.vgg19(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "vgg16":
        model = models.vgg16(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "vgg19_bn":
        model = models.vgg19_bn(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "vgg16_bn":
        model = models.vgg16_bn(num_c=args.num_classes, pretrained=True)

    if args.transfer:
        extra_fc = nn.Linear(2048, args.num_classes + args.OOD_num_classes)

    if args.load == True:
        print("loading model")
        checkpoint = torch.load(pretrained_model)

        ##### load model
        model.load_state_dict(checkpoint["model"])

    batch_num = len(
        train_loader) / args.batch_size if args.custom_sampler else len(
            train_loader)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.init_lr,
                          momentum=0.9,
                          nesterov=args.nesterov)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.num_epochs * batch_num)

    #### loss config
    criterion = nn.BCEWithLogitsLoss()

    #### create folder
    Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True,
                                                             parents=True)

    if args.board_clear == True:
        files = glob.glob(tensorboard_dir + "/*")
        for f in files:
            shutil.rmtree(f)
    i = 0
    while True:
        if Path(os.path.join(tensorboard_dir, str(i))).exists() == True:
            i += 1
        else:
            Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True,
                                                              parents=True)
            break
    summary = SummaryWriter(os.path.join(tensorboard_dir, str(i)))

    # Start training
    j = 0
    best_score = 0
    score = 0
    membership_loss = torch.tensor(0)
    transfer_loss = torch.tensor(0)
    for epoch in range(start_epoch, args.num_epochs):
        running_loss = 0
        running_membership_loss = 0
        running_transfer_loss = 0
        running_class_loss = 0
        train_acc = 0
        test_acc = 0
        stime = time.time()

        # for i, (train_data, OOD_data) in enumerate(zip(train_loader, OOD_loader)):
        for i, train_data in enumerate(train_loader):
            #### initialized
            org_image = train_data['input'] + 0.01 * torch.randn_like(
                train_data['input'])
            org_image = org_image.to(device)
            gt = train_data['label'].type(torch.FloatTensor).to(device)

            model = model.to(device).train()
            optimizer.zero_grad()

            #### forward path
            out1, out2 = model.pendis_forward(org_image)

            if args.membership:
                membership_loss = (
                    Membership_loss(out2, gt, args.num_classes) +
                    Membership_loss(out1, gt, args.num_classes))
                running_membership_loss += membership_loss.item()

            if args.transfer:
                extra_fc = extra_fc.to(device).train()

                OOD_image = (
                    OOD_data['input'] +
                    0.01 * torch.randn_like(OOD_data['input'])).to(device)
                OOD_gt = torch.cat(
                    (torch.zeros(args.batch_size, args.num_classes),
                     OOD_data['label'].type(torch.FloatTensor)),
                    dim=1).to(device)

                #### forward path
                _, feature = model.gen_forward(OOD_image)
                OOD_output = extra_fc(feature)
                transfer_loss = criterion(OOD_output, OOD_gt)
                running_transfer_loss += transfer_loss.item()

            #### calc loss
            class1_loss = criterion(out1, gt)
            class2_loss = criterion(out2, gt)
            class_loss = (class1_loss + class2_loss)

            total_loss = class_loss + membership_loss * 0.3 + transfer_loss

            #### calc accuracy
            train_acc += sum(
                torch.argmax(out1, dim=1) == torch.argmax(
                    gt, dim=1)).cpu().detach().item()
            train_acc += sum(
                torch.argmax(out2, dim=1) == torch.argmax(
                    gt, dim=1)).cpu().detach().item()

            total_loss.backward()
            optimizer.step()
            scheduler.step()

            running_class_loss += class_loss.item()
            running_loss += total_loss.item()

        with torch.no_grad():
            for i, test_data in enumerate(test_loader):
                org_image = test_data['input'].to(device)
                model = model.to(device).eval()
                gt = test_data['label'].type(torch.FloatTensor).to(device)

                #### forward path
                out1, out2 = model.pendis_forward(org_image)
                score_1 = nn.functional.softmax(out1, dim=1)
                score_2 = nn.functional.softmax(out2, dim=1)
                dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape(
                    (org_image.shape[0], -1))
                if i == 0:
                    dists = dist
                    labels = torch.zeros((org_image.shape[0], ))
                else:
                    dists = torch.cat((dists, dist), dim=0)
                    labels = torch.cat(
                        (labels, torch.zeros((org_image.shape[0]))), dim=0)

                test_acc += sum(
                    torch.argmax(torch.sigmoid(out1), dim=1) == torch.argmax(
                        gt, dim=1)).cpu().detach().item()
                test_acc += sum(
                    torch.argmax(torch.sigmoid(out2), dim=1) == torch.argmax(
                        gt, dim=1)).cpu().detach().item()

            for i, out_org_data in enumerate(out_test_loader):
                out_org_image = out_org_data['input'].to(device)

                out1, out2 = model.pendis_forward(out_org_image)
                score_1 = nn.functional.softmax(out1, dim=1)
                score_2 = nn.functional.softmax(out2, dim=1)
                dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape(
                    (out_org_image.shape[0], -1))

                dists = torch.cat((dists, dist), dim=0)
                labels = torch.cat((labels, torch.ones(
                    (out_org_image.shape[0]))),
                                   dim=0)

        roc = evaluate(labels.cpu(), dists.cpu(), metric='roc')
        print('Epoch{} AUROC: {:.3f}, test accuracy : {:.4f}'.format(
            epoch, roc, test_acc / test_dataset.num_image / 2))

        print(
            'Epoch [{}/{}], Step {}, total_loss = {:.4f}, class = {:.4f}, membership = {:.4f}, transfer = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4'
            .format(epoch, args.num_epochs, i + 1, running_loss / batch_num,
                    running_class_loss / batch_num,
                    running_membership_loss / batch_num,
                    running_transfer_loss / batch_num,
                    time.time() - stime,
                    scheduler.get_last_lr()[0] * 10**4))
        print('exe time: {:.2f}, lr: {:.4f}*e-4'.format(
            time.time() - stime,
            scheduler.get_last_lr()[0] * 10**4))

        print("train accuracy total : {:.4f}".format(
            train_acc / train_dataset.num_image / 2))
        print("test accuracy total : {:.4f}".format(
            test_acc / test_dataset.num_image / 2))

        summary.add_scalar('loss/total_loss', running_loss / batch_num, epoch)
        summary.add_scalar('loss/class_loss', running_class_loss / batch_num,
                           epoch)
        summary.add_scalar('loss/membership_loss',
                           running_membership_loss / batch_num, epoch)
        summary.add_scalar('acc/train_acc',
                           train_acc / train_dataset.num_image / 2, epoch)
        summary.add_scalar('acc/test_acc',
                           test_acc / test_dataset.num_image / 2, epoch)
        summary.add_scalar("learning_rate/lr",
                           scheduler.get_last_lr()[0], epoch)
        time.sleep(0.001)
        torch.save(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'init_lr': scheduler.get_last_lr()[0]
            },
            os.path.join(save_model, env, args.net_type,
                         'checkpoint_last_pre.pth.tar'))
Beispiel #3
0
def data_config(image_dir,
                OOD_dir,
                num_classes,
                OOD_num_classes,
                batch_size,
                num_instances,
                soft_label,
                custom_sampler,
                not_test_ODIN,
                transfer,
                resize=(160, 160)):
    train_dataset = load_data.Dog_metric_dataloader(image_dir=image_dir,
                                                    num_class=num_classes,
                                                    mode="train",
                                                    resize=resize,
                                                    soft_label=soft_label)
    if custom_sampler:
        MySampler = load_data.customSampler(train_dataset, batch_size,
                                            num_instances)
        train_loader = DataLoader(train_dataset,
                                  batch_sampler=MySampler,
                                  num_workers=2)
    else:
        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=2)

    test_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                            num_class=num_classes,
                                            mode="test",
                                            resize=resize)
    test_loader = DataLoader(test_dataset,
                             batch_size=8,
                             shuffle=False,
                             num_workers=2)

    out_test_dataset, out_test_loader, OOD_dataset, OOD_loader = 0, 0, 0, 0
    ### novelty data
    if not_test_ODIN:
        out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                                    num_class=num_classes,
                                                    mode="OOD",
                                                    resize=resize)
        out_test_loader = DataLoader(out_test_dataset,
                                     batch_size=8,
                                     shuffle=True,
                                     num_workers=2)

    ### perfectly OOD data
    if transfer:
        OOD_dataset = load_data.Dog_dataloader(image_dir=OOD_dir,
                                               num_class=OOD_num_classes,
                                               mode="OOD",
                                               resize=resize)
        OOD_loader = DataLoader(OOD_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=2)

    return train_dataset, train_loader, test_dataset, test_loader, out_test_dataset, out_test_loader, OOD_dataset, OOD_loader
Beispiel #4
0
def main():
    start_epoch = 0

    save_model = "./pre_trained"
    tensorboard_dir = "./tensorboard/OOD"
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Hyper-parameters
    eps = 1e-8
    init_lr = 5e-4

    unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

    ### data config
    test_data = load_data.Dog_dataloader(image_dir=image_dir,
                                         num_class=args.num_classes,
                                         mode="OOD")
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=2)

    ##### model, optimizer config
    if args.net_type == "resnet":
        model = models.resnet50(num_c=args.num_classes, pretrained=True)

    optimizer = optim.Adam(model.parameters(), lr=init_lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [30],
                                                     gamma=0.3)

    if args.resume == True:
        print("load checkpoint_last")
        checkpoint = torch.load(os.path.join(save_model, "resnet50.pth.tar"))

        ##### load model
        model.load_state_dict(checkpoint["model"])

    for i in range(args.num_classes):
        locals()["test_label{}".format(i)] = 0
    test_acc = 0
    MSP = torch.tensor([])
    with torch.no_grad():
        for i, (org_image, gt) in enumerate(test_loader):
            org_image = org_image.to(device)
            model = model.to(device).eval()
            gt = gt.type(torch.FloatTensor).to(device)
            #### forward path
            output = model(org_image)
            raw_image = unorm(org_image.squeeze(0)).cpu().detach()
            gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist()
            output_label = torch.argmax(torch.sigmoid(output),
                                        dim=1).cpu().detach().tolist()
            for idx, label in enumerate(gt_label):
                if label == output_label[idx]:
                    locals()["test_label{}".format(label)] += 1
                MSP = torch.cat(
                    (MSP,
                     (torch.softmax(output, dim=1).max().cpu()).unsqueeze(0)),
                    dim=0)
                # print(torch.softmax(output, dim=1).max())
                # print("label : {}, predicted class : {}".format(label, output_label))
            test_acc += sum(
                torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(
                    gt, dim=1)).cpu().detach().item()

            # thismanager = get_current_fig_manager()
            # thismanager.window.SetPosition((500, 0))
            # plt.get_current_fig_manager().window.wm_geometry("+1000+100") # move the window
            # plt.imshow(raw_image.permute(1,2,0))
            # plt.show()
    thres_list = [0.501, 0.601, 0.701, 0.801, 0.901]
    print("total # of data : {}".format(test_data.num_image))
    for idx, thres in enumerate(thres_list):
        print(thres, end=" ")
        if idx == 0:
            print(torch.sum(MSP < thres))
        else:
            print(torch.sum(torch.mul((thres + 0.1) >= MSP, thres < MSP)))

    print("test accuracy total : {:.4f}".format(test_acc /
                                                test_data.num_image))
    for num in range(args.num_classes):
        print("label{} : {:.4f}".format(
            num,
            locals()["test_label{}".format(num)] / test_data.len_list[num]),
              end=" ")
    print("\n")
    time.sleep(0.001)
Beispiel #5
0
def main():
    start_epoch = 0

    if args.metric:
        save_model = "./save_model_" + args.dataset + "_metric"
        tensorboard_dir = "./tensorboard/OOD_" + args.dataset
    else:
        save_model = "./save_model_" + args.dataset
        tensorboard_dir = "./tensorboard/OOD_" + args.dataset

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Hyper-parameters
    eps = 1e-8


    ### data config
    train_dataset = load_data.Dog_metric_dataloader(image_dir = image_dir,
                                                    num_class = args.num_classes,
                                                    mode = "train",
                                                    soft_label=args.soft_label)
    MySampler = customSampler(train_dataset, args.batch_size, args.num_instances)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               # batch_size=args.batch_size,
                                               batch_sampler= MySampler,
                                               num_workers=2)

    test_dataset = load_data.Dog_dataloader(image_dir = image_dir,
                                         num_class = args.num_classes,
                                         mode = "test")
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=8,
                                              shuffle=False,
                                              num_workers=2)

    out_test_dataset = load_data.Dog_dataloader(image_dir = image_dir,
                                             num_class = args.num_classes,
                                             mode = "OOD")
    out_test_loader = torch.utils.data.DataLoader(out_test_dataset,
                                              batch_size=8,
                                              shuffle=True,
                                              num_workers=2)



    ##### model, optimizer config
    if args.net_type == "resnet50":
        model = models.resnet50(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "resnet34":
        model = models.resnet34(num_c=args.num_classes, pretrained=True)


    optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9, nesterov=args.nesterov)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           args.num_epochs * len(train_loader)//50,
                                                           eta_min=args.init_lr/10)


    if args.resume == True:
        print("load checkpoint_last")
        checkpoint = torch.load(os.path.join(save_model, "checkpoint_last.pth.tar"))

        ##### load model
        model.load_state_dict(checkpoint["model"])
        start_epoch = checkpoint["epoch"]
        optimizer = optim.SGD(model.parameters(), lr = checkpoint["init_lr"])

    #### loss config
    criterion = nn.BCEWithLogitsLoss()
    triplet = torch.nn.TripletMarginLoss(margin=0.5, p=2)

    #### create folder
    Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True, parents=True)

    if args.board_clear == True:
        files = glob.glob(tensorboard_dir+"/*")
        for f in files:
            shutil.rmtree(f)
    i = 0
    while True:
        if Path(os.path.join(tensorboard_dir, str(i))).exists() == True:
            i += 1
        else:
            Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True, parents=True)
            break
    summary = SummaryWriter(os.path.join(tensorboard_dir, str(i)))


    # Start training
    j=0
    best_score=0
    score = 0
    triplet_loss = torch.tensor(0)
    membership_loss = torch.tensor(0)
    for epoch in range(start_epoch, args.num_epochs):
        for i in range(args.num_classes):
            locals()["train_label{}".format(i)] = 0
            locals()["test_label{}".format(i)] = 0
        total_loss = 0
        triplet_running_loss = 0
        membership_running_loss = 0
        class_running_loss = 0
        train_acc = 0
        test_acc = 0
        stime = time.time()

        for i, train_data in enumerate(train_loader):
            #### initialized
            org_image = train_data['input'] + 0.01 * torch.randn_like(train_data['input'])
            org_image = org_image.to(device)
            model = model.to(device).train()
            gt = train_data['label'].type(torch.FloatTensor).to(device)
            optimizer.zero_grad()

            #### forward path
            output, output_list = model.feature_list(org_image)


            if args.metric:

                target_layer = output_list[-1]
                negative_list = []
                for batch_idx in range(args.batch_size):
                    gt_arg = gt.argmax(dim=1)
                    negative = (gt_arg != gt_arg[batch_idx])
                    if batch_idx == 0:
                        negative_tensor = target_layer[np.random.choice(np.where(negative.cpu().numpy() == True)[0], 1)[0]]
                        positive_tensor = target_layer[np.random.choice(np.delete(
                            np.where(~negative.cpu().numpy() == True)[0],np.where(np.where(~negative.cpu().numpy() == True)[0] == batch_idx)),
                            1)[0]]
                        negative_tensor = torch.unsqueeze(negative_tensor, dim=0)
                        positive_tensor = torch.unsqueeze(positive_tensor, dim=0)
                    else:
                        tmp_negative_tensor = target_layer[np.random.choice(np.where(negative.cpu().numpy() == True)[0], 1)[0]]
                        negative_tensor = torch.cat((negative_tensor, torch.unsqueeze(tmp_negative_tensor, dim=0)), dim=0)

                        tmp_positive_tensor =  target_layer[np.random.choice(np.delete(
                            np.where(~negative.cpu().numpy() == True)[0],np.where(np.where(~negative.cpu().numpy() == True)[0] == batch_idx)),
                            1)[0]]
                        positive_tensor = torch.cat((positive_tensor, torch.unsqueeze(tmp_positive_tensor, dim=0)), dim=0)

                triplet_loss = 0.5 * triplet(target_layer, positive_tensor, negative_tensor)


            if args.membership:
                R_wrong = 0
                R_correct = 0
                gt_idx = torch.argmax(gt, dim=1)
                for batch_idx, which in enumerate(gt_idx):
                    for idx in range(args.num_classes):
                        output_sigmoid = torch.sigmoid(output)
                        if which == idx:
                            R_wrong += (1 - output_sigmoid[batch_idx][idx]) ** 2
                        else:
                            R_correct += output_sigmoid[batch_idx][idx] / (args.num_classes-1)
                membership_loss = (R_wrong + R_correct) / args.batch_size


            #### calc loss
            class_loss = criterion(output, gt)

            #### calc accuracy
            train_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()

            gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist()
            output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist()
            for idx, label in enumerate(gt_label):
                if label == output_label[idx]:
                    locals()["train_label{}".format(label)] += 1

            total_backward_loss = class_loss + triplet_loss + membership_loss
            total_backward_loss.backward()
            optimizer.step()
            scheduler.step()

            class_running_loss += class_loss.item()
            triplet_running_loss += triplet_loss.item()
            membership_running_loss += membership_loss.item()
            total_loss += total_backward_loss.item()


        with torch.no_grad():
            for i, test_data in enumerate(test_loader):
                org_image = test_data['input'].to(device)
                model = model.to(device).eval()
                gt = test_data['label'].type(torch.FloatTensor).to(device)

                #### forward path
                output = model(org_image)

                gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist()
                output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist()
                for idx, label in enumerate(gt_label):
                    if label == output_label[idx]:
                        locals()["test_label{}".format(label)] += 1


                test_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()





        print('Epoch [{}/{}], Step {}, class_loss = {:.4f}, membership_loss = {:.4f}, total_loss = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4'
                  .format(epoch, args.num_epochs, i+1,
                          class_running_loss/len(train_loader),
                          membership_running_loss/len(train_loader),
                          total_loss/len(train_loader),
                          time.time() - stime,
                          scheduler.get_last_lr()[0] * 10 ** 4))

        print("train accuracy total : {:.4f}".format(train_acc/(len(MySampler)*args.batch_size)))
        # print("train accuracy total : {:.4f}".format(train_acc/train_dataset.num_image))
        for num in range(args.num_classes):
            print("label{} : {:.4f}"
                  .format(num, locals()["train_label{}".format(num)]/train_dataset.len_list[num])
                  , end=" ")
        print()
        print("test accuracy total : {:.4f}".format(test_acc/test_dataset.num_image))
        for num in range(args.num_classes):
            print("label{} : {:.4f}"
                  .format(num, locals()["test_label{}".format(num)]/test_dataset.len_list[num])
                  , end=" ")
        print("\n")

        if epoch % 10 == 9:
            best_TNR, best_AUROC = test_ODIN(model, test_loader, out_test_loader, args.net_type, args)
            summary.add_scalar('AD_acc/AUROC', best_AUROC, epoch)
            summary.add_scalar('AD_acc/TNR', best_TNR, epoch)


        summary.add_scalar('loss/loss', total_loss/len(train_loader), epoch)
        summary.add_scalar('loss/membership_loss', membership_running_loss/len(train_loader), epoch)
        summary.add_scalar('acc/train_acc', train_acc/train_dataset.num_image, epoch)
        summary.add_scalar('acc/test_acc', test_acc/test_dataset.num_image, epoch)
        summary.add_scalar("learning_rate/lr", scheduler.get_last_lr()[0], epoch)
        time.sleep(0.001)
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'init_lr' : scheduler.get_last_lr()[0]
            }, os.path.join(save_model, env, args.net_type, 'checkpoint_last.pth.tar'))
        scheduler.step()
def main():
    start_epoch = 0

    save_model = "./save_model_dis/fine"
    pretrained_model_dir = "./save_model_dis/pre_training"
    tensorboard_dir = "./tensorboard/OOD_dis/fine/" + args.dataset
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    # Hyper-parameters
    eps = 1e-8

    ### data config
    train_dataset = load_data.Dog_dataloader(image_dir = image_dir,
                                          num_class = args.num_classes,
                                          mode = "train")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2)
    test_dataset = load_data.Dog_dataloader(image_dir = image_dir,
                                         num_class = args.num_classes,
                                         mode = "test")
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.num_classes,
                                              shuffle=True,
                                              num_workers=2)
    out_train_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                             num_class=args.num_classes,
                                             mode="OOD_val")
    out_train_loader = torch.utils.data.DataLoader(out_train_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)
    out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                             num_class=args.num_classes,
                                             mode="OOD")
    out_test_loader = torch.utils.data.DataLoader(out_test_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=2)




    ##### model, optimizer config
    if args.net_type == "resnet50":
        model = models.resnet50(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "resnet34":
        model = models.resnet34(num_c=args.num_classes, pretrained=True)

    if args.load == True:
        print("loading model")
        checkpoint = torch.load(os.path.join(pretrained_model_dir, args.pretrained_model, "checkpoint_last_pre.pth.tar"))

        ##### load model
        model.load_state_dict(checkpoint["model"])

    optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9, nesterov=args.nesterov)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           args.num_epochs * len(train_loader))

    #### loss config
    criterion = nn.BCEWithLogitsLoss()

    #### create folder
    Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True, parents=True)

    if args.board_clear == True:
        files = glob.glob(tensorboard_dir+"/*")
        for f in files:
            shutil.rmtree(f)
    i = 0
    while True:
        if Path(os.path.join(tensorboard_dir, str(i))).exists() == True:
            i += 1
        else:
            Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True, parents=True)
            break
    summary = SummaryWriter(os.path.join(tensorboard_dir, str(i)))


    # Start training
    j=0
    best_score=0
    score = 0
    for epoch in range(start_epoch, args.num_epochs):
        total_class_loss = 0
        total_dis_loss = 0
        train_acc = 0
        test_acc = 0
        stime = time.time()



        model.eval().to(device)
        with torch.no_grad():
            for i, test_data in enumerate(test_loader):
                org_image = test_data['input'].to(device)
                gt = test_data['label'].type(torch.FloatTensor).to(device)

                out1, out2 = model.dis_forward(org_image)
                score_1 = nn.functional.softmax(out1, dim=1)
                score_2 = nn.functional.softmax(out2, dim=1)
                dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape((org_image.shape[0], ))
                if i == 0:
                    dists = dist
                    labels = torch.zeros((org_image.shape[0],))
                else:
                    dists = torch.cat((dists, dist), dim=0)
                    labels = torch.cat((labels, torch.zeros((org_image.shape[0]))), dim=0)

                test_acc += sum(torch.argmax(torch.sigmoid(out1), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()
                test_acc += sum(torch.argmax(torch.sigmoid(out2), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()

            for i, out_org_data in enumerate(out_test_loader):
                out_org_image = out_org_data['input'].to(device)

                out1, out2 = model.dis_forward(out_org_image)
                score_1 = nn.functional.softmax(out1, dim=1)
                score_2 = nn.functional.softmax(out2, dim=1)
                dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape((out_org_image.shape[0], -1))

                dists = torch.cat((dists, dist), dim=0)
                labels = torch.cat((labels, torch.ones((out_org_image.shape[0]))), dim=0)

        roc = evaluate(labels.cpu(), dists.cpu(), metric='roc')
        print('Epoch{} AUROC: {:.3f}, test accuracy : {:.4f}'.format(epoch, roc, test_acc/test_dataset.num_image/2))


        for i, (org_data, out_org_data) in enumerate(zip(train_loader, out_train_loader)):
            #### initialized
            org_image = org_data['input'].to(device)
            out_org_image = out_org_data['input'].to(device)
            model = model.to(device).train()
            gt = org_data['label'].type(torch.FloatTensor).to(device)
            optimizer.zero_grad()

            #### forward path
            out1, out2 = model.dis_forward(org_image)

            #### calc accuracy
            train_acc += sum(torch.argmax(out1, dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()
            train_acc += sum(torch.argmax(out2, dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item()

            #### calc loss
            class1_loss = criterion(out1, gt)
            class2_loss = criterion(out2, gt)

            out1, out2 = model.dis_forward(out_org_image)
            dis_loss = DiscrepancyLoss(out1, out2, args.m)
            loss = class1_loss + class2_loss + dis_loss

            total_class_loss += class1_loss.item() + class2_loss.item()
            total_dis_loss += dis_loss.item()


            loss.backward()
            optimizer.step()
            scheduler.step()



        print('Epoch [{}/{}], Step {}, class_loss = {:.4f}, dis_loss = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4'
                  .format(epoch, args.num_epochs, i+1,
                          total_class_loss/len(out_train_loader),
                          dis_loss/len(out_train_loader),
                          time.time() - stime,
                          scheduler.get_last_lr()[0] * 10 ** 4))




        summary.add_scalar('loss/class_loss', total_class_loss/len(train_loader), epoch)
        summary.add_scalar('loss/dis_loss', total_dis_loss/len(train_loader), epoch)
        summary.add_scalar('acc/roc', roc, epoch)
        summary.add_scalar("learning_rate/lr", scheduler.get_last_lr()[0], epoch)
        time.sleep(0.001)
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'init_lr' : scheduler.get_last_lr()[0]
            }, os.path.join(save_model, env, args.net_type, 'checkpoint_last_fine.pth.tar'))
def main():
    output_dir = "./save_fig"

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Hyper-parameters
    eps = 1e-8

    ### data config
    test_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                            num_class=args.num_classes,
                                            mode="test")
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=2)

    ### novelty data
    out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir,
                                                num_class=args.num_classes,
                                                mode="OOD")
    out_test_loader = torch.utils.data.DataLoader(out_test_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=2)

    ##### model, optimizer config
    if args.net_type == "resnet50":
        model = models.resnet50(num_c=args.num_classes, pretrained=True)
    elif args.net_type == "resnet34":
        model = models.resnet34(num_c=args.num_classes,
                                num_cc=args.OOD_num_classes,
                                pretrained=True)
    elif args.net_type == "vgg19":
        model = models.vgg19(num_c=args.num_classes,
                             num_cc=args.OOD_num_classes,
                             pretrained=True)
    elif args.net_type == "vgg16":
        model = models.vgg16(num_c=args.num_classes,
                             num_cc=args.OOD_num_classes,
                             pretrained=True)
    elif args.net_type == "vgg19_bn":
        model = models.vgg19_bn(num_c=args.num_classes,
                                num_cc=args.OOD_num_classes,
                                pretrained=True)
    elif args.net_type == "vgg16_bn":
        model = models.vgg16_bn(num_c=args.num_classes,
                                num_cc=args.OOD_num_classes,
                                pretrained=True)

    print("load checkpoint_last")
    checkpoint = torch.load(args.model_path)

    ##### load model
    model.load_state_dict(checkpoint["model"])
    start_epoch = checkpoint["epoch"]
    optimizer = optim.SGD(model.parameters(), lr=checkpoint["init_lr"])

    #### create folder
    Path(output_dir).mkdir(exist_ok=True, parents=True)

    model = model.to(device).eval()
    # Start grad-CAM
    bp = BackPropagation(model=model)
    inv_normalize = transforms.Normalize(
        mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255],
        std=[1 / 0.229, 1 / 0.224, 1 / 0.255])
    target_layer = "layer4"

    stime = time.time()

    gcam = GradCAM(model=model)

    grad_cam = GradCAMmodule(target_layer, output_dir)
    grad_cam.model_config(model)
    for j, test_data in enumerate(test_loader):
        #### initialized
        org_image = test_data['input'].to(device)
        target_class = test_data['label'].to(device)

        target_class = int(target_class.argmax().cpu().detach())
        result = model(org_image).argmax()
        print("number: {} pred: {} target: {}".format(j, result, target_class))
        result = int(result.cpu().detach())
        grad_cam.saveGradCAM(org_image, result, j)