def stage_valmixup(net, dataloader, device, name="stage1_mixup"):
    print("validating mixup and trainloader ...")
    energy_loader_list = []
    energy_mixup_list = []
    target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            mixed = mixup(inputs, targets, args)
            out_loader = net(inputs)
            out_mixed = net(mixed)
            energy_loader_list.append(out_loader["energy"])
            energy_mixup_list.append(out_mixed["energy"])
            target_list.append(targets)
            progress_bar(batch_idx, len(trainloader))

    energy_loader_list = torch.cat(energy_loader_list, dim=0)
    energy_mixup_list = torch.cat(energy_mixup_list, dim=0)

    plot_listhist([energy_loader_list, energy_mixup_list],
                  args, labels=["loader", "mixup"],
                  name=name + "_energy")

    print("_______________Validate statistics:____________")
    print(f"train mid:{energy_loader_list.median()} | mixup mid:{energy_mixup_list.median()}")
    print(f"min  energy:{min(energy_loader_list.min(), energy_mixup_list.min())} "
          f"| max  energy:{max(energy_loader_list.max(), energy_mixup_list.max())}")
    return{
        "mid_known": energy_loader_list.median(),
        "mid_unknown": energy_mixup_list.median()
    }
Beispiel #2
0
def stage1_valvae(net, testloader, device):
    print("validating vae and net ...")
    # loading vae model
    vae = VanillaVAE(in_channels=1, latent_dim=args.latent_dim)
    vae = vae.to(device)
    if device == 'cuda':
        vae = torch.nn.DataParallel(vae)
        cudnn.benchmark = True
    if os.path.isfile(args.vae_resume):
        vae_checkpoint = torch.load(args.vae_resume)
        vae.load_state_dict(vae_checkpoint['net'])
        print('==> Resuming vae from checkpoint, loaded..')

    normfea_test_list = []
    normfea_sample_list = []
    normfea_mix_list = []
    target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            sampled = sampler(vae, device, args)
            out_test = net(inputs)
            out_sample = net(sampled)
            normfea_test_list.append(out_test["norm_fea"])
            normfea_sample_list.append(out_sample["norm_fea"])
            target_list.append(targets)
            progress_bar(batch_idx, len(trainloader))

        for batch_idx, (inputs, targets) in enumerate(mixuploader):
            inputs, targets = inputs.to(device), targets.to(device)
            mixed = mixup(inputs, targets, args)
            out_mixed = net(mixed)
            normfea_mix_list.append(out_mixed["norm_fea"])
            progress_bar(batch_idx, len(trainloader))

    normfea_test_list = torch.cat(normfea_test_list, dim=0)
    normfea_sample_list = torch.cat(normfea_sample_list, dim=0)
    normfea_mix_list = torch.cat(normfea_mix_list, dim=0)
    target_list = torch.cat(target_list, dim=0)

    unknown_label = target_list.max()
    normfea_test_unknown_list = normfea_test_list[target_list == unknown_label]
    normfea_test_known_list = normfea_test_list[target_list != unknown_label]

    plot_listhist([
        normfea_test_known_list, normfea_test_unknown_list,
        normfea_sample_list, normfea_mix_list
    ],
                  args,
                  labels=["test_known", "test_unknown", "sampled", "mixed"],
                  name="stage1_valvaemix_normfea_result")
Beispiel #3
0
def stage_valvae(net,
                 dataloader,
                 device,
                 name="stage1_valtrain&sample_normfea_result"):
    print("validating vae and net ...")
    # loading vae model
    vae = VanillaVAE(in_channels=1, latent_dim=args.latent_dim)
    vae = vae.to(device)
    if device == 'cuda':
        vae = torch.nn.DataParallel(vae)
        cudnn.benchmark = True
    if os.path.isfile(args.vae_resume):
        vae_checkpoint = torch.load(args.vae_resume)
        vae.load_state_dict(vae_checkpoint['net'])
        print('==> Resuming vae from checkpoint, loaded..')

    normfea_loader_list = []
    normfea_sample_list = []
    target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            sampled = sampler(vae, device, args)
            out_test = net(inputs)
            out_sample = net(sampled)
            normfea_loader_list.append(out_test["norm_fea"])
            normfea_sample_list.append(out_sample["norm_fea"])
            target_list.append(targets)
            progress_bar(batch_idx, len(trainloader))

    normfea_loader_list = torch.cat(normfea_loader_list, dim=0)
    normfea_sample_list = torch.cat(normfea_sample_list, dim=0)

    plot_listhist([normfea_loader_list, normfea_sample_list],
                  args,
                  labels=["train data", "sampled data"],
                  name=name)
    print(
        f"train mid:{normfea_loader_list.median()} | sampled mid:{normfea_sample_list.median()}"
    )
    print(
        f"min  norm:{min(normfea_loader_list.min(), normfea_sample_list.min())} "
        f"| max  norm:{max(normfea_loader_list.max(), normfea_sample_list.max())}"
    )
    return {
        "vae": vae,
        "mid_known": normfea_loader_list.median(),
        "mid_unknown": normfea_sample_list.median()
    }
