Пример #1
0
def load_omniglot_experiment_datasets(context: ReptileExperimentContext):
    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    omniglot_train_datasets, omniglot_test_datasets = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=context.num_clients_train,
        num_clients_test=context.num_clients_test,
        num_classes_per_client=context.num_classes_per_client,
        num_shots_per_class=context.num_shots_per_class,
        inner_batch_size=context.inner_batch_size)
    return omniglot_train_datasets, omniglot_test_datasets
Пример #2
0
def run_reptile_experiment(
    name,
    dataset,
    swap_labels,
    classes,
    shots,
    seed,
    model_class,
    sgd,
    adam_betas,
    num_clients_train,
    num_clients_test,
    meta_batch_size,
    num_meta_steps,
    meta_learning_rate_initial,
    meta_learning_rate_final,
    eval_interval,
    num_eval_clients_training,
    do_final_evaluation,
    num_eval_clients_final,
    inner_batch_size,
    inner_learning_rate,
    num_inner_steps,
    num_inner_steps_eval,
    mean=None,
    std=None
):
    fix_random_seeds(seed)
    fed_dataset_test = None
    if dataset == 'femnist':
        fed_dataset_train = load_femnist_dataset(
            data_dir=str((REPO_ROOT / 'data').absolute()),
            num_clients=num_clients_train,
            batch_size=inner_batch_size,
            random_seed=seed
        )
    elif dataset == 'omniglot':
        fed_dataset_train, fed_dataset_test = load_omniglot_datasets(
            data_dir=str((REPO_ROOT / 'data' / 'omniglot').absolute()),
            num_clients_train=num_clients_train,
            num_clients_test=num_clients_test,
            num_classes_per_client=classes,
            num_shots_per_class=shots,
            inner_batch_size=inner_batch_size,
            random_seed=seed
        )
    elif dataset == 'ham10k':
        fed_dataset_train = load_ham10k_federated(partitions=num_clients_train, batch_size=inner_batch_size, mean=mean, std=std)
    else:
        raise ValueError(f'dataset "{dataset}" unknown')

    #data_distribution_logged = False
    for lr in inner_learning_rate:
        for _is in num_inner_steps:
            reptile_context = ReptileExperimentContext(
                name=name,
                dataset_name=dataset,
                swap_labels=swap_labels,
                num_classes_per_client=classes,
                num_shots_per_class=shots,
                seed=seed,
                model_class=model_class,
                sgd=sgd,
                adam_betas=adam_betas,
                num_clients_train=num_clients_train,
                num_clients_test=num_clients_test,
                meta_batch_size=meta_batch_size,
                num_meta_steps=num_meta_steps,
                meta_learning_rate_initial=meta_learning_rate_initial,
                meta_learning_rate_final=meta_learning_rate_final,
                eval_interval=eval_interval,
                num_eval_clients_training=num_eval_clients_training,
                do_final_evaluation=do_final_evaluation,
                num_eval_clients_final=num_eval_clients_final,
                inner_batch_size=inner_batch_size,
                inner_learning_rate=lr,
                num_inner_steps=_is,
                num_inner_steps_eval=_is
            )

            experiment_specification = f'{reptile_context}'
            experiment_logger = create_tensorboard_logger(
                reptile_context.name, experiment_specification
            )
            reptile_context.experiment_logger = experiment_logger

            log_after_round_evaluation_fns = [
                partial(log_after_round_evaluation, experiment_logger)
            ]
            run_reptile(
                context=reptile_context,
                dataset_train=fed_dataset_train,
                dataset_test=fed_dataset_test,
                initial_model_state=None,
                after_round_evaluation=log_after_round_evaluation_fns
            )
