def main(args=None):
    args = parse_args(args)

    print('processing model: lr' + str(args.lr) + '_d' + str(args.lr_decay) + '_bs' + str(args.batch_size),
          'drop rate', args.drop_rate, 'seed', args.seed)

    # mfij related argument
    args.lambda0 = '3/(np.pi**2)'
    # Compute Hessian covariance matrix once
    _bs = args.batch_size
    args.batch_size = 5000
    train_dl, _, _ = load_mnist_data(args, train_shuffle=False)
    args.batch_size = _bs
    train_pen_features = get_penultimate_features(args, train_dl)
    hess = compute_hess(train_pen_features)
    args.cov = torch.Tensor(invert_hess(hess))
    del train_dl, train_pen_features

    shift_res = dict()
    for deg in np.arange(0, 181, 15):
        args.rotate_degs = deg
        if args.rotate_degs > 0:
            err, nll, ECE = eval_indomain(args, data_key='test_rotate')
        else:
            err, nll, ECE = eval_indomain(args, data_key='test')
        shift_res['rotate_' + str(deg)] = dict(err=err,
                                               nll=nll,
                                               ece=ECE)
    # save result
    model_dir = gen_model_dir(args)
    shift_fn = 'shift_Te' + str(args.temp_ens) + '_Ta' + str(args.temp_act) + '.json'
    with open(Path(model_dir, shift_fn), 'w') as fp:
        json.dump(shift_res, fp)

    # visualize ece across different degrees
    fig = viz_distribution_shift(shift_res)
    fig.savefig("figs/mnist_shift_ece.eps", dpi=fig.dpi, bbox_inches='tight', format='eps')
def eval_indomain(args, data_key):
    model_dir = gen_model_dir(args)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # define model
    model = MLP(args.num_hiddens, drop_rate=0.0)
    model.to(device)
    # load best model
    with open(Path(model_dir, "model"), 'rb') as f:
        params = torch.load(f)
        model.load_state_dict(params['model_weight'])

    ## load data
    if data_key == 'heldout':
        _, eval_dl, _ = load_mnist_data(args, train_shuffle=False)
    elif data_key == 'test':
        _, _, eval_dl = load_mnist_data(args, train_shuffle=False)
    elif data_key == 'test_rotate':
        eval_dl = load_mnist_shift(args, data_key)

    ## MLE predict
    mle_logits, labels = model_feedforward(model, eval_dl, device)
    mle_probs = F.softmax(mle_logits, dim=1).cpu().detach().numpy()
    labels = labels.cpu().detach().numpy()
    mle_preds = np.argmax(mle_probs, axis=-1)
    del mle_logits

    # get penultimate layer features
    pen_features = get_penultimate_features(args, eval_dl)
    mfij_probs = mfij_predict(
        args,
        loaders={data_key: pen_features},
        dimensions=[args.num_hiddens[-1], args.num_classes],
        num_train=55000)[data_key]
    mfij_preds = np.argmax(mfij_probs, axis=-1)

    # compute accuracy, NLL, Calibration
    mle_err = 1 - np.mean(mle_preds == labels)
    mfij_err = 1 - np.mean(mfij_preds == labels)

    mle_nll = np.mean(nll(labels, mle_probs))
    mfij_nll = np.mean(nll(labels, mfij_probs))

    fig, bin_confs, bin_accs, bin_percs = reliability_diagrams(
        mle_preds, labels, np.amax(mle_probs, axis=-1))
    mle_ECE = 100 * np.sum(
        np.array(bin_percs) * np.abs(np.array(bin_confs) - np.array(bin_accs)))

    fig, bin_confs, bin_accs, bin_percs = reliability_diagrams(
        mfij_preds, labels, np.amax(mfij_probs, axis=-1))
    mfij_ECE = 100 * np.sum(
        np.array(bin_percs) * np.abs(np.array(bin_confs) - np.array(bin_accs)))

    ## print result
    print("& mle & {:.3g} & {:.4g} & {:.4g} \\\\".format(
        mle_err * 100, mle_nll, mle_ECE))

    print("& mfij & {:.3g} & {:.4g} & {:.4g} \\\\".format(
        mfij_err * 100, mfij_nll, mfij_ECE))

    print(" ==================================== ")
    print("\n")
    return mfij_err, mfij_nll, mfij_ECE
