def create_compression_pipeline(flplan): if flplan.get('compression_pipeline_object_init') is not None: compression_pipeline = init_object( flplan.get('compression_pipeline_object_init')) else: compression_pipeline = NoCompressionPipeline() return compression_pipeline
def __init__(self, aggregator_uuid, federation_uuid, collaborator_common_names, init_model_fpath, latest_model_fpath, best_model_fpath, rounds_to_train=256, minimum_reporting=-1, straggler_cutoff_time=np.inf, single_col_cert_common_name=None, compression_pipeline=None, end_of_round_metadata=None, init_metadata_fname=None, latest_metadata_fname=None, send_metadata_to_clients=False, **kwargs): self.logger = logging.getLogger(__name__) self.uuid = aggregator_uuid self.federation_uuid = federation_uuid #FIXME: Should we do anything to insure the intial model is compressed? self.model = load_proto(init_model_fpath) self.latest_model_fpath = latest_model_fpath self.best_model_fpath = best_model_fpath self.collaborator_common_names = collaborator_common_names self.round_num = 1 self.rounds_to_train = rounds_to_train self.quit_job_sent_to = [] self.minimum_reporting = minimum_reporting self.straggler_cutoff_time = straggler_cutoff_time self.round_start_time = None self.single_col_cert_common_name = single_col_cert_common_name if self.single_col_cert_common_name is not None: self.log_big_warning() else: self.single_col_cert_common_name = '' # FIXME: '' instead of None is just for protobuf compatibility. Cleaner solution? self.model_update_in_progress = None self.init_per_col_round_stats() self.best_model_score = None self.aggregated_model_is_global_best = True self.mutex = Lock() self.compression_pipeline = compression_pipeline or NoCompressionPipeline( ) self.end_of_round_metadata = end_of_round_metadata self.init_metadata_fname = init_metadata_fname self.latest_metadata_fname = latest_metadata_fname self.send_metadata_to_clients = send_metadata_to_clients if self.init_metadata_fname is not None: self.metadata = load_yaml(init_metadata_fname) else: self.metadata = {} self.metadata['aggregator_uuid'] = aggregator_uuid self.metadata['federation_uuid'] = federation_uuid self.metadata_for_round = {}
def main(plan, model_weights_filename, native_model_weights_filepath, populate_weights_at_init, model_file_argument_name, data_dir, logging_config_path, logging_default_level, logging_directory, model_device, inference_patient=None): """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan Args: plan (string) : The filename for the federation (FL) plan YAML file model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!! native_model_weights_filepath (string) : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename) populate_weights_at_init (boolean) : Whether or not the model populates its own weights at instantiation model_file_argument_name (string) : Name of argument to be passed to model __init__ providing model file location info data_dir (string) : The directory path for the parent directory containing the data. Path will be relative to the working directory. logging_config_fname (string) : The log file logging_default_level (string) : The log level inference_patient (string) : Subdirectory of single patient to run inference on (exclusively) """ # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # check the inference config if 'inference' not in flplan: sys.exit( "FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled." ) if 'allowed' not in flplan[ 'inference'] or flplan['inference']['allowed'] != True: sys.exit( "FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed." ) # create the data object data = create_data_object_with_explicit_data_path( flplan=flplan, data_path=data_dir, inference_patient=inference_patient) # TODO: Find a good way to detect and communicate mishandling of model_file_argument_name # Ie, capture exception of model not gettinng its required kwarg for this purpose, also # how to tell if the model is using its random initialization rather than weights from file? if populate_weights_at_init: # Supplementing the flplan base model kwargs to include model weights file info. if model_weights_filename is not None: model_file_argument_name = model_file_argument_name or 'model_weights_filename' flplan['model_object_init']['init_kwargs'].update( {model_file_argument_name: model_weights_filename}) # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser) else: model_file_argument_name = model_file_argument_name or 'native_model_weights_filepath' flplan['model_object_init']['init_kwargs'].update( {model_file_argument_name: native_model_weights_filepath}) # create the base model object model = create_model_object(flplan, data, model_device=model_device) # the model may have an 'infer_volume' method instead of 'infer_batch' if not hasattr(model, 'infer_batch'): if hasattr(model, 'infer_volume'): model = InferenceOnlyModelWrapper(data=data, base_model=model) elif not hasattr(model, 'run_inference_and_store_results'): sys.exit( "If model object does not have a 'run_inference_and_store_results' method, it must have either an 'infer_batch' or 'infer_volume' method." ) if not populate_weights_at_init: # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline if model_weights_filename is not None: proto_path = os.path.join(base_dir, 'weights', model_weights_filename) proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto _, holdout_tensors = remove_and_save_holdout_tensors( model.get_tensor_dict()) tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser) else: # FIXME: how do we handle kwargs here? Will they come from the flplan? model.load_native(native_model_weights_filepath) # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block inference_kwargs = flplan['inference'].get('kwargs') or {} model.run_inference_and_store_results(**inference_kwargs)
def __init__(self, aggregator_uuid, federation_uuid, collaborator_common_names, init_model_fpath, latest_model_fpath, best_model_fpath, rounds_to_train=256, minimum_reporting=-1, straggler_cutoff_time=np.inf, single_col_cert_common_name=None, compression_pipeline=None, end_of_round_metadata=None, init_metadata_fname=None, latest_metadata_fname=None, send_metadata_to_clients=False, save_all_models_path=None, runtime_aggregator_config_dir=None, runtime_configurable_params=None, **kwargs): self.logger = logging.getLogger(__name__) self.uuid = aggregator_uuid self.federation_uuid = federation_uuid self.latest_model_fpath = latest_model_fpath self.best_model_fpath = best_model_fpath self.collaborator_common_names = collaborator_common_names self.rounds_to_train = rounds_to_train self.quit_job_sent_to = [] self.minimum_reporting = minimum_reporting self.straggler_cutoff_time = straggler_cutoff_time self.round_start_time = None self.single_col_cert_common_name = single_col_cert_common_name self.save_all_models_path = save_all_models_path self.runtime_aggregator_config_dir = runtime_aggregator_config_dir self.runtime_configurable_params = runtime_configurable_params if self.runtime_aggregator_config_dir is not None: self.update_config_from_filesystem() if self.single_col_cert_common_name is not None: self.log_big_warning() else: self.single_col_cert_common_name = '' # FIXME: '' instead of None is just for protobuf compatibility. Cleaner solution? # FIXME: Should we do anything to insure the intial model is compressed? self.model = load_proto(init_model_fpath) self.logger.info( "Loaded initial model from {}".format(init_model_fpath)) self.logger.info("Initial model version is {}".format( self.model.header.version)) self.round_num = self.model.header.version + 1 self._GRACEFULLY_QUIT = False self._do_quit = False self.model_update_in_progress = None self.init_per_col_round_stats() self.best_model_score = None self.aggregated_model_is_global_best = True self.mutex = Lock() self.compression_pipeline = compression_pipeline or NoCompressionPipeline( ) self.end_of_round_metadata = end_of_round_metadata self.init_metadata_fname = init_metadata_fname self.latest_metadata_fname = latest_metadata_fname self.send_metadata_to_clients = send_metadata_to_clients if self.init_metadata_fname is not None: self.metadata = load_yaml(init_metadata_fname) else: self.metadata = {} self.metadata['aggregator_uuid'] = aggregator_uuid self.metadata['federation_uuid'] = federation_uuid self.metadata_for_round = {}
def main(data_csv_path, gandlf_config_path, model_weights_path, output_pardir, model_output_tag, device): # we will use the GANDLFData val loader to serve up the samples to perform inference on # will copy the data into the training loader (but not used) # These get passed to data constructor data_path = { 'train': data_csv_path, 'val': data_csv_path, 'model_params_filepath': gandlf_config_path } divisibility_factor = 16 # construct the data object data = GANDLFData(data_path=data_path, divisibility_factor=divisibility_factor) # construct the model object (requires cpu since we're passing [padded] whole brains) model = Model(data=data, base_filters=30, min_learning_rate=0.000001, max_learning_rate=0.001, learning_rate_cycles_per_epoch=0.5, n_classes=4, n_channels=4, loss_function='dc', opt='sgd', use_penalties=False, device=device) # Populate the model weights proto_path = model_weights_path proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto _, holdout_tensors = shared_tensors, holdout_tensors = split_tensor_dict_for_holdouts( None, model.get_tensor_dict()) tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) print("\nWill be running inference on {} samples.\n".format( model.get_validation_data_size())) if not os.path.exists(output_pardir): os.mkdir(output_pardir) for subject in data.get_val_loader(): first_mode_path = subject['1']['path'][ 0] # using this because this is only one that's always defined subfolder = first_mode_path.split('/')[-2] #prep the path for the output file output_subdir = os.path.join(output_pardir, subfolder) if not os.path.exists(output_subdir): os.mkdir(output_subdir) outpath = os.path.join(output_subdir, subfolder + model_output_tag + '_seg.nii.gz') if is_mask_present(subject): label_path = subject['label']['path'][0] label_file = label_path.split('/')[-1] # copy the label file over to the output subdir copy_label_path = os.path.join(output_subdir, label_file) shutil.copyfile(label_path, copy_label_path) features, labels = subject_to_feature_and_label(subject) output = infer(model, features) output = np.squeeze(output.cpu().numpy()) # crop away the padding we put in output = output[:, :, :, :155] # the label on disk is transposed from what the gandlf loader produces print( "\nWARNING: gandlf loader produces transposed output from what sitk gets from file, so transposing here.\n" ) output = np.transpose(output, [0, 3, 2, 1]) # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map output = new_labels_from_float_output(array=output, class_label_map=class_label_map, binary_classification=False) # convert array to SimpleITK image image = sitk.GetImageFromArray(output) image.CopyInformation(sitk.ReadImage(first_mode_path)) print("\nWriting inference NIfTI image of shape {} to {}\n".format( output.shape, outpath)) sitk.WriteImage(image, outpath)
def __init__(self, collaborator_common_name, aggregator_uuid, federation_uuid, wrapped_model, channel, polling_interval=4, opt_treatment="CONTINUE_GLOBAL", compression_pipeline=None, epochs_per_round=1.0, num_batches_per_round=None, send_model_deltas=False, single_col_cert_common_name=None, save_best_native_path=None, save_best_native_kwargs=None, save_metadata_path=None, num_retries=5, **kwargs): self.logger = logging.getLogger(__name__) self.channel = channel self.polling_interval = polling_interval self.save_best_native_path = save_best_native_path self.save_best_native_kwargs = save_best_native_kwargs self.save_metadata_path = save_metadata_path self.num_retries = num_retries # this stuff is really about sanity/correctness checking to ensure the bookkeeping and control flow is correct self.common_name = collaborator_common_name self.aggregator_uuid = aggregator_uuid self.federation_uuid = federation_uuid self.single_col_cert_common_name = single_col_cert_common_name if self.single_col_cert_common_name is None: self.single_col_cert_common_name = '' # FIXME: this is just for protobuf compatibility. Cleaner solution? self.counter = 0 self.model_header = ModelHeader(id=wrapped_model.__class__.__name__, is_delta=send_model_deltas, delta_from_version=-1, version=-1) # number of epochs to perform per round of FL (is a float that is converted # to num_batches before calling the wrapped model train_batches method). # This is overridden by "num_batches_per_round" self.epochs_per_round = epochs_per_round self.num_batches_per_round = num_batches_per_round if num_batches_per_round is not None: self.logger.info( "Collaborator {} overriding epochs_per_round of {} with num_batches_per_round of {}" .format(self.common_name, self.epochs_per_round, self.num_batches_per_round)) self.wrapped_model = wrapped_model self.tensor_dict_split_fn_kwargs = wrapped_model.tensor_dict_split_fn_kwargs or {} # pipeline translating tensor_dict to and from a list of tensor protos self.compression_pipeline = compression_pipeline or NoCompressionPipeline( ) # RESET/CONTINUE_LOCAL/CONTINUE_GLOBAL if hasattr(OptTreatment, opt_treatment): self.opt_treatment = OptTreatment[opt_treatment] else: self.logger.error("Unknown opt_treatment: %s." % opt_treatment) raise NotImplementedError("Unknown opt_treatment: %s." % opt_treatment) # FIXME: this is a temporary fix for non-float values and other named params designated to hold out from aggregation. # Needs updated when we have proper collab-side state saving. self._remove_and_save_holdout_tensors( self.wrapped_model.get_tensor_dict( with_opt_vars=self._with_opt_vars())) # when sending model deltas, baseline values for shared tensors must be kept self.send_model_deltas = send_model_deltas # the base_for_deltas attibute is only accessed in the case that model deltas are being sent if self.send_model_deltas: self.base_for_deltas = {"tensor_dict": None, "version": None}
def main(plan, model_weights_filename, native_model_weights_filepath, data_dir, logging_config_path, logging_default_level, logging_directory, model_device): """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan Args: plan (string) : The filename for the federation (FL) plan YAML file model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!! native_model_weights_filepath : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename) data_dir (string) : The directory path for the parent directory containing the data. Path will be relative to the working directory. logging_config_fname (string) : The log file logging_default_level (string) : The log level """ if model_weights_filename is not None and native_model_weights_filepath is not None: sys.exit("Parameters model_weights_filename and native_model_weights_filepath are mutually exclusive.\nmodel_weights_file was set to {}\native_model_weights_filepath was set to {}".format(model_weights_filename, native_model_weights_filepath)) # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # check the inference config if 'inference' not in flplan: sys.exit("FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled.") if 'allowed' not in flplan['inference'] or flplan['inference']['allowed'] != True: sys.exit("FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed.") # create the data object data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_dir) # create the model object model = create_model_object(flplan, data, model_device=model_device) # record which tensors were held out from the saved proto _, holdout_tensors = remove_and_save_holdout_tensors(model.get_tensor_dict()) # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline if model_weights_filename is not None: proto_path = os.path.join(base_dir, 'weights', model_weights_filename) proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) elif native_model_weights_filepath is not None: # FIXME: how do we handle kwargs here? Will they come from the flplan? model.load_native(native_model_weights_filepath) else: sys.exit("One of model_weights_filename or native_model_weights_filepath is required.") # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block inference_kwargs = flplan['inference'].get('kwargs') or {} model.run_inference_and_store_results(**inference_kwargs)