def test_with_hist(net, dataloader, device, intervals=20, name="stage1_test"):
    energy_list = []  # energy value
    Target_list = []
    Predict_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            out = net(inputs)  # shape [batch,class]
            energy_list.append(out["energy"])
            Target_list.append(targets)
            _, predicted = (out['normweight_fea2cen']).max(1)
            Predict_list.append(predicted)
            progress_bar(batch_idx, len(dataloader), "|||")
    energy_list = torch.cat(energy_list, dim=0)
    Target_list = torch.cat(Target_list, dim=0)
    Predict_list = torch.cat(Predict_list, dim=0)
    unknown_label = Target_list.max()
    unknown_energy_list = energy_list[Target_list == unknown_label]
    known_energy_list = energy_list[Target_list != unknown_label]
    plot_listhist([known_energy_list, unknown_energy_list],
                  args,
                  labels=["known", "unknown"],
                  name=name + "_energy")

    best_F1 = 0
    best_thres = 0
    best_eval = None
    # for these unbounded metric, we explore more intervals by *5 to achieve a relatively fair comparison.
    expand_factor = 5
    openmetric_list = energy_list
    threshold_min = openmetric_list.min().item()
    threshold_max = openmetric_list.max().item()
    for thres in np.linspace(threshold_min, threshold_max,
                             expand_factor * intervals):
        Predict_list[openmetric_list < thres] = args.train_class_num
        eval = Evaluation(Predict_list.cpu().numpy(),
                          Target_list.cpu().numpy())
        if eval.f1_measure > best_F1:
            best_F1 = eval.f1_measure
            best_thres = thres
            best_eval = eval
    print(f"The energy range is [{threshold_min}, {threshold_max}] ")
    print(f"Best F1 is: {best_F1}  [in best threshold: {best_thres} ]")
    return {
        "best_F1": best_F1,
        "best_thres": best_thres,
        "best_eval": best_eval
    }
Beispiel #5
0
def stage_valmixup(net,
                   dataloader,
                   device,
                   name="stage1_valtrain&sample_normfea_result"):
    print("validating mixup and trainloader ...")
    normfea_loader_list = []
    normfea_mixup_list = []
    target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            mixed = mixup(inputs, targets, args)
            out_loader = net(inputs)
            out_mixed = net(mixed)
            normfea_loader_list.append(out_loader["norm_fea"])
            normfea_mixup_list.append(out_mixed["norm_fea"])
            target_list.append(targets)
            progress_bar(batch_idx, len(trainloader))

    normfea_loader_list = torch.cat(normfea_loader_list, dim=0)
    normfea_mixup_list = torch.cat(normfea_mixup_list, dim=0)

    plot_listhist([normfea_loader_list, normfea_mixup_list],
                  args,
                  labels=["train data", "sampled data"],
                  name=name)
    print("_______________Validate statistics:____________")
    print(
        f"train mid:{normfea_loader_list.median()} | mixup mid:{normfea_mixup_list.median()}"
    )
    print(
        f"min  norm:{min(normfea_loader_list.min(), normfea_mixup_list.min())} "
        f"| max  norm:{max(normfea_loader_list.max(), normfea_mixup_list.max())}"
    )
    return {
        "mid_known": normfea_loader_list.median(),
        "mid_unknown": normfea_mixup_list.median()
    }
