def main(plan, out_filepath):

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')

    flplan = yaml.dump(parse_fl_plan(os.path.join(plan_dir, plan)))

    print(flplan)
    if out_filepath is not None:
        with open(out_filepath, 'w') as f:
            f.write(flplan)
Beispiel #2
0
def main(plan, collaborators_file, data_config_fname, logging_config_path, logging_default_level, logging_directory, model_device, **kwargs):
    """Run the federation simulation from the federation (FL) plan.

    Runs a federated training from the federation (FL) plan but creates the
    aggregator and collaborators on the same compute node. This allows
    the developer to test the model and data loaders before running
    on the remote collaborator nodes.

    Args:
        plan: The Federation (FL) plan (YAML file)
        collaborators_file: The file listing the collaborators
        data_config_fname: The file describing where the dataset is located on the collaborators
        logging_config_path: The log file
        logging_default_level: The log level
        **kwargs: Variable parameters to pass to the function

    """
    # FIXME: consistent filesystem (#15)
    # establish location for fl plan as well as
    # where to get and write model protobufs
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    weights_dir = os.path.join(base_dir, 'weights')
    metadata_dir = os.path.join(base_dir, 'metadata')
    collaborators_dir = os.path.join(base_dir, 'collaborator_lists')
    logging_config_path = os.path.join(script_dir, logging_config_path)
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory)

    # load the flplan, local_config and collaborators file
    flplan = parse_fl_plan(os.path.join(plan_dir, plan))
    local_config = load_yaml(os.path.join(base_dir, data_config_fname))
    collaborator_common_names = load_yaml(os.path.join(collaborators_dir, collaborators_file))['collaborator_common_names']
  
    # TODO: Run a loop here over various parameter values and iterations
    # TODO: implement more than just saving init, best, and latest model
    federate(flplan,
             local_config,
             collaborator_common_names,
             base_dir,
             weights_dir,
             metadata_dir,
             model_device)
Beispiel #3
0
def main(plan, collaborators_file, single_col_cert_common_name,
         logging_config_path, logging_default_level, logging_directory, resume,
         script_dir):
    """Runs the aggregator service from the Federation (FL) plan

    Args:
        plan: The Federation (FL) plan
        collaborators_file: The file listing the collaborators
        single_col_cert_common_name: The SSL certificate
        logging_config_path: The log configuration file
        logging_default_level: The log level
        resume: Whether the aggregator should load the latest model instead of the initial model
        script_dir: default None uses the script dir. Otherwise, use the directory passed as the script dir
    """
    # FIXME: consistent filesystem (#15)
    if script_dir is None:
        script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    weights_dir = os.path.join(base_dir, 'weights')
    collaborators_dir = os.path.join(base_dir, 'collaborator_lists')
    metadata_dir = os.path.join(base_dir, 'metadata')
    logging_config_path = os.path.join(script_dir, logging_config_path)
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path,
                  default_level=logging_default_level,
                  logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))
    collaborator_common_names = load_yaml(
        os.path.join(collaborators_dir,
                     collaborators_file))['collaborator_common_names']

    agg = create_aggregator_object_from_flplan(flplan,
                                               collaborator_common_names,
                                               single_col_cert_common_name,
                                               base_dir, weights_dir,
                                               metadata_dir, resume)
    server = create_aggregator_server_from_flplan(agg, flplan)
    serve_kwargs = get_serve_kwargs_from_flpan(flplan, base_dir)

    server.serve(**serve_kwargs)
def main(plan, collaborators_file, single_col_cert_common_name,
         logging_config_path, logging_default_level, logging_directory):
    """Runs the aggregator service from the Federation (FL) plan

    Args:
        plan: The Federation (FL) plan
        collaborators_file: The file listing the collaborators
        single_col_cert_common_name: The SSL certificate
        logging_config_path: The log configuration file
        logging_default_level: The log level

    """
    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    weights_dir = os.path.join(base_dir, 'weights')
    collaborators_dir = os.path.join(base_dir, 'collaborator_lists')
    metadata_dir = os.path.join(base_dir, 'metadata')
    logging_config_path = os.path.join(script_dir, logging_config_path)
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path,
                  default_level=logging_default_level,
                  logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))
    collaborator_common_names = load_yaml(
        os.path.join(collaborators_dir,
                     collaborators_file))['collaborator_common_names']

    agg = create_aggregator_object_from_flplan(flplan,
                                               collaborator_common_names,
                                               single_col_cert_common_name,
                                               weights_dir, metadata_dir)
    server = create_aggregator_server_from_flplan(agg, flplan)
    serve_kwargs = get_serve_kwargs_from_flpan(flplan, base_dir)

    server.serve(**serve_kwargs)
