def save_aggregated_model(self):
        if self.save_all_models_path is None:
            return

        dir_path = self.ensure_save_all_path_exists()

        dump_proto(self.model, os.path.join(dir_path, "aggregated.pbuf"))
    def save_local_update(self, collaborator, update):
        if self.save_all_models_path is None:
            return

        # FIXME: better user experience would be good
        # hash the collaborator name so we can ensure directory names are legal
        md5 = hashlib.md5()
        md5.update(collaborator.encode())
        hashed_col = md5.hexdigest()[:8]

        dir_path = self.ensure_save_all_path_exists()

        dump_proto(update, os.path.join(dir_path,
                                        "{}.pbuf".format(hashed_col)))
    def end_of_round(self):
        """Runs required tasks when the training round has ended.
        """
        # FIXME: what all should we do to track results/metrics? It should really be an easy, extensible solution

        # compute the weighted loss average
        round_loss = self.get_weighted_average_of_collaborators(
            self.per_col_round_stats["loss_results"],
            self.per_col_round_stats["collaborator_training_sizes"])

        # compute the weighted validation average
        round_val = self.get_weighted_average_of_collaborators(
            self.per_col_round_stats["agg_validation_results"],
            self.per_col_round_stats["collaborator_validation_sizes"])

        # FIXME: is it correct to put this in the metadata?
        self.metadata_for_round.update({
            'loss':
            round_loss,
            'round_{}_validation'.format(self.round_num - 1):
            round_val
        })

        # FIXME: proper logging
        self.logger.info('round results for model id/version {}/{}'.format(
            self.model.header.id, self.model.header.version))
        self.logger.info('\tvalidation: {}'.format(round_val))
        self.logger.info('\tloss: {}'.format(round_loss))

        # construct the model protobuf from in progress tensors (with incremented version number)
        self.model = construct_proto(
            tensor_dict=self.model_update_in_progress["tensor_dict"],
            model_id=self.model.header.id,
            model_version=self.model.header.version + 1,
            is_delta=self.model_update_in_progress["is_delta"],
            delta_from_version=self.
            model_update_in_progress["delta_from_version"],
            compression_pipeline=self.compression_pipeline)

        # add end of round metadata
        self.metadata_for_round.update(self.get_end_of_round_metadata())

        # add the metadata for this round to the total metadata file
        self.metadata['round {}'.format(
            self.round_num)] = self.metadata_for_round
        self.metadata_for_round = {}

        self.logger.info("Metadata:\n{}".format(yaml.dump(self.metadata)))

        if self.latest_metadata_fname is not None:
            with open(self.latest_metadata_fname, 'w') as f:
                f.write(yaml.dump(self.metadata))
            self.logger.info("Wrote metadata to {}".format(
                self.latest_metadata_fname))

        # Save the new model as latest model.
        dump_proto(self.model, self.latest_model_fpath)

        # if configured, also save to the backup location
        if self.save_all_models_path is not None:
            self.save_aggregated_model()

        # in case that round_val is a dictionary (asuming one level only), basing best model on average of inner value
        if isinstance(round_val, dict):
            model_score = np.average(list(round_val.values()))
        else:
            model_score = round_val
        if self.best_model_score is None or self.best_model_score < model_score:
            self.logger.info(
                "Saved the best model with score {:f}.".format(model_score))
            self.best_model_score = model_score
            # Save a model proto version to file as current best model.
            dump_proto(self.model, self.best_model_fpath)
            self.aggregated_model_is_global_best = True
        else:
            self.aggregated_model_is_global_best = False

        # clear the update pointer
        self.model_update_in_progress = None

        # if we have enabled runtime configuration updates, do that now
        if self.runtime_aggregator_config_dir is not None:
            self.update_config_from_filesystem()

        self.init_per_col_round_stats()

        self.round_num += 1
        self.logger.debug("Start a new round %d." % self.round_num)
        self.round_start_time = None

        self._do_quit = self._GRACEFULLY_QUIT
