def train_classifier(networks, optimizers, dataloader, epoch=None, **options): for net in networks.values(): net.train() netC = networks['classifier_kplusone'] optimizerC = optimizers['classifier_kplusone'] batch_size = options['batch_size'] image_size = options['image_size'] dataset_filename = options.get('aux_dataset') if not dataset_filename or not os.path.exists(dataset_filename): raise ValueError("Aux Dataset not available") print("Using aux_dataset {}".format(dataset_filename)) aux_dataloader = FlexibleCustomDataloader(dataset_filename, batch_size=batch_size, image_size=image_size) for i, (images, class_labels) in enumerate(dataloader): images = Variable(images).cuda() labels = Variable(class_labels).cuda() ############################ # Classifier Update ############################ netC.zero_grad() # Classify real examples into the correct K classes classifier_logits = netC(images) augmented_logits = F.pad(classifier_logits, (0, 1)) # _, labels_idx = labels.max(dim=1) labels_idx = labels errC = F.nll_loss(F.log_softmax(augmented_logits, dim=1), labels_idx) errC.backward() # log.collect('Classifier Loss', errC) # Classify aux_dataset examples as open set aux_images, aux_labels = aux_dataloader.get_batch() classifier_logits = netC(Variable(aux_images)) augmented_logits = F.pad(classifier_logits, (0, 1)) log_soft_open = F.log_softmax(augmented_logits, dim=1)[:, -1] errOpenSet = -log_soft_open.mean() errOpenSet.backward() # log.collect('Open Set Loss', errOpenSet) optimizerC.step() ############################ # Keep track of accuracy on positive-labeled examples for monitoring # log.collect_prediction('Classifier Accuracy', netC(images), labels) # log.print_every() results = { 'errC': errC.item(), 'errOpenSet': errOpenSet.item(), } return results
parser.add_argument('--mode', default='', help='If set to "baseline" use the baseline classifier') options = vars(parser.parse_args()) sys.path.append(os.path.dirname(os.path.dirname(__file__))) from dataloader import CustomDataloader, FlexibleCustomDataloader from training import train_classifier from networks import build_networks, save_networks, get_optimizers from options import load_options, get_current_epoch from comparison import evaluate_with_comparison from evaluation import save_evaluation options = load_options(options) dataloader = FlexibleCustomDataloader(fold='train', **options) networks = build_networks(dataloader.num_classes, **options) optimizers = get_optimizers(networks, finetune=True, **options) eval_dataloader = CustomDataloader(last_batch=True, shuffle=False, fold='test', **options) start_epoch = get_current_epoch(options['result_dir']) + 1 for epoch in range(start_epoch, start_epoch + options['epochs']): train_classifier(networks, optimizers, dataloader, epoch=epoch, **options) #print(networks['classifier_kplusone']) #weights = networks['classifier_kplusone'].fc1.weight eval_results = evaluate_with_comparison(networks, eval_dataloader, **options)
def train_classifier(networks, optimizers, dataloader, epoch=None, **options): for net in networks.values(): net.train() netD = networks['discriminator'] optimizerD = optimizers['discriminator'] result_dir = options['result_dir'] batch_size = options['batch_size'] image_size = options['image_size'] latent_size = options['latent_size'] # Hack: use a ground-truth dataset to test #dataset_filename = '/mnt/data/svhn-59.dataset' dataset_filename = os.path.join(options['result_dir'], 'aux_dataset.dataset') aux_dataloader = FlexibleCustomDataloader(dataset_filename, batch_size=batch_size, image_size=image_size) start_time = time.time() correct = 0 total = 0 for i, (images, class_labels) in enumerate(dataloader): images = Variable(images) labels = Variable(class_labels) ############################ # Discriminator Updates ########################### netD.zero_grad() # Classify real examples into the correct K classes real_logits = netD(images) positive_labels = (labels == 1).type(torch.cuda.FloatTensor) augmented_logits = F.pad(real_logits, pad=(0,1)) augmented_labels = F.pad(positive_labels, pad=(0,1)) log_likelihood = F.log_softmax(augmented_logits, dim=1) * augmented_labels errC = -0.5 * log_likelihood.mean() # Classify the user-labeled (active learning) examples aux_images, aux_labels = aux_dataloader.get_batch() aux_images = Variable(aux_images) aux_labels = Variable(aux_labels) aux_logits = netD(aux_images) augmented_logits = F.pad(aux_logits, pad=(0,1)) augmented_labels = F.pad(aux_labels, pad=(0, 1)) augmented_positive_labels = (augmented_labels == 1).type(torch.FloatTensor).cuda() is_positive = (aux_labels.max(dim=1)[0] == 1).type(torch.FloatTensor).cuda() is_negative = 1 - is_positive fake_log_likelihood = F.log_softmax(augmented_logits, dim=1)[:,-1] * is_negative #real_log_likelihood = augmented_logits[:,-1].abs() * is_positive real_log_likelihood = (F.log_softmax(augmented_logits, dim=1) * augmented_positive_labels).sum(dim=1) errC -= fake_log_likelihood.mean() errC -= 0.5 * real_log_likelihood.mean() errC.backward() optimizerD.step() ############################ # Keep track of accuracy on positive-labeled examples for monitoring _, pred_idx = real_logits.max(1) _, label_idx = labels.max(1) correct += sum(pred_idx == label_idx).data.cpu().numpy()[0] total += len(labels) if i % 100 == 0: bps = (i+1) / (time.time() - start_time) ed = 0#errD.data[0] eg = 0#errG.data[0] ec = errC.data[0] acc = correct / max(total, 1) msg = '[{}][{}/{}] D:{:.3f} G:{:.3f} C:{:.3f} Acc. {:.3f} {:.3f} batch/sec' msg = msg.format( epoch, i+1, len(dataloader), ed, eg, ec, acc, bps) print(msg) print("Accuracy {}/{}".format(correct, total)) return True
def train_classifier(networks, optimizers, dataloader, epoch=None, **options): for net in networks.values(): net.train() netC = networks['classifier_kplusone'] optimizerC = optimizers['classifier_kplusone'] batch_size = options['batch_size'] image_size = options['image_size'] dataset_filename = options.get('aux_dataset') if not dataset_filename or not os.path.exists(dataset_filename): raise ValueError("Aux Dataset not available") print("Using aux_dataset {}".format(dataset_filename)) aux_dataloader = FlexibleCustomDataloader(dataset_filename, batch_size=batch_size, image_size=image_size) loss_class = losses.losses() for i, (images, class_labels) in enumerate(dataloader): images = Variable(images) # Following line FOR MNIST ONLY!!!!!!!! Remove otherwise #images = T.Pad(2).forward(images) labels = Variable(class_labels) ############################ # Classifier Update ############################ netC.zero_grad() # Classify real examples into the correct K classes #classifier_logits = netC(images) #augmented_logits = F.pad(classifier_logits, (0,1)) #_, labels_idx = labels.max(dim=1) # TODO:: Replace with Matt's loss function :: #errC = F.nll_loss(F.log_softmax(augmented_logits, dim=1), labels_idx) #errC.backward() classifier_logits = netC(images) _, labels_idx = labels.max(dim=1) #errC = loss_class.kliep_loss(classifier_logits, labels_idx) errC = loss_class.power_loss_05(classifier_logits, labels_idx) errC.backward() log.collect('Classifier Loss', errC) # Classify aux_dataset examples as open set aux_images, aux_labels = aux_dataloader.get_batch() #classifier_logits = netC(Variable(aux_images)) #augmented_logits = F.pad(classifier_logits, (0,1)) #log_soft_open = F.log_softmax(augmented_logits, dim=1)[:, -1] #errOpenSet = -log_soft_open.mean() #errOpenSet.backward() classifier_logits = netC(Variable(aux_images)) augmented_logits = F.pad(classifier_logits, (0, 1)) target_label = Variable(torch.LongTensor( classifier_logits.shape[0])).cuda() target_label[:] = classifier_logits.shape[1] #outputs.shape[1] #densityratio_loss = loss_class.kliep_loss(augmented_logits, target_label) densityratio_loss = loss_class.power_loss_05(augmented_logits, target_label) densityratio_loss.backward() log.collect('Open Set Loss', densityratio_loss) optimizerC.step() ############################ # Keep track of accuracy on positive-labeled examples for monitoring log.collect_prediction('Classifier Accuracy', netC(images), labels) log.print_every() return True
def main(): parser = argparse.ArgumentParser() parser.add_argument('--batch_size', type=int, default=64, help="batch size") parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') parser.add_argument('--fold', '-f', type=int, default=0, help='which fold you gonna train with') args = parser.parse_args() DATASET = 'tiny_imagenet-known-20-split' # MODEL = 'custom_classifier_9' MODEL = 'classifier32' fold_num = args.fold batch_size = args.batch_size is_train = True is_write = True start_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p') runs = 'runs/{}-{}{}-{}'.format(MODEL, DATASET, fold_num, start_time) if is_write: writer = SummaryWriter(runs) closed_trainloader = FlexibleCustomDataloader( fold='train', batch_size=batch_size, dataset='./data/{}{}a.dataset'.format(DATASET, fold_num)) closed_testloader = FlexibleCustomDataloader( fold='test', batch_size=batch_size, dataset='./data/{}{}a.dataset'.format(DATASET, fold_num)) open_trainloader = FlexibleCustomDataloader( fold='train', batch_size=batch_size, dataset='./data/{}{}b.dataset'.format(DATASET, fold_num)) open_testloader = FlexibleCustomDataloader( fold='test', batch_size=batch_size, dataset='./data/{}{}b.dataset'.format(DATASET, fold_num)) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) PATH = '{}/{}{}_custom_network_15'.format(runs, DATASET, fold_num) if is_train: net = classifier32() net.to(device) net.train() criterion = nn.CrossEntropyLoss() #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) optimizer = optim.Adam(net.parameters(), lr=0.0001) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[50, 100, 150, 200, 250, 300, 350, 400, 450], gamma=0.1) running_loss = 0.0 for epoch in range(30): for i, (images, labels) in enumerate(closed_trainloader, 0): images = Variable(images) images = images.cuda() labels = Variable(labels) optimizer.zero_grad() # writer.add_graph(net, images) outputs = net(images) labels = torch.argmax(labels, dim=1) # writer.add_embedding(outputs, metadata=class_labels, label_img=images.unsqueeze(1)) loss = criterion(outputs, labels) loss.backward() optimizer.step() #scheduler.step() running_loss += loss.item() if i % 100 == 99: if is_write: writer.add_scalar('training loss', running_loss / 100, epoch * len(closed_trainloader) + i) current_time = datetime.datetime.now().strftime( '%Y-%m-%d_%I-%M-%S-%p') print(current_time) print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100)) # writer.add_figure('predictions vs. actuals', # plot_classes_preds(net, images, labels)) running_loss = 0.0 if epoch % 50 == 49: torch.save(net.state_dict(), "{}_{}.pth".format(PATH, epoch + 1)) torch.save(net.state_dict(), "{}_latest.pth".format(PATH)) test_net = classifier32() # PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/custom_classifier_14-tiny_imagenet-known-20-split0-2020-08-26_08-25-01-AM" # PATH = '{}/{}{}_custom_classifier_13'.format(PATH_1, DATASET, fold_num) test_net.load_state_dict(torch.load("{}_latest.pth".format(PATH))) test_net.to(device) closed_acc = evalute_classifier(test_net, closed_testloader) print("closed-set accuracy: ", closed_acc) auc_d = evaluate_openset(test_net, closed_testloader, open_testloader) print("auc discriminator: ", auc_d) result_file = '{}/{}{}.txt'.format(runs, DATASET, fold_num) current_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p') if os.path.exists(result_file): f = open(result_file, 'a') f.write(current_time + "\n") f.write("{}{} \n".format(DATASET, fold_num)) f.write("{} epoch".format(i)) f.write("close-set accuracy: {} \n".format(closed_acc)) f.write("AUROC: {} \n".format(auc_d)) f.close() else: f = open(result_file, 'w') f.write(current_time + "\n") f.write("{}{} \n".format(DATASET, fold_num)) f.write("{} epoch".format(i)) f.write("close-set accuracy: {} \n".format(closed_acc)) f.write("AUROC: {} \n".format(auc_d)) f.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--batch_size', type=int, default=64, help="batch size") parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') parser.add_argument('--fold', '-f', type=int, default=0, help='which fold you gonna train with') parser.add_argument('--seed', type=int, default=None) parser.add_argument('--multi-eval', type=bool, default=False) parser.add_argument('--update-freq', type=int, default=1) args = parser.parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) if args.seed is None: args.seed = np.random.randint(100000) print("seed: {}".format(args.seed)) np.random.seed(args.seed) torch.manual_seed(args.seed) if device.type == 'cuda': torch.cuda.manual_seed(args.seed) DATASET = 'tiny_imagenet-known-20-split' # MODEL = 'custom_classifier_9' MODEL = 'hybrid' fold_num = args.fold batch_size = args.batch_size is_train = False is_write = False start_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p') runs = 'runs/{}-{}{}-{}'.format(MODEL, DATASET, fold_num, start_time) if is_write: writer = SummaryWriter(runs) closed_trainloader = FlexibleCustomDataloader( fold='train', batch_size=batch_size, dataset='./data/{}{}a.dataset'.format(DATASET, fold_num)) closed_testloader = FlexibleCustomDataloader( fold='test', batch_size=batch_size, dataset='./data/{}{}a.dataset'.format(DATASET, fold_num)) open_trainloader = FlexibleCustomDataloader( fold='train', batch_size=batch_size, dataset='./data/{}{}b.dataset'.format(DATASET, fold_num)) open_testloader = FlexibleCustomDataloader( fold='test', batch_size=batch_size, dataset='./data/{}{}b.dataset'.format(DATASET, fold_num)) batch_time = RunningAverageMeter(0.97) bpd_meter = RunningAverageMeter(0.97) logpz_meter = RunningAverageMeter(0.97) deltalogp_meter = RunningAverageMeter(0.97) firmom_meter = RunningAverageMeter(0.97) secmom_meter = RunningAverageMeter(0.97) gnorm_meter = RunningAverageMeter(0.97) ce_meter = RunningAverageMeter(0.97) PATH = '{}/{}{}_hybrid'.format(runs, DATASET, fold_num) if is_train: encoder = encoder32() encoder.to(device) encoder.train() flow = ResidualFlow(n_classes=20, input_size=(64, 128, 4, 4), n_blocks=[32, 32, 32], intermediate_dim=512, factor_out=False, quadratic=False, init_layer=None, actnorm=True, fc_actnorm=False, dropout=0, fc=False, coeff=0.98, vnorms='2222', n_lipschitz_iters=None, sn_atol=1e-3, sn_rtol=1e-3, n_power_series=None, n_dist='poisson', n_samples=1, kernels='3-1-3', activation_fn='swish', fc_end=True, n_exact_terms=2, preact=True, neumann_grad=True, grad_in_forward=False, first_resblock=True, learn_p=False, classification='hybrid', classification_hdim=256, block_type='resblock') flow.to(device) flow.train() classifier = classifier32() classifier.to(device) classifier.train() ema = ExponentialMovingAverage(flow) flow.train() criterion = nn.CrossEntropyLoss() # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) optimizer = optim.Adam(encoder.parameters(), lr=0.0001) optimizer_2 = optim.Adam(flow.parameters(), lr=0.0001) optimizer_3 = optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9) # optimizer_3 = optim.Adam(classifier.parameters(), lr=0.0001) # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, # milestones=[50, 100, 150, 200, 250, 300, 350, 400, 450], # gamma=0.1) beta = 1 running_loss = 0.0 running_bpd = 0.0 running_cls = 0.0 best_loss = 1000 tau = 100000 for epoch in range(600): for i, (images, labels) in enumerate(closed_trainloader, 0): global_itr = epoch * len(closed_trainloader) + i images = Variable(images) images = images.cuda() labels = Variable(labels) # writer.add_graph(net, images) outputs = encoder(images) bpd, logits, logpz, neg_delta_logp = compute_loss(outputs, flow, beta=beta) cls_outputs = classifier(outputs) labels = torch.argmax(labels, dim=1) cls_loss = criterion(cls_outputs, labels) firmom, secmom = estimator_moments(flow) bpd_meter.update(bpd.item()) logpz_meter.update(logpz.item()) deltalogp_meter.update(neg_delta_logp.item()) firmom_meter.update(firmom) secmom_meter.update(secmom) loss = bpd + cls_loss # # loss.backward() # # labels = torch.argmax(labels, dim=1) # # # writer.add_embedding(outputs, metadata=class_labels, label_img=images.unsqueeze(1)) # loss = criterion(outputs, labels) loss.backward() if global_itr % args.update_freq == args.update_freq - 1: if args.update_freq > 1: with torch.no_grad(): for p in flow.parameters(): if p.grad is not None: p.grad /= args.update_freq grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_( flow.parameters(), 1.) optimizer.step() optimizer_2.step() optimizer_3.step() optimizer.zero_grad() optimizer_2.zero_grad() optimizer_3.zero_grad() update_lipschitz(flow) ema.apply() gnorm_meter.update(grad_norm) running_bpd += bpd.item() running_cls += cls_loss.item() running_loss += loss.item() if i % 100 == 99: if is_write: writer.add_scalar('bits per dimension', running_bpd / 100, global_itr) writer.add_scalar('classification loss', running_cls / 100, global_itr) writer.add_scalar('total loss', running_loss / 100, global_itr) current_time = datetime.datetime.now().strftime( '%Y-%m-%d_%I-%M-%S-%p') print(current_time) print( '[%d, %5d] bpd: %.3f, cls_loss: %.3f, total_loss: %.3f' % (epoch + 1, i + 1, running_bpd / 100, running_cls / 100, running_loss / 100)) if epoch > 1 and running_loss / 100 < best_loss: best_loss = running_loss / 100 print("best loss updated! :", best_loss) torch.save( { 'state_dict': flow.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'args': args, 'ema': ema, }, "{}_flow_best.pth".format(PATH)) torch.save(encoder.state_dict(), "{}_encoder_best.pth".format(PATH)) torch.save(classifier.state_dict(), "{}_classifier_best.pth".format(PATH)) # writer.add_figure('predictions vs. actuals', # plot_classes_preds(net, images, labels)) running_loss = 0.0 running_bpd = 0.0 running_cls = 0.0 del images torch.cuda.empty_cache() gc.collect() if epoch % 50 == 49: torch.save( { 'state_dict': flow.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'args': args, 'ema': ema, }, "{}_flow_{}.pth".format(PATH, epoch + 1)) torch.save(encoder.state_dict(), "{}_encoder_{}.pth".format(PATH, epoch + 1)) torch.save(classifier.state_dict(), "{}_classifier_{}.pth".format(PATH, epoch + 1)) PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM" PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num) if args.multi_eval: for i in range(50, 550, 50): test_encoder = encoder32() test_encoder.to(device) test_encoder.load_state_dict( torch.load("{}_encoder_{}.pth".format(PATH, i))) # state_dict = torch.load("{}_encoder_{}.pth".format(PATH, i)) # # create new OrderedDict that does not contain `module.` # # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = k[7:] # remove `module.` # new_state_dict[name] = v # # load params # test_encoder.load_state_dict(new_state_dict) test_classifier = classifier32() test_classifier.to(device) # state_dict = torch.load("{}_classifier_{}.pth".format(PATH, i)) # # create new OrderedDict that does not contain `module.` # # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = k[7:] # remove `module.` # new_state_dict[name] = v # # load params # test_classifier.load_state_dict(new_state_dict) test_classifier.load_state_dict( torch.load("{}_classifier_{}.pth".format(PATH, i))) test_flow = ResidualFlow(n_classes=20, input_size=(64, 128, 4, 4), n_blocks=[32, 32, 32], intermediate_dim=512, factor_out=False, quadratic=False, init_layer=None, actnorm=True, fc_actnorm=False, dropout=0, fc=False, coeff=0.98, vnorms='2222', n_lipschitz_iters=None, sn_atol=1e-3, sn_rtol=1e-3, n_power_series=None, n_dist='poisson', n_samples=1, kernels='3-1-3', activation_fn='swish', fc_end=True, n_exact_terms=2, preact=True, neumann_grad=True, grad_in_forward=False, first_resblock=True, learn_p=False, classification='hybrid', classification_hdim=256, block_type='resblock') test_flow.to(device) with torch.no_grad(): x = torch.rand(1, *input_size[1:]).to(device) test_flow(x) checkpt = torch.load("{}_flow_{}.pth".format(PATH, i)) sd = { k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k } state = test_flow.state_dict() state.update(sd) test_flow.load_state_dict(state, strict=True) # test_ema.set(checkpt['ema']) hybrid = HybridModel(test_encoder, test_classifier, test_flow) closed_acc = evalute_classifier(hybrid, closed_testloader) print("closed-set accuracy: ", closed_acc) auc_d = evaluate_openset(hybrid, closed_testloader, open_testloader) print("auc discriminator: ", auc_d) result_file = '{}/{}{}.txt'.format(runs, DATASET, fold_num) current_time = datetime.datetime.now().strftime( '%Y-%m-%d_%I-%M-%S-%p') if is_write: if os.path.exists(result_file): f = open(result_file, 'a') f.write(current_time + "\n") f.write("seed: {}\n".format(args.seed)) f.write("{}{} \n".format(DATASET, fold_num)) f.write("{} epoch".format(i)) f.write("close-set accuracy: {} \n".format(closed_acc)) f.write("AUROC: {} \n".format(auc_d)) f.close() else: f = open(result_file, 'w') f.write(current_time + "\n") f.write("seed: {}\n".format(args.seed)) f.write("{}{} \n".format(DATASET, fold_num)) f.write("{} epoch".format(i)) f.write("close-set accuracy: {} \n".format(closed_acc)) f.write("AUROC: {} \n".format(auc_d)) f.close() else: PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM" PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num) test_encoder = encoder32() test_encoder.to(device) test_encoder.load_state_dict( torch.load("{}_encoder_latest.pth".format(PATH))) test_classifier = classifier32() test_classifier.to(device) test_classifier.load_state_dict( torch.load("{}_classifier_latest.pth".format(PATH))) test_flow = ResidualFlow(n_classes=20, input_size=(64, 128, 4, 4), n_blocks=[32, 32, 32], intermediate_dim=512, factor_out=False, quadratic=False, init_layer=None, actnorm=True, fc_actnorm=False, dropout=0, fc=False, coeff=0.98, vnorms='2222', n_lipschitz_iters=None, sn_atol=1e-3, sn_rtol=1e-3, n_power_series=None, n_dist='poisson', n_samples=1, kernels='3-1-3', activation_fn='swish', fc_end=True, n_exact_terms=2, preact=True, neumann_grad=True, grad_in_forward=False, first_resblock=True, learn_p=False, classification='hybrid', classification_hdim=256, block_type='resblock') test_flow.to(device) with torch.no_grad(): x = torch.rand(1, *input_size[1:]).to(device) test_flow(x) checkpt = torch.load("{}_flow_latest.pth".format(PATH)) sd = { k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k } state = test_flow.state_dict() state.update(sd) test_flow.load_state_dict(state, strict=True) hybrid = HybridModel(test_encoder, test_classifier, test_flow) closed_acc = evalute_classifier(hybrid, closed_testloader) print("closed-set accuracy: ", closed_acc) auc_d = evaluate_openset(hybrid, closed_testloader, open_testloader) print("auc discriminator: ", auc_d)