def load_omniglot_experiment_datasets(context: ReptileExperimentContext): # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' omniglot_train_datasets, omniglot_test_datasets = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=context.num_clients_train, num_clients_test=context.num_clients_test, num_classes_per_client=context.num_classes_per_client, num_shots_per_class=context.num_shots_per_class, inner_batch_size=context.inner_batch_size) return omniglot_train_datasets, omniglot_test_datasets
def run_reptile_experiment( name, dataset, swap_labels, classes, shots, seed, model_class, sgd, adam_betas, num_clients_train, num_clients_test, meta_batch_size, num_meta_steps, meta_learning_rate_initial, meta_learning_rate_final, eval_interval, num_eval_clients_training, do_final_evaluation, num_eval_clients_final, inner_batch_size, inner_learning_rate, num_inner_steps, num_inner_steps_eval, mean=None, std=None ): fix_random_seeds(seed) fed_dataset_test = None if dataset == 'femnist': fed_dataset_train = load_femnist_dataset( data_dir=str((REPO_ROOT / 'data').absolute()), num_clients=num_clients_train, batch_size=inner_batch_size, random_seed=seed ) elif dataset == 'omniglot': fed_dataset_train, fed_dataset_test = load_omniglot_datasets( data_dir=str((REPO_ROOT / 'data' / 'omniglot').absolute()), num_clients_train=num_clients_train, num_clients_test=num_clients_test, num_classes_per_client=classes, num_shots_per_class=shots, inner_batch_size=inner_batch_size, random_seed=seed ) elif dataset == 'ham10k': fed_dataset_train = load_ham10k_federated(partitions=num_clients_train, batch_size=inner_batch_size, mean=mean, std=std) else: raise ValueError(f'dataset "{dataset}" unknown') #data_distribution_logged = False for lr in inner_learning_rate: for _is in num_inner_steps: reptile_context = ReptileExperimentContext( name=name, dataset_name=dataset, swap_labels=swap_labels, num_classes_per_client=classes, num_shots_per_class=shots, seed=seed, model_class=model_class, sgd=sgd, adam_betas=adam_betas, num_clients_train=num_clients_train, num_clients_test=num_clients_test, meta_batch_size=meta_batch_size, num_meta_steps=num_meta_steps, meta_learning_rate_initial=meta_learning_rate_initial, meta_learning_rate_final=meta_learning_rate_final, eval_interval=eval_interval, num_eval_clients_training=num_eval_clients_training, do_final_evaluation=do_final_evaluation, num_eval_clients_final=num_eval_clients_final, inner_batch_size=inner_batch_size, inner_learning_rate=lr, num_inner_steps=_is, num_inner_steps_eval=_is ) experiment_specification = f'{reptile_context}' experiment_logger = create_tensorboard_logger( reptile_context.name, experiment_specification ) reptile_context.experiment_logger = experiment_logger log_after_round_evaluation_fns = [ partial(log_after_round_evaluation, experiment_logger) ] run_reptile( context=reptile_context, dataset_train=fed_dataset_train, dataset_test=fed_dataset_test, initial_model_state=None, after_round_evaluation=log_after_round_evaluation_fns )
def run_reptile(context: ExperimentContext, initial_model_state=None): num_clients_train = 10000 num_clients_test = 1000 num_classes_per_client = 5 num_shots_per_class = 5 eval_iters = 10 reptile_args = ReptileTrainingArgs(model=OmniglotLightning, inner_optimizer=optim.Adam, inner_learning_rate=0.001, num_inner_steps=5, num_inner_steps_eval=50, log_every_n_steps=3, inner_batch_size=10, meta_batch_size=5, meta_learning_rate_initial=1, meta_learning_rate_final=0, num_meta_steps=3000) experiment_logger = create_tensorboard_logger( context.name, "dataloading_ours;models_ours") # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' ####### tf.disable_eager_execution() ####### omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=num_clients_train, num_clients_test=num_clients_test, num_classes_per_client=num_classes_per_client, num_shots_per_class=num_shots_per_class, inner_batch_size=reptile_args.inner_batch_size, random_seed=RANDOM_SEED) # Prepare ModelArgs for task training inner_optimizer_args = OptimizerArgs( optimizer_class=reptile_args.inner_optimizer, lr=reptile_args.inner_learning_rate, betas=(0, 0.999)) inner_model_args = ModelArgs(model_class=reptile_args.model, optimizer_args=inner_optimizer_args, num_classes=num_classes_per_client) dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD) meta_model_args = ModelArgs(reptile_args.model, dummy_optimizer_args, num_classes=num_classes_per_client) """ # Set up clients # Since we are doing meta-learning, we need separate sets of training and # test clients train_clients = [] for c in omniglot_train_clients.train_data_local_dict.keys(): client = ReptileClient( client_id=str(c), model_args=inner_model_args, context=context, train_dataloader=omniglot_train_clients.train_data_local_dict[c], num_train_samples=omniglot_train_clients.data_local_train_num_dict[c], test_dataloader=omniglot_train_clients.test_data_local_dict[c], num_test_samples=omniglot_train_clients.data_local_test_num_dict[c], lightning_logger=experiment_logger ) checkpoint_callback = ModelCheckpoint( filepath=str(client.get_checkpoint_path(suffix='cb').absolute())) client.set_trainer_callbacks([checkpoint_callback]) train_clients.append(client) test_clients = [] for c in omniglot_test_clients.train_data_local_dict.keys(): client = ReptileClient( client_id=str(c), model_args=inner_model_args, context=context, train_dataloader=omniglot_test_clients.train_data_local_dict[c], num_train_samples=omniglot_test_clients.data_local_train_num_dict[c], test_dataloader=omniglot_test_clients.test_data_local_dict[c], num_test_samples=omniglot_test_clients.data_local_test_num_dict[c], lightning_logger=experiment_logger ) test_clients.append(client) # Set up server server = ReptileServer( participant_name='initial_server', model_args=meta_model_args, context=context, initial_model_state=initial_model_state )""" #torch_model = OmniglotModel(num_classes=num_classes_per_client) torch_model = OmniglotLightning(participant_name='global_model', **inner_model_args.kwargs) #torch_optimizer = inner_optimizer_args.optimizer_class( # torch_model.parameters(), # **inner_optimizer_args.optimizer_kwargs #) reptile = Reptile(global_model=torch_model, model_kwargs=inner_model_args.kwargs, inner_iterations=reptile_args.num_inner_steps, inner_iterations_eval=reptile_args.num_inner_steps_eval) for i in range(reptile_args.num_meta_steps): frac_done = i / reptile_args.num_meta_steps cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + ( 1 - frac_done) * reptile_args.meta_learning_rate_initial meta_batch = { k: omniglot_train_clients.train_data_local_dict[k] for k in cyclerange( i * reptile_args.meta_batch_size % len(omniglot_train_clients.train_data_local_dict), (i + 1) * reptile_args.meta_batch_size % len(omniglot_train_clients.train_data_local_dict), len(omniglot_train_clients.train_data_local_dict)) } reptile.train_step(meta_batch=meta_batch, meta_step_size=cur_meta_step_size) if i % eval_iters == 0: accuracies = [] k = RANDOM.randrange( len(omniglot_train_clients.train_data_local_dict)) train_train = omniglot_train_clients.train_data_local_dict[k] train_test = omniglot_train_clients.test_data_local_dict[k] k = RANDOM.randrange( len(omniglot_test_clients.train_data_local_dict)) test_train = omniglot_test_clients.train_data_local_dict[k] test_test = omniglot_test_clients.test_data_local_dict[k] for train_dl, test_dl in [(train_train, train_test), (test_train, test_test)]: correct = reptile.evaluate(train_dl, test_dl) accuracies.append(correct / num_classes_per_client) print('batch %d: train=%f test=%f' % (i, accuracies[0], accuracies[1])) # Write to TensorBoard experiment_logger.experiment.add_scalar( 'train-test/acc/{}/mean'.format('global_model'), accuracies[0], global_step=i) experiment_logger.experiment.add_scalar( 'test-test/acc/{}/mean'.format('global_model'), accuracies[1], global_step=i) """
def create_client_batches(clients: List[ReptileClient], batch_size: int) -> List[List[ReptileClient]]: if batch_size == -1: client_batches = [clients] else: client_batches = [ clients[i:i + batch_size] for i in range(0, len(clients), batch_size) ] return client_batches num_clients_train = 10000 num_clients_test = 1000 num_classes_per_client = 5 num_shots_per_class = 5 eval_iters = 10 reptile_args = ReptileTrainingArgs(model=OmniglotModel, inner_optimizer=optim.Adam, inner_learning_rate=0.001, num_inner_steps=5, num_inner_steps_eval=50, log_every_n_steps=3, inner_batch_size=10, meta_batch_size=5, meta_learning_rate_initial=1, meta_learning_rate_final=0, num_meta_steps=3000) experiment_logger = create_tensorboard_logger( context.name, "dataloading_ours;models_ours") # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' ####### tf.disable_eager_execution() ####### omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=num_clients_train, num_clients_test=num_clients_test, num_classes_per_client=num_classes_per_client, num_shots_per_class=num_shots_per_class, inner_batch_size=reptile_args.inner_batch_size, random_seed=RANDOM_SEED) # Prepare ModelArgs for task training inner_optimizer_args = OptimizerArgs( optimizer_class=reptile_args.inner_optimizer, lr=reptile_args.inner_learning_rate, betas=(0, 0.999)) inner_model_args = ModelArgs(reptile_args.model, inner_optimizer_args, num_classes=num_classes_per_client) dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD) meta_model_args = ModelArgs(reptile_args.model, dummy_optimizer_args, num_classes=num_classes_per_client) """ # Set up clients # Since we are doing meta-learning, we need separate sets of training and # test clients train_clients = initialize_reptile_clients(context, train_datasets) test_clients = initialize_reptile_clients(context, test_datasets) # Set up server server = ReptileServer( participant_name='initial_server', model_args=context.meta_model_args, context=context, initial_model_state=initial_model_state )""" torch_model = OmniglotModel(num_classes=num_classes_per_client) torch_optimizer = inner_optimizer_args.optimizer_class( torch_model.parameters(), **inner_optimizer_args.optimizer_kwargs) reptile = Reptile(model=torch_model, optimizer=torch_optimizer, inner_iterations=reptile_args.num_inner_steps, inner_iterations_eval=reptile_args.num_inner_steps_eval) for i in range(reptile_args.num_meta_steps): frac_done = i / reptile_args.num_meta_steps cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + ( 1 - frac_done) * reptile_args.meta_learning_rate_initial meta_batch = { k: omniglot_train_clients.train_data_local_dict[k] for k in cyclerange( i * reptile_args.meta_batch_size % len(omniglot_train_clients.train_data_local_dict), (i + 1) * reptile_args.meta_batch_size % len(omniglot_train_clients.train_data_local_dict), len(omniglot_train_clients.train_data_local_dict)) } reptile.train_step(meta_batch=meta_batch, meta_step_size=cur_meta_step_size) if i % eval_iters == 0: accuracies = [] k = RANDOM.randrange( len(omniglot_train_clients.train_data_local_dict)) train_train = omniglot_train_clients.train_data_local_dict[k] train_test = omniglot_train_clients.test_data_local_dict[k] k = RANDOM.randrange( len(omniglot_test_clients.train_data_local_dict)) test_train = omniglot_test_clients.train_data_local_dict[k] test_test = omniglot_test_clients.test_data_local_dict[k] for train_dl, test_dl in [(train_train, train_test), (test_train, test_test)]: correct = reptile.evaluate(train_dl, test_dl) accuracies.append(correct / num_classes_per_client) print('batch %d: train=%f test=%f' % (i, accuracies[0], accuracies[1])) # Write to TensorBoard experiment_logger.experiment.add_scalar( 'train-test/acc/{}/mean'.format('global_model'), accuracies[0], global_step=i) experiment_logger.experiment.add_scalar( 'test-test/acc/{}/mean'.format('global_model'), accuracies[1], global_step=i) """
def run_reptile(context: str, initial_model_state=None): args = argument_parser().parse_args() RANDOM = random.Random(args.seed) # TODO: Possibly implement logic using ReptileExperimentContext reptile_args = ReptileTrainingArgs( model_class=OmniglotLightning, sgd=args.sgd, inner_learning_rate=args.learning_rate, num_inner_steps=args.inner_iters, num_inner_steps_eval=args.eval_iters, log_every_n_steps=3, meta_learning_rate_initial=args.meta_step, meta_learning_rate_final=args.meta_step_final, num_classes_per_client=args.classes ) experiment_logger = create_tensorboard_logger( 'reptile', ( f"{context};seed{args.seed};" f"train-clients{args.train_clients};" f"{args.classes}-way{args.shots}-shot;" f"ib{args.inner_batch}ii{args.inner_iters}" f"ilr{str(args.learning_rate).replace('.', '')}" f"ms{str(args.meta_step).replace('.', '')}" f"mb{args.meta_batch}ei{args.eval_iters}" f"{'sgd' if args.sgd else 'adam'}" ) ) # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=args.train_clients, num_clients_test=args.test_clients, num_classes_per_client=args.classes, num_shots_per_class=args.shots, inner_batch_size=args.inner_batch, random_seed=args.seed ) # Set up clients # Since we are doing meta-learning, we need separate sets of training and # test clients train_clients = initialize_clients(omniglot_train_clients, reptile_args.get_inner_model_args(), context, experiment_logger) test_clients = initialize_clients(omniglot_test_clients, reptile_args.get_inner_model_args(), context, experiment_logger) # Set up server server = ReptileServer( participant_name='initial_server', model_args=reptile_args.get_meta_model_args(), context=context, # TODO: Change to ReptileExperimentContext initial_model_state=initial_model_state ) # Perform training for i in range(args.meta_iters): if args.meta_batch == -1: meta_batch = train_clients else: meta_batch = [ train_clients[k] for k in cyclerange( i*args.meta_batch % len(train_clients), (i+1)*args.meta_batch % len(train_clients), len(train_clients) ) ] # Meta training step reptile_train_step( aggregator=server, participants=meta_batch, inner_training_args=reptile_args.get_inner_training_args(), meta_training_args=reptile_args.get_meta_training_args( frac_done=i / args.meta_iters ) ) # Evaluation on train and test clients if i % args.eval_interval == 0: # train-test set # Pick one train client at random and test on it k = RANDOM.randrange(len(train_clients)) client = [train_clients[k]] reptile_train_step( aggregator=server, participants=client, inner_training_args=reptile_args.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=client) experiment_logger.experiment.add_scalar( 'train-test/acc/{}/mean'.format('global_model'), torch.mean(result.get('test/acc')), global_step=i + 1 ) # test-test set # Pick one test client at random and test on it k = RANDOM.randrange(len(test_clients)) client = [test_clients[k]] reptile_train_step( aggregator=server, participants=client, inner_training_args=reptile_args.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=client) experiment_logger.experiment.add_scalar( 'test-test/acc/{}/mean'.format('global_model'), torch.mean(result.get('test/acc')), global_step=i + 1 ) logger.info('finished training round') # Final evaluation on a sample of train/test clients for label, client_set in zip(['Train', 'Test'], [train_clients, test_clients]): eval_sample = RANDOM.sample(client_set, args.eval_samples) reptile_train_step( aggregator=server, participants=eval_sample, inner_training_args=reptile_args.get_inner_training_args(eval=True), evaluation_mode=True ) result = evaluate_local_models(participants=eval_sample) log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'), experiment_logger, global_step=args.meta_iters) experiment_logger.experiment.add_scalar( f'final_{label}_acc', torch.mean(result.get('test/acc')), global_step=0 ) print(f"{label} accuracy: {torch.mean(result.get('test/acc'))}")
def run_reptile(context: ExperimentContext, initial_model_state=None): num_clients_train = 10000 num_clients_test = 1000 num_classes_per_client = 5 num_shots_per_class = 5 eval_iters = 10 reptile_args = ReptileTrainingArgs(model=OmniglotModel, inner_optimizer=optim.Adam, inner_learning_rate=0.001, num_inner_steps=5, num_inner_steps_eval=50, log_every_n_steps=3, inner_batch_size=10, meta_batch_size=5, meta_learning_rate_initial=1, meta_learning_rate_final=0, num_meta_steps=3000) experiment_logger = create_tensorboard_logger(context.name, "framework_ours_test_1") # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=num_clients_train, num_clients_test=num_clients_test, num_classes_per_client=num_classes_per_client, num_shots_per_class=num_shots_per_class, inner_batch_size=reptile_args.inner_batch_size, random_seed=RANDOM_SEED) # Prepare ModelArgs for task training inner_optimizer_args = OptimizerArgs( optimizer_class=reptile_args.inner_optimizer, lr=reptile_args.inner_learning_rate, betas=(0, 0.999)) inner_model_args = ModelArgs(reptile_args.model, inner_optimizer_args, num_classes=num_classes_per_client) dummy_optimizer_args = OptimizerArgs(optimizer_class=optim.SGD) meta_model_args = ModelArgs(reptile_args.model, dummy_optimizer_args, num_classes=num_classes_per_client) #################### torch_model = OmniglotModel(num_classes=num_classes_per_client) torch_optimizer = inner_optimizer_args.optimizer_class( torch_model.parameters(), **inner_optimizer_args.optimizer_kwargs) torch_criterion = torch.nn.CrossEntropyLoss() #################### # Set up clients # Since we are doing meta-learning, we need separate sets of training and # test clients train_clients = [] for c in omniglot_train_clients.train_data_local_dict.keys(): client = ReptileClient( client_id=str(c), model_args=inner_model_args, context=context, train_dataloader=omniglot_train_clients.train_data_local_dict[c], num_train_samples=omniglot_train_clients. data_local_train_num_dict[c], test_dataloader=omniglot_train_clients.test_data_local_dict[c], num_test_samples=omniglot_train_clients. data_local_test_num_dict[c], lightning_logger=experiment_logger) #checkpoint_callback = ModelCheckpoint( # filepath=str(client.get_checkpoint_path(suffix='cb').absolute())) #client.set_trainer_callbacks([checkpoint_callback]) train_clients.append(client) test_clients = [] for c in omniglot_test_clients.train_data_local_dict.keys(): client = ReptileClient( client_id=str(c), model_args=inner_model_args, context=context, train_dataloader=omniglot_test_clients.train_data_local_dict[c], num_train_samples=omniglot_test_clients. data_local_train_num_dict[c], test_dataloader=omniglot_test_clients.test_data_local_dict[c], num_test_samples=omniglot_test_clients.data_local_test_num_dict[c], lightning_logger=experiment_logger) test_clients.append(client) # Set up server server = ReptileServer(model_state=copy.deepcopy(torch_model.state_dict()), participant_name='initial_server', model_args=meta_model_args, context=context, initial_model_state=initial_model_state) for i in range(reptile_args.num_meta_steps): meta_batch = [ train_clients[k] for k in cyclerange( i * reptile_args.meta_batch_size % len(train_clients), (i + 1) * reptile_args.meta_batch_size % len(train_clients), len(train_clients)) ] """reptile.train_step( meta_batch=meta_batch, meta_step_size=cur_meta_step_size )""" reptile_train_step( model=torch_model, optimizer=torch_optimizer, criterion=torch_criterion, aggregator=server, participants=meta_batch, inner_training_args=reptile_args.get_inner_training_args(), meta_training_args=reptile_args.get_meta_training_args( frac_done=i / reptile_args.num_meta_steps)) if i % eval_iters == 0: accuracies = [] # train-test set # Pick one train client at random and test on it client = [RANDOM.choice(train_clients)] reptile_train_step( model=torch_model, optimizer=torch_optimizer, criterion=torch_criterion, aggregator=server, participants=client, inner_training_args=reptile_args.get_inner_training_args( eval=True), evaluation_mode=True) inputs, labels = list(client[0]._test_dataloader)[0] torch_model.load_state_dict(client[0].model_state) test_preds = torch_model(inputs).argmax(dim=1) num_correct = int( sum([pred == label for pred, label in zip(test_preds, labels)])) ##############result = evaluate_local_models(participants=client) accuracies.append(num_correct / num_classes_per_client) # log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'), # experiment_logger, global_step=i+1) #experiment_logger.experiment.add_scalar('train-test/acc/{}/mean'.format('global_model'), # torch.mean(result.get('test/acc')), # global_step=i + 1) # test-test set # Pick one test client at random and test on it client = [RANDOM.choice(test_clients)] reptile_train_step( model=torch_model, optimizer=torch_optimizer, criterion=torch_criterion, aggregator=server, participants=client, inner_training_args=reptile_args.get_inner_training_args( eval=True), evaluation_mode=True) inputs, labels = list(client[0]._test_dataloader)[0] torch_model.load_state_dict(client[0].model_state) test_preds = torch_model(inputs).argmax(dim=1) num_correct = int( sum([pred == label for pred, label in zip(test_preds, labels)])) ##############result = evaluate_local_models(participants=client) accuracies.append(num_correct / num_classes_per_client) # log_loss_and_acc('global_model', result.get('test/loss'), result.get('test/acc'), # experiment_logger, global_step=i+1) print('batch %d: train=%f test=%f' % (i, accuracies[0], accuracies[1])) experiment_logger.experiment.add_scalar( 'train-test/acc/{}/mean'.format('global_model'), accuracies[0], global_step=i) experiment_logger.experiment.add_scalar( 'test-test/acc/{}/mean'.format('global_model'), accuracies[1], global_step=i) """
def run_reptile(context: ExperimentContext, initial_model_state=None): num_clients_train = 10000 num_clients_test = 1000 num_classes_per_client = 5 num_shots_per_class = 5 # Every *eval_iters* meta steps, evaluation is performed on one random # client in the training and test set, respectively eval_iters = 10 reptile_args = ReptileTrainingArgs(model=OmniglotLightning, inner_optimizer=optim.Adam, inner_learning_rate=0.001, num_inner_steps=5, num_inner_steps_eval=50, log_every_n_steps=3, inner_batch_size=10, meta_batch_size=5, meta_learning_rate_initial=1, meta_learning_rate_final=0, num_meta_steps=3000) experiment_logger = create_tensorboard_logger(context.name, ( f"nichol_reptile_our_dataloading;{num_clients_train}clients;{num_classes_per_client}" f"-way{num_shots_per_class}-shot;" f"mlr{str(reptile_args.meta_learning_rate_initial).replace('.', '')}" f"ilr{str(reptile_args.inner_learning_rate).replace('.', '')}" f"is{reptile_args.num_inner_steps}")) # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' tf.disable_eager_execution() omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=num_clients_train, num_clients_test=num_clients_test, num_classes_per_client=num_classes_per_client, num_shots_per_class=num_shots_per_class, inner_batch_size=reptile_args.inner_batch_size, tensorflow=True, random_seed=RANDOM_SEED) with tf.Session() as sess: model = OmniglotModel(num_classes_per_client, learning_rate=reptile_args.inner_learning_rate) reptile = Reptile(sess, transductive=True, pre_step_op=weight_decay(1)) accuracy_ph = tf.placeholder(tf.float32, shape=()) tf.summary.scalar('accuracy', accuracy_ph) merged = tf.summary.merge_all() tf.global_variables_initializer().run() sess.run(tf.global_variables_initializer()) for i in range(reptile_args.num_meta_steps): frac_done = i / reptile_args.num_meta_steps cur_meta_step_size = frac_done * reptile_args.meta_learning_rate_final + ( 1 - frac_done) * reptile_args.meta_learning_rate_initial meta_batch = { k: omniglot_train_clients.train_data_local_dict[k] for k in cyclerange( i * reptile_args.meta_batch_size % len(omniglot_train_clients.train_data_local_dict), (i + 1) * reptile_args.meta_batch_size % len(omniglot_train_clients.train_data_local_dict), len(omniglot_train_clients.train_data_local_dict)) } reptile.train_step(meta_batch=meta_batch, input_ph=model.input_ph, label_ph=model.label_ph, minimize_op=model.minimize_op, inner_iters=reptile_args.num_inner_steps, meta_step_size=cur_meta_step_size) if i % eval_iters == 0: accuracies = [] k = RANDOM.randrange( len(omniglot_train_clients.train_data_local_dict)) train_train = omniglot_train_clients.train_data_local_dict[k] train_test = omniglot_train_clients.test_data_local_dict[k] k = RANDOM.randrange( len(omniglot_test_clients.train_data_local_dict)) test_train = omniglot_test_clients.train_data_local_dict[k] test_test = omniglot_test_clients.test_data_local_dict[k] for train_dl, test_dl in [(train_train, train_test), (test_train, test_test)]: correct = reptile.evaluate( train_data_loader=train_dl, test_data_loader=test_dl, input_ph=model.input_ph, label_ph=model.label_ph, minimize_op=model.minimize_op, predictions=model.predictions, inner_iters=reptile_args.num_inner_steps_eval) #summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes_per_client}) accuracies.append(correct / num_classes_per_client) print('batch %d: train=%f test=%f' % (i, accuracies[0], accuracies[1])) # Write to TensorBoard experiment_logger.experiment.add_scalar( 'train-test/acc/{}/mean'.format('global_model'), accuracies[0], global_step=i) experiment_logger.experiment.add_scalar( 'test-test/acc/{}/mean'.format('global_model'), accuracies[1], global_step=i)
def run_reptile(context: str, initial_model_state=None): args = argument_parser().parse_args() RANDOM = random.Random(args.seed) experiment_logger = create_tensorboard_logger( 'reptile', (f"{context};seed{args.seed};" f"train-clients{args.train_clients};" f"{args.classes}-way{args.shots}-shot;" f"ib{args.inner_batch}ii{args.inner_iters}" f"ilr{str(args.learning_rate).replace('.', '')}" f"ms{str(args.meta_step).replace('.', '')}" f"mb{args.meta_batch}ei{args.eval_iters}" f"{'sgd' if args.sgd else 'adam'}")) # Load and prepare Omniglot data data_dir = REPO_ROOT / 'data' / 'omniglot' tf.disable_eager_execution() omniglot_train_clients, omniglot_test_clients = load_omniglot_datasets( str(data_dir.absolute()), num_clients_train=args.train_clients, num_clients_test=args.test_clients, num_classes_per_client=args.classes, num_shots_per_class=args.shots, inner_batch_size=args.inner_batch, tensorflow=True, random_seed=args.seed) model_kwargs = {'learning_rate': args.learning_rate} if args.sgd: model_kwargs['optimizer'] = tf.train.GradientDescentOptimizer model = OmniglotModel(num_classes=args.classes, **model_kwargs) with tf.Session() as sess: reptile = ReptileForFederatedData(session=sess, transductive=True, pre_step_op=weight_decay(1)) accuracy_ph = tf.placeholder(tf.float32, shape=()) tf.summary.scalar('accuracy', accuracy_ph) merged = tf.summary.merge_all() tf.global_variables_initializer().run() sess.run(tf.global_variables_initializer()) for i in range(args.meta_iters): frac_done = i / args.meta_iters cur_meta_step_size = frac_done * args.meta_step_final + ( 1 - frac_done) * args.meta_step crange = cyclerange( start=i * args.meta_batch % len(omniglot_train_clients.train_data_local_dict), stop=(i + 1) * args.meta_batch % len(omniglot_train_clients.train_data_local_dict), len=len(omniglot_train_clients.train_data_local_dict)) #print(f"Meta-step {i}: train clients {crange[0]}-{crange[-1]}") meta_batch = { k: omniglot_train_clients.train_data_local_dict[k] for k in crange } reptile.train_step(meta_batch=meta_batch, input_ph=model.input_ph, label_ph=model.label_ph, minimize_op=model.minimize_op, inner_iters=args.inner_iters, meta_step_size=cur_meta_step_size) if i % args.eval_interval == 0: accuracies = [] k = RANDOM.randrange( len(omniglot_train_clients.train_data_local_dict)) train_train = omniglot_train_clients.train_data_local_dict[k] train_test = omniglot_train_clients.test_data_local_dict[k] k = RANDOM.randrange( len(omniglot_test_clients.train_data_local_dict)) test_train = omniglot_test_clients.train_data_local_dict[k] test_test = omniglot_test_clients.test_data_local_dict[k] for train_dl, test_dl in [(train_train, train_test), (test_train, test_test)]: correct = reptile.evaluate(train_data_loader=train_dl, test_data_loader=test_dl, input_ph=model.input_ph, label_ph=model.label_ph, minimize_op=model.minimize_op, predictions=model.predictions, inner_iters=args.eval_iters) #summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes_per_client}) accuracies.append(correct / args.classes) print('batch %d: train=%f test=%f' % (i, accuracies[0], accuracies[1])) # Write to TensorBoard experiment_logger.experiment.add_scalar( 'train-test/acc/{}/mean'.format('global_model'), accuracies[0], global_step=i) experiment_logger.experiment.add_scalar( 'test-test/acc/{}/mean'.format('global_model'), accuracies[1], global_step=i) # Final evaluation on a sample of training/test clients for label, dataset in zip( ['Train', 'Test'], [omniglot_train_clients, omniglot_test_clients]): keys = RANDOM.sample(dataset.train_data_local_dict.keys(), args.eval_samples) train_eval_sample = { k: dataset.train_data_local_dict[k] for k in keys } test_eval_sample = { k: dataset.test_data_local_dict[k] for k in keys } accuracy = evaluate(sess=sess, model=model, train_dataloaders=train_eval_sample, test_dataloaders=test_eval_sample, num_classes=args.classes, eval_inner_iters=args.eval_iters, transductive=True) experiment_logger.experiment.add_scalar(f'final_{label}_acc', accuracy, global_step=0) print(f"{label} accuracy: {accuracy}")