def main(plan, native_model_weights_filepath, collaborators_file, feature_shape, n_classes, data_config_fname, logging_config_path, logging_default_level, model_device):
    """Creates a protobuf file of the initial weights for the model

    Uses the federation (FL) plan to create an initial weights file
    for the federation.

    Args:
        plan: The federation (FL) plan filename
        native_model_weights_filepath: A framework-specific filepath. Path will be relative to the working directory.
        collaborators_file:
        feature_shape: The input shape to the model
        data_config_fname: The data configuration file (defines where the datasets are located)
        logging_config_path: The log path
        logging_default_level (int): The default log level

    """

    setup_logging(path=logging_config_path, default_level=logging_default_level)

    logger = logging.getLogger(__name__)

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    weights_dir = os.path.join(base_dir, 'weights')

    # ensure the weights dir exists
    if not os.path.exists(weights_dir):
        print('creating folder:', weights_dir)
        os.makedirs(weights_dir)

    # parse the plan and local config
    flplan = parse_fl_plan(os.path.join(plan_dir, plan))
    local_config = load_yaml(os.path.join(base_dir, data_config_fname))

    # get the output filename
    fpath = os.path.join(weights_dir, flplan['aggregator_object_init']['init_kwargs']['init_model_fname'])

    # create the data object for models whose architecture depends on the feature shape
    if feature_shape is None:
        if collaborators_file is None:
            sys.exit("You must specify either a feature shape or a collaborator list in order for the script to determine the input layer shape")
        # FIXME: this will ultimately run in a governor environment and should not require any data to work
        # pick the first collaborator to create the data and model (could be any)
        collaborator_common_name = load_yaml(os.path.join(base_dir, 'collaborator_lists', collaborators_file))['collaborator_common_names'][0]
        data = create_data_object(flplan, collaborator_common_name, local_config, n_classes=n_classes)
    else:
        data = get_object('openfl.data.dummy.randomdata', 'RandomData', feature_shape=feature_shape)
        logger.info('Using data object of type {} and feature shape {}'.format(type(data), feature_shape))

    # create the model object and compression pipeline
    wrapped_model = create_model_object(flplan, data, model_device=model_device)
    compression_pipeline = create_compression_pipeline(flplan)

    # determine if we need to store the optimizer variables
    # FIXME: what if this key is missing?
    try:
        opt_treatment = OptTreatment[flplan['collaborator_object_init']['init_kwargs']['opt_treatment']]
    except KeyError:
        # FIXME: this error message should use the exception to determine the missing key and the Enum to display the options dynamically
        sys.exit("FL plan must specify ['collaborator_object_init']['init_kwargs']['opt_treatment'] as [RESET|CONTINUE_LOCAL|CONTINUE_GLOBAL]")

    # FIXME: this should be an "opt_treatment requires parameters type check rather than a magic string"
    with_opt_vars = opt_treatment == OptTreatment['CONTINUE_GLOBAL']

    if native_model_weights_filepath is not None:
        wrapped_model.load_native(native_model_weights_filepath)
    
    tensor_dict_split_fn_kwargs = wrapped_model.tensor_dict_split_fn_kwargs or {}

    tensor_dict, holdout_params = split_tensor_dict_for_holdouts(logger,
                                                                 wrapped_model.get_tensor_dict(with_opt_vars=with_opt_vars),
                                                                 **tensor_dict_split_fn_kwargs)
    logger.warn('Following paramters omitted from global initial model, '\
                'local initialization will determine values: {}'.format(list(holdout_params.keys())))

    model_proto = construct_proto(tensor_dict=tensor_dict,
                                  model_id=wrapped_model.__class__.__name__,
                                  model_version=0,
                                  is_delta=False,
                                  delta_from_version=-1,
                                  compression_pipeline=compression_pipeline)

    dump_proto(model_proto=model_proto, fpath=fpath)

    logger.info("Created initial weights file: {}".format(fpath))