def run_reptile(context: ExperimentContext, initial_model_state=None):

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotLightning,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(
        context.name, "dataloading_ours;models_ours")

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'

    #######
    tf.disable_eager_execution()
    #######

    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        random_seed=RANDOM_SEED)

    # Prepare ModelArgs for task training
    inner_optimizer_args = OptimizerArgs(
        optimizer_class=reptile_args.inner_optimizer,
        lr=reptile_args.inner_learning_rate,
        betas=(0, 0.999))
    inner_model_args = ModelArgs(model_class=reptile_args.model,
                                 optimizer_args=inner_optimizer_args,
                                 num_classes=num_classes_per_client)
    dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD)
    meta_model_args = ModelArgs(reptile_args.model,
                                dummy_optimizer_args,
                                num_classes=num_classes_per_client)
    """
    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = []
    for c in omniglot_train_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_train_clients.train_data_local_dict[c],
            num_train_samples=omniglot_train_clients.data_local_train_num_dict[c],
            test_dataloader=omniglot_train_clients.test_data_local_dict[c],
            num_test_samples=omniglot_train_clients.data_local_test_num_dict[c],
            lightning_logger=experiment_logger
        )
        checkpoint_callback = ModelCheckpoint(
            filepath=str(client.get_checkpoint_path(suffix='cb').absolute()))
        client.set_trainer_callbacks([checkpoint_callback])
        train_clients.append(client)
    test_clients = []
    for c in omniglot_test_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_test_clients.train_data_local_dict[c],
            num_train_samples=omniglot_test_clients.data_local_train_num_dict[c],
            test_dataloader=omniglot_test_clients.test_data_local_dict[c],
            num_test_samples=omniglot_test_clients.data_local_test_num_dict[c],
            lightning_logger=experiment_logger
        )
        test_clients.append(client)

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=meta_model_args,
        context=context,
        initial_model_state=initial_model_state
    )"""

    #torch_model = OmniglotModel(num_classes=num_classes_per_client)
    torch_model = OmniglotLightning(participant_name='global_model',
                                    **inner_model_args.kwargs)
    #torch_optimizer = inner_optimizer_args.optimizer_class(
    #    torch_model.parameters(),
    #    **inner_optimizer_args.optimizer_kwargs
    #)

    reptile = Reptile(global_model=torch_model,
                      model_kwargs=inner_model_args.kwargs,
                      inner_iterations=reptile_args.num_inner_steps,
                      inner_iterations_eval=reptile_args.num_inner_steps_eval)

    for i in range(reptile_args.num_meta_steps):
        frac_done = i / reptile_args.num_meta_steps
        cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + (
            1 - frac_done) * reptile_args.meta_learning_rate_initial

        meta_batch = {
            k: omniglot_train_clients.train_data_local_dict[k]
            for k in cyclerange(
                i * reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict), (i + 1) *
                reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict),
                len(omniglot_train_clients.train_data_local_dict))
        }

        reptile.train_step(meta_batch=meta_batch,
                           meta_step_size=cur_meta_step_size)

        if i % eval_iters == 0:
            accuracies = []
            k = RANDOM.randrange(
                len(omniglot_train_clients.train_data_local_dict))
            train_train = omniglot_train_clients.train_data_local_dict[k]
            train_test = omniglot_train_clients.test_data_local_dict[k]
            k = RANDOM.randrange(
                len(omniglot_test_clients.train_data_local_dict))
            test_train = omniglot_test_clients.train_data_local_dict[k]
            test_test = omniglot_test_clients.test_data_local_dict[k]

            for train_dl, test_dl in [(train_train, train_test),
                                      (test_train, test_test)]:
                correct = reptile.evaluate(train_dl, test_dl)
                accuracies.append(correct / num_classes_per_client)
            print('batch %d: train=%f test=%f' %
                  (i, accuracies[0], accuracies[1]))

            # Write to TensorBoard
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                accuracies[0],
                global_step=i)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                accuracies[1],
                global_step=i)
    """
Пример #4
0
def create_client_batches(clients: List[ReptileClient],
                          batch_size: int) -> List[List[ReptileClient]]:
    if batch_size == -1:
        client_batches = [clients]
    else:
        client_batches = [
            clients[i:i + batch_size]
            for i in range(0, len(clients), batch_size)
        ]
    return client_batches

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotModel,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(
        context.name, "dataloading_ours;models_ours")

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'

    #######
    tf.disable_eager_execution()
    #######

    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        random_seed=RANDOM_SEED)

    # Prepare ModelArgs for task training
    inner_optimizer_args = OptimizerArgs(
        optimizer_class=reptile_args.inner_optimizer,
        lr=reptile_args.inner_learning_rate,
        betas=(0, 0.999))
    inner_model_args = ModelArgs(reptile_args.model,
                                 inner_optimizer_args,
                                 num_classes=num_classes_per_client)
    dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD)
    meta_model_args = ModelArgs(reptile_args.model,
                                dummy_optimizer_args,
                                num_classes=num_classes_per_client)
    """
    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = initialize_reptile_clients(context, train_datasets)
    test_clients = initialize_reptile_clients(context, test_datasets)

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=context.meta_model_args,
        context=context,
        initial_model_state=initial_model_state
    )"""

    torch_model = OmniglotModel(num_classes=num_classes_per_client)
    torch_optimizer = inner_optimizer_args.optimizer_class(
        torch_model.parameters(), **inner_optimizer_args.optimizer_kwargs)

    reptile = Reptile(model=torch_model,
                      optimizer=torch_optimizer,
                      inner_iterations=reptile_args.num_inner_steps,
                      inner_iterations_eval=reptile_args.num_inner_steps_eval)

    for i in range(reptile_args.num_meta_steps):
        frac_done = i / reptile_args.num_meta_steps
        cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + (
            1 - frac_done) * reptile_args.meta_learning_rate_initial

        meta_batch = {
            k: omniglot_train_clients.train_data_local_dict[k]
            for k in cyclerange(
                i * reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict), (i + 1) *
                reptile_args.meta_batch_size %
                len(omniglot_train_clients.train_data_local_dict),
                len(omniglot_train_clients.train_data_local_dict))
        }

        reptile.train_step(meta_batch=meta_batch,
                           meta_step_size=cur_meta_step_size)

        if i % eval_iters == 0:
            accuracies = []
            k = RANDOM.randrange(
                len(omniglot_train_clients.train_data_local_dict))
            train_train = omniglot_train_clients.train_data_local_dict[k]
            train_test = omniglot_train_clients.test_data_local_dict[k]
            k = RANDOM.randrange(
                len(omniglot_test_clients.train_data_local_dict))
            test_train = omniglot_test_clients.train_data_local_dict[k]
            test_test = omniglot_test_clients.test_data_local_dict[k]

            for train_dl, test_dl in [(train_train, train_test),
                                      (test_train, test_test)]:
                correct = reptile.evaluate(train_dl, test_dl)
                accuracies.append(correct / num_classes_per_client)
            print('batch %d: train=%f test=%f' %
                  (i, accuracies[0], accuracies[1]))

            # Write to TensorBoard
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                accuracies[0],
                global_step=i)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                accuracies[1],
                global_step=i)
    """
