Ejemplo n.º 1
0
    def _save_model(self, round_number, file_path):
        """
        Save the best or latest model.

        Args:
            round_number: int
                Model round to be saved
            file_path: str
                Either the best model or latest model file path

        Returns:
            None
        """
        # Extract the model from TensorDB and set it to the new model
        og_tensor_dict, _ = utils.deconstruct_model_proto(
            self.model, compression_pipeline=self.compression_pipeline)
        tensor_keys = [
            TensorKey(k, self.uuid, round_number, False, ('model', ))
            for k, v in og_tensor_dict.items()
        ]
        tensor_dict = {}
        for tk in tensor_keys:
            tk_name, _, _, _, _ = tk
            tensor_dict[tk_name] = self.tensor_db.get_tensor_from_cache(tk)
            if tensor_dict[tk_name] is None:
                self.logger.info('Cannot save model for round {}.'
                                 ' Continuing...'.format(round_number))
                return
        if file_path == self.best_state_path:
            self.best_tensor_dict = tensor_dict
        if file_path == self.last_state_path:
            self.last_tensor_dict = tensor_dict
        self.model = utils.construct_model_proto(tensor_dict, round_number,
                                                 self.compression_pipeline)
        utils.dump_proto(self.model, file_path)
Ejemplo n.º 2
0
def run_challenge_experiment(aggregation_function,
                             choose_training_collaborators,
                             training_hyper_parameters_for_round,
                             institution_split_csv_filename,
                             brats_training_data_parent_dir,
                             db_store_rounds=5,
                             rounds_to_train=5,
                             device='cpu',
                             save_checkpoints=True,
                             restore_from_checkpoint_folder=None,
                             include_validation_with_hausdorff=True,
                             use_pretrained_model=True):

    fx.init('fets_challenge_workspace')

    from sys import path, exit

    file = Path(__file__).resolve()
    root = file.parent.resolve()  # interface root, containing command modules
    work = Path.cwd().resolve()

    path.append(str(root))
    path.insert(0, str(work))

    # create gandlf_csv and get collaborator names
    gandlf_csv_path = os.path.join(work, 'gandlf_paths.csv')
    # split_csv_path = os.path.join(work, institution_split_csv_filename)
    collaborator_names = construct_fedsim_csv(brats_training_data_parent_dir,
                                              institution_split_csv_filename,
                                              0.8, gandlf_csv_path)

    aggregation_wrapper = CustomAggregationWrapper(aggregation_function)

    overrides = {
        'aggregator.settings.rounds_to_train': rounds_to_train,
        'aggregator.settings.db_store_rounds': db_store_rounds,
        'tasks.train.aggregation_type': aggregation_wrapper,
        'task_runner.settings.device': device,
    }

    # Update the plan if necessary
    plan = fx.update_plan(overrides)

    if not include_validation_with_hausdorff:
        plan.config['task_runner']['settings']['fets_config_dict'][
            'metrics'] = ['dice', 'dice_per_label']

    # Overwrite collaborator names
    plan.authorized_cols = collaborator_names
    # overwrite datapath values with the collaborator name itself
    for col in collaborator_names:
        plan.cols_data_paths[col] = col

    # get the data loaders for each collaborator
    collaborator_data_loaders = {
        col: copy(plan).get_data_loader(col)
        for col in collaborator_names
    }

    transformed_csv_dict = extract_csv_partitions(
        os.path.join(work, 'gandlf_paths.csv'))
    # get the task runner, passing the first data loader
    for col in collaborator_data_loaders:
        #Insert logic to serialize train / val CSVs here
        transformed_csv_dict[col]['train'].to_csv(
            os.path.join(work, 'seg_test_train.csv'))
        transformed_csv_dict[col]['val'].to_csv(
            os.path.join(work, 'seg_test_val.csv'))
        task_runner = copy(plan).get_task_runner(
            collaborator_data_loaders[col])

    if use_pretrained_model:
        print('Loading pretrained model...')
        if device == 'cpu':
            checkpoint = torch.load(
                f'{root}/pretrained_model/resunet_pretrained.pth',
                map_location=torch.device('cpu'))
            task_runner.model.load_state_dict(checkpoint['model_state_dict'])
            task_runner.optimizer.load_state_dict(
                checkpoint['optimizer_state_dict'])
        else:
            checkpoint = torch.load(
                f'{root}/pretrained_model/resunet_pretrained.pth')
            task_runner.model.load_state_dict(checkpoint['model_state_dict'])
            task_runner.optimizer.load_state_dict(
                checkpoint['optimizer_state_dict'])

    tensor_pipe = plan.get_tensor_pipe()

    # Initialize model weights
    init_state_path = plan.config['aggregator']['settings']['init_state_path']
    tensor_dict, _ = split_tensor_dict_for_holdouts(
        logger, task_runner.get_tensor_dict(False))

    model_snap = utils.construct_model_proto(tensor_dict=tensor_dict,
                                             round_number=0,
                                             tensor_pipe=tensor_pipe)

    utils.dump_proto(model_proto=model_snap, fpath=init_state_path)

    # get the aggregator, now that we have the initial weights file set up
    logger.info('Creating aggregator...')
    aggregator = plan.get_aggregator()
    # manually override the aggregator UUID (for checkpoint resume when rounds change)
    aggregator.uuid = 'aggregator'
    aggregator._load_initial_tensors()

    # create our collaborators
    logger.info('Creating collaborators...')
    collaborators = {
        col: copy(plan).get_collaborator(col,
                                         task_runner=task_runner,
                                         client=aggregator)
        for col in collaborator_names
    }

    collaborator_time_stats = gen_collaborator_time_stats(plan.authorized_cols)

    collaborators_chosen_each_round = {}
    collaborator_times_per_round = {}

    logger.info('Starting experiment')

    total_simulated_time = 0
    best_dice = -1.0
    best_dice_over_time_auc = 0

    # results dataframe data
    experiment_results = {
        'round': [],
        'time': [],
        'convergence_score': [],
        'round_dice': [],
        'dice_label_0': [],
        'dice_label_1': [],
        'dice_label_2': [],
        'dice_label_4': [],
    }
    if include_validation_with_hausdorff:
        experiment_results.update({
            'hausdorff95_label_0': [],
            'hausdorff95_label_1': [],
            'hausdorff95_label_2': [],
            'hausdorff95_label_4': [],
        })

    if restore_from_checkpoint_folder is None:
        checkpoint_folder = setup_checkpoint_folder()
        logger.info(f'\nCreated experiment folder {checkpoint_folder}...')
        starting_round_num = 0
    else:
        if not Path(f'checkpoint/{restore_from_checkpoint_folder}').exists():
            logger.warning(
                f'Could not find provided checkpoint folder: {restore_from_checkpoint_folder}. Exiting...'
            )
            exit(1)
        else:
            logger.info(
                f'Attempting to load last completed round from {restore_from_checkpoint_folder}'
            )
            state = load_checkpoint(restore_from_checkpoint_folder)
            checkpoint_folder = restore_from_checkpoint_folder

            [
                loaded_collaborator_names, starting_round_num,
                collaborator_time_stats, total_simulated_time, best_dice,
                best_dice_over_time_auc, collaborators_chosen_each_round,
                collaborator_times_per_round, experiment_results, summary,
                agg_tensor_db
            ] = state

            if loaded_collaborator_names != collaborator_names:
                logger.error(
                    f'Collaborator names found in checkpoint ({loaded_collaborator_names}) '
                    f'do not match provided collaborators ({collaborator_names})'
                )
                exit(1)

            logger.info(f'Previous summary for round {starting_round_num}')
            logger.info(summary)

            starting_round_num += 1
            aggregator.tensor_db.tensor_db = agg_tensor_db
            aggregator.round_number = starting_round_num

    for round_num in range(starting_round_num, rounds_to_train):
        # pick collaborators to train for the round
        training_collaborators = choose_training_collaborators(
            collaborator_names, aggregator.tensor_db._iterate(), round_num,
            collaborators_chosen_each_round, collaborator_times_per_round)

        logger.info('Collaborators chosen to train for round {}:\n\t{}'.format(
            round_num, training_collaborators))

        # save the collaborators chosen this round
        collaborators_chosen_each_round[round_num] = training_collaborators

        # get the hyper-parameters from the competitor
        hparams = training_hyper_parameters_for_round(
            collaborator_names, aggregator.tensor_db._iterate(), round_num,
            collaborators_chosen_each_round, collaborator_times_per_round)

        learning_rate, epochs_per_round, batches_per_round = hparams

        if (epochs_per_round is None) == (batches_per_round is None):
            logger.error(
                'Hyper-parameter function error: function must return "None" for either "epochs_per_round" or "batches_per_round" but not both.'
            )
            return

        hparam_message = "\n\tlearning rate: {}".format(learning_rate)

        # None gets mapped to -1 in the tensor_db
        if epochs_per_round is None:
            epochs_per_round = -1
            hparam_message += "\n\tbatches_per_round: {}".format(
                batches_per_round)
        elif batches_per_round is None:
            batches_per_round = -1
            hparam_message += "\n\tepochs_per_round: {}".format(
                epochs_per_round)

        logger.info("Hyper-parameters for round {}:{}".format(
            round_num, hparam_message))

        # cache each tensor in the aggregator tensor_db
        hparam_dict = {}
        tk = TensorKey(tensor_name='learning_rate',
                       origin=aggregator.uuid,
                       round_number=round_num,
                       report=False,
                       tags=('hparam', 'model'))
        hparam_dict[tk] = np.array(learning_rate)
        tk = TensorKey(tensor_name='epochs_per_round',
                       origin=aggregator.uuid,
                       round_number=round_num,
                       report=False,
                       tags=('hparam', 'model'))
        hparam_dict[tk] = np.array(epochs_per_round)
        tk = TensorKey(tensor_name='batches_per_round',
                       origin=aggregator.uuid,
                       round_number=round_num,
                       report=False,
                       tags=('hparam', 'model'))
        hparam_dict[tk] = np.array(batches_per_round)
        aggregator.tensor_db.cache_tensor(hparam_dict)

        # pre-compute the times for each collaborator
        times_per_collaborator = compute_times_per_collaborator(
            collaborator_names, training_collaborators, batches_per_round,
            epochs_per_round, collaborator_data_loaders,
            collaborator_time_stats, round_num)
        collaborator_times_per_round[round_num] = times_per_collaborator

        aggregator.assigner.set_training_collaborators(training_collaborators)

        # update the state in the aggregation wrapper
        aggregation_wrapper.set_state_data_for_round(
            collaborators_chosen_each_round, collaborator_times_per_round)

        # turn the times list into a list of tuples and sort it
        times_list = [(t, col) for col, t in times_per_collaborator.items()]
        times_list = sorted(times_list)

        # now call each collaborator in order of time
        # FIXME: this doesn't break up each task. We need this if we're doing straggler handling
        for t, col in times_list:
            # set the task_runner data loader
            task_runner.data_loader = collaborator_data_loaders[col]

            # run the collaborator
            collaborators[col].run_simulation()

            logger.info(
                "Collaborator {} took simulated time: {} minutes".format(
                    col, round(t / 60, 2)))

        # the round time is the max of the times_list
        round_time = max([t for t, _ in times_list])
        total_simulated_time += round_time

        # get the performace validation scores for the round
        round_dice = get_metric('valid_dice', round_num, aggregator.tensor_db)
        dice_label_0 = get_metric('valid_dice_per_label_0', round_num,
                                  aggregator.tensor_db)
        dice_label_1 = get_metric('valid_dice_per_label_1', round_num,
                                  aggregator.tensor_db)
        dice_label_2 = get_metric('valid_dice_per_label_2', round_num,
                                  aggregator.tensor_db)
        dice_label_4 = get_metric('valid_dice_per_label_4', round_num,
                                  aggregator.tensor_db)
        if include_validation_with_hausdorff:
            hausdorff95_label_0 = get_metric('valid_hd95_per_label_0',
                                             round_num, aggregator.tensor_db)
            hausdorff95_label_1 = get_metric('valid_hd95_per_label_1',
                                             round_num, aggregator.tensor_db)
            hausdorff95_label_2 = get_metric('valid_hd95_per_label_2',
                                             round_num, aggregator.tensor_db)
            hausdorff95_label_4 = get_metric('valid_hd95_per_label_4',
                                             round_num, aggregator.tensor_db)

        # update best score
        if best_dice < round_dice:
            best_dice = round_dice
            # Set the weights for the final model
            if round_num == 0:
                # here the initial model was validated (temp model does not exist)
                logger.info(
                    f'Skipping best model saving to disk as it is a random initialization.'
                )
            elif not os.path.exists(
                    f'checkpoint/{checkpoint_folder}/temp_model.pkl'):
                raise ValueError(
                    f'Expected temporary model at: checkpoint/{checkpoint_folder}/temp_model.pkl to exist but it was not found.'
                )
            else:
                # here the temp model was the one validated
                shutil.copyfile(
                    src=f'checkpoint/{checkpoint_folder}/temp_model.pkl',
                    dst=f'checkpoint/{checkpoint_folder}/best_model.pkl')
                logger.info(
                    f'Saved model with best average binary DICE: {best_dice} to ~/.local/workspace/checkpoint/{checkpoint_folder}/best_model.pkl'
                )

        ## RUN VALIDATION ON INTERMEDIATE CONSENSUS MODEL
        # set the task_runner data loader
        # task_runner.data_loader = collaborator_data_loaders[col]
        ### DELETE THIS LINE ###
        # print(f'Collaborator {col} training data count = {task_runner.data_loader.get_train_data_size()}')

        # run the collaborator
        #collaborators[col].run_simulation()

        ## CONVERGENCE METRIC COMPUTATION
        # update the auc score
        best_dice_over_time_auc += best_dice * round_time

        # project the auc score as remaining time * best dice
        # this projection assumes that the current best score is carried forward for the entire week
        projected_auc = (MAX_SIMULATION_TIME - total_simulated_time
                         ) * best_dice + best_dice_over_time_auc
        projected_auc /= MAX_SIMULATION_TIME

        # End of round summary
        summary = '"**** END OF ROUND {} SUMMARY *****"'.format(round_num)
        summary += "\n\tSimulation Time: {} minutes".format(
            round(total_simulated_time / 60, 2))
        summary += "\n\t(Projected) Convergence Score: {}".format(
            projected_auc)
        summary += "\n\tDICE Label 0: {}".format(dice_label_0)
        summary += "\n\tDICE Label 1: {}".format(dice_label_1)
        summary += "\n\tDICE Label 2: {}".format(dice_label_2)
        summary += "\n\tDICE Label 4: {}".format(dice_label_4)
        if include_validation_with_hausdorff:
            summary += "\n\tHausdorff95 Label 0: {}".format(
                hausdorff95_label_0)
            summary += "\n\tHausdorff95 Label 1: {}".format(
                hausdorff95_label_1)
            summary += "\n\tHausdorff95 Label 2: {}".format(
                hausdorff95_label_2)
            summary += "\n\tHausdorff95 Label 4: {}".format(
                hausdorff95_label_4)

        experiment_results['round'].append(round_num)
        experiment_results['time'].append(total_simulated_time)
        experiment_results['convergence_score'].append(projected_auc)
        experiment_results['round_dice'].append(round_dice)
        experiment_results['dice_label_0'].append(dice_label_0)
        experiment_results['dice_label_1'].append(dice_label_1)
        experiment_results['dice_label_2'].append(dice_label_2)
        experiment_results['dice_label_4'].append(dice_label_4)
        if include_validation_with_hausdorff:
            experiment_results['hausdorff95_label_0'].append(
                hausdorff95_label_0)
            experiment_results['hausdorff95_label_1'].append(
                hausdorff95_label_1)
            experiment_results['hausdorff95_label_2'].append(
                hausdorff95_label_2)
            experiment_results['hausdorff95_label_4'].append(
                hausdorff95_label_4)
        logger.info(summary)

        if save_checkpoints:
            logger.info(f'Saving checkpoint for round {round_num}')
            logger.info(
                f'To resume from this checkpoint, set the restore_from_checkpoint_folder parameter to \'{checkpoint_folder}\''
            )
            save_checkpoint(checkpoint_folder, aggregator, collaborator_names,
                            collaborators, round_num, collaborator_time_stats,
                            total_simulated_time, best_dice,
                            best_dice_over_time_auc,
                            collaborators_chosen_each_round,
                            collaborator_times_per_round, experiment_results,
                            summary)

        # if the total_simulated_time has exceeded the maximum time, we break
        # in practice, this means that the previous round's model is the last model scored,
        # so a long final round should not actually benefit the competitor, since that final
        # model is never globally validated
        if total_simulated_time > MAX_SIMULATION_TIME:
            logger.info("Simulation time exceeded. Ending Experiment")
            break

        # save the most recent aggregated model in native format to be copied over as best when appropriate
        # (note this model has not been validated by the collaborators yet)
        task_runner.rebuild_model(round_num,
                                  aggregator.last_tensor_dict,
                                  validation=True)
        task_runner.save_native(
            f'checkpoint/{checkpoint_folder}/temp_model.pkl')

    return pd.DataFrame.from_dict(experiment_results), checkpoint_folder