def main(plan,
         model_weights_filename,
         native_model_weights_filepath,
         populate_weights_at_init,
         model_file_argument_name,
         data_dir,
         logging_config_path,
         logging_default_level,
         logging_directory,
         model_device,
         inference_patient=None):
    """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan

    Args:
        plan (string)                           : The filename for the federation (FL) plan YAML file
        model_weights_filename (string)         : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!!
        native_model_weights_filepath (string)  : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename)
        populate_weights_at_init (boolean)      : Whether or not the model populates its own weights at instantiation
        model_file_argument_name (string)       : Name of argument to be passed to model __init__ providing model file location info
        data_dir (string)                       : The directory path for the parent directory containing the data. Path will be relative to the working directory.
        logging_config_fname (string)           : The log file
        logging_default_level (string)          : The log level
        inference_patient (string)              : Subdirectory of single patient to run inference on (exclusively)

    """

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path,
                  default_level=logging_default_level,
                  logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # check the inference config
    if 'inference' not in flplan:
        sys.exit(
            "FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled."
        )

    if 'allowed' not in flplan[
            'inference'] or flplan['inference']['allowed'] != True:
        sys.exit(
            "FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed."
        )

    # create the data object
    data = create_data_object_with_explicit_data_path(
        flplan=flplan, data_path=data_dir, inference_patient=inference_patient)

    # TODO: Find a good way to detect and communicate mishandling of model_file_argument_name
    #       Ie, capture exception of model not gettinng its required kwarg for this purpose, also
    #       how to tell if the model is using its random initialization rather than weights from file?
    if populate_weights_at_init:
        # Supplementing the flplan base model kwargs to include model weights file info.
        if model_weights_filename is not None:
            model_file_argument_name = model_file_argument_name or 'model_weights_filename'
            flplan['model_object_init']['init_kwargs'].update(
                {model_file_argument_name: model_weights_filename})
        # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser)
        else:
            model_file_argument_name = model_file_argument_name or 'native_model_weights_filepath'
            flplan['model_object_init']['init_kwargs'].update(
                {model_file_argument_name: native_model_weights_filepath})

    # create the base model object
    model = create_model_object(flplan, data, model_device=model_device)

    # the model may have an 'infer_volume' method instead of 'infer_batch'
    if not hasattr(model, 'infer_batch'):
        if hasattr(model, 'infer_volume'):
            model = InferenceOnlyModelWrapper(data=data, base_model=model)
        elif not hasattr(model, 'run_inference_and_store_results'):
            sys.exit(
                "If model object does not have a 'run_inference_and_store_results' method, it must have either an 'infer_batch' or 'infer_volume' method."
            )

    if not populate_weights_at_init:
        # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline
        if model_weights_filename is not None:
            proto_path = os.path.join(base_dir, 'weights',
                                      model_weights_filename)
            proto = load_proto(proto_path)
            tensor_dict_from_proto = deconstruct_proto(proto,
                                                       NoCompressionPipeline())

            # restore any tensors held out from the proto
            _, holdout_tensors = remove_and_save_holdout_tensors(
                model.get_tensor_dict())
            tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}

            model.set_tensor_dict(tensor_dict, with_opt_vars=False)

        # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser)
        else:
            # FIXME: how do we handle kwargs here? Will they come from the flplan?
            model.load_native(native_model_weights_filepath)

    # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block
    inference_kwargs = flplan['inference'].get('kwargs') or {}
    model.run_inference_and_store_results(**inference_kwargs)
