Beispiel #1
0
def main(args):

    # Initialize arguments based on the chosen testset
    if args.testset == "disjoint_mnist":
        test_loaders = [
            mnist_combined_test_loader(args.test_batch_size),
            mnist_combined_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["first5_mnist", "last5_mnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 5
        args.test_arch = ["lenet5", "lenet5"]
        test_arch = [LeNet5, LeNet5]
        fpan_input_channel = 1
        m = mnist
    elif args.testset == "mnist_cifar10":
        test_loaders = [
            mnist_cifar10_single_channel_test_loader(args.test_batch_size),
            mnist_cifar10_3_channel_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["mnist", "cifar10"]
        args.test_input_channel = [1, 3]
        args.test_output_size = 10
        args.test_arch = ["lenet5", "resnet18"]
        test_arch = [LeNet5, ResNet18]
        fpan_input_channel = 1
        m = mnist_cifar10
    elif args.testset == "fmnist_kmnist":
        test_loaders = [
            fmnist_kmnist_test_loader(args.test_batch_size),
            fmnist_kmnist_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["fmnist", "kmnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 10
        args.test_arch = ["resnet18", "resnet18"]
        test_arch = [ResNet18, ResNet18]
        fpan_input_channel = 1
        m = fmnist_kmnist

    # Initialize logits statistics function
    experiment = m.smart_coordinator_fpan

    # FPAN architecture
    if args.fpan_arch == "resnet18":
        fpan_arch = ResNet18
    elif args.fpan_arch == "lenet5":
        fpan_arch = LeNet5
    elif args.fpan_arch == "lenet5_halfed":
        fpan_arch = LeNet5Halfed

    # Running the test
    print(f"Testset: {args.testset}")
    print(f"FPAN arch: {args.fpan_arch}")
    print(f"UPAN Dataset: {args.upan_data}")
    print(f"UPAN type: {args.upan_type}")
    results = []

    for i in range(len(args.seeds)):
        print(f"\nIteration: {i+1}, Seed: {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Load experts
        all_experts = []
        for expert_idx in range(len(test_arch)):
            expert = test_arch[expert_idx](
                input_channel=args.test_input_channel[expert_idx],
                output_size=args.test_output_size).to(device)
            expert.load_state_dict(
                torch.load(
                    args.model_dir +
                    f"{args.test_expert[expert_idx]}_{args.test_arch[expert_idx]}_{args.seeds[i]}",
                    map_location=torch.device(device),
                ))
            all_experts.append(expert)

        # Running the experiment
        fpan = fpan_arch(input_channel=fpan_input_channel,
                         output_size=len(args.test_expert)).to(args.device)
        fpan.load_state_dict(
            torch.load(
                args.fpan_dir +
                f"fpan_{args.testset}({fpan_input_channel})_({args.upan_data}_{args.upan_type}){args.seeds[i]}",
                map_location=torch.device(args.device),
            ))
        result = experiment(args, all_experts[0], all_experts[1], fpan, device,
                            test_loaders)

        # Adding more info to the result to be saved
        for r in result:
            r.update({"iteration": i, "seed": args.seeds[i]})
        results.extend(result)

    # Save the results
    if args.save_results:
        save_results(
            f"fpan_{args.testset}({fpan_input_channel})_({args.upan_data}_{args.upan_type})",
            results,
            f"{args.results_dir}",
        )
def train_fpan(args):
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Initialize arguments based on the chosen dataset
    if args.fpan_data == "disjoint_mnist":
        train_loaders = [
            mnist_combined_train_loader_noshuffle(args.batch_size),
            mnist_combined_train_loader_noshuffle(args.batch_size),
        ]
        test_loaders = mnist_combined_test_loader(args.test_batch_size)
        args.expert = ["first5_mnist", "last5_mnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 5
        args.expert_arch = ["lenet5", "lenet5"]
        expert_arch = [LeNet5, LeNet5]
        target_create_fns = craft_disjoint_mnist_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "mnist_cifar10":
        train_loaders = [
            mnist_cifar10_single_channel_train_loader_noshuffle(
                args.batch_size),
            mnist_cifar10_3_channel_train_loader_noshuffle(args.batch_size),
        ]
        test_loaders = mnist_cifar10_single_channel_test_loader(
            args.test_batch_size)
        args.expert = ["mnist", "cifar10"]
        args.expert_input_channel = [1, 3]
        args.expert_output_size = 10
        args.expert_arch = ["lenet5", "resnet18"]
        expert_arch = [LeNet5, ResNet18]
        target_create_fns = craft_mnist_cifar10_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "fmnist_kmnist":
        train_loaders = [
            fmnist_kmnist_train_loader_noshuffle(args.batch_size),
            fmnist_kmnist_train_loader_noshuffle(args.batch_size),
        ]
        test_loaders = fmnist_kmnist_test_loader(args.test_batch_size)
        args.expert = ["fmnist", "kmnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 10
        args.expert_arch = ["resnet18", "resnet18"]
        expert_arch = [ResNet18, ResNet18]
        target_create_fns = craft_fmnist_kmnist_target
        fpan_output_size = 2
        fpan_input_channel = 1

    if args.fpan_arch == "resnet18":
        fpan_arch = ResNet18
    elif args.fpan_arch == "lenet5":
        fpan_arch = LeNet5
    elif args.fpan_arch == "lenet5_halfed":
        fpan_arch = LeNet5Halfed

    # Initialize arguments based on the dataset that UPAN was trained on
    if args.upan_data == "disjoint_mnist":
        args.model_arch = ["lenet5", "lenet5"]
        args.model_output_size = 5
    elif args.upan_data == "mnist_cifar10":
        args.model_arch = ["lenet5", "resnet18"]
        args.model_output_size = 10
    elif args.upan_data == "fmnist_kmnist":
        args.model_arch = ["resnet18", "resnet18"]
        args.model_output_size = 10

    # Create the directory for saving if it does not exist
    create_op_dir(args.output_dir)

    print(f"\nFPAN Dataset: {args.fpan_data}")
    print(f"FPAN arch: {args.fpan_arch}")
    print(f"UPAN Dataset: {args.upan_data}")
    print(f"UPAN type: {args.upan_type}\n")

    fpan_results = []
    for i in range(len(args.seeds)):
        print(f"Iteration {i+1}, Seed {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Train FPAN model
        fpan, fpan_test_loss, fpan_acc = train_model(
            fpan=fpan_arch(input_channel=fpan_input_channel,
                           output_size=fpan_output_size).to(device),
            trial=i,
            device=device,
            expert_arch=expert_arch,
            train_loader=train_loaders,
            test_loader=test_loaders,
            target_create_fn=target_create_fns,
            config_args=args,
        )

        # Save the FPAN model
        torch.save(
            fpan.state_dict(),
            args.output_dir +
            f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type}){args.seeds[i]}",
        )

        # Save the results in list first
        fpan_results.append({
            "iteration": i,
            "seed": args.seeds[i],
            "loss": fpan_test_loss,
            "acc": fpan_acc,
        })

    # Save all the results
    if args.save_results:
        save_results(
            f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type})",
            fpan_results,
            args.results_dir,
        )
Beispiel #3
0
def train_fpan(args):
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    np.random.seed(0)
    torch.manual_seed(0)

    transform = transforms.Compose([
        lambda x: PIL.Image.fromarray(x),
        lambda x: transforms.ToTensor()(np.array(x)),
        transforms.Normalize((0.5, ), (0.5, )),
    ])

    transform_rgb = transforms.Compose([
        lambda x: PIL.Image.fromarray(x, 'RGB'),
        lambda x: transforms.ToTensor()(np.array(x)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataset = []
    for _ in range(8000):
        x = np.random.normal(0.5, 0.5, (32, 32))  # PIL range = [0, 1]
        x = transform(x)
        dataset.append((x, _))
    gaussian_train_loader_noshuffle = DataLoader(dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False)

    dataset_rgb = []
    for _ in range(8000):
        x = np.random.normal(0.5, 0.5, (32, 32, 3))
        x = transform_rgb(x)
        dataset_rgb.append((x, _))
    gaussian_3_channel_train_loader_noshuffle = DataLoader(
        dataset_rgb, batch_size=args.batch_size, shuffle=False)

    # Initialize arguments based on the chosen dataset
    if args.fpan_data == "disjoint_mnist":
        train_loaders = [
            gaussian_train_loader_noshuffle,
            gaussian_train_loader_noshuffle,
        ]
        test_loaders = mnist_combined_test_loader(args.test_batch_size)
        args.expert = ["first5_mnist", "last5_mnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 5
        args.expert_arch = ["lenet5", "lenet5"]
        expert_arch = [LeNet5, LeNet5]
        target_create_fns = craft_disjoint_mnist_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "mnist_cifar10":
        train_loaders = [
            gaussian_train_loader_noshuffle,
            gaussian_3_channel_train_loader_noshuffle,
        ]
        test_loaders = mnist_cifar10_single_channel_test_loader(
            args.test_batch_size)
        args.expert = ["mnist", "cifar10"]
        args.expert_input_channel = [1, 3]
        args.expert_output_size = 10
        args.expert_arch = ["lenet5", "resnet18"]
        expert_arch = [LeNet5, ResNet18]
        target_create_fns = craft_mnist_cifar10_target
        fpan_output_size = 2
        fpan_input_channel = 1
    elif args.fpan_data == "fmnist_kmnist":
        train_loaders = [
            gaussian_train_loader_noshuffle,
            gaussian_train_loader_noshuffle,
        ]
        test_loaders = fmnist_kmnist_test_loader(args.test_batch_size)
        args.expert = ["fmnist", "kmnist"]
        args.expert_input_channel = [1, 1]
        args.expert_output_size = 10
        args.expert_arch = ["resnet18", "resnet18"]
        expert_arch = [ResNet18, ResNet18]
        target_create_fns = craft_fmnist_kmnist_target
        fpan_output_size = 2
        fpan_input_channel = 1

    if args.fpan_arch == "resnet18":
        fpan_arch = ResNet18
    elif args.fpan_arch == "lenet5":
        fpan_arch = LeNet5
    elif args.fpan_arch == "lenet5_halfed":
        fpan_arch = LeNet5Halfed

    # Initialize arguments based on the dataset that UPAN was trained on
    if args.upan_data == "disjoint_mnist":
        args.model_arch = ["lenet5", "lenet5"]
        args.model_output_size = 5
    elif args.upan_data == "mnist_cifar10":
        args.model_arch = ["lenet5", "resnet18"]
        args.model_output_size = 10
    elif args.upan_data == "fmnist_kmnist":
        args.model_arch = ["resnet18", "resnet18"]
        args.model_output_size = 10

    # Create the directory for saving if it does not exist
    create_op_dir(args.output_dir)

    print(f"\nFPAN Dataset: {args.fpan_data}")
    print(f"FPAN arch: {args.fpan_arch}")
    print(f"UPAN Dataset: {args.upan_data}")
    print(f"UPAN type: {args.upan_type}\n")

    fpan_results = []
    for i in range(len(args.seeds)):
        print(f"Iteration {i+1}, Seed {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Train FPAN model
        fpan, fpan_test_loss, fpan_acc = train_model(
            fpan=fpan_arch(input_channel=fpan_input_channel,
                           output_size=fpan_output_size).to(device),
            trial=i,
            device=device,
            expert_arch=expert_arch,
            train_loader=train_loaders,
            test_loader=test_loaders,
            target_create_fn=target_create_fns,
            config_args=args,
        )

        # Save the FPAN model
        if args.save_models:
            torch.save(
                fpan.state_dict(),
                args.output_dir +
                f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type}){args.seeds[i]}",
            )

        # Save the results in list first
        fpan_results.append({
            "iteration": i,
            "seed": args.seeds[i],
            "loss": fpan_test_loss,
            "acc": fpan_acc,
        })

    # Save all the results
    if args.save_results:
        save_results(
            f"fpan_{args.fpan_data}({fpan_input_channel})_({args.upan_data}_{args.upan_type})",
            fpan_results,
            args.results_dir,
        )
def train_upan(args):
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Initialize arguments based on the chosen dataset
    if args.dataset == "disjoint_mnist":
        train_loaders = [
            mnist_combined_train_loader(args.batch_size),
            mnist_combined_train_loader(args.batch_size),
        ]
        args.train_expert = ["first5_mnist", "last5_mnist"]
        args.train_input_channel = [1, 1]
        args.train_output_size = 5
        args.train_arch = ["lenet5", "lenet5"]
        train_arch = [LeNet5, LeNet5]
        target_create_fns = [craft_first_5_target, craft_last_5_target]
    elif args.dataset == "mnist_cifar10":
        train_loaders = [
            mnist_cifar10_single_channel_train_loader(args.batch_size),
            mnist_cifar10_3_channel_train_loader(args.batch_size),
        ]
        args.train_expert = ["mnist", "cifar10"]
        args.train_input_channel = [1, 3]
        args.train_output_size = 10
        args.train_arch = ["lenet5", "resnet18"]
        train_arch = [LeNet5, ResNet18]
        target_create_fns = [craft_mnist_target, craft_cifar10_target]
    elif args.dataset == "fmnist_kmnist":
        train_loaders = [
            fmnist_kmnist_train_loader(args.batch_size),
            fmnist_kmnist_train_loader(args.batch_size),
        ]
        args.train_expert = ["fmnist", "kmnist"]
        args.train_input_channel = [1, 1]
        args.train_output_size = 10
        args.train_arch = ["resnet18", "resnet18"]
        train_arch = [ResNet18, ResNet18]
        target_create_fns = [craft_fmnist_target, craft_kmnist_target]

    # Initialize arguments based on the chosen testset
    if args.testset == "disjoint_mnist":
        test_loaders = [
            mnist_combined_test_loader(args.test_batch_size),
            mnist_combined_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["first5_mnist", "last5_mnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 5
        args.test_arch = ["lenet5", "lenet5"]
        test_arch = [LeNet5, LeNet5]
        test_target_create_fns = [craft_first_5_target, craft_last_5_target]
    elif args.testset == "mnist_cifar10":
        test_loaders = [
            mnist_cifar10_single_channel_test_loader(args.test_batch_size),
            mnist_cifar10_3_channel_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["mnist", "cifar10"]
        args.test_input_channel = [1, 3]
        args.test_output_size = 10
        args.test_arch = ["lenet5", "resnet18"]
        test_arch = [LeNet5, ResNet18]
        test_target_create_fns = [craft_mnist_target, craft_cifar10_target]
    elif args.testset == "fmnist_kmnist":
        test_loaders = [
            fmnist_kmnist_test_loader(args.test_batch_size),
            fmnist_kmnist_test_loader(args.test_batch_size),
        ]
        args.test_expert = ["fmnist", "kmnist"]
        args.test_input_channel = [1, 1]
        args.test_output_size = 10
        args.test_arch = ["resnet18", "resnet18"]
        test_arch = [ResNet18, ResNet18]
        test_target_create_fns = [craft_fmnist_target, craft_kmnist_target]

    # Initialize UPAN based on its type
    if args.upan_type == "logits":
        upan_input_size = args.train_output_size  # output size of expert
        upan_arch = PAN
    elif args.upan_type == "agnostic_logits":
        upan_input_size = 5  # number of statistical functions used
        upan_arch = AgnosticPAN

    # Create the directory for saving if it does not exist
    create_op_dir(args.output_dir)

    print(f"\nDataset: {args.dataset}")
    print(f"Testset: {args.testset}")
    print(f"UPAN type: {args.upan_type}\n")

    upan_results = []
    for i in range(len(args.seeds)):
        print(f"Iteration {i+1}, Seed {args.seeds[i]}")

        np.random.seed(args.seeds[i])
        torch.manual_seed(args.seeds[i])
        torch.cuda.manual_seed_all(args.seeds[i])
        torch.backends.cudnn.deterministic = True

        # Train UPAN model
        upan, upan_test_loss, upan_acc = train_model(
            upan=upan_arch(input_size=upan_input_size).to(device),
            trial=i,
            train_arch=train_arch,
            test_arch=test_arch,
            device=device,
            train_loader=train_loaders,
            test_loader=test_loaders,
            target_create_fn=target_create_fns,
            test_target_create_fn=test_target_create_fns,
            config_args=args,
        )

        # Save the UPAN model
        torch.save(
            upan.state_dict(),
            args.output_dir +
            f"upan_{args.upan_type}_{args.dataset}{args.train_arch}_{args.seeds[i]}",
        )

        # Save the results in list first
        upan_results.append({
            "iteration": i,
            "seed": args.seeds[i],
            "loss": upan_test_loss,
            "acc": upan_acc,
        })

    # Save all the results
    if args.save_results:
        save_results(
            f"upan_{args.upan_type}_{args.dataset}{args.train_arch}_{args.testset}{args.test_arch}",
            upan_results,
            args.results_dir,
        )