Ejemplo n.º 3
0
def run_experiment(collaborator_dict, override_config={}):
    """
    Core function that executes the FL Plan.

    Args:
        collaborator_dict : dict {collaborator_name(str): FederatedModel}
            This dictionary defines which collaborators will participate in the
            experiment, as well as a reference to that collaborator's
            federated model.
        override_config : dict {flplan.key : flplan.value}
            Override any of the plan parameters at runtime using this
            dictionary. To get a list of the available options, execute
            `fx.get_plan()`

    Returns:
        final_federated_model : FederatedModel
            The final model resulting from the federated learning experiment
    """
    from sys import path

    file = Path(__file__).resolve()
    root = file.parent.resolve()  # interface root, containing command modules
    work = Path.cwd().resolve()

    path.append(str(root))
    path.insert(0, str(work))

    # Update the plan if necessary
    if len(override_config) > 0:
        update_plan(override_config)

    # TODO: Fix this implementation. The full plan parsing is reused here,
    # but the model and data will be overwritten based on user specifications
    plan_config = 'plan/plan.yaml'
    cols_config = 'plan/cols.yaml'
    data_config = 'plan/data.yaml'

    plan = Plan.Parse(plan_config_path=Path(plan_config),
                      cols_config_path=Path(cols_config),
                      data_config_path=Path(data_config))

    # Overwrite plan values
    plan.authorized_cols = list(collaborator_dict)
    tensor_pipe = plan.get_tensor_pipe()

    # This must be set to the final index of the list (this is the last
    # tensorflow session to get created)
    plan.runner_ = list(collaborator_dict.values())[-1]
    model = plan.runner_

    # Initialize model weights
    init_state_path = plan.config['aggregator']['settings']['init_state_path']
    rounds_to_train = plan.config['aggregator']['settings']['rounds_to_train']
    tensor_dict, holdout_params = split_tensor_dict_for_holdouts(
        logger, plan.runner_.get_tensor_dict(False))

    model_snap = utils.construct_model_proto(tensor_dict=tensor_dict,
                                             round_number=0,
                                             tensor_pipe=tensor_pipe)

    logger.info(f'Creating Initial Weights File    🠆 {init_state_path}')

    utils.dump_proto(model_proto=model_snap, fpath=init_state_path)

    logger.info('Starting Experiment...')

    aggregator = plan.get_aggregator()

    model_states = {
        collaborator: None
        for collaborator in collaborator_dict.keys()
    }

    # Create the collaborators
    collaborators = {
        collaborator: create_collaborator(plan, collaborator, model,
                                          aggregator)
        for collaborator in plan.authorized_cols
    }

    for round_num in range(rounds_to_train):
        for col in plan.authorized_cols:

            collaborator = collaborators[col]
            model.set_data_loader(collaborator_dict[col].data_loader)

            if round_num != 0:
                model.rebuild_model(round_num, model_states[col])

            collaborator.run_simulation()

            model_states[col] = model.get_tensor_dict(with_opt_vars=True)

    # Set the weights for the final model
    model.rebuild_model(rounds_to_train - 1,
                        aggregator.last_tensor_dict,
                        validation=True)
    return model