Beispiel #6
0
def main(plan, collaborator_common_name, single_col_cert_common_name,
         data_config_fname, data_dir, validate_without_patches_flag,
         data_in_memory_flag, data_queue_max_length, data_queue_num_workers,
         torch_threads, kmp_affinity_flag, logging_config_path,
         logging_default_level, logging_directory, model_device):
    """Runs the collaborator client process from the federation (FL) plan

    Args:
        plan                            : The filename for the federation (FL) plan YAML file
        collaborator_common_name        : The common name for the collaborator node
        single_col_cert_common_name     : The SSL certificate for this collaborator
        data_config_fname               : The dataset configuration filename (YAML)
        data_dir                        : parent directory holding the patient data subdirectories(to be split into train and val)
        validate_without_patches_flag   : controls a model init kwarg
        data_in_memory_flag             : controls a data init kwarg 
        data_queue_max_length           : controls a data init kwarg 
        data_queue_num_workers          : controls a data init kwarg
        torch_threads                   : model init kwarg
        kmp_affinity_flag               : controls a model init kwarg
        logging_config_fname            : The log file
        logging_default_level           : The log level
        model_device                    : gets passed to model 'init' function as "device"
    """
    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    weights_dir = os.path.join(base_dir, 'weights')
    metadata_dir = os.path.join(base_dir, 'metadata')
    logging_config_path = os.path.join(script_dir, logging_config_path)
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path,
                  default_level=logging_default_level,
                  logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # FIXME: Find a better solution for passing model and data init kwargs
    model_init_kwarg_keys = [
        'validate_without_patches', 'torch_threads', 'kmp_affinity'
    ]
    model_init_kwarg_vals = [
        validate_without_patches_flag, torch_threads, kmp_affinity_flag
    ]
    for key, value in zip(model_init_kwarg_keys, model_init_kwarg_vals):
        if (value is not None) and (value != False):
            flplan['model_object_init']['init_kwargs'][key] = value

    data_init_kwarg_keys = ['in_memory', 'q_max_length', 'q_num_workers']
    data_init_kwarg_vals = [
        data_in_memory_flag, data_queue_max_length, data_queue_num_workers
    ]
    for key, value in zip(data_init_kwarg_keys, data_init_kwarg_vals):
        if (value is not None) and (value != False):
            flplan['data_object_init']['init_kwargs'][key] = value

    local_config = load_yaml(os.path.join(base_dir, data_config_fname))

    try:
        collaborator = create_collaborator_object_from_flplan(
            flplan,
            collaborator_common_name,
            local_config,
            base_dir,
            weights_dir,
            metadata_dir,
            single_col_cert_common_name,
            data_dir=data_dir,
            model_device=model_device)

        collaborator.run()
        sys.exit(0)
    except Exception as e:
        logging.getLogger(__name__).exception(repr(e))
        # this is for Sarthak
        sys.exit(666)
