Esempio n. 1
0
    def evaluate():

        nonlocal best_metric_value
        nonlocal patience_elapsed
        nonlocal stop
        nonlocal epoch

        corrects = []
        for _ in tqdm(range(args['data.test_episodes']),
                      desc="Epoch {:d} Val".format(epoch + 1)):
            sample = load_episode(val_data, test_tr, args['data.test_way'],
                                  args['data.test_shot'],
                                  args['data.test_query'], device)
            corrects.append(classification_accuracy(sample, model)[0])
        val_acc = torch.mean(torch.cat(corrects))
        iteration_logger.writerow({
            'global_iteration': epoch,
            'val_acc': val_acc.item()
        })
        plot_csv(iteration_logger.filename, iteration_logger.filename)

        print(f"Epoch {epoch}: Val Acc: {val_acc}")

        if val_acc > best_metric_value:
            best_metric_value = val_acc
            print("==> best model (metric = {:0.6f}), saving model...".format(
                best_metric_value))
            model.cpu()
            torch.save(model, os.path.join(args['log.exp_dir'],
                                           'best_model.pt'))
            model.to(device)
            patience_elapsed = 0

        else:
            patience_elapsed += 1
            if patience_elapsed > args['train.patience']:
                print("==> patience {:d} exceeded".format(
                    args['train.patience']))
                stop = True