def main(args=None):
    args = parse_args(args)

    print(
        'processing model: lr' + str(args.lr) + '_d' + str(args.lr_decay) +
        '_bs' + str(args.batch_size), 'drop rate', args.drop_rate, 'seed',
        args.seed)
    # mfij related argument
    args.lambda0 = '3/(np.pi**2)'
    # Compute Hessian covariance matrix once
    _bs = args.batch_size
    args.batch_size = 5000
    train_dl, _, _ = load_mnist_data(args, train_shuffle=False)
    args.batch_size = _bs
    train_pen_features = get_penultimate_features(args, train_dl)
    hess = compute_hess(train_pen_features)
    args.cov = torch.Tensor(invert_hess(hess))
    del train_dl, train_pen_features

    tune_res = dict()
    # find the best ensemble and activation temperatures on heldout set
    ensemble_temperature_list = np.arange(-4, 3, 1, dtype=float)
    activation_temperature_list = np.array(
        [0.001, 0.25, 0.5, 0.75, 1, 1.5, 2, 5])

    best_nll = np.inf

    for i, ens_T in enumerate(ensemble_temperature_list):
        args.temp_ens = float(10**ens_T)
        for j, act_T in enumerate(activation_temperature_list):
            args.temp_act = act_T
            print('temperatures:', args.temp_ens, args.temp_act)
            mfij_errs, mfij_nlls, mfij_ECEs = eval_indomain(args,
                                                            data_key='heldout')
            tune_res[str(ens_T) + '_' + str(act_T)] = dict(err=mfij_errs,
                                                           nll=mfij_nlls,
                                                           ece=mfij_ECEs)
            if mfij_nlls < best_nll:
                best_nll = mfij_nlls
                nll_ts = [ens_T, act_T]
    tune_res['best_nll_temp'] = nll_ts
    tune_res['best_nll_heldout'] = best_nll

    print('eval best temperature on test', nll_ts)
    args.temp_ens, args.temp_act = nll_ts
    args.temp_ens = float(10**args.temp_ens)
    test_errs, test_nlls, test_ECE = eval_indomain(args, data_key='test')
    tune_res['best_on_test'] = dict(err=test_errs, nll=test_nlls, ece=test_ECE)
    # save result
    model_dir = gen_model_dir(args)
    tune_fn = 'in_domain_mfij.json'
    with open(Path(model_dir, tune_fn), 'w') as fp:
        json.dump(tune_res, fp)

    # visualize temperature heatmap
    fig = viz_temperature_heatmap(
        tune_res,
        ensemble_temperature_list=list(ensemble_temperature_list),
        activation_temperature_list=activation_temperature_list,
        ensemble_temperature_label=[
            '1e-4', '1e-3', '1e-2', '1e-1', '1.', '1e1', '1e2'
        ],
        task='in_domain')
    fig.savefig("figs/mnist_mf_nll_temp.eps",
                dpi=fig.dpi,
                bbox_inches='tight',
                format='eps')