def main(plan, native_model_weights_filepath, collaborators_file, feature_shape, n_classes, data_config_fname, logging_config_path, logging_default_level, model_device):
    """Creates a protobuf file of the initial weights for the model

    Uses the federation (FL) plan to create an initial weights file
    for the federation.

    Args:
        plan: The federation (FL) plan filename
        native_model_weights_filepath: A framework-specific filepath. Path will be relative to the working directory.
        collaborators_file:
        feature_shape: The input shape to the model
        data_config_fname: The data configuration file (defines where the datasets are located)
        logging_config_path: The log path
        logging_default_level (int): The default log level

    """

    setup_logging(path=logging_config_path, default_level=logging_default_level)

    logger = logging.getLogger(__name__)

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    weights_dir = os.path.join(base_dir, 'weights')

    # ensure the weights dir exists
    if not os.path.exists(weights_dir):
        print('creating folder:', weights_dir)
        os.makedirs(weights_dir)

    # parse the plan and local config
    flplan = parse_fl_plan(os.path.join(plan_dir, plan))
    local_config = load_yaml(os.path.join(base_dir, data_config_fname))

    # get the output filename
    fpath = os.path.join(weights_dir, flplan['aggregator_object_init']['init_kwargs']['init_model_fname'])

    # create the data object for models whose architecture depends on the feature shape
    if feature_shape is None:
        if collaborators_file is None:
            sys.exit("You must specify either a feature shape or a collaborator list in order for the script to determine the input layer shape")
        # FIXME: this will ultimately run in a governor environment and should not require any data to work
        # pick the first collaborator to create the data and model (could be any)
        collaborator_common_name = load_yaml(os.path.join(base_dir, 'collaborator_lists', collaborators_file))['collaborator_common_names'][0]
        data = create_data_object(flplan, collaborator_common_name, local_config, n_classes=n_classes)
    else:
        data = get_object('openfl.data.dummy.randomdata', 'RandomData', feature_shape=feature_shape)
        logger.info('Using data object of type {} and feature shape {}'.format(type(data), feature_shape))

    # create the model object and compression pipeline
    wrapped_model = create_model_object(flplan, data, model_device=model_device)
    compression_pipeline = create_compression_pipeline(flplan)

    # determine if we need to store the optimizer variables
    # FIXME: what if this key is missing?
    try:
        opt_treatment = OptTreatment[flplan['collaborator_object_init']['init_kwargs']['opt_treatment']]
    except KeyError:
        # FIXME: this error message should use the exception to determine the missing key and the Enum to display the options dynamically
        sys.exit("FL plan must specify ['collaborator_object_init']['init_kwargs']['opt_treatment'] as [RESET|CONTINUE_LOCAL|CONTINUE_GLOBAL]")

    # FIXME: this should be an "opt_treatment requires parameters type check rather than a magic string"
    with_opt_vars = opt_treatment == OptTreatment['CONTINUE_GLOBAL']

    if native_model_weights_filepath is not None:
        wrapped_model.load_native(native_model_weights_filepath)
    
    tensor_dict_split_fn_kwargs = wrapped_model.tensor_dict_split_fn_kwargs or {}

    tensor_dict, holdout_params = split_tensor_dict_for_holdouts(logger,
                                                                 wrapped_model.get_tensor_dict(with_opt_vars=with_opt_vars),
                                                                 **tensor_dict_split_fn_kwargs)
    logger.warn('Following paramters omitted from global initial model, '\
                'local initialization will determine values: {}'.format(list(holdout_params.keys())))

    model_proto = construct_proto(tensor_dict=tensor_dict,
                                  model_id=wrapped_model.__class__.__name__,
                                  model_version=0,
                                  is_delta=False,
                                  delta_from_version=-1,
                                  compression_pipeline=compression_pipeline)

    dump_proto(model_proto=model_proto, fpath=fpath)

    logger.info("Created initial weights file: {}".format(fpath))
