def middle_validate(net, trainloader, device, name=""): print("validating vae and net ...") known_energy, unknown_energy = [], [] # loading vae model vae = VanillaVAE(in_channels=1, latent_dim=args.latent_dim) vae = vae.to(device) if device == 'cuda': vae = torch.nn.DataParallel(vae) cudnn.benchmark = True if os.path.isfile(args.vae_resume): vae_checkpoint = torch.load(args.vae_resume) vae.load_state_dict(vae_checkpoint['net']) print('==> Resuming vae from checkpoint, loaded..') with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) sampled = sampler(vae, device, args) out_known = net(inputs) out_unkown = net(sampled) known_energy.append(out_known["pnorm"]) unknown_energy.append(out_unkown["pnorm"]) progress_bar(batch_idx, len(trainloader)) known_energy = torch.cat(known_energy, dim=0) unknown_energy = torch.cat(unknown_energy, dim=0) energy_hist_sperate(known_energy, unknown_energy, args, name) return { # unkown is smaller than known "vae": vae, "mid_known": known_energy.median().data, "mid_unknown": unknown_energy.median().data }
def mixup_validate(net, trainloader, mixuploader, device, stage="1"): print("validating mixup ...") known_energy, unknown_energy = [], [] with torch.no_grad(): batch_idx = -1 for (inputs, targets), (inputs_bak, targets_bak) in zip(trainloader, mixuploader): batch_idx += 1 inputs, targets = inputs.to(device), targets.to(device) inputs_bak, targets_bak = inputs_bak.to(device), targets_bak.to( device) mixed = mixup(inputs, targets, inputs_bak, targets_bak, args) out_known = net(inputs) out_unkown = net(mixed) known_energy.append(out_known["energy"]) unknown_energy.append(out_unkown["energy"]) progress_bar(batch_idx, len(trainloader)) known_energy = torch.cat(known_energy, dim=0) unknown_energy = torch.cat(unknown_energy, dim=0) energy_hist_sperate(known_energy, unknown_energy, args, "mixup_stage" + stage) return { # unkown is smaller than known "mid_known": known_energy.median().data, "mid_unknown": unknown_energy.median().data }