Exemple #1
0
def main(args):

    # Initialize arguments based on the chosen testset
    if args.testset == "disjoint_mnist":
        test_loaders = [
            mnist_combined_test_loader(args.test_batch_size),
            mnist_combined_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["first5_mnist", "last5_mnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 5
        args.test_arch = ["lenet5", "lenet5"]
        test_arch = [LeNet5, LeNet5]
        fpan_input_channel = 1
        m = mnist
    elif args.testset == "mnist_cifar10":
        test_loaders = [
            mnist_cifar10_single_channel_test_loader(args.test_batch_size),
            mnist_cifar10_3_channel_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["mnist", "cifar10"]
        args.test_input_channel = [1, 3]
        args.test_output_size = 10
        args.test_arch = ["lenet5", "resnet18"]
        test_arch = [LeNet5, ResNet18]
        fpan_input_channel = 1
        m = mnist_cifar10
    elif args.testset == "fmnist_kmnist":
        test_loaders = [
            fmnist_kmnist_test_loader(args.test_batch_size),
            fmnist_kmnist_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["fmnist", "kmnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 10
        args.test_arch = ["resnet18", "resnet18"]
        test_arch = [ResNet18, ResNet18]
        fpan_input_channel = 1
        m = fmnist_kmnist

    # Initialize logits statistics function
    experiment = m.smart_coordinator_fpan

    # FPAN architecture
    if args.fpan_arch == "resnet18":
        fpan_arch = ResNet18
    elif args.fpan_arch == "lenet5":
        fpan_arch = LeNet5
    elif args.fpan_arch == "lenet5_halfed":
        fpan_arch = LeNet5Halfed

    # Running the test
    print(f"Testset: {args.testset}")
    print(f"FPAN arch: {args.fpan_arch}")
    print(f"UPAN Dataset: {args.upan_data}")
    print(f"UPAN type: {args.upan_type}")
    results = []

    for i in range(len(args.seeds)):
        print(f"\nIteration: {i+1}, Seed: {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Load experts
        all_experts = []
        for expert_idx in range(len(test_arch)):
            expert = test_arch[expert_idx](
                input_channel=args.test_input_channel[expert_idx],
                output_size=args.test_output_size).to(device)
            expert.load_state_dict(
                torch.load(
                    args.model_dir +
                    f"{args.test_expert[expert_idx]}_{args.test_arch[expert_idx]}_{args.seeds[i]}",
                    map_location=torch.device(device),
                ))
            all_experts.append(expert)

        # Running the experiment
        fpan = fpan_arch(input_channel=fpan_input_channel,
                         output_size=len(args.test_expert)).to(args.device)
        fpan.load_state_dict(
            torch.load(
                args.fpan_dir +
                f"fpan_{args.testset}({fpan_input_channel})_({args.upan_data}_{args.upan_type}){args.seeds[i]}",
                map_location=torch.device(args.device),
            ))
        result = experiment(args, all_experts[0], all_experts[1], fpan, device,
                            test_loaders)

        # Adding more info to the result to be saved
        for r in result:
            r.update({"iteration": i, "seed": args.seeds[i]})
        results.extend(result)

    # Save the results
    if args.save_results:
        save_results(
            f"fpan_{args.testset}({fpan_input_channel})_({args.upan_data}_{args.upan_type})",
            results,
            f"{args.results_dir}",
        )
Exemple #2
0
def train_fpan(args):
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    np.random.seed(0)
    torch.manual_seed(0)

    transform = transforms.Compose([
        lambda x: PIL.Image.fromarray(x),
        lambda x: transforms.ToTensor()(np.array(x)),
        transforms.Normalize((0.5, ), (0.5, )),
    ])

    transform_rgb = transforms.Compose([
        lambda x: PIL.Image.fromarray(x, 'RGB'),
        lambda x: transforms.ToTensor()(np.array(x)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataset = []
    for _ in range(8000):
        x = np.random.normal(0.5, 0.5, (32, 32))  # PIL range = [0, 1]
        x = transform(x)
        dataset.append((x, _))
    gaussian_train_loader_noshuffle = DataLoader(dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False)

    dataset_rgb = []
    for _ in range(8000):
        x = np.random.normal(0.5, 0.5, (32, 32, 3))
        x = transform_rgb(x)
        dataset_rgb.append((x, _))
    gaussian_3_channel_train_loader_noshuffle = DataLoader(
        dataset_rgb, batch_size=args.batch_size, shuffle=False)

    # Initialize arguments based on the chosen dataset
    if args.fpan_data == "disjoint_mnist":
        train_loaders = [
            gaussian_train_loader_noshuffle,
            gaussian_train_loader_noshuffle,
        ]
        test_loaders = mnist_combined_test_loader(args.test_batch_size)
        args.expert = ["first5_mnist", "last5_mnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 5
        args.expert_arch = ["lenet5", "lenet5"]
        expert_arch = [LeNet5, LeNet5]
        target_create_fns = craft_disjoint_mnist_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "mnist_cifar10":
        train_loaders = [
            gaussian_train_loader_noshuffle,
            gaussian_3_channel_train_loader_noshuffle,
        ]
        test_loaders = mnist_cifar10_single_channel_test_loader(
            args.test_batch_size)
        args.expert = ["mnist", "cifar10"]
        args.expert_input_channel = [1, 3]
        args.expert_output_size = 10
        args.expert_arch = ["lenet5", "resnet18"]
        expert_arch = [LeNet5, ResNet18]
        target_create_fns = craft_mnist_cifar10_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "fmnist_kmnist":
        train_loaders = [
            gaussian_train_loader_noshuffle,
            gaussian_train_loader_noshuffle,
        ]
        test_loaders = fmnist_kmnist_test_loader(args.test_batch_size)
        args.expert = ["fmnist", "kmnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 10
        args.expert_arch = ["resnet18", "resnet18"]
        expert_arch = [ResNet18, ResNet18]
        target_create_fns = craft_fmnist_kmnist_target
        fpan_output_size = 2
        fpan_input_channel = 1

    if args.fpan_arch == "resnet18":
        fpan_arch = ResNet18
    elif args.fpan_arch == "lenet5":
        fpan_arch = LeNet5
    elif args.fpan_arch == "lenet5_halfed":
        fpan_arch = LeNet5Halfed

    # Initialize arguments based on the dataset that UPAN was trained on
    if args.upan_data == "disjoint_mnist":
        args.model_arch = ["lenet5", "lenet5"]
        args.model_output_size = 5
    elif args.upan_data == "mnist_cifar10":
        args.model_arch = ["lenet5", "resnet18"]
        args.model_output_size = 10
    elif args.upan_data == "fmnist_kmnist":
        args.model_arch = ["resnet18", "resnet18"]
        args.model_output_size = 10

    # Create the directory for saving if it does not exist
    create_op_dir(args.output_dir)

    print(f"\nFPAN Dataset: {args.fpan_data}")
    print(f"FPAN arch: {args.fpan_arch}")
    print(f"UPAN Dataset: {args.upan_data}")
    print(f"UPAN type: {args.upan_type}\n")

    fpan_results = []
    for i in range(len(args.seeds)):
        print(f"Iteration {i+1}, Seed {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Train FPAN model
        fpan, fpan_test_loss, fpan_acc = train_model(
            fpan=fpan_arch(input_channel=fpan_input_channel,
                           output_size=fpan_output_size).to(device),
            trial=i,
            device=device,
            expert_arch=expert_arch,
            train_loader=train_loaders,
            test_loader=test_loaders,
            target_create_fn=target_create_fns,
            config_args=args,
        )

        # Save the FPAN model
        if args.save_models:
            torch.save(
                fpan.state_dict(),
                args.output_dir +
                f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type}){args.seeds[i]}",
            )

        # Save the results in list first
        fpan_results.append({
            "iteration": i,
            "seed": args.seeds[i],
            "loss": fpan_test_loss,
            "acc": fpan_acc,
        })

    # Save all the results
    if args.save_results:
        save_results(
            f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type})",
            fpan_results,
            args.results_dir,
        )
Exemple #3
0
def train_pan(args):

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    # Initialize arguments based on dataset chosen
    if args.dataset == "disjoint_mnist":
        train_loaders = [
            mnist_combined_train_loader(args.batch_size),
            mnist_combined_train_loader(args.batch_size),
        ]
        test_loaders = [
            mnist_combined_test_loader(args.test_batch_size),
            mnist_combined_test_loader(args.test_batch_size),
        ]
        args.d1 = "first5_mnist"
        args.d2 = "last5_mnist"
        args.m1_input_channel = 1
        args.m2_input_channel = 1
        args.output_size = 5
        target_create_fns = [craft_first_5_target, craft_last_5_target]
    elif args.dataset == "mnist_cifar10":
        train_loaders = [
            mnist_cifar10_single_channel_train_loader(args.batch_size),
            mnist_cifar10_3_channel_train_loader(args.batch_size),
        ]
        test_loaders = [
            mnist_cifar10_single_channel_test_loader(args.test_batch_size),
            mnist_cifar10_3_channel_test_loader(args.test_batch_size),
        ]
        args.d1 = "mnist"
        args.d2 = "cifar10"
        args.m1_input_channel = 1
        args.m2_input_channel = 3
        args.output_size = 10
        target_create_fns = [craft_mnist_target, craft_cifar10_target]

    # Initialize models based on architecture chosen
    if args.arch == "lenet5":
        arch = LeNet5
        feature_size = 120
    elif args.arch == "lenet5_halfed":
        arch = LeNet5Halfed
        feature_size = 60
    elif args.arch == "resnet18":
        arch = ResNet18
        feature_size = 512

    # Initialize PAN based on its type
    if args.pan_type == "feature":
        pan_input_size = feature_size
        pan_arch = PAN
    elif args.pan_type == "logits":
        pan_input_size = args.output_size
        pan_arch = PAN
    elif args.pan_type == "agnostic_feature":
        pan_input_size = 3
        pan_arch = AgnosticPAN
    elif args.pan_type == "agnostic_logits":
        pan_input_size = 3
        pan_arch = AgnosticPAN

    # Create the directory for saving if it does not exist
    create_op_dir(args.output_dir)

    print(f"Dataset: {args.dataset}")
    print(f"Model: {args.arch}")
    print(f"PAN type: {args.pan_type}")
    pan1_results = []
    pan2_results = []

    for i in range(len(args.seeds)):
        print(f"Iteration {i}, Seed {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])

        # Load models
        model1 = arch(
            input_channel=args.m1_input_channel, output_size=args.output_size
        ).to(device)
        model1.load_state_dict(
            torch.load(
                args.model_dir + f"{args.d1}_{args.arch}_{args.seeds[i]}",
                map_location=torch.device("cpu"),
            )
        )
        pan1, pan1_test_loss, pan1_acc = train_model(
            pan=pan_arch(input_size=pan_input_size).to(device),
            model=model1,
            device=device,
            train_loader=train_loaders[0],
            test_loader=test_loaders[0],
            target_create_fn=target_create_fns[0],
            config_args=args,
        )

        model2 = arch(
            input_channel=args.m2_input_channel, output_size=args.output_size
        ).to(device)
        model2.load_state_dict(
            torch.load(
                args.model_dir + f"{args.d2}_{args.arch}_{args.seeds[i]}",
                map_location=torch.device("cpu"),
            )
        )
        pan2, pan2_test_loss, pan2_acc = train_model(
            pan=pan_arch(input_size=pan_input_size).to(device),
            model=model2,
            device=device,
            train_loader=train_loaders[1],
            test_loader=test_loaders[1],
            target_create_fn=target_create_fns[1],
            config_args=args,
        )

        # Save the pan model
        torch.save(
            pan1.state_dict(),
            args.output_dir
            + f"pan_{args.pan_type}_{args.dataset}({args.d1})_{args.arch}_{args.seeds[i]}",
        )
        torch.save(
            pan2.state_dict(),
            args.output_dir
            + f"pan_{args.pan_type}_{args.dataset}({args.d2})_{args.arch}_{args.seeds[i]}",
        )

        # save the results in list first
        pan1_results.append(
            {
                "iteration": i,
                "seed": args.seeds[i],
                "loss": pan1_test_loss,
                "acc": pan1_acc,
            }
        )
        pan2_results.append(
            {
                "iteration": i,
                "seed": args.seeds[i],
                "loss": pan2_test_loss,
                "acc": pan2_acc,
            }
        )

    # Save all the results
    if args.save_results:
        save_results(
            f"pan_{args.pan_type}_{args.dataset}({args.d1})_{args.arch}",
            pan1_results,
            args.results_dir,
        )
        save_results(
            f"pan_{args.pan_type}_{args.dataset}({args.d2})_{args.arch}",
            pan2_results,
            args.results_dir,
        )
Exemple #4
0
def main(args):
    # Initialize arguments based on dataset chosen
    if args.dataset == "disjoint_mnist":
        test_loader = mnist_combined_test_loader(args.test_batch_size)
        args.d1 = "first5_mnist"
        args.d2 = "last5_mnist"
        args.m1_input_channel = 1
        args.m2_input_channel = 1
        args.output_size = 5
        m = mnist
    elif args.dataset == "mnist_cifar10":
        test_loader = [
            dual_channel_mnist_test_loader(args.test_batch_size),
            dual_channel_cifar10_test_loader(args.test_batch_size),
        ]
        args.d1 = "mnist"
        args.d2 = "cifar10"
        args.m1_input_channel = 1
        args.m2_input_channel = 3
        args.output_size = 10
        m = mnist_cifar10

    # Initialize models based on architecture chosen
    if args.arch == "lenet5":
        arch = LeNet5
        args.feature_size = 120
    elif args.arch == "lenet5_halfed":
        arch = LeNet5Halfed
        args.feature_size = 60
    elif args.arch == "resnet18":
        arch = ResNet18
        args.feature_size = 512

    # Initialize logits statistics function
    if args.experiment == "logits_statistics":
        experiment = m.logits_statistics
    elif args.experiment == "multi_pass_aug_mean":
        experiment = m.multi_pass_aug_mean
    elif args.experiment == "multi_pass_aug_voting":
        experiment = m.multi_pass_aug_voting
    elif args.experiment == "smart_coord":
        experiment = m.smart_coordinator

    # Pan settings
    if args.pan_type == "feature":
        pan_input_size = args.feature_size
        pan_arch = PAN
    elif args.pan_type == "logits":
        pan_input_size = args.output_size
        pan_arch = PAN
    elif args.pan_type == "agnostic_feature":
        pan_input_size = 3
        pan_arch = AgnosticPAN
    elif args.pan_type == "agnostic_logits":
        pan_input_size = 3
        pan_arch = AgnosticPAN

    # Running the test
    print(f"Dataset: {args.dataset}")
    print(f"Model: {args.arch}")
    results = []

    for i in range(len(args.seeds)):
        seed = args.seeds[i]
        np.random.seed(seed)
        torch.manual_seed(seed)
        print(f"\nIteration: {i+1}, Seed: {seed}")

        # Load models
        model1 = arch(input_channel=args.m1_input_channel,
                      output_size=args.output_size).to(args.device)
        model1.load_state_dict(
            torch.load(
                args.output_dir + f"{args.d1}_{args.arch}_{args.seeds[i]}",
                map_location=torch.device("cpu"),
            ))
        model2 = arch(input_channel=args.m2_input_channel,
                      output_size=args.output_size).to(args.device)
        model2.load_state_dict(
            torch.load(
                args.output_dir + f"{args.d2}_{args.arch}_{args.seeds[i]}",
                map_location=torch.device("cpu"),
            ))

        # Running the experiment
        if args.experiment == "smart_coord":
            pan1 = pan_arch(input_size=pan_input_size).to(args.device)
            pan1.load_state_dict(
                torch.load(
                    args.pan_dir +
                    f"pan_{args.pan_type}_{args.dataset}({args.d1})_{args.arch}_{args.seeds[i]}",
                    map_location=torch.device("cpu"),
                ))
            pan2 = pan_arch(input_size=pan_input_size).to(args.device)
            pan2.load_state_dict(
                torch.load(
                    args.pan_dir +
                    f"pan_{args.pan_type}_{args.dataset}({args.d2})_{args.arch}_{args.seeds[i]}",
                    map_location=torch.device("cpu"),
                ))
            result = experiment(args, model1, model2, pan1, pan2, device,
                                test_loader)
        else:
            result = experiment(args, model1, model2, device, test_loader)

        # Adding more info to the result to be saved
        for r in result:
            r.update({"iteration": i, "seed": args.seeds[i]})
        results.extend(result)

    # Save the results
    if args.save_results and args.experiment == "smart_coord":
        save_results(
            f"{args.dataset}_{args.arch}_{args.pan_type}",
            results,
            f"{args.results_dir}{args.experiment}/",
        )
    elif args.save_results:
        save_results(
            f"{args.dataset}_{args.arch}",
            results,
            f"{args.results_dir}{args.experiment}/",
        )
def train_fpan(args):
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Initialize arguments based on the chosen dataset
    if args.fpan_data == "disjoint_mnist":
        train_loaders = [
            mnist_combined_train_loader_noshuffle(args.batch_size),
            mnist_combined_train_loader_noshuffle(args.batch_size),
        ]
        test_loaders = mnist_combined_test_loader(args.test_batch_size)
        args.expert = ["first5_mnist", "last5_mnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 5
        args.expert_arch = ["lenet5", "lenet5"]
        expert_arch = [LeNet5, LeNet5]
        target_create_fns = craft_disjoint_mnist_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "mnist_cifar10":
        train_loaders = [
            mnist_cifar10_single_channel_train_loader_noshuffle(
                args.batch_size),
            mnist_cifar10_3_channel_train_loader_noshuffle(args.batch_size),
        ]
        test_loaders = mnist_cifar10_single_channel_test_loader(
            args.test_batch_size)
        args.expert = ["mnist", "cifar10"]
        args.expert_input_channel = [1, 3]
        args.expert_output_size = 10
        args.expert_arch = ["lenet5", "resnet18"]
        expert_arch = [LeNet5, ResNet18]
        target_create_fns = craft_mnist_cifar10_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "fmnist_kmnist":
        train_loaders = [
            fmnist_kmnist_train_loader_noshuffle(args.batch_size),
            fmnist_kmnist_train_loader_noshuffle(args.batch_size),
        ]
        test_loaders = fmnist_kmnist_test_loader(args.test_batch_size)
        args.expert = ["fmnist", "kmnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 10
        args.expert_arch = ["resnet18", "resnet18"]
        expert_arch = [ResNet18, ResNet18]
        target_create_fns = craft_fmnist_kmnist_target
        fpan_output_size = 2
        fpan_input_channel = 1

    if args.fpan_arch == "resnet18":
        fpan_arch = ResNet18
    elif args.fpan_arch == "lenet5":
        fpan_arch = LeNet5
    elif args.fpan_arch == "lenet5_halfed":
        fpan_arch = LeNet5Halfed

    # Initialize arguments based on the dataset that UPAN was trained on
    if args.upan_data == "disjoint_mnist":
        args.model_arch = ["lenet5", "lenet5"]
        args.model_output_size = 5
    elif args.upan_data == "mnist_cifar10":
        args.model_arch = ["lenet5", "resnet18"]
        args.model_output_size = 10
    elif args.upan_data == "fmnist_kmnist":
        args.model_arch = ["resnet18", "resnet18"]
        args.model_output_size = 10

    # Create the directory for saving if it does not exist
    create_op_dir(args.output_dir)

    print(f"\nFPAN Dataset: {args.fpan_data}")
    print(f"FPAN arch: {args.fpan_arch}")
    print(f"UPAN Dataset: {args.upan_data}")
    print(f"UPAN type: {args.upan_type}\n")

    fpan_results = []
    for i in range(len(args.seeds)):
        print(f"Iteration {i+1}, Seed {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Train FPAN model
        fpan, fpan_test_loss, fpan_acc = train_model(
            fpan=fpan_arch(input_channel=fpan_input_channel,
                           output_size=fpan_output_size).to(device),
            trial=i,
            device=device,
            expert_arch=expert_arch,
            train_loader=train_loaders,
            test_loader=test_loaders,
            target_create_fn=target_create_fns,
            config_args=args,
        )

        # Save the FPAN model
        torch.save(
            fpan.state_dict(),
            args.output_dir +
            f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type}){args.seeds[i]}",
        )

        # Save the results in list first
        fpan_results.append({
            "iteration": i,
            "seed": args.seeds[i],
            "loss": fpan_test_loss,
            "acc": fpan_acc,
        })

    # Save all the results
    if args.save_results:
        save_results(
            f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type})",
            fpan_results,
            args.results_dir,
        )
def train_upan(args):
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Initialize arguments based on the chosen dataset
    if args.dataset == "disjoint_mnist":
        train_loaders = [
            mnist_combined_train_loader(args.batch_size),
            mnist_combined_train_loader(args.batch_size),
        ]
        args.train_expert = ["first5_mnist", "last5_mnist"]
        args.train_input_channel = [1, 1]
        args.train_output_size = 5
        args.train_arch = ["lenet5", "lenet5"]
        train_arch = [LeNet5, LeNet5]
        target_create_fns = [craft_first_5_target, craft_last_5_target]
    elif args.dataset == "mnist_cifar10":
        train_loaders = [
            mnist_cifar10_single_channel_train_loader(args.batch_size),
            mnist_cifar10_3_channel_train_loader(args.batch_size),
        ]
        args.train_expert = ["mnist", "cifar10"]
        args.train_input_channel = [1, 3]
        args.train_output_size = 10
        args.train_arch = ["lenet5", "resnet18"]
        train_arch = [LeNet5, ResNet18]
        target_create_fns = [craft_mnist_target, craft_cifar10_target]
    elif args.dataset == "fmnist_kmnist":
        train_loaders = [
            fmnist_kmnist_train_loader(args.batch_size),
            fmnist_kmnist_train_loader(args.batch_size),
        ]
        args.train_expert = ["fmnist", "kmnist"]
        args.train_input_channel = [1, 1]
        args.train_output_size = 10
        args.train_arch = ["resnet18", "resnet18"]
        train_arch = [ResNet18, ResNet18]
        target_create_fns = [craft_fmnist_target, craft_kmnist_target]

    # Initialize arguments based on the chosen testset
    if args.testset == "disjoint_mnist":
        test_loaders = [
            mnist_combined_test_loader(args.test_batch_size),
            mnist_combined_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["first5_mnist", "last5_mnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 5
        args.test_arch = ["lenet5", "lenet5"]
        test_arch = [LeNet5, LeNet5]
        test_target_create_fns = [craft_first_5_target, craft_last_5_target]
    elif args.testset == "mnist_cifar10":
        test_loaders = [
            mnist_cifar10_single_channel_test_loader(args.test_batch_size),
            mnist_cifar10_3_channel_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["mnist", "cifar10"]
        args.test_input_channel = [1, 3]
        args.test_output_size = 10
        args.test_arch = ["lenet5", "resnet18"]
        test_arch = [LeNet5, ResNet18]
        test_target_create_fns = [craft_mnist_target, craft_cifar10_target]
    elif args.testset == "fmnist_kmnist":
        test_loaders = [
            fmnist_kmnist_test_loader(args.test_batch_size),
            fmnist_kmnist_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["fmnist", "kmnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 10
        args.test_arch = ["resnet18", "resnet18"]
        test_arch = [ResNet18, ResNet18]
        test_target_create_fns = [craft_fmnist_target, craft_kmnist_target]

    # Initialize UPAN based on its type
    if args.upan_type == "logits":
        upan_input_size = args.train_output_size  # output size of expert
        upan_arch = PAN
    elif args.upan_type == "agnostic_logits":
        upan_input_size = 5  # number of statistical functions used
        upan_arch = AgnosticPAN

    # Create the directory for saving if it does not exist
    create_op_dir(args.output_dir)

    print(f"\nDataset: {args.dataset}")
    print(f"Testset: {args.testset}")
    print(f"UPAN type: {args.upan_type}\n")

    upan_results = []
    for i in range(len(args.seeds)):
        print(f"Iteration {i+1}, Seed {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Train UPAN model
        upan, upan_test_loss, upan_acc = train_model(
            upan=upan_arch(input_size=upan_input_size).to(device),
            trial=i,
            train_arch=train_arch,
            test_arch=test_arch,
            device=device,
            train_loader=train_loaders,
            test_loader=test_loaders,
            target_create_fn=target_create_fns,
            test_target_create_fn=test_target_create_fns,
            config_args=args,
        )

        # Save the UPAN model
        torch.save(
            upan.state_dict(),
            args.output_dir +
            f"upan_{args.upan_type}_{args.dataset}{args.train_arch}_{args.seeds[i]}",
        )

        # Save the results in list first
        upan_results.append({
            "iteration": i,
            "seed": args.seeds[i],
            "loss": upan_test_loss,
            "acc": upan_acc,
        })

    # Save all the results
    if args.save_results:
        save_results(
            f"upan_{args.upan_type}_{args.dataset}{args.train_arch}_{args.testset}{args.test_arch}",
            upan_results,
            args.results_dir,
        )