def evaluate(): nonlocal best_metric_value nonlocal patience_elapsed nonlocal stop nonlocal epoch corrects = [] for _ in tqdm(range(args['data.test_episodes']), desc="Epoch {:d} Val".format(epoch + 1)): sample = load_episode(val_data, test_tr, args['data.test_way'], args['data.test_shot'], args['data.test_query'], device) corrects.append(classification_accuracy(sample, model)[0]) val_acc = torch.mean(torch.cat(corrects)) iteration_logger.writerow({ 'global_iteration': epoch, 'val_acc': val_acc.item() }) plot_csv(iteration_logger.filename, iteration_logger.filename) print(f"Epoch {epoch}: Val Acc: {val_acc}") if val_acc > best_metric_value: best_metric_value = val_acc print("==> best model (metric = {:0.6f}), saving model...".format( best_metric_value)) model.cpu() torch.save(model, os.path.join(args['log.exp_dir'], 'best_model.pt')) model.to(device) patience_elapsed = 0 else: patience_elapsed += 1 if patience_elapsed > args['train.patience']: print("==> patience {:d} exceeded".format( args['train.patience'])) stop = True
def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) real_acc = fake_acc = acc = 0 if weight_gan > 0: fake = generate_from_noise(model, x.size(0), clamp=clamp) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) D_real_accuracy = torch.sum( torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) D_fake_accuracy = torch.sum( torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty( x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(model, x.size(0), clamp=clamp, guard_nans=False) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) # Trace real_acc = D_real_accuracy.item() fake_acc = D_fake_accuracy.item() acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item()) z, nll, y_logits, (prior, logdet) = model.forward(x, None, return_details=True) train_bpd = nll.mean().item() loss = 0 if weight_gan > 0: loss = loss + weight_gan * G_loss if weight_prior > 0: loss = loss + weight_prior * -prior.mean() if weight_logdet > 0: loss = loss + weight_logdet * -logdet.mean() if weight_entropy_reg > 0: _, _, _, (sample_prior, sample_logdet) = model.forward(fake, None, return_details=True) # notice this is actually "decreasing" sample likelihood. loss = loss + weight_entropy_reg * (sample_prior.mean() + sample_logdet.mean()) # Jac Reg if jac_reg_lambda > 0: # Sample x_samples = generate_from_noise(model, args.batch_size, clamp=clamp).detach() x_samples.requires_grad_() z = model.forward(x_samples, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) sample_foward_jac = compute_jacobian_regularizer(x_samples, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) randz = torch.randn(zshape).to(device) randz = torch.autograd.Variable(randz, requires_grad=True) images = model(z=randz, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [randz] + other_zs sample_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # Data x.requires_grad_() z = model.forward(x, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) z.requires_grad_() images = model(z=z, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [z] + other_zs data_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac ) loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac + data_foward_jac + data_inverse_jac) if not eval_only: optimizer.zero_grad() loss.backward() if not db: assert max_grad_clip == max_grad_norm == 0 if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Replace NaN gradient with 0 for p in model.parameters(): if p.requires_grad and p.grad is not None: g = p.grad.data g[g != g] = 0 optimizer.step() if engine.iter_ind % 100 == 0: with torch.no_grad(): fake = generate_from_noise(model, x.size(0), clamp=clamp) z = model.forward(fake, None, return_details=True)[0] print("Z max min") print(z.max().item(), z.min().item()) if (fake != fake).float().sum() > 0: title = 'NaNs' else: title = "Good" grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.title(title) plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) if engine.iter_ind % eval_every == 0: def check_all_zero_except_leading(x): return x % 10**np.floor(np.log10(x)) == 0 if engine.iter_ind == 0 or check_all_zero_except_leading( engine.iter_ind): torch.save( model.state_dict(), os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt')) model.eval() with torch.no_grad(): # Plot recon fpath = os.path.join(output_dir, '_recon', f'recon_{engine.iter_ind}.png') sample_pad = run_recon_evolution( model, generate_from_noise(model, args.batch_size, clamp=clamp).detach(), fpath) print( f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}") pad = run_recon_evolution(model, x_for_recon, fpath) print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}") pad = pad.item() sample_pad = sample_pad.item() # Inception score sample = torch.cat([ generate_from_noise(model, args.batch_size, clamp=clamp) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] sample = sample + .5 if (sample != sample).float().sum() > 0: print("Sample NaNs") raise else: fid = run_fid(x_real_inception.clamp_(0, 1), sample.clamp_(0, 1)) print(f'fid: {fid}, global_iter: {engine.iter_ind}') # Eval BPD eval_bpd = np.mean([ model.forward(x.to(device), None, return_details=True)[1].mean().item() for x, _ in test_loader ]) stats_dict = { 'global_iteration': engine.iter_ind, 'fid': fid, 'train_bpd': train_bpd, 'pad': pad, 'eval_bpd': eval_bpd, 'sample_pad': sample_pad, 'batch_real_acc': real_acc, 'batch_fake_acc': fake_acc, 'batch_acc': acc } iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename) model.train() if engine.iter_ind + 2 % svd_every == 0: model.eval() svd_dict = {} ret = utils.computeSVDjacobian(x_for_recon, model) D_for, D_inv = ret['D_for'], ret['D_inv'] cn = float(D_for.max() / D_for.min()) cn_inv = float(D_inv.max() / D_inv.min()) svd_dict['global_iteration'] = engine.iter_ind svd_dict['condition_num'] = cn svd_dict['max_sv'] = float(D_for.max()) svd_dict['min_sv'] = float(D_for.min()) svd_dict['inverse_condition_num'] = cn_inv svd_dict['inverse_max_sv'] = float(D_inv.max()) svd_dict['inverse_min_sv'] = float(D_inv.min()) svd_logger.writerow(svd_dict) # plot_utils.plot_stability_stats(output_dir) # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv') model.train() if eval_only: sys.exit() # Dummy losses['total_loss'] = torch.mean(nll).item() return losses
def main(args): # use_cuda = not args.no_cuda and torch.cuda.is_available() set_random_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} if args.dataset == 'mnist': train_data = get_dataset('mnist-train', args.dataroot) test_data = get_dataset('mnist-test', args.dataroot) train_tr = test_tr = get_transform('mnist_normalize') if args.dataset == 'cifar10': train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize' train_data = get_dataset('cifar10-train', args.dataroot) test_data = get_dataset('cifar10-test', args.dataroot) train_tr = get_transform(train_tr_name) test_tr = get_transform('cifar_normalize') if args.dataset == 'cifar-fs-train': train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize' train_data = get_dataset('cifar-fs-train-train', args.dataroot) test_data = get_dataset('cifar-fs-train-test', args.dataroot) train_tr = get_transform(train_tr_name) test_tr = get_transform('cifar_normalize') if args.dataset == 'miniimagenet': train_data = get_dataset('miniimagenet-train-train', args.dataroot) test_data = get_dataset('miniimagenet-train-test', args.dataroot) train_tr = get_transform('cifar_augment_normalize_84' if args.data_augmentation else 'cifar_normalize') test_tr = get_transform('cifar_normalize') model = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device) if args.ckpt_path != '': loaded = torch.load(args.ckpt_path) model.load_state_dict(loaded) ipdb.set_trace() if args.eval: acc = test(args, model, device, test_loader, args.n_eval_batches) print("Eval Acc: ", acc) sys.exit() # Trace logging mkdir(args.output_dir) eval_fieldnames = ['global_iteration','val_acc','train_acc'] eval_logger = CSVLogger(every=1, fieldnames=eval_fieldnames, resume=args.resume, filename=os.path.join(args.output_dir, 'eval_log.csv')) wandb.run.name = os.path.basename(args.output_dir) wandb.run.save() wandb.watch(model) if args.optim == 'adadelta': optimizer = optim.Adadelta(model.parameters(), lr=args.lr) elif args.optim == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr) elif args.optim == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=5e-4) if args.dataset == 'mnist': scheduler = StepLR(optimizer, step_size=1, gamma=.7) else: scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) start_epoch = 1 if args.resume: last_ckpt_path = os.path.join(args.output_dir, 'last_ckpt.pt') if os.path.exists(last_ckpt_path): loaded = torch.load(last_ckpt_path) model.load_state_dict(loaded['model_sd']) optimizer.load_state_dict(loaded['optimizer_sd']) scheduler.load_state_dict(loaded['scheduler_sd']) start_epoch = loaded['epoch'] # It's important to set seed again before training b/c dataloading code # might have reset the seed. set_random_seed(args.seed) best_val = 0 if args.db: scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4], gamma=0.1) args.epochs = 5 for epoch in range(start_epoch, args.epochs + 1): if epoch % args.ckpt_every == 0: torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_{epoch}.pt")) stats_dict = {'global_iteration':epoch} val = stats_dict['val_acc'] = test(args, model, device, test_data, test_tr, args.n_eval_batches) stats_dict['train_acc'] = test(args, model, device, train_data, test_tr, args.n_eval_batches) grid = make_grid(torch.stack([train_tr(x) for x in train_data['x'][:30]]), nrow=6).permute(1,2,0).numpy() img_dict = {"examples": [wandb.Image(grid, caption="Data batch")]} wandb.log(stats_dict) wandb.log(img_dict) eval_logger.writerow(stats_dict) plot_csv(eval_logger.filename, os.path.join(args.output_dir, 'iteration_plots.png')) train(args, model, device, train_data, train_tr, optimizer, epoch) scheduler.step(epoch) if val > best_val: best_val = val torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_best.pt")) # For `resume` model.cpu() torch.save({ 'model_sd': model.state_dict(), 'optimizer_sd': optimizer.state_dict(), 'scheduler_sd': scheduler.state_dict(), 'epoch': epoch + 1 }, os.path.join(args.output_dir, "last_ckpt.pt")) model.to(device)
x[mask] = k _replace_nan_with_k_inplace(x_is, -1) with torch.no_grad(): issf, _, _, acts_fake = inception_score(x_is, cuda=True, batch_size=32, resize=True, splits=10, return_preds=True) idxs_ = np.argsort(np.abs(acts_fake).sum(-1))[:1800] # filter the ones with super large values acts_fake = acts_fake[idxs_] # ipdb.set_trace() m1, s1 = calculate_activation_statistics(acts_real) m2, s2 = calculate_activation_statistics(acts_fake) try: fid_value = calculate_frechet_distance(m1, s1, m2, s2) except: # ipdb.set_trace() # This mostly happens when there are "a few really bad samples", which # results in, say, usually large activation (1e30). # These "activation outliers" mess up the statistics, and results in # ValueError fid_value = 2000 print (idx, issf, fid_value) stats_dict = { 'global_iteration': idx , 'fid': fid_value } iteration_logger.writerow(stats_dict) try: plot_csv(iteration_logger.filename) except: pass # ipdb.set_trace()
save_dir, 'svd_log.csv' ) # Makes separate PDFs for each logged measure except: print('Something went wrong when computing the SVD...') try: plot_utils.plot_individual_figures( save_dir, 'every_N_log.csv' ) # Makes separate PDFs for each logged measure except: pass if args.plot_csv: try: plot_csv(iteration_logger.filename, key_name='global_iteration', yscale='log') except: pass model.train() projection.train() if torch.isnan(loss): print('=' * 80) print('Loss is NaN. Quitting...') print('=' * 80) sys.exit(1) global_iteration += 1
def main(opt): # Logging trace_file = os.path.join(opt['output_dir'], '{}_trace.txt'.format(opt['exp_name'])) # Load data if opt['dataset'] == 'cifar-fs': train_data = get_dataset('cifar-fs-train-train', opt['dataroot']) val_data = get_dataset('cifar-fs-val', opt['dataroot']) test_data = get_dataset('cifar-fs-test', opt['dataroot']) tr = get_transform('cifar_resize_normalize') normalize = cifar_normalize elif opt['dataset'] == 'miniimagenet': train_data = get_dataset('miniimagenet-train-train', opt['dataroot']) val_data = get_dataset('miniimagenet-val', opt['dataroot']) test_data = get_dataset('miniimagenet-test', opt['dataroot']) tr = get_transform('cifar_resize_normalize_84') normalize = cifar_normalize if opt['input_regularization'] == 'oe': reg_data = load_ood_data({ 'name': 'tinyimages', 'ood_scale': 1, 'n_anom': 50000, }) if not opt['ooe_only']: if opt['db']: ood_distributions = ['ooe', 'gaussian'] else: ood_distributions = [ 'ooe', 'gaussian', 'rademacher', 'texture3', 'svhn', 'tinyimagenet', 'lsun' ] if opt['input_regularization'] == 'oe': ood_distributions.append('tinyimages') ood_tensors = [('ooe', None)] + [(out_name, load_ood_data({ 'name': out_name, 'ood_scale': 1, 'n_anom': 10000, })) for out_name in ood_distributions[1:]] # Load trained model loaded = torch.load(opt['model.model_path']) if not isinstance(loaded, OrderedDict): fs_model = loaded else: classifier = ResNetClassifier(64, train_data['im_size']).to(device) classifier.load_state_dict(loaded) fs_model = Protonet(classifier.encoder) fs_model.eval() fs_model = fs_model.to(device) # Init Confidence Methods if opt['confidence_method'] == 'oec': init_sample = load_episode(train_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], device) conf_model = OECConfidence(None, fs_model, init_sample, opt) elif opt['confidence_method'] == 'deep-oec': init_sample = load_episode(train_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], device) conf_model = DeepOECConfidence(None, fs_model, init_sample, opt) elif opt['confidence_method'] == 'dm-iso': encoder = fs_model.encoder deep_mahala_obj = DeepMahala(None, None, None, encoder, device, num_feats=encoder.depth, num_classes=train_data['n_classes'], pretrained_path="", fit=False, normalize=None) conf_model = DMConfidence(deep_mahala_obj, { 'ls': range(encoder.depth), 'reduction': 'max', 'g_magnitude': .1 }, True, 'iso') if opt['pretrained_oec_path']: conf_model.load_state_dict(torch.load(opt['pretrained_oec_path'])) conf_model.to(device) print(conf_model) optimizer = optim.Adam(conf_model.confidence_parameters(), lr=opt['lr'], weight_decay=opt['wd']) scheduler = StepLR(optimizer, step_size=opt['lrsche_step_size'], gamma=opt['lrsche_gamma']) num_param = sum(p.numel() for p in conf_model.confidence_parameters()) print(f"Learning Confidence, Number of Parameters -- {num_param}") if conf_model.pretrain_parameters() is not None: pretrain_optimizer = optim.Adam(conf_model.pretrain_parameters(), lr=10) pretrain_iter = 100 start_idx = 0 if opt['resume']: last_ckpt_path = os.path.join(opt['output_dir'], 'last_ckpt.pt') if os.path.exists(last_ckpt_path): try: last_ckpt = torch.load(last_ckpt_path) if 'conf_model' in last_ckpt: conf_model = last_ckpt['conf_model'] else: sd = last_ckpt['conf_model_sd'] conf_model.load_state_dict(sd) optimizer = last_ckpt['optimizer'] pretrain_optimizer = last_ckpt['pretrain_optimizer'] scheduler = last_ckpt['scheduler'] start_idx = last_ckpt['outer_idx'] conf_model.to(device) except EOFError: print( "\n\nResuming but got EOF error, starting from init..\n\n") wandb.run.name = opt['exp_name'] wandb.run.save() # try: wandb.watch(conf_model) # except: # resuming a run # pass # Eval and Logging confs = { opt['confidence_method']: conf_model, } if opt['confidence_method'] == 'oec': confs['ed'] = FSCConfidence(fs_model, 'ed') elif opt['confidence_method'] == 'deep-oec': encoder = fs_model.encoder deep_mahala_obj = DeepMahala(None, None, None, encoder, device, num_feats=encoder.depth, num_classes=train_data['n_classes'], pretrained_path="", fit=False, normalize=None) confs['dm'] = DMConfidence(deep_mahala_obj, { 'ls': range(encoder.depth), 'reduction': 'max', 'g_magnitude': 0 }, True, 'iso').to(device) # Temporal Ensemble for Evaluation if opt['n_ensemble'] > 1: nets = [deepcopy(conf_model) for _ in range(opt['n_ensemble'])] confs['mixture-' + opt['confidence_method']] = Ensemble( nets, 'mixture') confs['poe-' + opt['confidence_method']] = Ensemble(nets, 'poe') ensemble_update_interval = opt['eval_every_outer'] // opt['n_ensemble'] iteration_fieldnames = ['global_iteration'] for c in confs: iteration_fieldnames += [ f'{c}_train_ooe', f'{c}_val_ooe', f'{c}_test_ooe', f'{c}_ood' ] iteration_logger = CSVLogger(every=0, fieldnames=iteration_fieldnames, filename=os.path.join(opt['output_dir'], 'iteration_log.csv')) best_val_ooe = 0 PATIENCE = 5 # Number of evaluations to wait waited = 0 progress_bar = tqdm(range(start_idx, opt['train_iter'])) for outer_idx in progress_bar: sample = load_episode(train_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], device) conf_model.train() if opt['full_supervision']: # sanity check conf_model.support(sample['xs']) in_score = conf_model.score(sample['xq'], detach=False).squeeze() out_score = conf_model.score(sample['ooc_xq'], detach=False).squeeze() out_scores = [out_score] for curr_ood, ood_tensor in ood_tensors: if curr_ood == 'ooe': continue start = outer_idx % (len(ood_tensor) // 2) stop = min( start + sample['xq'].shape[0] * sample['xq'].shape[0], len(ood_tensor) // 2) oxq = torch.stack([tr(x) for x in ood_tensor[start:stop]]).to(device) o = conf_model.score(oxq, detach=False).squeeze() out_scores.append(o) # out_score = torch.cat(out_scores) in_score = in_score.repeat(len(ood_tensors)) loss, acc = compute_loss_bce(in_score, out_score, mean_center=False) else: conf_model.support(sample['xs']) if opt['interpolate']: half_n_way = sample['xq'].shape[0] // 2 interp = .5 * (sample['xq'][:half_n_way] + sample['xq'][half_n_way:2 * half_n_way]) sample['ooc_xq'][:half_n_way] = interp if opt['input_regularization'] == 'oe': # Reshape ooc_xq nw, nq, c, h, w = sample['ooc_xq'].shape sample['ooc_xq'] = sample['ooc_xq'].view(1, nw * nq, c, h, w) oe_bs = int(nw * nq * opt['input_regularization_percent']) start = (outer_idx * oe_bs) % len(reg_data) end = np.min([start + oe_bs, len(reg_data)]) oe_batch = torch.stack([tr(x) for x in reg_data[start:end] ]).to(device) oe_batch = oe_batch.unsqueeze(0) sample['ooc_xq'][:, :oe_batch.shape[1]] = oe_batch if opt['in_out_1_batch']: inps = torch.cat([sample['xq'], sample['ooc_xq']], 1) scores = conf_model.score(inps, detach=False).squeeze() in_score, out_score = scores[:sample['xq'].shape[1]], scores[ sample['xq'].shape[1]:] else: in_score = conf_model.score(sample['xq'], detach=False).squeeze() out_score = conf_model.score(sample['ooc_xq'], detach=False).squeeze() loss, acc = compute_loss_bce(in_score, out_score, mean_center=False) if conf_model.pretrain_parameters( ) is not None and outer_idx < pretrain_iter: pretrain_optimizer.zero_grad() loss.backward() pretrain_optimizer.step() else: optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() progress_bar.set_postfix(loss='{:.3e}'.format(loss), acc='{:.3e}'.format(acc)) # Update Ensemble if opt['n_ensemble'] > 1 and outer_idx % ensemble_update_interval == 0: update_ind = (outer_idx // ensemble_update_interval) % opt['n_ensemble'] if opt['db']: print(f"===> Updating Ensemble: {update_ind}") confs['mixture-' + opt['confidence_method']].nets[update_ind] = deepcopy( conf_model) confs['poe-' + opt['confidence_method']].nets[update_ind] = deepcopy( conf_model) # AUROC eval if outer_idx % opt['eval_every_outer'] == 0: if not opt['eval_in_train']: conf_model.eval() # Eval.. stats_dict = {'global_iteration': outer_idx} for conf_name, conf in confs.items(): conf.eval() # OOE eval ooe_aurocs = {} for split, in_data in [('train', train_data), ('val', val_data), ('test', test_data)]: auroc = np.mean( eval_ood_aurocs( None, in_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], opt['data.test_episodes'], device, conf, no_grad=False if opt['confidence_method'].startswith('dm') else True)['aurocs']) ooe_aurocs[split] = auroc print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format( conf_name, outer_idx, split, ooe_aurocs[split]) _print_and_log(print_str, trace_file) stats_dict[f'{conf_name}_train_ooe'] = ooe_aurocs['train'] stats_dict[f'{conf_name}_val_ooe'] = ooe_aurocs['val'] stats_dict[f'{conf_name}_test_ooe'] = ooe_aurocs['test'] # OOD eval if not opt['ooe_only']: aurocs = [] for curr_ood, ood_tensor in ood_tensors: auroc = np.mean( eval_ood_aurocs( ood_tensor, test_data, tr, opt['data.test_way'], opt['data.test_shot'], opt['data.test_query'], opt['data.test_episodes'], device, conf, no_grad=False if opt['confidence_method'].startswith('dm') else True)['aurocs']) aurocs.append(auroc) print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format( conf_name, outer_idx, curr_ood, auroc) _print_and_log(print_str, trace_file) mean_ood_auroc = np.mean(aurocs) print_str = '{}, iter: {} (OOD_mean), auroc: {:.3e}'.format( conf_name, outer_idx, mean_ood_auroc) _print_and_log(print_str, trace_file) stats_dict[f'{conf_name}_ood'] = mean_ood_auroc iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename, iteration_logger.filename) wandb.log(stats_dict) if stats_dict[f'{opt["confidence_method"]}_val_ooe'] > best_val_ooe: conf_model.cpu() torch.save( conf_model.state_dict(), os.path.join(opt['output_dir'], opt['exp_name'] + '_conf_best.pt')) conf_model.to(device) # Ckpt ensemble if opt['n_ensemble'] > 1: ensemble = confs['mixture-' + opt['confidence_method']] ensemble.cpu() torch.save( ensemble.state_dict(), os.path.join(opt['output_dir'], opt['exp_name'] + '_ensemble_best.pt')) ensemble.to(device) waited = 0 else: waited += 1 if waited >= PATIENCE: print("PATIENCE exceeded...exiting") sys.exit() # For `resume` conf_model.cpu() torch.save( { 'conf_model_sd': conf_model.state_dict(), 'optimizer': optimizer, 'pretrain_optimizer': pretrain_optimizer if conf_model.pretrain_parameters() is not None else None, 'scheduler': scheduler, 'outer_idx': outer_idx, }, os.path.join(opt['output_dir'], 'last_ckpt.pt')) conf_model.to(device) conf_model.train() sys.exit()