def main(plan, out_filepath): # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') flplan = yaml.dump(parse_fl_plan(os.path.join(plan_dir, plan))) print(flplan) if out_filepath is not None: with open(out_filepath, 'w') as f: f.write(flplan)
def main(plan, collaborators_file, data_config_fname, logging_config_path, logging_default_level, logging_directory, model_device, **kwargs): """Run the federation simulation from the federation (FL) plan. Runs a federated training from the federation (FL) plan but creates the aggregator and collaborators on the same compute node. This allows the developer to test the model and data loaders before running on the remote collaborator nodes. Args: plan: The Federation (FL) plan (YAML file) collaborators_file: The file listing the collaborators data_config_fname: The file describing where the dataset is located on the collaborators logging_config_path: The log file logging_default_level: The log level **kwargs: Variable parameters to pass to the function """ # FIXME: consistent filesystem (#15) # establish location for fl plan as well as # where to get and write model protobufs script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') weights_dir = os.path.join(base_dir, 'weights') metadata_dir = os.path.join(base_dir, 'metadata') collaborators_dir = os.path.join(base_dir, 'collaborator_lists') logging_config_path = os.path.join(script_dir, logging_config_path) logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) # load the flplan, local_config and collaborators file flplan = parse_fl_plan(os.path.join(plan_dir, plan)) local_config = load_yaml(os.path.join(base_dir, data_config_fname)) collaborator_common_names = load_yaml(os.path.join(collaborators_dir, collaborators_file))['collaborator_common_names'] # TODO: Run a loop here over various parameter values and iterations # TODO: implement more than just saving init, best, and latest model federate(flplan, local_config, collaborator_common_names, base_dir, weights_dir, metadata_dir, model_device)
def main(plan, collaborators_file, single_col_cert_common_name, logging_config_path, logging_default_level, logging_directory, resume, script_dir): """Runs the aggregator service from the Federation (FL) plan Args: plan: The Federation (FL) plan collaborators_file: The file listing the collaborators single_col_cert_common_name: The SSL certificate logging_config_path: The log configuration file logging_default_level: The log level resume: Whether the aggregator should load the latest model instead of the initial model script_dir: default None uses the script dir. Otherwise, use the directory passed as the script dir """ # FIXME: consistent filesystem (#15) if script_dir is None: script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') weights_dir = os.path.join(base_dir, 'weights') collaborators_dir = os.path.join(base_dir, 'collaborator_lists') metadata_dir = os.path.join(base_dir, 'metadata') logging_config_path = os.path.join(script_dir, logging_config_path) logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) collaborator_common_names = load_yaml( os.path.join(collaborators_dir, collaborators_file))['collaborator_common_names'] agg = create_aggregator_object_from_flplan(flplan, collaborator_common_names, single_col_cert_common_name, base_dir, weights_dir, metadata_dir, resume) server = create_aggregator_server_from_flplan(agg, flplan) serve_kwargs = get_serve_kwargs_from_flpan(flplan, base_dir) server.serve(**serve_kwargs)
def main(plan, collaborators_file, single_col_cert_common_name, logging_config_path, logging_default_level, logging_directory): """Runs the aggregator service from the Federation (FL) plan Args: plan: The Federation (FL) plan collaborators_file: The file listing the collaborators single_col_cert_common_name: The SSL certificate logging_config_path: The log configuration file logging_default_level: The log level """ # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') weights_dir = os.path.join(base_dir, 'weights') collaborators_dir = os.path.join(base_dir, 'collaborator_lists') metadata_dir = os.path.join(base_dir, 'metadata') logging_config_path = os.path.join(script_dir, logging_config_path) logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) collaborator_common_names = load_yaml( os.path.join(collaborators_dir, collaborators_file))['collaborator_common_names'] agg = create_aggregator_object_from_flplan(flplan, collaborator_common_names, single_col_cert_common_name, weights_dir, metadata_dir) server = create_aggregator_server_from_flplan(agg, flplan) serve_kwargs = get_serve_kwargs_from_flpan(flplan, base_dir) server.serve(**serve_kwargs)
def main(plan, model_weights_filename, native_model_weights_filepath, populate_weights_at_init, model_file_argument_name, data_dir, logging_config_path, logging_default_level, logging_directory, model_device, inference_patient=None): """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan Args: plan (string) : The filename for the federation (FL) plan YAML file model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!! native_model_weights_filepath (string) : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename) populate_weights_at_init (boolean) : Whether or not the model populates its own weights at instantiation model_file_argument_name (string) : Name of argument to be passed to model __init__ providing model file location info data_dir (string) : The directory path for the parent directory containing the data. Path will be relative to the working directory. logging_config_fname (string) : The log file logging_default_level (string) : The log level inference_patient (string) : Subdirectory of single patient to run inference on (exclusively) """ # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # check the inference config if 'inference' not in flplan: sys.exit( "FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled." ) if 'allowed' not in flplan[ 'inference'] or flplan['inference']['allowed'] != True: sys.exit( "FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed." ) # create the data object data = create_data_object_with_explicit_data_path( flplan=flplan, data_path=data_dir, inference_patient=inference_patient) # TODO: Find a good way to detect and communicate mishandling of model_file_argument_name # Ie, capture exception of model not gettinng its required kwarg for this purpose, also # how to tell if the model is using its random initialization rather than weights from file? if populate_weights_at_init: # Supplementing the flplan base model kwargs to include model weights file info. if model_weights_filename is not None: model_file_argument_name = model_file_argument_name or 'model_weights_filename' flplan['model_object_init']['init_kwargs'].update( {model_file_argument_name: model_weights_filename}) # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser) else: model_file_argument_name = model_file_argument_name or 'native_model_weights_filepath' flplan['model_object_init']['init_kwargs'].update( {model_file_argument_name: native_model_weights_filepath}) # create the base model object model = create_model_object(flplan, data, model_device=model_device) # the model may have an 'infer_volume' method instead of 'infer_batch' if not hasattr(model, 'infer_batch'): if hasattr(model, 'infer_volume'): model = InferenceOnlyModelWrapper(data=data, base_model=model) elif not hasattr(model, 'run_inference_and_store_results'): sys.exit( "If model object does not have a 'run_inference_and_store_results' method, it must have either an 'infer_batch' or 'infer_volume' method." ) if not populate_weights_at_init: # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline if model_weights_filename is not None: proto_path = os.path.join(base_dir, 'weights', model_weights_filename) proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto _, holdout_tensors = remove_and_save_holdout_tensors( model.get_tensor_dict()) tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser) else: # FIXME: how do we handle kwargs here? Will they come from the flplan? model.load_native(native_model_weights_filepath) # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block inference_kwargs = flplan['inference'].get('kwargs') or {} model.run_inference_and_store_results(**inference_kwargs)
def main(plan, collaborator_common_name, single_col_cert_common_name, data_config_fname, data_dir, validate_without_patches_flag, data_in_memory_flag, data_queue_max_length, data_queue_num_workers, torch_threads, kmp_affinity_flag, logging_config_path, logging_default_level, logging_directory, model_device): """Runs the collaborator client process from the federation (FL) plan Args: plan : The filename for the federation (FL) plan YAML file collaborator_common_name : The common name for the collaborator node single_col_cert_common_name : The SSL certificate for this collaborator data_config_fname : The dataset configuration filename (YAML) data_dir : parent directory holding the patient data subdirectories(to be split into train and val) validate_without_patches_flag : controls a model init kwarg data_in_memory_flag : controls a data init kwarg data_queue_max_length : controls a data init kwarg data_queue_num_workers : controls a data init kwarg torch_threads : model init kwarg kmp_affinity_flag : controls a model init kwarg logging_config_fname : The log file logging_default_level : The log level model_device : gets passed to model 'init' function as "device" """ # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') weights_dir = os.path.join(base_dir, 'weights') metadata_dir = os.path.join(base_dir, 'metadata') logging_config_path = os.path.join(script_dir, logging_config_path) logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # FIXME: Find a better solution for passing model and data init kwargs model_init_kwarg_keys = [ 'validate_without_patches', 'torch_threads', 'kmp_affinity' ] model_init_kwarg_vals = [ validate_without_patches_flag, torch_threads, kmp_affinity_flag ] for key, value in zip(model_init_kwarg_keys, model_init_kwarg_vals): if (value is not None) and (value != False): flplan['model_object_init']['init_kwargs'][key] = value data_init_kwarg_keys = ['in_memory', 'q_max_length', 'q_num_workers'] data_init_kwarg_vals = [ data_in_memory_flag, data_queue_max_length, data_queue_num_workers ] for key, value in zip(data_init_kwarg_keys, data_init_kwarg_vals): if (value is not None) and (value != False): flplan['data_object_init']['init_kwargs'][key] = value local_config = load_yaml(os.path.join(base_dir, data_config_fname)) try: collaborator = create_collaborator_object_from_flplan( flplan, collaborator_common_name, local_config, base_dir, weights_dir, metadata_dir, single_col_cert_common_name, data_dir=data_dir, model_device=model_device) collaborator.run() sys.exit(0) except Exception as e: logging.getLogger(__name__).exception(repr(e)) # this is for Sarthak sys.exit(666)
def main(plan, native_model_weights_filepath, collaborators_file, feature_shape, n_classes, data_config_fname, logging_config_path, logging_default_level, model_device): """Creates a protobuf file of the initial weights for the model Uses the federation (FL) plan to create an initial weights file for the federation. Args: plan: The federation (FL) plan filename native_model_weights_filepath: A framework-specific filepath. Path will be relative to the working directory. collaborators_file: feature_shape: The input shape to the model data_config_fname: The data configuration file (defines where the datasets are located) logging_config_path: The log path logging_default_level (int): The default log level """ setup_logging(path=logging_config_path, default_level=logging_default_level) logger = logging.getLogger(__name__) # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') weights_dir = os.path.join(base_dir, 'weights') # ensure the weights dir exists if not os.path.exists(weights_dir): print('creating folder:', weights_dir) os.makedirs(weights_dir) # parse the plan and local config flplan = parse_fl_plan(os.path.join(plan_dir, plan)) local_config = load_yaml(os.path.join(base_dir, data_config_fname)) # get the output filename fpath = os.path.join(weights_dir, flplan['aggregator_object_init']['init_kwargs']['init_model_fname']) # create the data object for models whose architecture depends on the feature shape if feature_shape is None: if collaborators_file is None: sys.exit("You must specify either a feature shape or a collaborator list in order for the script to determine the input layer shape") # FIXME: this will ultimately run in a governor environment and should not require any data to work # pick the first collaborator to create the data and model (could be any) collaborator_common_name = load_yaml(os.path.join(base_dir, 'collaborator_lists', collaborators_file))['collaborator_common_names'][0] data = create_data_object(flplan, collaborator_common_name, local_config, n_classes=n_classes) else: data = get_object('openfl.data.dummy.randomdata', 'RandomData', feature_shape=feature_shape) logger.info('Using data object of type {} and feature shape {}'.format(type(data), feature_shape)) # create the model object and compression pipeline wrapped_model = create_model_object(flplan, data, model_device=model_device) compression_pipeline = create_compression_pipeline(flplan) # determine if we need to store the optimizer variables # FIXME: what if this key is missing? try: opt_treatment = OptTreatment[flplan['collaborator_object_init']['init_kwargs']['opt_treatment']] except KeyError: # FIXME: this error message should use the exception to determine the missing key and the Enum to display the options dynamically sys.exit("FL plan must specify ['collaborator_object_init']['init_kwargs']['opt_treatment'] as [RESET|CONTINUE_LOCAL|CONTINUE_GLOBAL]") # FIXME: this should be an "opt_treatment requires parameters type check rather than a magic string" with_opt_vars = opt_treatment == OptTreatment['CONTINUE_GLOBAL'] if native_model_weights_filepath is not None: wrapped_model.load_native(native_model_weights_filepath) tensor_dict_split_fn_kwargs = wrapped_model.tensor_dict_split_fn_kwargs or {} tensor_dict, holdout_params = split_tensor_dict_for_holdouts(logger, wrapped_model.get_tensor_dict(with_opt_vars=with_opt_vars), **tensor_dict_split_fn_kwargs) logger.warn('Following paramters omitted from global initial model, '\ 'local initialization will determine values: {}'.format(list(holdout_params.keys()))) model_proto = construct_proto(tensor_dict=tensor_dict, model_id=wrapped_model.__class__.__name__, model_version=0, is_delta=False, delta_from_version=-1, compression_pipeline=compression_pipeline) dump_proto(model_proto=model_proto, fpath=fpath) logger.info("Created initial weights file: {}".format(fpath))
def main(data_path, plan_path, model_weights_path, output_pardir, model_output_tag, device, legacy_model_flag=False): # TODO: We do not currently make use of the ability for brainmage to infer by first cropping external # zero planes, or inference by patching and fusing. flplan = parse_fl_plan(plan_path) # make sure the class list we are using is compatible with the hard-coded class_label_map above if flplan['data_object_init']['init_kwargs']['class_list'] != class_list: raise ValueError('We currently only support class_list=', class_list) # construct the data object data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_path) # code is written with assumption we are using the gandlf data object if not issubclass(data.__class__, GANDLFData): raise ValueError( 'This script is currently assumed to be using a child of fets.data.pytorch.gandlf_data.GANDLFData, you are using: ', data.__class__.__name__) # construct the model object (requires cpu since we're passing [padded] whole brains) model = create_model_object(flplan=flplan, data_object=data, model_device=device) # code is written with assumption we are using the brainmage object if not issubclass(model.__class__, BrainMaGeModel): raise ValueError( 'This script is currently assumed to be using a child of fets.models.pytorch.brainmage.BrainMaGeModel, you are using: ', data.__class__.__name__) # legacy models are defined in a single file, newer ones have a folder that holds per-layer files if legacy_model_flag: tensor_dict_from_proto = load_legacy_model_protobuf(model_weights_path) else: tensor_dict_from_proto = load_model(model_weights_path) # restore any tensors held out from the proto _, holdout_tensors = split_tensor_dict_for_holdouts( None, model.get_tensor_dict()) tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) print("\nWill be running inference on {} validation samples.\n".format( model.get_validation_data_size())) if not os.path.exists(output_pardir): os.mkdir(output_pardir) subdir_to_DICE = {} dice_outpath = None for subject in data.get_val_loader(): first_mode_path = subject['1']['path'][ 0] # using this because this is only one that's always defined subfolder = first_mode_path.split('/')[-2] #prep the path for the output files output_subdir = os.path.join(output_pardir, subfolder) if not os.path.exists(output_subdir): os.mkdir(output_subdir) inference_outpath = os.path.join( output_subdir, subfolder + model_output_tag + '_seg.nii.gz') if dice_outpath == None: dice_outpath = os.path.join( output_pardir, model_output_tag + '_subdirs_to_DICE.pkl') if not is_mask_present(subject): raise ValueError( 'We are expecting to run this on subjects that have labels.') label_path = subject['label']['path'][0] label_file = label_path.split('/')[-1] subdir_name = label_path.split('/')[-2] # copy the label file over to the output subdir copy_label_path = os.path.join(output_subdir, label_file) shutil.copyfile(label_path, copy_label_path) features, ground_truth = subject_to_feature_and_label(subject=subject, pad_z=pad_z) output = infer(model, features) # FIXME: Find a better solution # crop away the padding we put in output = output[:, :, :, :, :155] print( one_hot(segmask_array=ground_truth, class_list=class_list).shape, output.shape) # get the DICE score dice_dict = clinical_dice(output=output, target=one_hot(segmask_array=ground_truth, class_list=class_list), class_list=class_list, to_scalar=True) subdir_to_DICE[subdir_name] = dice_dict output = np.squeeze(output.cpu().numpy()) # GANDLFData loader produces transposed output from what sitk gets from file, so transposing here. output = np.transpose(output, [0, 3, 2, 1]) # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map output = new_labels_from_float_output(array=output, class_label_map=class_label_map, binary_classification=False) # convert array to SimpleITK image image = sitk.GetImageFromArray(output) image.CopyInformation(sitk.ReadImage(first_mode_path)) print("\nWriting inference NIfTI image of shape {} to {}".format( output.shape, inference_outpath)) sitk.WriteImage(image, inference_outpath) print("\nCorresponding DICE scores were: ") print("{}\n\n".format(dice_dict)) print("Saving subdir_name_to_DICE at: ", dice_outpath) with open(dice_outpath, 'wb') as _file: pkl.dump(subdir_to_DICE, _file)
def main(plan, model_weights_filename, native_model_weights_filepath, data_dir, logging_config_path, logging_default_level, logging_directory, model_device): """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan Args: plan (string) : The filename for the federation (FL) plan YAML file model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!! native_model_weights_filepath : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename) data_dir (string) : The directory path for the parent directory containing the data. Path will be relative to the working directory. logging_config_fname (string) : The log file logging_default_level (string) : The log level """ if model_weights_filename is not None and native_model_weights_filepath is not None: sys.exit("Parameters model_weights_filename and native_model_weights_filepath are mutually exclusive.\nmodel_weights_file was set to {}\native_model_weights_filepath was set to {}".format(model_weights_filename, native_model_weights_filepath)) # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # check the inference config if 'inference' not in flplan: sys.exit("FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled.") if 'allowed' not in flplan['inference'] or flplan['inference']['allowed'] != True: sys.exit("FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed.") # create the data object data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_dir) # create the model object model = create_model_object(flplan, data, model_device=model_device) # record which tensors were held out from the saved proto _, holdout_tensors = remove_and_save_holdout_tensors(model.get_tensor_dict()) # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline if model_weights_filename is not None: proto_path = os.path.join(base_dir, 'weights', model_weights_filename) proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) elif native_model_weights_filepath is not None: # FIXME: how do we handle kwargs here? Will they come from the flplan? model.load_native(native_model_weights_filepath) else: sys.exit("One of model_weights_filename or native_model_weights_filepath is required.") # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block inference_kwargs = flplan['inference'].get('kwargs') or {} model.run_inference_and_store_results(**inference_kwargs)
def main(plan, resume, collaborators_file, data_config_fname, validate_without_patches_flag, data_in_memory_flag, data_queue_max_length, data_queue_num_workers, torch_threads, kmp_affinity_flag, logging_config_path, logging_default_level, logging_directory, model_device, **kwargs): """Run the federation simulation from the federation (FL) plan. Runs a federated training from the federation (FL) plan but creates the aggregator and collaborators on the same compute node. This allows the developer to test the model and data loaders before running on the remote collaborator nodes. Args: plan : The Federation (FL) plan (YAML file) resume : Whether or not the aggregator is told to resume from previous best collaborators_file : The file listing the collaborators data_config_fname : The file describing where the dataset is located on the collaborators validate_withouut_patches_flag : controls a model init kwarg data_in_memory_flag : controls a data init kwarg data_queue_max_length : controls a data init kwarg data_queue_num_workers : controls a data init kwarg torch_threads : number of threads to set in torch kmp_affinity_flag : controls a model init kwarg logging_config_path : The log file logging_default_level : The log level **kwargs : Variable parameters to pass to the function """ # FIXME: consistent filesystem (#15) # establish location for fl plan as well as # where to get and write model protobufs script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') weights_dir = os.path.join(base_dir, 'weights') metadata_dir = os.path.join(base_dir, 'metadata') collaborators_dir = os.path.join(base_dir, 'collaborator_lists') logging_config_path = os.path.join(script_dir, logging_config_path) logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) # load the flplan, local_config and collaborators file flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # FIXME: Find a better solution for passing model and data init kwargs model_init_kwarg_keys = [ 'validate_without_patches', 'torch_threads', 'kmp_affinity' ] model_init_kwarg_vals = [ validate_without_patches_flag, torch_threads, kmp_affinity_flag ] for key, value in zip(model_init_kwarg_keys, model_init_kwarg_vals): if (value is not None) and (value != False): flplan['model_object_init']['init_kwargs'][key] = value data_init_kwarg_keys = ['in_memory', 'q_max_length', 'q_num_workers'] data_init_kwarg_vals = [ data_in_memory_flag, data_queue_max_length, data_queue_num_workers ] for key, value in zip(data_init_kwarg_keys, data_init_kwarg_vals): if (value is not None) and (value != False): flplan['data_object_init']['init_kwargs'][key] = value local_config = load_yaml(os.path.join(base_dir, data_config_fname)) collaborator_common_names = load_yaml( os.path.join(collaborators_dir, collaborators_file))['collaborator_common_names'] # TODO: Run a loop here over various parameter values and iterations # TODO: implement more than just saving init, best, and latest model federate(flplan=flplan, resume=resume, local_config=local_config, collaborator_common_names=collaborator_common_names, base_dir=base_dir, weights_dir=weights_dir, metadata_dir=metadata_dir, model_device=model_device)