Esempio n. 2
0
    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        real_acc = fake_acc = acc = 0
        if weight_gan > 0:
            fake = generate_from_noise(model, x.size(0), clamp=clamp)

            D_real_scores = run_noised_disc(discriminator, x.detach())
            D_fake_scores = run_noised_disc(discriminator, fake.detach())

            ones_target = torch.ones((x.size(0), 1), device=x.device)
            zeros_target = torch.zeros((x.size(0), 1), device=x.device)

            D_real_accuracy = torch.sum(
                torch.round(F.sigmoid(D_real_scores)) ==
                ones_target).float() / ones_target.size(0)
            D_fake_accuracy = torch.sum(
                torch.round(F.sigmoid(D_fake_scores)) ==
                zeros_target).float() / zeros_target.size(0)

            D_real_loss = F.binary_cross_entropy_with_logits(
                D_real_scores, ones_target)
            D_fake_loss = F.binary_cross_entropy_with_logits(
                D_fake_scores, zeros_target)

            D_loss = (D_real_loss + D_fake_loss) / 2
            gp = gradient_penalty(
                x.detach(), fake.detach(),
                lambda _x: run_noised_disc(discriminator, _x))
            D_loss_plus_gp = D_loss + 10 * gp
            D_optimizer.zero_grad()
            D_loss_plus_gp.backward()
            D_optimizer.step()

            # Train generator
            fake = generate_from_noise(model,
                                       x.size(0),
                                       clamp=clamp,
                                       guard_nans=False)
            G_loss = F.binary_cross_entropy_with_logits(
                run_noised_disc(discriminator, fake),
                torch.ones((x.size(0), 1), device=x.device))

            # Trace
            real_acc = D_real_accuracy.item()
            fake_acc = D_fake_accuracy.item()
            acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item())

        z, nll, y_logits, (prior, logdet) = model.forward(x,
                                                          None,
                                                          return_details=True)
        train_bpd = nll.mean().item()

        loss = 0
        if weight_gan > 0:
            loss = loss + weight_gan * G_loss
        if weight_prior > 0:
            loss = loss + weight_prior * -prior.mean()
        if weight_logdet > 0:
            loss = loss + weight_logdet * -logdet.mean()

        if weight_entropy_reg > 0:
            _, _, _, (sample_prior,
                      sample_logdet) = model.forward(fake,
                                                     None,
                                                     return_details=True)
            # notice this is actually "decreasing" sample likelihood.
            loss = loss + weight_entropy_reg * (sample_prior.mean() +
                                                sample_logdet.mean())
        # Jac Reg
        if jac_reg_lambda > 0:
            # Sample
            x_samples = generate_from_noise(model,
                                            args.batch_size,
                                            clamp=clamp).detach()
            x_samples.requires_grad_()
            z = model.forward(x_samples, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            sample_foward_jac = compute_jacobian_regularizer(x_samples,
                                                             all_z,
                                                             n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            randz = torch.randn(zshape).to(device)
            randz = torch.autograd.Variable(randz, requires_grad=True)
            images = model(z=randz,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [randz] + other_zs
            sample_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # Data
            x.requires_grad_()
            z = model.forward(x, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            z.requires_grad_()
            images = model(z=z,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [z] + other_zs
            data_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac )
            loss = loss + jac_reg_lambda * (sample_foward_jac +
                                            sample_inverse_jac +
                                            data_foward_jac + data_inverse_jac)

        if not eval_only:
            optimizer.zero_grad()
            loss.backward()
            if not db:
                assert max_grad_clip == max_grad_norm == 0
            if max_grad_clip > 0:
                torch.nn.utils.clip_grad_value_(model.parameters(),
                                                max_grad_clip)
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)

            # Replace NaN gradient with 0
            for p in model.parameters():
                if p.requires_grad and p.grad is not None:
                    g = p.grad.data
                    g[g != g] = 0

            optimizer.step()

        if engine.iter_ind % 100 == 0:
            with torch.no_grad():
                fake = generate_from_noise(model, x.size(0), clamp=clamp)
                z = model.forward(fake, None, return_details=True)[0]
            print("Z max min")
            print(z.max().item(), z.min().item())
            if (fake != fake).float().sum() > 0:
                title = 'NaNs'
            else:
                title = "Good"
            grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.title(title)
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

        if engine.iter_ind % eval_every == 0:

            def check_all_zero_except_leading(x):
                return x % 10**np.floor(np.log10(x)) == 0

            if engine.iter_ind == 0 or check_all_zero_except_leading(
                    engine.iter_ind):
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt'))

            model.eval()

            with torch.no_grad():
                # Plot recon
                fpath = os.path.join(output_dir, '_recon',
                                     f'recon_{engine.iter_ind}.png')
                sample_pad = run_recon_evolution(
                    model,
                    generate_from_noise(model, args.batch_size,
                                        clamp=clamp).detach(), fpath)
                print(
                    f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}")

                pad = run_recon_evolution(model, x_for_recon, fpath)
                print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}")
                pad = pad.item()
                sample_pad = sample_pad.item()

                # Inception score
                sample = torch.cat([
                    generate_from_noise(model, args.batch_size, clamp=clamp)
                    for _ in range(N_inception // args.batch_size + 1)
                ], 0)[:N_inception]
                sample = sample + .5

                if (sample != sample).float().sum() > 0:
                    print("Sample NaNs")
                    raise
                else:
                    fid = run_fid(x_real_inception.clamp_(0, 1),
                                  sample.clamp_(0, 1))
                    print(f'fid: {fid}, global_iter: {engine.iter_ind}')

                # Eval BPD
                eval_bpd = np.mean([
                    model.forward(x.to(device), None,
                                  return_details=True)[1].mean().item()
                    for x, _ in test_loader
                ])

                stats_dict = {
                    'global_iteration': engine.iter_ind,
                    'fid': fid,
                    'train_bpd': train_bpd,
                    'pad': pad,
                    'eval_bpd': eval_bpd,
                    'sample_pad': sample_pad,
                    'batch_real_acc': real_acc,
                    'batch_fake_acc': fake_acc,
                    'batch_acc': acc
                }
                iteration_logger.writerow(stats_dict)
                plot_csv(iteration_logger.filename)
            model.train()

        if engine.iter_ind + 2 % svd_every == 0:
            model.eval()
            svd_dict = {}
            ret = utils.computeSVDjacobian(x_for_recon, model)
            D_for, D_inv = ret['D_for'], ret['D_inv']
            cn = float(D_for.max() / D_for.min())
            cn_inv = float(D_inv.max() / D_inv.min())
            svd_dict['global_iteration'] = engine.iter_ind
            svd_dict['condition_num'] = cn
            svd_dict['max_sv'] = float(D_for.max())
            svd_dict['min_sv'] = float(D_for.min())
            svd_dict['inverse_condition_num'] = cn_inv
            svd_dict['inverse_max_sv'] = float(D_inv.max())
            svd_dict['inverse_min_sv'] = float(D_inv.min())
            svd_logger.writerow(svd_dict)
            # plot_utils.plot_stability_stats(output_dir)
            # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv')
            model.train()
            if eval_only:
                sys.exit()

        # Dummy
        losses['total_loss'] = torch.mean(nll).item()
        return losses
Esempio n. 3
0
def main(args):
    #
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    set_random_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    if args.dataset == 'mnist':
        train_data = get_dataset('mnist-train',  args.dataroot)
        test_data = get_dataset('mnist-test',  args.dataroot)
        train_tr = test_tr = get_transform('mnist_normalize')

    if args.dataset == 'cifar10':
        train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize'
        train_data = get_dataset('cifar10-train',  args.dataroot)
        test_data = get_dataset('cifar10-test',  args.dataroot)
        train_tr = get_transform(train_tr_name)
        test_tr = get_transform('cifar_normalize')
        
    if args.dataset == 'cifar-fs-train':
        train_tr_name = 'cifar_augment_normalize' if args.data_augmentation else 'cifar_normalize'
        train_data = get_dataset('cifar-fs-train-train',  args.dataroot)
        test_data = get_dataset('cifar-fs-train-test',  args.dataroot)
        train_tr = get_transform(train_tr_name)
        test_tr = get_transform('cifar_normalize')

    if args.dataset == 'miniimagenet':
        train_data = get_dataset('miniimagenet-train-train', args.dataroot)
        test_data = get_dataset('miniimagenet-train-test', args.dataroot)
        train_tr = get_transform('cifar_augment_normalize_84' if args.data_augmentation else 'cifar_normalize')
        test_tr = get_transform('cifar_normalize')
    

    model = ResNetClassifier(train_data['n_classes'], train_data['im_size']).to(device)
    if args.ckpt_path != '':
        loaded = torch.load(args.ckpt_path)
        model.load_state_dict(loaded)
        ipdb.set_trace()
    if args.eval:
        acc = test(args, model, device, test_loader, args.n_eval_batches)
        print("Eval Acc: ", acc)
        sys.exit()

    # Trace logging
    mkdir(args.output_dir)
    eval_fieldnames = ['global_iteration','val_acc','train_acc']
    eval_logger = CSVLogger(every=1,
                                 fieldnames=eval_fieldnames,
                                 resume=args.resume,
                                 filename=os.path.join(args.output_dir, 'eval_log.csv'))
    wandb.run.name = os.path.basename(args.output_dir)
    wandb.run.save()
    wandb.watch(model)

    if args.optim == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
    elif args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
    if args.dataset == 'mnist':
        scheduler = StepLR(optimizer, step_size=1, gamma=.7)
    else:
        scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

    start_epoch = 1
    if args.resume:
        last_ckpt_path = os.path.join(args.output_dir, 'last_ckpt.pt')
        if os.path.exists(last_ckpt_path):
            loaded = torch.load(last_ckpt_path)
            model.load_state_dict(loaded['model_sd'])
            optimizer.load_state_dict(loaded['optimizer_sd'])
            scheduler.load_state_dict(loaded['scheduler_sd'])
            start_epoch = loaded['epoch']

    # It's important to set seed again before training b/c dataloading code
    # might have reset the seed.
    set_random_seed(args.seed)
    best_val = 0
    if args.db: 
        scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4], gamma=0.1)
        args.epochs = 5
    for epoch in range(start_epoch, args.epochs + 1):
        if epoch % args.ckpt_every == 0:
            torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_{epoch}.pt"))

        stats_dict = {'global_iteration':epoch}
        val = stats_dict['val_acc'] = test(args, model, device, test_data, test_tr, args.n_eval_batches)
        stats_dict['train_acc'] = test(args, model, device, train_data, test_tr, args.n_eval_batches)
        grid = make_grid(torch.stack([train_tr(x) for x in train_data['x'][:30]]), nrow=6).permute(1,2,0).numpy()
        img_dict = {"examples": [wandb.Image(grid, caption="Data batch")]}
        wandb.log(stats_dict)
        wandb.log(img_dict)
        eval_logger.writerow(stats_dict)
        plot_csv(eval_logger.filename, os.path.join(args.output_dir, 'iteration_plots.png'))

        train(args, model, device, train_data, train_tr, optimizer, epoch)
        
        scheduler.step(epoch)

        if val > best_val: 
            best_val = val
            torch.save(model.state_dict(), os.path.join(args.output_dir , f"ckpt_best.pt"))

        # For `resume`
        model.cpu()
        torch.save({
            'model_sd': model.state_dict(),
            'optimizer_sd': optimizer.state_dict(), 
            'scheduler_sd': scheduler.state_dict(), 
            'epoch': epoch + 1
            }, os.path.join(args.output_dir, "last_ckpt.pt"))
        model.to(device)
Esempio n. 4
0
            x[mask] = k
        _replace_nan_with_k_inplace(x_is, -1)
        with torch.no_grad():
            issf, _, _, acts_fake = inception_score(x_is, cuda=True, batch_size=32, resize=True, splits=10, return_preds=True)
        idxs_ = np.argsort(np.abs(acts_fake).sum(-1))[:1800] # filter the ones with super large values
        acts_fake = acts_fake[idxs_]
        # ipdb.set_trace()

        m1, s1 = calculate_activation_statistics(acts_real)
        m2, s2 = calculate_activation_statistics(acts_fake)
        try:
            fid_value = calculate_frechet_distance(m1, s1, m2, s2)
        except:
            # ipdb.set_trace()
            # This mostly happens when there are "a few really bad samples", which
            # results in, say, usually large activation (1e30).  
            # These  "activation outliers" mess up the statistics, and results in 
            # ValueError  
            fid_value = 2000
        print (idx, issf, fid_value)
        stats_dict = {
                'global_iteration': idx ,
                'fid': fid_value
        }
        iteration_logger.writerow(stats_dict)
        try:
            plot_csv(iteration_logger.filename)
        except:
            pass
    # ipdb.set_trace()
    
Esempio n. 5
0
                            save_dir, 'svd_log.csv'
                        )  # Makes separate PDFs for each logged measure
                    except:
                        print('Something went wrong when computing the SVD...')

                try:
                    plot_utils.plot_individual_figures(
                        save_dir, 'every_N_log.csv'
                    )  # Makes separate PDFs for each logged measure
                except:
                    pass

                if args.plot_csv:
                    try:
                        plot_csv(iteration_logger.filename,
                                 key_name='global_iteration',
                                 yscale='log')
                    except:
                        pass

                model.train()
                projection.train()

            if torch.isnan(loss):
                print('=' * 80)
                print('Loss is NaN. Quitting...')
                print('=' * 80)
                sys.exit(1)

            global_iteration += 1