def eval_ood(args, data_key):
    model_dir = gen_model_dir(args)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # define model
    model = MLP(args.num_hiddens, drop_rate=0.0)
    model.to(device)
    model.eval()
    # load best model
    with open(Path(model_dir, "model"), 'rb') as f:
        params = torch.load(f)
        model.load_state_dict(params['model_weight'])

    ## OOD data is from NotMNIST datatset
    assert args.notset == 'notmnist'

    if data_key == 'test':
        _, _, eval_dl = load_mnist_data(args, train_shuffle=False)
        print('load test data')
        notmnist_dl = load_notmnist(args, heldout=False)
    elif data_key == 'heldout':
        _, eval_dl, _ = load_mnist_data(args, train_shuffle=False)
        print('load heldout data for OOD eval')
        notmnist_dl = load_notmnist(args, heldout=True)

    ## MLE predict
    in_mle_logits, _ = model_feedforward(model, eval_dl, device)
    in_mle_probs = F.softmax(in_mle_logits, dim=1).cpu().detach().numpy()

    out_mle_logits, _ = model_feedforward(model, notmnist_dl, device)
    out_mle_probs = F.softmax(out_mle_logits, dim=1).cpu().detach().numpy()
    del in_mle_logits, out_mle_logits

    # get penultimate layer features
    in_pen_features = get_penultimate_features(args, eval_dl)
    out_pen_features = get_penultimate_features(args, notmnist_dl)

    probs_dict = mfij_predict(
        args,
        loaders={
            data_key: in_pen_features,
            args.notset: out_pen_features
        },
        dimensions=[args.num_hiddens[-1], args.num_classes],
        num_train=55000,
        block_size=500)

    in_inf_probs = probs_dict[data_key]
    out_inf_probs = probs_dict[args.notset]

    # softmax OOD detection
    in_mle_stats = {'prob': np.amax(in_mle_probs, axis=1)}
    out_mle_stats = {'prob': np.amax(out_mle_probs, axis=1)}
    res_mle = ood_metric(in_mle_stats,
                         out_mle_stats,
                         stypes=['prob'],
                         verbose=False)['prob']

    # mfij OOD detection
    in_inf_stats = {'prob': np.amax(in_inf_probs, axis=1)}
    out_inf_stats = {'prob': np.amax(out_inf_probs, axis=1)}
    res_inf = ood_metric(in_inf_stats,
                         out_inf_stats,
                         stypes=['prob'],
                         verbose=False)['prob']

    print("& mle & {:6.3f} & {:6.3f} & {:6.3f}/{:6.3f} & {:6.3f} \\\\".format(
        res_mle['DTACC'] * 100,
        res_mle['AUROC'] * 100,
        res_mle['AUIN'] * 100,
        res_mle['AUOUT'] * 100,
        res_mle['TNR'] * 100,
    ))
    print(" ==================================== ")
    print("& mfij & {:6.3f} & {:6.3f} & {:6.3f}/{:6.3f} & {:6.3f} \\\\".format(
        res_inf['DTACC'] * 100,
        res_inf['AUROC'] * 100,
        res_inf['AUIN'] * 100,
        res_inf['AUOUT'] * 100,
        res_inf['TNR'] * 100,
    ))
    print(" ==================================== ")
    print("\n")
    return in_mle_probs, out_mle_probs, in_inf_probs, out_inf_probs, res_inf