def run_reptile(context: str, initial_model_state=None):

    args = argument_parser().parse_args()
    RANDOM = random.Random(args.seed)

    # TODO: Possibly implement logic using ReptileExperimentContext
    reptile_args = ReptileTrainingArgs(
        model_class=OmniglotLightning,
        sgd=args.sgd,
        inner_learning_rate=args.learning_rate,
        num_inner_steps=args.inner_iters,
        num_inner_steps_eval=args.eval_iters,
        log_every_n_steps=3,
        meta_learning_rate_initial=args.meta_step,
        meta_learning_rate_final=args.meta_step_final,
        num_classes_per_client=args.classes
    )
    experiment_logger = create_tensorboard_logger(
        'reptile',
        (
            f"{context};seed{args.seed};"
            f"train-clients{args.train_clients};"
            f"{args.classes}-way{args.shots}-shot;"
            f"ib{args.inner_batch}ii{args.inner_iters}"
            f"ilr{str(args.learning_rate).replace('.', '')}"
            f"ms{str(args.meta_step).replace('.', '')}"
            f"mb{args.meta_batch}ei{args.eval_iters}"
            f"{'sgd' if args.sgd else 'adam'}"
        )
    )

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=args.train_clients,
        num_clients_test=args.test_clients,
        num_classes_per_client=args.classes,
        num_shots_per_class=args.shots,
        inner_batch_size=args.inner_batch,
        random_seed=args.seed
    )

    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = initialize_clients(omniglot_train_clients, reptile_args.get_inner_model_args(), context,
                                       experiment_logger)
    test_clients = initialize_clients(omniglot_test_clients, reptile_args.get_inner_model_args(), context,
                                      experiment_logger)

    # Set up server
    server = ReptileServer(
        participant_name='initial_server',
        model_args=reptile_args.get_meta_model_args(),
        context=context,  # TODO: Change to ReptileExperimentContext
        initial_model_state=initial_model_state
    )

    # Perform training
    for i in range(args.meta_iters):
        if args.meta_batch == -1:
            meta_batch = train_clients
        else:
            meta_batch = [
                train_clients[k] for k in cyclerange(
                    i*args.meta_batch % len(train_clients),
                    (i+1)*args.meta_batch % len(train_clients),
                    len(train_clients)
                )
            ]
        # Meta training step
        reptile_train_step(
            aggregator=server,
            participants=meta_batch,
            inner_training_args=reptile_args.get_inner_training_args(),
            meta_training_args=reptile_args.get_meta_training_args(
                frac_done=i / args.meta_iters
            )
        )

        # Evaluation on train and test clients
        if i % args.eval_interval == 0:
            # train-test set
            # Pick one train client at random and test on it
            k = RANDOM.randrange(len(train_clients))
            client = [train_clients[k]]
            reptile_train_step(
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(eval=True),
                evaluation_mode=True
            )
            result = evaluate_local_models(participants=client)
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                torch.mean(result.get('test/acc')),
                global_step=i + 1
            )
            # test-test set
            # Pick one test client at random and test on it
            k = RANDOM.randrange(len(test_clients))
            client = [test_clients[k]]
            reptile_train_step(
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(eval=True),
                evaluation_mode=True
            )
            result = evaluate_local_models(participants=client)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                torch.mean(result.get('test/acc')),
                global_step=i + 1
            )
        logger.info('finished training round')

    # Final evaluation on a sample of train/test clients
    for label, client_set in zip(['Train', 'Test'], [train_clients, test_clients]):
        eval_sample = RANDOM.sample(client_set, args.eval_samples)
        reptile_train_step(
            aggregator=server,
            participants=eval_sample,
            inner_training_args=reptile_args.get_inner_training_args(eval=True),
            evaluation_mode=True
        )
        result = evaluate_local_models(participants=eval_sample)
        log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'),
                         experiment_logger, global_step=args.meta_iters)
        experiment_logger.experiment.add_scalar(
            f'final_{label}_acc',
            torch.mean(result.get('test/acc')),
            global_step=0
        )
        print(f"{label} accuracy: {torch.mean(result.get('test/acc'))}")