Beispiel #8
0
def main(data_path,
         plan_path,
         model_weights_path,
         output_pardir,
         model_output_tag,
         device,
         legacy_model_flag=False):

    # TODO: We do not currently make use of the ability for brainmage to infer by first cropping external
    #       zero planes, or inference by patching and fusing.

    flplan = parse_fl_plan(plan_path)

    # make sure the class list we are using is compatible with the hard-coded class_label_map above
    if flplan['data_object_init']['init_kwargs']['class_list'] != class_list:
        raise ValueError('We currently only support class_list=', class_list)

    # construct the data object
    data = create_data_object_with_explicit_data_path(flplan=flplan,
                                                      data_path=data_path)

    # code is written with assumption we are using the gandlf data object
    if not issubclass(data.__class__, GANDLFData):
        raise ValueError(
            'This script is currently assumed to be using a child of fets.data.pytorch.gandlf_data.GANDLFData, you are using: ',
            data.__class__.__name__)

    # construct the model object (requires cpu since we're passing [padded] whole brains)
    model = create_model_object(flplan=flplan,
                                data_object=data,
                                model_device=device)

    # code is written with assumption we are using the brainmage object
    if not issubclass(model.__class__, BrainMaGeModel):
        raise ValueError(
            'This script is currently assumed to be using a child of fets.models.pytorch.brainmage.BrainMaGeModel, you are using: ',
            data.__class__.__name__)

    # legacy models are defined in a single file, newer ones have a folder that holds per-layer files
    if legacy_model_flag:
        tensor_dict_from_proto = load_legacy_model_protobuf(model_weights_path)
    else:
        tensor_dict_from_proto = load_model(model_weights_path)

    # restore any tensors held out from the proto
    _, holdout_tensors = split_tensor_dict_for_holdouts(
        None, model.get_tensor_dict())
    tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}
    model.set_tensor_dict(tensor_dict, with_opt_vars=False)

    print("\nWill be running inference on {} validation samples.\n".format(
        model.get_validation_data_size()))

    if not os.path.exists(output_pardir):
        os.mkdir(output_pardir)

    subdir_to_DICE = {}
    dice_outpath = None

    for subject in data.get_val_loader():
        first_mode_path = subject['1']['path'][
            0]  # using this because this is only one that's always defined
        subfolder = first_mode_path.split('/')[-2]

        #prep the path for the output files
        output_subdir = os.path.join(output_pardir, subfolder)
        if not os.path.exists(output_subdir):
            os.mkdir(output_subdir)
        inference_outpath = os.path.join(
            output_subdir, subfolder + model_output_tag + '_seg.nii.gz')
        if dice_outpath == None:
            dice_outpath = os.path.join(
                output_pardir, model_output_tag + '_subdirs_to_DICE.pkl')

        if not is_mask_present(subject):
            raise ValueError(
                'We are expecting to run this on subjects that have labels.')

        label_path = subject['label']['path'][0]
        label_file = label_path.split('/')[-1]
        subdir_name = label_path.split('/')[-2]

        # copy the label file over to the output subdir
        copy_label_path = os.path.join(output_subdir, label_file)
        shutil.copyfile(label_path, copy_label_path)

        features, ground_truth = subject_to_feature_and_label(subject=subject,
                                                              pad_z=pad_z)

        output = infer(model, features)

        # FIXME: Find a better solution
        # crop away the padding we put in
        output = output[:, :, :, :, :155]

        print(
            one_hot(segmask_array=ground_truth, class_list=class_list).shape,
            output.shape)

        # get the DICE score
        dice_dict = clinical_dice(output=output,
                                  target=one_hot(segmask_array=ground_truth,
                                                 class_list=class_list),
                                  class_list=class_list,
                                  to_scalar=True)

        subdir_to_DICE[subdir_name] = dice_dict

        output = np.squeeze(output.cpu().numpy())

        # GANDLFData loader produces transposed output from what sitk gets from file, so transposing here.
        output = np.transpose(output, [0, 3, 2, 1])

        # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map
        output = new_labels_from_float_output(array=output,
                                              class_label_map=class_label_map,
                                              binary_classification=False)

        # convert array to SimpleITK image
        image = sitk.GetImageFromArray(output)

        image.CopyInformation(sitk.ReadImage(first_mode_path))

        print("\nWriting inference NIfTI image of shape {} to {}".format(
            output.shape, inference_outpath))
        sitk.WriteImage(image, inference_outpath)
        print("\nCorresponding DICE scores were: ")
        print("{}\n\n".format(dice_dict))

    print("Saving subdir_name_to_DICE at: ", dice_outpath)
    with open(dice_outpath, 'wb') as _file:
        pkl.dump(subdir_to_DICE, _file)
def main(plan, model_weights_filename, native_model_weights_filepath, data_dir, logging_config_path, logging_default_level, logging_directory, model_device):
    """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan

    Args:
        plan (string)                   : The filename for the federation (FL) plan YAML file
        model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!!
        native_model_weights_filepath   : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename)
        data_dir (string)               : The directory path for the parent directory containing the data. Path will be relative to the working directory.
        logging_config_fname (string)   : The log file
        logging_default_level (string)  : The log level

    """

    if model_weights_filename is not None and native_model_weights_filepath is not None:
        sys.exit("Parameters model_weights_filename and native_model_weights_filepath are mutually exclusive.\nmodel_weights_file was set to {}\native_model_weights_filepath was set to {}".format(model_weights_filename, native_model_weights_filepath))

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # check the inference config
    if 'inference' not in flplan:
        sys.exit("FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled.")
    
    if 'allowed' not in flplan['inference'] or flplan['inference']['allowed'] != True:        
        sys.exit("FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed.")

    # create the data object
    data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_dir)

    # create the model object
    model = create_model_object(flplan, data, model_device=model_device)

    # record which tensors were held out from the saved proto
    _, holdout_tensors = remove_and_save_holdout_tensors(model.get_tensor_dict())

    # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline
    if model_weights_filename is not None:        
        proto_path = os.path.join(base_dir, 'weights', model_weights_filename)
        proto = load_proto(proto_path)
        tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline())

        # restore any tensors held out from the proto
        tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}

        model.set_tensor_dict(tensor_dict, with_opt_vars=False)
    elif native_model_weights_filepath is not None:
        # FIXME: how do we handle kwargs here? Will they come from the flplan?
        model.load_native(native_model_weights_filepath)
    else:
        sys.exit("One of model_weights_filename or native_model_weights_filepath is required.")

    # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block
    inference_kwargs = flplan['inference'].get('kwargs') or {}
    model.run_inference_and_store_results(**inference_kwargs)