Ejemplo n.º 4
0
    def __init__(self,
                 aggregator_uuid,
                 federation_uuid,
                 authorized_cols,
                 init_state_path,
                 best_state_path,
                 last_state_path,
                 assigner,
                 rounds_to_train=256,
                 single_col_cert_common_name=None,
                 compression_pipeline=None,
                 db_store_rounds=1,
                 **kwargs):
        """Initialize."""
        self.round_number = 0
        self.single_col_cert_common_name = single_col_cert_common_name

        if self.single_col_cert_common_name is not None:
            self._log_big_warning()
        else:
            # FIXME: '' instead of None is just for protobuf compatibility.
            # Cleaner solution?
            self.single_col_cert_common_name = ''

        self.rounds_to_train = rounds_to_train

        # if the collaborator requests a delta, this value is set to true
        self.authorized_cols = authorized_cols
        self.uuid = aggregator_uuid
        self.federation_uuid = federation_uuid
        self.assigner = assigner
        self.quit_job_sent_to = []

        self.tensor_db = TensorDB()
        self.db_store_rounds = db_store_rounds
        self.compression_pipeline = compression_pipeline \
            or NoCompressionPipeline()
        self.tensor_codec = TensorCodec(self.compression_pipeline)
        self.logger = getLogger(__name__)

        self.init_state_path = init_state_path
        self.best_state_path = best_state_path
        self.last_state_path = last_state_path

        self.best_tensor_dict: dict = {}
        self.last_tensor_dict: dict = {}

        self.best_model_score = None

        if kwargs.get('initial_tensor_dict', None) is not None:
            self._load_initial_tensors_from_dict(kwargs['initial_tensor_dict'])
            self.model = utils.construct_model_proto(
                tensor_dict=kwargs['initial_tensor_dict'],
                round_number=0,
                tensor_pipe=self.compression_pipeline)
        else:
            self.model: ModelProto = utils.load_proto(self.init_state_path)
            self._load_initial_tensors()  # keys are TensorKeys

        self.log_dir = f'logs/{self.uuid}_{self.federation_uuid}'
        # TODO use native tensorboard
        # self.tb_writer = tb.SummaryWriter(self.log_dir, flush_secs = 10)

        self.collaborator_tensor_results = {}  # {TensorKey: nparray}}

        # these enable getting all tensors for a task
        # {TaskResultKey: list of TensorKeys}
        self.collaborator_tasks_results = {}
        # {TaskResultKey: data_size}
        self.collaborator_task_weight = {}