Пример #6
0
def run_reptile(context: ExperimentContext, initial_model_state=None):

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotModel,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(context.name,
                                                  "framework_ours_test_1")

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'

    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        random_seed=RANDOM_SEED)

    # Prepare ModelArgs for task training
    inner_optimizer_args = OptimizerArgs(
        optimizer_class=reptile_args.inner_optimizer,
        lr=reptile_args.inner_learning_rate,
        betas=(0, 0.999))
    inner_model_args = ModelArgs(reptile_args.model,
                                 inner_optimizer_args,
                                 num_classes=num_classes_per_client)
    dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD)
    meta_model_args = ModelArgs(reptile_args.model,
                                dummy_optimizer_args,
                                num_classes=num_classes_per_client)

    ####################
    torch_model = OmniglotModel(num_classes=num_classes_per_client)
    torch_optimizer = inner_optimizer_args.optimizer_class(
        torch_model.parameters(), **inner_optimizer_args.optimizer_kwargs)
    torch_criterion = torch.nn.CrossEntropyLoss()
    ####################

    # Set up clients
    # Since we are doing meta-learning, we need separate sets of training and
    # test clients
    train_clients = []
    for c in omniglot_train_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_train_clients.train_data_local_dict[c],
            num_train_samples=omniglot_train_clients.
            data_local_train_num_dict[c],
            test_dataloader=omniglot_train_clients.test_data_local_dict[c],
            num_test_samples=omniglot_train_clients.
            data_local_test_num_dict[c],
            lightning_logger=experiment_logger)
        #checkpoint_callback = ModelCheckpoint(
        #    filepath=str(client.get_checkpoint_path(suffix='cb').absolute()))
        #client.set_trainer_callbacks([checkpoint_callback])
        train_clients.append(client)
    test_clients = []
    for c in omniglot_test_clients.train_data_local_dict.keys():
        client = ReptileClient(
            client_id=str(c),
            model_args=inner_model_args,
            context=context,
            train_dataloader=omniglot_test_clients.train_data_local_dict[c],
            num_train_samples=omniglot_test_clients.
            data_local_train_num_dict[c],
            test_dataloader=omniglot_test_clients.test_data_local_dict[c],
            num_test_samples=omniglot_test_clients.data_local_test_num_dict[c],
            lightning_logger=experiment_logger)
        test_clients.append(client)

    # Set up server
    server = ReptileServer(model_state=copy.deepcopy(torch_model.state_dict()),
                           participant_name='initial_server',
                           model_args=meta_model_args,
                           context=context,
                           initial_model_state=initial_model_state)

    for i in range(reptile_args.num_meta_steps):

        meta_batch = [
            train_clients[k] for k in cyclerange(
                i * reptile_args.meta_batch_size %
                len(train_clients), (i + 1) * reptile_args.meta_batch_size %
                len(train_clients), len(train_clients))
        ]
        """reptile.train_step(
            meta_batch=meta_batch,
            meta_step_size=cur_meta_step_size
        )"""
        reptile_train_step(
            model=torch_model,
            optimizer=torch_optimizer,
            criterion=torch_criterion,
            aggregator=server,
            participants=meta_batch,
            inner_training_args=reptile_args.get_inner_training_args(),
            meta_training_args=reptile_args.get_meta_training_args(
                frac_done=i / reptile_args.num_meta_steps))

        if i % eval_iters == 0:
            accuracies = []
            # train-test set
            # Pick one train client at random and test on it
            client = [RANDOM.choice(train_clients)]
            reptile_train_step(
                model=torch_model,
                optimizer=torch_optimizer,
                criterion=torch_criterion,
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(
                    eval=True),
                evaluation_mode=True)
            inputs, labels = list(client[0]._test_dataloader)[0]
            torch_model.load_state_dict(client[0].model_state)
            test_preds = torch_model(inputs).argmax(dim=1)
            num_correct = int(
                sum([pred == label
                     for pred, label in zip(test_preds, labels)]))
            ##############result = evaluate_local_models(participants=client)
            accuracies.append(num_correct / num_classes_per_client)
            # log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'),
            #                 experiment_logger, global_step=i+1)
            #experiment_logger.experiment.add_scalar('train-test/acc/{}/mean'.format('global_model'),
            #                                        torch.mean(result.get('test/acc')),
            #                                        global_step=i + 1)
            # test-test set
            # Pick one test client at random and test on it
            client = [RANDOM.choice(test_clients)]
            reptile_train_step(
                model=torch_model,
                optimizer=torch_optimizer,
                criterion=torch_criterion,
                aggregator=server,
                participants=client,
                inner_training_args=reptile_args.get_inner_training_args(
                    eval=True),
                evaluation_mode=True)
            inputs, labels = list(client[0]._test_dataloader)[0]
            torch_model.load_state_dict(client[0].model_state)
            test_preds = torch_model(inputs).argmax(dim=1)
            num_correct = int(
                sum([pred == label
                     for pred, label in zip(test_preds, labels)]))
            ##############result = evaluate_local_models(participants=client)
            accuracies.append(num_correct / num_classes_per_client)
            # log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'),
            #                 experiment_logger, global_step=i+1)

            print('batch %d: train=%f test=%f' %
                  (i, accuracies[0], accuracies[1]))
            experiment_logger.experiment.add_scalar(
                'train-test/acc/{}/mean'.format('global_model'),
                accuracies[0],
                global_step=i)
            experiment_logger.experiment.add_scalar(
                'test-test/acc/{}/mean'.format('global_model'),
                accuracies[1],
                global_step=i)
    """
Пример #7
0
def run_reptile(context: ExperimentContext, initial_model_state=None):

    num_clients_train = 10000
    num_clients_test = 1000
    num_classes_per_client = 5
    num_shots_per_class = 5

    # Every *eval_iters* meta steps, evaluation is performed on one random
    # client in the training and test set, respectively
    eval_iters = 10

    reptile_args = ReptileTrainingArgs(model=OmniglotLightning,
                                       inner_optimizer=optim.Adam,
                                       inner_learning_rate=0.001,
                                       num_inner_steps=5,
                                       num_inner_steps_eval=50,
                                       log_every_n_steps=3,
                                       inner_batch_size=10,
                                       meta_batch_size=5,
                                       meta_learning_rate_initial=1,
                                       meta_learning_rate_final=0,
                                       num_meta_steps=3000)
    experiment_logger = create_tensorboard_logger(context.name, (
        f"nichol_reptile_our_dataloading;{num_clients_train}clients;{num_classes_per_client}"
        f"-way{num_shots_per_class}-shot;"
        f"mlr{str(reptile_args.meta_learning_rate_initial).replace('.', '')}"
        f"ilr{str(reptile_args.inner_learning_rate).replace('.', '')}"
        f"is{reptile_args.num_inner_steps}"))

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    tf.disable_eager_execution()
    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=num_clients_train,
        num_clients_test=num_clients_test,
        num_classes_per_client=num_classes_per_client,
        num_shots_per_class=num_shots_per_class,
        inner_batch_size=reptile_args.inner_batch_size,
        tensorflow=True,
        random_seed=RANDOM_SEED)

    with tf.Session() as sess:
        model = OmniglotModel(num_classes_per_client,
                              learning_rate=reptile_args.inner_learning_rate)
        reptile = Reptile(sess, transductive=True, pre_step_op=weight_decay(1))
        accuracy_ph = tf.placeholder(tf.float32, shape=())
        tf.summary.scalar('accuracy', accuracy_ph)
        merged = tf.summary.merge_all()
        tf.global_variables_initializer().run()
        sess.run(tf.global_variables_initializer())
        for i in range(reptile_args.num_meta_steps):
            frac_done = i / reptile_args.num_meta_steps
            cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + (
                1 - frac_done) * reptile_args.meta_learning_rate_initial

            meta_batch = {
                k: omniglot_train_clients.train_data_local_dict[k]
                for k in cyclerange(
                    i * reptile_args.meta_batch_size %
                    len(omniglot_train_clients.train_data_local_dict), (i +
                                                                        1) *
                    reptile_args.meta_batch_size %
                    len(omniglot_train_clients.train_data_local_dict),
                    len(omniglot_train_clients.train_data_local_dict))
            }

            reptile.train_step(meta_batch=meta_batch,
                               input_ph=model.input_ph,
                               label_ph=model.label_ph,
                               minimize_op=model.minimize_op,
                               inner_iters=reptile_args.num_inner_steps,
                               meta_step_size=cur_meta_step_size)
            if i % eval_iters == 0:
                accuracies = []
                k = RANDOM.randrange(
                    len(omniglot_train_clients.train_data_local_dict))
                train_train = omniglot_train_clients.train_data_local_dict[k]
                train_test = omniglot_train_clients.test_data_local_dict[k]
                k = RANDOM.randrange(
                    len(omniglot_test_clients.train_data_local_dict))
                test_train = omniglot_test_clients.train_data_local_dict[k]
                test_test = omniglot_test_clients.test_data_local_dict[k]

                for train_dl, test_dl in [(train_train, train_test),
                                          (test_train, test_test)]:
                    correct = reptile.evaluate(
                        train_data_loader=train_dl,
                        test_data_loader=test_dl,
                        input_ph=model.input_ph,
                        label_ph=model.label_ph,
                        minimize_op=model.minimize_op,
                        predictions=model.predictions,
                        inner_iters=reptile_args.num_inner_steps_eval)
                    #summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes_per_client})
                    accuracies.append(correct / num_classes_per_client)
                print('batch %d: train=%f test=%f' %
                      (i, accuracies[0], accuracies[1]))

                # Write to TensorBoard
                experiment_logger.experiment.add_scalar(
                    'train-test/acc/{}/mean'.format('global_model'),
                    accuracies[0],
                    global_step=i)
                experiment_logger.experiment.add_scalar(
                    'test-test/acc/{}/mean'.format('global_model'),
                    accuracies[1],
                    global_step=i)
def run_reptile(context: str, initial_model_state=None):

    args = argument_parser().parse_args()
    RANDOM = random.Random(args.seed)

    experiment_logger = create_tensorboard_logger(
        'reptile', (f"{context};seed{args.seed};"
                    f"train-clients{args.train_clients};"
                    f"{args.classes}-way{args.shots}-shot;"
                    f"ib{args.inner_batch}ii{args.inner_iters}"
                    f"ilr{str(args.learning_rate).replace('.', '')}"
                    f"ms{str(args.meta_step).replace('.', '')}"
                    f"mb{args.meta_batch}ei{args.eval_iters}"
                    f"{'sgd' if args.sgd else 'adam'}"))

    # Load and prepare Omniglot data
    data_dir = REPO_ROOT / 'data' / 'omniglot'
    tf.disable_eager_execution()
    omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets(
        str(data_dir.absolute()),
        num_clients_train=args.train_clients,
        num_clients_test=args.test_clients,
        num_classes_per_client=args.classes,
        num_shots_per_class=args.shots,
        inner_batch_size=args.inner_batch,
        tensorflow=True,
        random_seed=args.seed)

    model_kwargs = {'learning_rate': args.learning_rate}
    if args.sgd:
        model_kwargs['optimizer'] = tf.train.GradientDescentOptimizer
    model = OmniglotModel(num_classes=args.classes, **model_kwargs)

    with tf.Session() as sess:
        reptile = ReptileForFederatedData(session=sess,
                                          transductive=True,
                                          pre_step_op=weight_decay(1))
        accuracy_ph = tf.placeholder(tf.float32, shape=())
        tf.summary.scalar('accuracy', accuracy_ph)
        merged = tf.summary.merge_all()
        tf.global_variables_initializer().run()
        sess.run(tf.global_variables_initializer())

        for i in range(args.meta_iters):
            frac_done = i / args.meta_iters
            cur_meta_step_size = frac_done * args.meta_step_final + (
                1 - frac_done) * args.meta_step

            crange = cyclerange(
                start=i * args.meta_batch %
                len(omniglot_train_clients.train_data_local_dict),
                stop=(i + 1) * args.meta_batch %
                len(omniglot_train_clients.train_data_local_dict),
                len=len(omniglot_train_clients.train_data_local_dict))
            #print(f"Meta-step {i}: train clients {crange[0]}-{crange[-1]}")
            meta_batch = {
                k: omniglot_train_clients.train_data_local_dict[k]
                for k in crange
            }
            reptile.train_step(meta_batch=meta_batch,
                               input_ph=model.input_ph,
                               label_ph=model.label_ph,
                               minimize_op=model.minimize_op,
                               inner_iters=args.inner_iters,
                               meta_step_size=cur_meta_step_size)
            if i % args.eval_interval == 0:
                accuracies = []
                k = RANDOM.randrange(
                    len(omniglot_train_clients.train_data_local_dict))
                train_train = omniglot_train_clients.train_data_local_dict[k]
                train_test = omniglot_train_clients.test_data_local_dict[k]
                k = RANDOM.randrange(
                    len(omniglot_test_clients.train_data_local_dict))
                test_train = omniglot_test_clients.train_data_local_dict[k]
                test_test = omniglot_test_clients.test_data_local_dict[k]

                for train_dl, test_dl in [(train_train, train_test),
                                          (test_train, test_test)]:
                    correct = reptile.evaluate(train_data_loader=train_dl,
                                               test_data_loader=test_dl,
                                               input_ph=model.input_ph,
                                               label_ph=model.label_ph,
                                               minimize_op=model.minimize_op,
                                               predictions=model.predictions,
                                               inner_iters=args.eval_iters)
                    #summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes_per_client})
                    accuracies.append(correct / args.classes)
                print('batch %d: train=%f test=%f' %
                      (i, accuracies[0], accuracies[1]))

                # Write to TensorBoard
                experiment_logger.experiment.add_scalar(
                    'train-test/acc/{}/mean'.format('global_model'),
                    accuracies[0],
                    global_step=i)
                experiment_logger.experiment.add_scalar(
                    'test-test/acc/{}/mean'.format('global_model'),
                    accuracies[1],
                    global_step=i)

        # Final evaluation on a sample of training/test clients
        for label, dataset in zip(
            ['Train', 'Test'],
            [omniglot_train_clients, omniglot_test_clients]):
            keys = RANDOM.sample(dataset.train_data_local_dict.keys(),
                                 args.eval_samples)
            train_eval_sample = {
                k: dataset.train_data_local_dict[k]
                for k in keys
            }
            test_eval_sample = {
                k: dataset.test_data_local_dict[k]
                for k in keys
            }
            accuracy = evaluate(sess=sess,
                                model=model,
                                train_dataloaders=train_eval_sample,
                                test_dataloaders=test_eval_sample,
                                num_classes=args.classes,
                                eval_inner_iters=args.eval_iters,
                                transductive=True)
            experiment_logger.experiment.add_scalar(f'final_{label}_acc',
                                                    accuracy,
                                                    global_step=0)
            print(f"{label} accuracy: {accuracy}")