예제 #1
0
def middle_validate(net, trainloader, device, name=""):
    print("validating vae and net ...")
    known_energy, unknown_energy = [], []

    # loading vae model
    vae = VanillaVAE(in_channels=1, latent_dim=args.latent_dim)
    vae = vae.to(device)
    if device == 'cuda':
        vae = torch.nn.DataParallel(vae)
        cudnn.benchmark = True
    if os.path.isfile(args.vae_resume):
        vae_checkpoint = torch.load(args.vae_resume)
        vae.load_state_dict(vae_checkpoint['net'])
        print('==> Resuming vae from checkpoint, loaded..')

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            sampled = sampler(vae, device, args)
            out_known = net(inputs)
            out_unkown = net(sampled)
            known_energy.append(out_known["pnorm"])
            unknown_energy.append(out_unkown["pnorm"])
            progress_bar(batch_idx, len(trainloader))

    known_energy = torch.cat(known_energy, dim=0)
    unknown_energy = torch.cat(unknown_energy, dim=0)
    energy_hist_sperate(known_energy, unknown_energy, args, name)
    return {
        # unkown is smaller than known
        "vae": vae,
        "mid_known": known_energy.median().data,
        "mid_unknown": unknown_energy.median().data
    }
예제 #2
0
def mixup_validate(net, trainloader, mixuploader, device, stage="1"):
    print("validating mixup ...")
    known_energy, unknown_energy = [], []
    with torch.no_grad():
        batch_idx = -1
        for (inputs, targets), (inputs_bak,
                                targets_bak) in zip(trainloader, mixuploader):
            batch_idx += 1
            inputs, targets = inputs.to(device), targets.to(device)
            inputs_bak, targets_bak = inputs_bak.to(device), targets_bak.to(
                device)
            mixed = mixup(inputs, targets, inputs_bak, targets_bak, args)
            out_known = net(inputs)
            out_unkown = net(mixed)
            known_energy.append(out_known["energy"])
            unknown_energy.append(out_unkown["energy"])
            progress_bar(batch_idx, len(trainloader))

    known_energy = torch.cat(known_energy, dim=0)
    unknown_energy = torch.cat(unknown_energy, dim=0)
    energy_hist_sperate(known_energy, unknown_energy, args,
                        "mixup_stage" + stage)
    return {
        # unkown is smaller than known
        "mid_known": known_energy.median().data,
        "mid_unknown": unknown_energy.median().data
    }