Ejemplo n.º 5
0
    def fit(self):
        """Run the estimator."""
        import fastestimator as fe
        from fastestimator.trace.io.best_model_saver import BestModelSaver
        from sys import path

        file = Path(__file__).resolve()
        # interface root, containing command modules
        root = file.parent.resolve()
        work = Path.cwd().resolve()

        path.append(str(root))
        path.insert(0, str(work))

        # TODO: Fix this implementation. The full plan parsing is reused here,
        # but the model and data will be overwritten based on
        # user specifications
        plan_config = (Path(fx.WORKSPACE_PREFIX) / 'plan' / 'plan.yaml')
        cols_config = (Path(fx.WORKSPACE_PREFIX) / 'plan' / 'cols.yaml')
        data_config = (Path(fx.WORKSPACE_PREFIX) / 'plan' / 'data.yaml')

        plan = Plan.Parse(plan_config_path=plan_config,
                          cols_config_path=cols_config,
                          data_config_path=data_config)

        self.rounds = plan.config['aggregator']['settings']['rounds_to_train']
        data_loader = FastEstimatorDataLoader(self.estimator.pipeline)
        runner = FastEstimatorTaskRunner(self.estimator,
                                         data_loader=data_loader)
        # Overwrite plan values
        tensor_pipe = plan.get_tensor_pipe()
        # Initialize model weights
        init_state_path = plan.config['aggregator']['settings'][
            'init_state_path']
        tensor_dict, holdout_params = split_tensor_dict_for_holdouts(
            self.logger, runner.get_tensor_dict(False))

        model_snap = utils.construct_model_proto(tensor_dict=tensor_dict,
                                                 round_number=0,
                                                 tensor_pipe=tensor_pipe)

        self.logger.info(f'Creating Initial Weights File'
                         f'    🠆 {init_state_path}')

        utils.dump_proto(model_proto=model_snap, fpath=init_state_path)

        self.logger.info('Starting Experiment...')

        aggregator = plan.get_aggregator()

        model_states = {
            collaborator: None
            for collaborator in plan.authorized_cols
        }
        runners = {}
        save_dir = {}
        data_path = 1
        for col in plan.authorized_cols:
            data = self.estimator.pipeline.data
            train_data, eval_data, test_data = split_data(
                data['train'], data['eval'], data['test'], data_path,
                len(plan.authorized_cols))
            pipeline_kwargs = {}
            for k, v in self.estimator.pipeline.__dict__.items():
                if k in [
                        'batch_size', 'ops', 'num_process', 'drop_last',
                        'pad_value', 'collate_fn'
                ]:
                    pipeline_kwargs[k] = v
            pipeline_kwargs.update({
                'train_data': train_data,
                'eval_data': eval_data,
                'test_data': test_data
            })
            pipeline = fe.Pipeline(**pipeline_kwargs)

            data_loader = FastEstimatorDataLoader(pipeline)
            self.estimator.system.pipeline = pipeline

            runners[col] = FastEstimatorTaskRunner(estimator=self.estimator,
                                                   data_loader=data_loader)
            runners[col].set_optimizer_treatment('CONTINUE_LOCAL')

            for trace in runners[col].estimator.system.traces:
                if isinstance(trace, BestModelSaver):
                    save_dir_path = f'{trace.save_dir}/{col}'
                    os.makedirs(save_dir_path, exist_ok=True)
                    save_dir[col] = save_dir_path

            data_path += 1

        # Create the collaborators
        collaborators = {
            collaborator:
            fx.create_collaborator(plan, collaborator, runners[collaborator],
                                   aggregator)
            for collaborator in plan.authorized_cols
        }

        model = None
        for round_num in range(self.rounds):
            for col in plan.authorized_cols:

                collaborator = collaborators[col]

                if round_num != 0:
                    # For FastEstimator Jupyter notebook, models must be
                    # saved in different directories (i.e. path must be
                    # reset here)

                    runners[col].estimator.system.load_state(
                        f'save/{col}_state')
                    runners[col].rebuild_model(round_num, model_states[col])

                # Reset the save directory if BestModelSaver is present
                # in traces
                for trace in runners[col].estimator.system.traces:
                    if isinstance(trace, BestModelSaver):
                        trace.save_dir = save_dir[col]

                collaborator.run_simulation()

                model_states[col] = runners[col].get_tensor_dict(
                    with_opt_vars=True)
                model = runners[col].model
                runners[col].estimator.system.save_state(f'save/{col}_state')

        # TODO This will return the model from the last collaborator,
        #  NOT the final aggregated model (though they should be similar).
        # There should be a method added to the aggregator that will load
        # the best model from disk and return it
        return model