def main(args=None):
    args = parse_args(args)

    print(
        'processing model: lr' + str(args.lr) + '_d' + str(args.lr_decay) +
        '_bs' + str(args.batch_size), 'drop rate', args.drop_rate, 'seed',
        args.seed)

    # mfij related argument
    args.lambda0 = '3/(np.pi**2)'
    # Compute Hessian covariance matrix once
    _bs = args.batch_size
    args.batch_size = 5000
    train_dl, _, _ = load_mnist_data(args, train_shuffle=False)
    args.batch_size = _bs
    train_pen_features = get_penultimate_features(args, train_dl)
    hess = compute_hess(train_pen_features)
    args.cov = torch.Tensor(invert_hess(hess))
    del train_dl, train_pen_features

    notmnist_res = dict()
    best_auc = 0
    ensemble_temperature_list = np.arange(-4, 3, 1, dtype=float)
    activation_temperature_list = np.array(
        [0.001, 0.25, 0.5, 0.75, 1, 1.5, 2, 5])

    for i, ens_T in enumerate(ensemble_temperature_list):
        args.temp_ens = float(10**ens_T)
        for j, act_T in enumerate(activation_temperature_list):
            args.temp_act = act_T

            print('temperatures:', args.temp_ens, args.temp_act)
            _, _, _, _, res_mfij = eval_ood(args, data_key='heldout')
            notmnist_res[str(ens_T) + '_' + str(act_T)] = res_mfij
            if res_mfij['AUROC'] > best_auc:
                best_auc = res_mfij['AUROC']
                best_ts = [ens_T, act_T]

    notmnist_res['best_auc'] = best_auc
    notmnist_res['best_ts'] = best_ts
    print('eval best temperature on test', best_ts)
    args.temp_ens = float(10**best_ts[0])
    args.temp_act = best_ts[1]
    _, _, _, _, test_mfij = eval_ood(args, data_key='test')
    notmnist_res['best_test_ood'] = test_mfij

    # save result
    model_dir = gen_model_dir(args)
    tune_fn = args.notset + '_mfij.json'
    with open(Path(model_dir, tune_fn), 'w') as fp:
        json.dump(notmnist_res, fp)

    # visualize temperature heatmap
    fig = viz_temperature_heatmap(
        notmnist_res,
        ensemble_temperature_list=list(ensemble_temperature_list),
        activation_temperature_list=activation_temperature_list,
        ensemble_temperature_label=[
            '1e-4', '1e-3', '1e-2', '1e-1', '1.', '1e1', '1e2'
        ],
        task='ood')
    fig.savefig("figs/mnist_mf_auc_temp.eps",
                dpi=fig.dpi,
                bbox_inches='tight',
                format='eps')
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load data
    train_dl, valid_dl, test_dl = load_mnist_data(args)

    # define model
    if args.model_str == 'mlp':
        model = MLP(args.num_hiddens, drop_rate=args.drop_rate)

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                          gamma=args.lr_decay)

    model_dir = gen_model_dir(args)
    model_dir.mkdir(parents=True, exist_ok=True)

    loss_fn = torch.nn.CrossEntropyLoss()
    min_err = 1
    for epoch in range(args.n_epochs):
        model.train()
        loss_train = 0
        for data, target in train_dl:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

            loss_train += loss.item()
        loss_train /= len(train_dl)
        lr_scheduler.step()

        # eval heldout
        model.eval()
        with torch.no_grad():
            err_heldout = 1 - eval_acc(valid_dl, model, device=device)

        print(
            'Train Epoch: {}, Train Loss: {:.6f}, Heldout Err: {:.6f}'.format(
                epoch, loss_train, err_heldout))

        if err_heldout < min_err:
            min_err = err_heldout
            # err_heldout = 1 - eval_acc(valid_dl, model, device=device)
            loss_heldout = sum(
                loss_fn(model(xb.to(device)), yb.to(device))
                for xb, yb in valid_dl).item() / len(valid_dl)
            # save model
            with open(Path(model_dir, "model"), 'wb') as f:
                torch.save(
                    {
                        'model_weight': model.state_dict(),
                        'epoch': epoch,
                        'loss_train': loss_train,
                        'loss_heldout': loss_heldout,
                        'err_heldout': err_heldout,
                    },
                    f,
                )
            print(
                'New best! epoch: {}, learning rate: {:.4g}, train loss: {:.4f}, val err: {:.2f}.'
                .format(epoch,
                        lr_scheduler.get_last_lr()[0], loss_train,
                        err_heldout * 100))

    # load best model
    with open(Path(model_dir, "model"), 'rb') as f:
        params = torch.load(f)
        model.load_state_dict(params['model_weight'])
    # test
    model.eval()
    err_test = 1 - eval_acc(test_dl, model, device=device)
    print('epoch: {}, val error: {:.4f}, test error: {:.4f}'.format(
        params["epoch"], params["err_heldout"] * 100, err_test * 100))

    with open(Path(model_dir, "res.json"), 'w') as fp:
        json.dump(
            {
                'epoch': params["epoch"],
                'loss_train': params["loss_train"],
                'loss_heldout': params["loss_heldout"],
                'err_heldout': params["err_heldout"],
                'err_test': err_test,
            }, fp)