Esempio n. 6
0
def main(opt):

    # Logging
    trace_file = os.path.join(opt['output_dir'],
                              '{}_trace.txt'.format(opt['exp_name']))

    # Load data
    if opt['dataset'] == 'cifar-fs':
        train_data = get_dataset('cifar-fs-train-train', opt['dataroot'])
        val_data = get_dataset('cifar-fs-val', opt['dataroot'])
        test_data = get_dataset('cifar-fs-test', opt['dataroot'])
        tr = get_transform('cifar_resize_normalize')
        normalize = cifar_normalize
    elif opt['dataset'] == 'miniimagenet':
        train_data = get_dataset('miniimagenet-train-train', opt['dataroot'])
        val_data = get_dataset('miniimagenet-val', opt['dataroot'])
        test_data = get_dataset('miniimagenet-test', opt['dataroot'])
        tr = get_transform('cifar_resize_normalize_84')
        normalize = cifar_normalize

    if opt['input_regularization'] == 'oe':
        reg_data = load_ood_data({
            'name': 'tinyimages',
            'ood_scale': 1,
            'n_anom': 50000,
        })

    if not opt['ooe_only']:
        if opt['db']:
            ood_distributions = ['ooe', 'gaussian']
        else:
            ood_distributions = [
                'ooe', 'gaussian', 'rademacher', 'texture3', 'svhn',
                'tinyimagenet', 'lsun'
            ]
            if opt['input_regularization'] == 'oe':
                ood_distributions.append('tinyimages')

        ood_tensors = [('ooe', None)] + [(out_name,
                                          load_ood_data({
                                              'name': out_name,
                                              'ood_scale': 1,
                                              'n_anom': 10000,
                                          }))
                                         for out_name in ood_distributions[1:]]

    # Load trained model
    loaded = torch.load(opt['model.model_path'])
    if not isinstance(loaded, OrderedDict):
        fs_model = loaded
    else:
        classifier = ResNetClassifier(64, train_data['im_size']).to(device)
        classifier.load_state_dict(loaded)
        fs_model = Protonet(classifier.encoder)
    fs_model.eval()
    fs_model = fs_model.to(device)

    # Init Confidence Methods
    if opt['confidence_method'] == 'oec':
        init_sample = load_episode(train_data, tr, opt['data.test_way'],
                                   opt['data.test_shot'],
                                   opt['data.test_query'], device)
        conf_model = OECConfidence(None, fs_model, init_sample, opt)
    elif opt['confidence_method'] == 'deep-oec':
        init_sample = load_episode(train_data, tr, opt['data.test_way'],
                                   opt['data.test_shot'],
                                   opt['data.test_query'], device)
        conf_model = DeepOECConfidence(None, fs_model, init_sample, opt)
    elif opt['confidence_method'] == 'dm-iso':
        encoder = fs_model.encoder
        deep_mahala_obj = DeepMahala(None,
                                     None,
                                     None,
                                     encoder,
                                     device,
                                     num_feats=encoder.depth,
                                     num_classes=train_data['n_classes'],
                                     pretrained_path="",
                                     fit=False,
                                     normalize=None)

        conf_model = DMConfidence(deep_mahala_obj, {
            'ls': range(encoder.depth),
            'reduction': 'max',
            'g_magnitude': .1
        }, True, 'iso')

    if opt['pretrained_oec_path']:
        conf_model.load_state_dict(torch.load(opt['pretrained_oec_path']))

    conf_model.to(device)
    print(conf_model)

    optimizer = optim.Adam(conf_model.confidence_parameters(),
                           lr=opt['lr'],
                           weight_decay=opt['wd'])
    scheduler = StepLR(optimizer,
                       step_size=opt['lrsche_step_size'],
                       gamma=opt['lrsche_gamma'])

    num_param = sum(p.numel() for p in conf_model.confidence_parameters())
    print(f"Learning Confidence, Number of Parameters -- {num_param}")

    if conf_model.pretrain_parameters() is not None:
        pretrain_optimizer = optim.Adam(conf_model.pretrain_parameters(),
                                        lr=10)
        pretrain_iter = 100

    start_idx = 0
    if opt['resume']:
        last_ckpt_path = os.path.join(opt['output_dir'], 'last_ckpt.pt')
        if os.path.exists(last_ckpt_path):
            try:
                last_ckpt = torch.load(last_ckpt_path)
                if 'conf_model' in last_ckpt:
                    conf_model = last_ckpt['conf_model']
                else:
                    sd = last_ckpt['conf_model_sd']
                    conf_model.load_state_dict(sd)
                optimizer = last_ckpt['optimizer']
                pretrain_optimizer = last_ckpt['pretrain_optimizer']
                scheduler = last_ckpt['scheduler']
                start_idx = last_ckpt['outer_idx']
                conf_model.to(device)
            except EOFError:
                print(
                    "\n\nResuming but got EOF error, starting from init..\n\n")

    wandb.run.name = opt['exp_name']
    wandb.run.save()
    # try:
    wandb.watch(conf_model)
    # except: # resuming a run
    #     pass

    # Eval and Logging
    confs = {
        opt['confidence_method']: conf_model,
    }
    if opt['confidence_method'] == 'oec':
        confs['ed'] = FSCConfidence(fs_model, 'ed')
    elif opt['confidence_method'] == 'deep-oec':
        encoder = fs_model.encoder
        deep_mahala_obj = DeepMahala(None,
                                     None,
                                     None,
                                     encoder,
                                     device,
                                     num_feats=encoder.depth,
                                     num_classes=train_data['n_classes'],
                                     pretrained_path="",
                                     fit=False,
                                     normalize=None)
        confs['dm'] = DMConfidence(deep_mahala_obj, {
            'ls': range(encoder.depth),
            'reduction': 'max',
            'g_magnitude': 0
        }, True, 'iso').to(device)
    # Temporal Ensemble for Evaluation
    if opt['n_ensemble'] > 1:
        nets = [deepcopy(conf_model) for _ in range(opt['n_ensemble'])]
        confs['mixture-' + opt['confidence_method']] = Ensemble(
            nets, 'mixture')
        confs['poe-' + opt['confidence_method']] = Ensemble(nets, 'poe')
        ensemble_update_interval = opt['eval_every_outer'] // opt['n_ensemble']

    iteration_fieldnames = ['global_iteration']
    for c in confs:
        iteration_fieldnames += [
            f'{c}_train_ooe', f'{c}_val_ooe', f'{c}_test_ooe', f'{c}_ood'
        ]
    iteration_logger = CSVLogger(every=0,
                                 fieldnames=iteration_fieldnames,
                                 filename=os.path.join(opt['output_dir'],
                                                       'iteration_log.csv'))

    best_val_ooe = 0
    PATIENCE = 5  # Number of evaluations to wait
    waited = 0

    progress_bar = tqdm(range(start_idx, opt['train_iter']))
    for outer_idx in progress_bar:
        sample = load_episode(train_data, tr, opt['data.test_way'],
                              opt['data.test_shot'], opt['data.test_query'],
                              device)

        conf_model.train()
        if opt['full_supervision']:  # sanity check
            conf_model.support(sample['xs'])
            in_score = conf_model.score(sample['xq'], detach=False).squeeze()
            out_score = conf_model.score(sample['ooc_xq'],
                                         detach=False).squeeze()
            out_scores = [out_score]
            for curr_ood, ood_tensor in ood_tensors:
                if curr_ood == 'ooe':
                    continue
                start = outer_idx % (len(ood_tensor) // 2)
                stop = min(
                    start + sample['xq'].shape[0] * sample['xq'].shape[0],
                    len(ood_tensor) // 2)
                oxq = torch.stack([tr(x)
                                   for x in ood_tensor[start:stop]]).to(device)
                o = conf_model.score(oxq, detach=False).squeeze()
                out_scores.append(o)
            #
            out_score = torch.cat(out_scores)
            in_score = in_score.repeat(len(ood_tensors))
            loss, acc = compute_loss_bce(in_score,
                                         out_score,
                                         mean_center=False)
        else:
            conf_model.support(sample['xs'])
            if opt['interpolate']:
                half_n_way = sample['xq'].shape[0] // 2
                interp = .5 * (sample['xq'][:half_n_way] +
                               sample['xq'][half_n_way:2 * half_n_way])
                sample['ooc_xq'][:half_n_way] = interp

            if opt['input_regularization'] == 'oe':
                # Reshape ooc_xq
                nw, nq, c, h, w = sample['ooc_xq'].shape
                sample['ooc_xq'] = sample['ooc_xq'].view(1, nw * nq, c, h, w)
                oe_bs = int(nw * nq * opt['input_regularization_percent'])

                start = (outer_idx * oe_bs) % len(reg_data)
                end = np.min([start + oe_bs, len(reg_data)])
                oe_batch = torch.stack([tr(x) for x in reg_data[start:end]
                                        ]).to(device)
                oe_batch = oe_batch.unsqueeze(0)
                sample['ooc_xq'][:, :oe_batch.shape[1]] = oe_batch

            if opt['in_out_1_batch']:
                inps = torch.cat([sample['xq'], sample['ooc_xq']], 1)
                scores = conf_model.score(inps, detach=False).squeeze()
                in_score, out_score = scores[:sample['xq'].shape[1]], scores[
                    sample['xq'].shape[1]:]
            else:
                in_score = conf_model.score(sample['xq'],
                                            detach=False).squeeze()
                out_score = conf_model.score(sample['ooc_xq'],
                                             detach=False).squeeze()

            loss, acc = compute_loss_bce(in_score,
                                         out_score,
                                         mean_center=False)

        if conf_model.pretrain_parameters(
        ) is not None and outer_idx < pretrain_iter:
            pretrain_optimizer.zero_grad()
            loss.backward()
            pretrain_optimizer.step()
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()

        progress_bar.set_postfix(loss='{:.3e}'.format(loss),
                                 acc='{:.3e}'.format(acc))

        # Update Ensemble
        if opt['n_ensemble'] > 1 and outer_idx % ensemble_update_interval == 0:
            update_ind = (outer_idx //
                          ensemble_update_interval) % opt['n_ensemble']
            if opt['db']:
                print(f"===> Updating Ensemble: {update_ind}")
            confs['mixture-' +
                  opt['confidence_method']].nets[update_ind] = deepcopy(
                      conf_model)
            confs['poe-' +
                  opt['confidence_method']].nets[update_ind] = deepcopy(
                      conf_model)

        # AUROC eval
        if outer_idx % opt['eval_every_outer'] == 0:
            if not opt['eval_in_train']:
                conf_model.eval()

            # Eval..
            stats_dict = {'global_iteration': outer_idx}
            for conf_name, conf in confs.items():
                conf.eval()
                # OOE eval
                ooe_aurocs = {}
                for split, in_data in [('train', train_data),
                                       ('val', val_data), ('test', test_data)]:
                    auroc = np.mean(
                        eval_ood_aurocs(
                            None,
                            in_data,
                            tr,
                            opt['data.test_way'],
                            opt['data.test_shot'],
                            opt['data.test_query'],
                            opt['data.test_episodes'],
                            device,
                            conf,
                            no_grad=False
                            if opt['confidence_method'].startswith('dm') else
                            True)['aurocs'])
                    ooe_aurocs[split] = auroc
                    print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format(
                        conf_name, outer_idx, split, ooe_aurocs[split])
                    _print_and_log(print_str, trace_file)
                stats_dict[f'{conf_name}_train_ooe'] = ooe_aurocs['train']
                stats_dict[f'{conf_name}_val_ooe'] = ooe_aurocs['val']
                stats_dict[f'{conf_name}_test_ooe'] = ooe_aurocs['test']

                # OOD eval
                if not opt['ooe_only']:
                    aurocs = []
                    for curr_ood, ood_tensor in ood_tensors:
                        auroc = np.mean(
                            eval_ood_aurocs(
                                ood_tensor,
                                test_data,
                                tr,
                                opt['data.test_way'],
                                opt['data.test_shot'],
                                opt['data.test_query'],
                                opt['data.test_episodes'],
                                device,
                                conf,
                                no_grad=False
                                if opt['confidence_method'].startswith('dm')
                                else True)['aurocs'])
                        aurocs.append(auroc)

                        print_str = '{}, iter: {} ({}), auroc: {:.3e}'.format(
                            conf_name, outer_idx, curr_ood, auroc)
                        _print_and_log(print_str, trace_file)

                    mean_ood_auroc = np.mean(aurocs)
                    print_str = '{}, iter: {} (OOD_mean), auroc: {:.3e}'.format(
                        conf_name, outer_idx, mean_ood_auroc)
                    _print_and_log(print_str, trace_file)

                    stats_dict[f'{conf_name}_ood'] = mean_ood_auroc

            iteration_logger.writerow(stats_dict)
            plot_csv(iteration_logger.filename, iteration_logger.filename)
            wandb.log(stats_dict)

            if stats_dict[f'{opt["confidence_method"]}_val_ooe'] > best_val_ooe:
                conf_model.cpu()
                torch.save(
                    conf_model.state_dict(),
                    os.path.join(opt['output_dir'],
                                 opt['exp_name'] + '_conf_best.pt'))
                conf_model.to(device)
                # Ckpt ensemble
                if opt['n_ensemble'] > 1:
                    ensemble = confs['mixture-' + opt['confidence_method']]
                    ensemble.cpu()
                    torch.save(
                        ensemble.state_dict(),
                        os.path.join(opt['output_dir'],
                                     opt['exp_name'] + '_ensemble_best.pt'))
                    ensemble.to(device)
                waited = 0
            else:
                waited += 1
                if waited >= PATIENCE:
                    print("PATIENCE exceeded...exiting")
                    sys.exit()
            # For `resume`
            conf_model.cpu()
            torch.save(
                {
                    'conf_model_sd':
                    conf_model.state_dict(),
                    'optimizer':
                    optimizer,
                    'pretrain_optimizer':
                    pretrain_optimizer
                    if conf_model.pretrain_parameters() is not None else None,
                    'scheduler':
                    scheduler,
                    'outer_idx':
                    outer_idx,
                }, os.path.join(opt['output_dir'], 'last_ckpt.pt'))
            conf_model.to(device)
            conf_model.train()
    sys.exit()