Ejemplo n.º 6
0
def initialize(context, plan_config, cols_config, data_config,
               aggregator_address, feature_shape):
    """
    Initialize Data Science plan.

    Create a protocol buffer file of the initial model weights for
     the federation.
    """
    plan = Plan.Parse(plan_config_path=Path(plan_config),
                      cols_config_path=Path(cols_config),
                      data_config_path=Path(data_config))

    init_state_path = plan.config['aggregator']['settings']['init_state_path']

    # TODO:  Is this part really needed?  Why would we need to collaborator
    #  name to know the input shape to the model?

    # if  feature_shape is None:
    #     if  cols_config is None:
    #         exit('You must specify either a feature
    #         shape or authorized collaborator
    #         list in order for the script to determine the input layer shape')
    print(plan.cols_data_paths)

    collaborator_cname = list(plan.cols_data_paths)[0]

    # else:

    #     logger.info(f'Using data object of type {type(data)}
    #     and feature shape {feature_shape}')
    #     raise NotImplementedError()

    # data_loader = plan.get_data_loader(collaborator_cname)
    # task_runner = plan.get_task_runner(collaborator_cname)

    data_loader = plan.get_data_loader(collaborator_cname)
    task_runner = plan.get_task_runner(data_loader)
    tensor_pipe = plan.get_tensor_pipe()

    # I believe there is no need for this line as task_runner has this variable
    # initialized with empty dict tensor_dict_split_fn_kwargs =
    # task_runner.tensor_dict_split_fn_kwargs or {}
    tensor_dict, holdout_params = split_tensor_dict_for_holdouts(
        logger, task_runner.get_tensor_dict(False),
        **task_runner.tensor_dict_split_fn_kwargs)

    logger.warn(f'Following parameters omitted from global initial model, '
                f'local initialization will determine'
                f' values: {list(holdout_params.keys())}')

    model_snap = utils.construct_model_proto(tensor_dict=tensor_dict,
                                             round_number=0,
                                             tensor_pipe=tensor_pipe)

    logger.info(f'Creating Initial Weights File    🠆 {init_state_path}')

    utils.dump_proto(model_proto=model_snap, fpath=init_state_path)

    plan_origin = Plan.Parse(Path(plan_config), resolve=False).config

    if (plan_origin['network']['settings']['agg_addr'] == 'auto'
            or aggregator_address):
        plan_origin['network']['settings'] = plan_origin['network'].get(
            'settings', {})
        plan_origin['network']['settings']['agg_addr'] =\
            aggregator_address or getfqdn()

        logger.warn(f"Patching Aggregator Addr in Plan"
                    f" 🠆 {plan_origin['network']['settings']['agg_addr']}")

        Plan.Dump(Path(plan_config), plan_origin)

    plan.config = plan_origin

    # Record that plan with this hash has been initialized
    if 'plans' not in context.obj:
        context.obj['plans'] = []
    context.obj['plans'].append(f"{Path(plan_config).stem}_{plan.hash[:8]}")
    logger.info(f"{context.obj['plans']}")