Beispiel #6
0
def stage_test(net, testloader, device, name="stage1_test_normfea_doublebar"):
    correct = 0
    total = 0
    normfea_list = []
    pnorm_list = []
    energy_list = []
    Target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            out = net(inputs)  # shape [batch,class]
            normfea_list.append(out["norm_fea"])
            pnorm_list.append(out["pnorm"])
            energy_list.append(out["energy"])
            Target_list.append(targets)
            _, predicted = (out["normweight_fea2cen"]).max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            progress_bar(
                batch_idx, len(testloader), '| Acc: %.3f%% (%d/%d)' %
                (100. * correct / total, correct, total))
    print("\nTesting results is {:.2f}%".format(100. * correct / total))

    normfea_list = torch.cat(normfea_list, dim=0)
    pnorm_list = torch.cat(pnorm_list, dim=0)
    energy_list = torch.cat(energy_list, dim=0)
    Target_list = torch.cat(Target_list, dim=0)
    unknown_label = Target_list.max()
    unknown_normfea_list = normfea_list[Target_list == unknown_label]
    known_normfea_list = normfea_list[Target_list != unknown_label]

    unknown_pnorm_list = pnorm_list[Target_list == unknown_label]
    known_pnorm_list = pnorm_list[Target_list != unknown_label]

    unknown_energy_list = energy_list[Target_list == unknown_label]
    known_energy_list = energy_list[Target_list != unknown_label]

    print("_______________Testing statistics:____________")
    print(
        f"test known mid:{known_normfea_list.median()} | unknown mid:{unknown_normfea_list.median()}"
    )
    print(
        f"min  norm:{min(known_normfea_list.min(), unknown_normfea_list.min())} "
        f"| max  norm:{max(known_normfea_list.max(), unknown_normfea_list.max())}"
    )
    plot_listhist([known_normfea_list, unknown_normfea_list],
                  args,
                  labels=["known", "unknown"],
                  name=name)
    plot_listhist([known_pnorm_list, unknown_pnorm_list],
                  args,
                  labels=["known", "unknown"],
                  name=name + "_pnorm")

    plot_listhist([known_energy_list, unknown_energy_list],
                  args,
                  labels=["known", "unknown"],
                  name=name + "_energy")
def stage_valmixup(net, dataloader, device, name="stage1_mixup_doublebar"):
    print("validating mixup and trainloader ...")
    normfea_loader_list = []
    normfea_mixup_list = []
    pnorm_loader_list = []
    pnorm_mixup_list = []
    energy_loader_list = []
    energy_mixup_list = []
    normweight_loader_list = []
    normweight_mixup_list = []
    target_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            mixed = mixup(inputs, targets, args)
            out_loader = net(inputs)
            out_mixed = net(mixed)
            normfea_loader_list.append(out_loader["norm_fea"])
            normfea_mixup_list.append(out_mixed["norm_fea"])
            pnorm_loader_list.append(out_loader["pnorm"])
            pnorm_mixup_list.append(out_mixed["pnorm"])
            energy_loader_list.append(out_loader["energy"])
            energy_mixup_list.append(out_mixed["energy"])
            normweight_loader_list.append(out_loader["normweight_fea2cen"])
            normweight_mixup_list.append(out_mixed["normweight_fea2cen"])

            target_list.append(targets)
            progress_bar(batch_idx, len(trainloader))

    normfea_loader_list = torch.cat(normfea_loader_list, dim=0)
    normfea_mixup_list = torch.cat(normfea_mixup_list, dim=0)
    pnorm_loader_list = torch.cat(pnorm_loader_list, dim=0)
    pnorm_mixup_list = torch.cat(pnorm_mixup_list, dim=0)
    energy_loader_list = torch.cat(energy_loader_list, dim=0)
    energy_mixup_list = torch.cat(energy_mixup_list, dim=0)
    normweight_loader_list = torch.cat(normweight_loader_list, dim=0)
    softmax_loader_list = torch.softmax(normweight_loader_list,
                                        dim=1).max(dim=1, keepdim=False)[0]
    normweight_mixup_list = torch.cat(normweight_mixup_list, dim=0)
    softmax_mixup_list = torch.softmax(normweight_mixup_list,
                                       dim=1).max(dim=1, keepdim=False)[0]

    plot_listhist([pnorm_loader_list, pnorm_mixup_list],
                  args,
                  labels=["loader", "mixup"],
                  name=name + "_pnorm")

    plot_listhist([normfea_loader_list, normfea_mixup_list],
                  args,
                  labels=["loader", "mixup"],
                  name=name + "_normfea")

    plot_listhist([energy_loader_list, energy_mixup_list],
                  args,
                  labels=["loader", "mixup"],
                  name=name + "_energy")

    plot_listhist([softmax_loader_list, softmax_mixup_list],
                  args,
                  labels=["loader", "mixup"],
                  name=name + "_softmax")

    print("_______________Validate statistics:____________")
    print(
        f"train mid:{energy_loader_list.median()} | mixup mid:{energy_mixup_list.median()}"
    )
    print(
        f"min  energy:{min(energy_loader_list.min(), energy_mixup_list.min())} "
        f"| max  energy:{max(energy_loader_list.max(), energy_mixup_list.max())}"
    )
    return {
        "mid_known": energy_loader_list.median(),
        "mid_unknown": energy_mixup_list.median()
    }