x_test, y_test, inp_shape, num_classes = load_dataset(
    dataset_name=m["dataset"], shuffle=False, n_inputs=n_inputs)[2:]
det_model_savedir = get_model_savedir(model="baseNN",
                                      dataset=m["dataset"],
                                      architecture=m["architecture"],
                                      debug=args.debug,
                                      model_idx=args.model_idx)
detnet = baseNN(inp_shape, num_classes, *list(m.values()))
detnet.load(savedir=det_model_savedir, device=args.device)

det_attacks = load_attack(method=args.attack_method,
                          model_savedir=det_model_savedir)

det_predictions, det_atk_predictions, det_softmax_robustness, det_successful_idxs, det_failed_idxs = \
            evaluate_attack(net=detnet, x_test=x_test, x_attack=det_attacks, y_test=y_test,
                            device=args.device, return_classification_idxs=True)

m = fullBNN_settings["model_" + str(args.model_idx)]
x_test, y_test, inp_shape, out_size = load_dataset(dataset_name=m["dataset"],
                                                   shuffle=False,
                                                   n_inputs=n_inputs)[2:]

bay_model_savedir = get_model_savedir(model="fullBNN",
                                      dataset=m["dataset"],
                                      architecture=m["architecture"],
                                      model_idx=args.model_idx,
                                      debug=args.debug)

bayesnet = BNN(m["dataset"], *list(m.values())[1:], inp_shape, num_classes)
bayesnet.load(savedir=bay_model_savedir, device=args.device)
                               model_savedir=savedir)

    else:
        x_attack = attack(net=net,
                          x_test=x_test,
                          y_test=y_test,
                          device=args.device,
                          method=args.attack_method)
        save_attack(x_test,
                    x_attack,
                    method=args.attack_method,
                    model_savedir=savedir)

    evaluate_attack(net=net,
                    x_test=x_test,
                    x_attack=x_attack,
                    y_test=y_test,
                    device=args.device)

else:

    if args.model == "fullBNN":

        m = fullBNN_settings["model_" + str(args.model_idx)]

        x_test, y_test, inp_shape, out_size = load_dataset(
            dataset_name=m["dataset"], n_inputs=n_inputs)[2:]

        savedir = get_model_savedir(model=args.model,
                                    dataset=m["dataset"],
                                    architecture=m["architecture"],
            for im_idx in range(det_lrp.shape[0]):
                det_lrp[im_idx] = normalize(det_lrp[im_idx])
                det_attack_lrp[im_idx] = normalize(det_attack_lrp[im_idx])

                for samp_idx in range(len(n_samples_list)):
                    bay_lrp[samp_idx][im_idx] = normalize(
                        bay_lrp[samp_idx][im_idx])
                    bay_attack_lrp[samp_idx][im_idx] = normalize(
                        bay_attack_lrp[samp_idx][im_idx])

        ### Evaluate explanations

        det_preds, det_atk_preds, det_softmax_robustness, det_successful_idxs, det_failed_idxs = evaluate_attack(
            net=detnet,
            x_test=images,
            x_attack=det_attack,
            y_test=y_test,
            device=args.device,
            return_classification_idxs=True)
        det_softmax_robustness = det_softmax_robustness.detach().cpu().numpy()

        det_lrp_robustness, det_lrp_pxl_idxs = lrp_robustness(
            original_heatmaps=det_lrp,
            adversarial_heatmaps=det_attack_lrp,
            topk=topk,
            method=lrp_robustness_method)

        succ_det_lrp_robustness, succ_det_lrp_pxl_idxs = lrp_robustness(
            original_heatmaps=det_lrp[det_successful_idxs],
            adversarial_heatmaps=det_attack_lrp[det_successful_idxs],
            topk=topk,
        dataset_name=m["dataset"], shuffle=False, n_inputs=n_inputs)[2:]
    model_savedir = get_model_savedir(model="baseNN",
                                      dataset=m["dataset"],
                                      architecture=m["architecture"],
                                      debug=args.debug,
                                      model_idx=args.model_idx)
    detnet = baseNN(inp_shape, num_classes, *list(m.values()))
    detnet.load(savedir=model_savedir, device=args.device)

    attacks = load_attack(method=args.attack_method,
                          model_savedir=model_savedir)

    predictions, atk_predictions, softmax_robustness, successful_idxs, failed_idxs = evaluate_attack(
        net=detnet,
        x_test=x_test,
        x_attack=attacks,
        y_test=y_test,
        device=args.device,
        return_classification_idxs=True)

    learnable_layers_idxs = detnet.learnable_layers_idxs

elif args.model == "fullBNN":

    m = fullBNN_settings["model_" + str(args.model_idx)]
    x_test, y_test, inp_shape, num_classes = load_dataset(
        dataset_name=m["dataset"], shuffle=False, n_inputs=n_inputs)[2:]

    model_savedir = get_model_savedir(model=args.model,
                                      dataset=m["dataset"],
                                      architecture=m["architecture"],