def main(plan, resume, collaborators_file, data_config_fname,
         validate_without_patches_flag, data_in_memory_flag,
         data_queue_max_length, data_queue_num_workers, torch_threads,
         kmp_affinity_flag, logging_config_path, logging_default_level,
         logging_directory, model_device, **kwargs):
    """Run the federation simulation from the federation (FL) plan.

    Runs a federated training from the federation (FL) plan but creates the
    aggregator and collaborators on the same compute node. This allows
    the developer to test the model and data loaders before running
    on the remote collaborator nodes.

    Args:
        plan                            : The Federation (FL) plan (YAML file)
        resume                          : Whether or not the aggregator is told to resume from previous best
        collaborators_file              : The file listing the collaborators
        data_config_fname               : The file describing where the dataset is located on the collaborators
        validate_withouut_patches_flag  : controls a model init kwarg
        data_in_memory_flag             : controls a data init kwarg 
        data_queue_max_length           : controls a data init kwarg 
        data_queue_num_workers          : controls a data init kwarg
        torch_threads                   : number of threads to set in torch 
        kmp_affinity_flag               : controls a model init kwarg
        logging_config_path             : The log file
        logging_default_level           : The log level
        **kwargs                        : Variable parameters to pass to the function

    """
    # FIXME: consistent filesystem (#15)
    # establish location for fl plan as well as
    # where to get and write model protobufs
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    weights_dir = os.path.join(base_dir, 'weights')
    metadata_dir = os.path.join(base_dir, 'metadata')
    collaborators_dir = os.path.join(base_dir, 'collaborator_lists')
    logging_config_path = os.path.join(script_dir, logging_config_path)
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path,
                  default_level=logging_default_level,
                  logging_directory=logging_directory)

    # load the flplan, local_config and collaborators file
    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # FIXME: Find a better solution for passing model and data init kwargs
    model_init_kwarg_keys = [
        'validate_without_patches', 'torch_threads', 'kmp_affinity'
    ]
    model_init_kwarg_vals = [
        validate_without_patches_flag, torch_threads, kmp_affinity_flag
    ]
    for key, value in zip(model_init_kwarg_keys, model_init_kwarg_vals):
        if (value is not None) and (value != False):
            flplan['model_object_init']['init_kwargs'][key] = value

    data_init_kwarg_keys = ['in_memory', 'q_max_length', 'q_num_workers']
    data_init_kwarg_vals = [
        data_in_memory_flag, data_queue_max_length, data_queue_num_workers
    ]
    for key, value in zip(data_init_kwarg_keys, data_init_kwarg_vals):
        if (value is not None) and (value != False):
            flplan['data_object_init']['init_kwargs'][key] = value

    local_config = load_yaml(os.path.join(base_dir, data_config_fname))
    collaborator_common_names = load_yaml(
        os.path.join(collaborators_dir,
                     collaborators_file))['collaborator_common_names']

    # TODO: Run a loop here over various parameter values and iterations
    # TODO: implement more than just saving init, best, and latest model
    federate(flplan=flplan,
             resume=resume,
             local_config=local_config,
             collaborator_common_names=collaborator_common_names,
             base_dir=base_dir,
             weights_dir=weights_dir,
             metadata_dir=metadata_dir,
             model_device=model_device)