def main(plan, model_weights_filename, native_model_weights_filepath, populate_weights_at_init, model_file_argument_name, data_dir, logging_config_path, logging_default_level, logging_directory, model_device, inference_patient=None): """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan Args: plan (string) : The filename for the federation (FL) plan YAML file model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!! native_model_weights_filepath (string) : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename) populate_weights_at_init (boolean) : Whether or not the model populates its own weights at instantiation model_file_argument_name (string) : Name of argument to be passed to model __init__ providing model file location info data_dir (string) : The directory path for the parent directory containing the data. Path will be relative to the working directory. logging_config_fname (string) : The log file logging_default_level (string) : The log level inference_patient (string) : Subdirectory of single patient to run inference on (exclusively) """ # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # check the inference config if 'inference' not in flplan: sys.exit( "FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled." ) if 'allowed' not in flplan[ 'inference'] or flplan['inference']['allowed'] != True: sys.exit( "FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed." ) # create the data object data = create_data_object_with_explicit_data_path( flplan=flplan, data_path=data_dir, inference_patient=inference_patient) # TODO: Find a good way to detect and communicate mishandling of model_file_argument_name # Ie, capture exception of model not gettinng its required kwarg for this purpose, also # how to tell if the model is using its random initialization rather than weights from file? if populate_weights_at_init: # Supplementing the flplan base model kwargs to include model weights file info. if model_weights_filename is not None: model_file_argument_name = model_file_argument_name or 'model_weights_filename' flplan['model_object_init']['init_kwargs'].update( {model_file_argument_name: model_weights_filename}) # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser) else: model_file_argument_name = model_file_argument_name or 'native_model_weights_filepath' flplan['model_object_init']['init_kwargs'].update( {model_file_argument_name: native_model_weights_filepath}) # create the base model object model = create_model_object(flplan, data, model_device=model_device) # the model may have an 'infer_volume' method instead of 'infer_batch' if not hasattr(model, 'infer_batch'): if hasattr(model, 'infer_volume'): model = InferenceOnlyModelWrapper(data=data, base_model=model) elif not hasattr(model, 'run_inference_and_store_results'): sys.exit( "If model object does not have a 'run_inference_and_store_results' method, it must have either an 'infer_batch' or 'infer_volume' method." ) if not populate_weights_at_init: # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline if model_weights_filename is not None: proto_path = os.path.join(base_dir, 'weights', model_weights_filename) proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto _, holdout_tensors = remove_and_save_holdout_tensors( model.get_tensor_dict()) tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser) else: # FIXME: how do we handle kwargs here? Will they come from the flplan? model.load_native(native_model_weights_filepath) # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block inference_kwargs = flplan['inference'].get('kwargs') or {} model.run_inference_and_store_results(**inference_kwargs)
def UploadLocalModelUpdate(self, message): """Parses the collaborator reply message to get the collaborator model update Args: message: Message from the collaborator Returns: The reply to the message (usually just the acknowledgement to the collaborator) """ self.mutex.acquire(blocking=True) try: t = time.time() self.validate_header(message) self.logger.info("Receive model update from %s " % message.header.sender) # Get the model parameters from the model proto and additional model info model_tensors = deconstruct_proto( model_proto=message.model, compression_pipeline=self.compression_pipeline) is_delta = message.model.header.is_delta delta_from_version = message.model.header.delta_from_version # if collaborator out of sync, we need to log and ignore if self.collaborator_out_of_sync(message.model.header): self.logger.info( "Model version mismatch in UploadLocalModelUpdate from {}. Aggregator version: {} Collaborator version: {}. Ignoring update" .format(message.header.sender, self.model.header.version, message.model.header.version)) return LocalModelUpdateAck( header=self.create_reply_header(message)) # ensure we haven't received an update from this collaborator already check_not_in(message.header.sender, self.per_col_round_stats["loss_results"], self.logger) check_not_in( message.header.sender, self.per_col_round_stats["collaborator_training_sizes"], self.logger) # dump the local update, if necessary if self.save_all_models_path is not None: self.save_local_update(message.header.sender, message.model) # if this is our very first update for the round, we take these model tensors as-is # FIXME: move to model deltas, add with original to reconstruct # FIXME: this really only works with a trusted collaborator. Sanity check this against self.model if self.model_update_in_progress is None: self.model_update_in_progress = { "tensor_dict": model_tensors, "is_delta": is_delta, "delta_from_version": delta_from_version } # otherwise, we compute the streaming weighted average else: # get the current update size total total_update_size = np.sum( list( self.per_col_round_stats["collaborator_training_sizes"] .values())) # compute the weights for the global vs local tensors for our streaming average weight_g = total_update_size / (message.data_size + total_update_size) weight_l = message.data_size / (message.data_size + total_update_size) # The model parameters are represented in float32 and will be transmitted in byte stream. weight_g = weight_g.astype(np.float32) weight_l = weight_l.astype(np.float32) # FIXME: right now we're really using names just to sanity check consistent ordering # check that the models include the same number of tensors, and that whether or not # it is a delta and from what version is the same check_equal(len(self.model_update_in_progress["tensor_dict"]), len(model_tensors), self.logger) check_equal(self.model_update_in_progress["is_delta"], is_delta, self.logger) check_equal( self.model_update_in_progress["delta_from_version"], delta_from_version, self.logger) # aggregate all the model tensors in the tensor_dict # (weighted average of local update l and global tensor g for all l, g) for name, l in model_tensors.items(): g = self.model_update_in_progress["tensor_dict"][name] # check that g and l have the same shape check_equal(g.shape, l.shape, self.logger) # now store a weighted average into the update in progress self.model_update_in_progress["tensor_dict"][ name] = np.average([g, l], weights=[weight_g, weight_l], axis=0) # store the loss results and training update size self.per_col_round_stats["loss_results"][ message.header.sender] = message.loss self.per_col_round_stats["collaborator_training_sizes"][ message.header.sender] = message.data_size # return LocalModelUpdateAck self.logger.debug("Complete model update from %s " % message.header.sender) reply = LocalModelUpdateAck( header=self.create_reply_header(message)) self.end_of_round_check() self.logger.debug( 'aggregator handled UploadLocalModelUpdate in time {}'.format( time.time() - t)) finally: self.mutex.release() return reply
def do_download_model_job(self): """Download model operation Asks the aggregator for the latest model to download and downloads it. """ # time the download download_start = time.time() # sanity check on version is implicit in send # FIXME: this needs to be a more robust response. The aggregator should actually have sent an error code, rather than an unhandled exception # an exception can happen in cases where we simply need to retry for i in range(self.num_retries): try: reply = self.channel.DownloadModel( ModelDownloadRequest(header=self.create_message_header(), model_header=self.model_header)) break except Exception as e: self.logger.exception(repr(e)) # if final retry, raise exception if i + 1 == self.num_retries: raise e else: self.logger.warning( "Retrying download of model. Try {} of {}".format( i + 1, self.num_retries)) received_model_proto = reply.model received_model_version = received_model_proto.header.version # handling possability that the recieved model is delta received_model_is_delta = received_model_proto.header.is_delta received_model_delta_from_version = received_model_proto.header.delta_from_version self.logger.info("{} took {} seconds to download the model".format( self, round(time.time() - download_start, 3))) self.validate_header(reply) self.logger.info( "{} - Completed the model downloading job.".format(self)) check_type(reply, GlobalModelUpdate, self.logger) # check if our version has been reverted, possibly due to an aggregator reset version_reverted = self.model_header.version > received_model_proto.header.version # set our model header self.model_header = received_model_proto.header # compute the aggregated tensors dict from the model proto agg_tensor_dict = deconstruct_proto( model_proto=received_model_proto, compression_pipeline=self.compression_pipeline) # TODO: If updating of base is not done every round, we will no longer be able to use the base to get # the current global values of the shared tensors. if self.send_model_deltas: self.update_base_for_deltas( tensor_dict=agg_tensor_dict, delta_from_version=received_model_delta_from_version, version=received_model_version, is_delta=received_model_is_delta) # base_for_deltas can provide the global shared tensor values here agg_tensor_dict = self.base_for_deltas["tensor_dict"] # restore any tensors held out from aggregation tensor_dict = {**agg_tensor_dict, **self.holdout_tensors} if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL: with_opt_vars = True else: with_opt_vars = False # Ensuring proper initialization regardless of model state. Initial global models # do not contain optimizer state, and so cannot be used to reset the optimizer params. # Additionally, in any mode other than continue global, if we received an older model, we need to # reset our optimizer parameters if reply.model.header.version == 0 or \ (self.opt_treatment != OptTreatment.CONTINUE_GLOBAL and version_reverted): with_opt_vars = False self.logger.info("Resetting optimizer vars") self.wrapped_model.reset_opt_vars() self.wrapped_model.set_tensor_dict(tensor_dict, with_opt_vars=with_opt_vars) self.logger.debug("Loaded the model.") # if we are supposed to save the best model and the model is the best, we save it if reply.is_global_best and self.save_best_native_path: self.wrapped_model.save_native(self.save_best_native_path, self.save_best_native_kwargs) self.logger.info("Saving best model to {}".format( self.save_best_native_path)) # if we should save metadata and have meta in protobuf if self.save_metadata_path is not None and reply.metadata_yaml is not None: with open(self.save_metadata_path, 'w') as f: f.write(reply.metadata_yaml) self.logger.info("Wrote metadata to {}".format( self.save_metadata_path)) # FIXME: for the CONTINUE_LOCAL treatment, we need to store the status in case of a crash. if self.opt_treatment == OptTreatment.RESET: try: self.wrapped_model.reset_opt_vars() except: self.logger.exception( "Failed to reset the optimization variables.") raise else: self.logger.debug("Reset the optimization variables.")
def main(data_csv_path, gandlf_config_path, model_weights_path, output_pardir, model_output_tag, device): # we will use the GANDLFData val loader to serve up the samples to perform inference on # will copy the data into the training loader (but not used) # These get passed to data constructor data_path = { 'train': data_csv_path, 'val': data_csv_path, 'model_params_filepath': gandlf_config_path } divisibility_factor = 16 # construct the data object data = GANDLFData(data_path=data_path, divisibility_factor=divisibility_factor) # construct the model object (requires cpu since we're passing [padded] whole brains) model = Model(data=data, base_filters=30, min_learning_rate=0.000001, max_learning_rate=0.001, learning_rate_cycles_per_epoch=0.5, n_classes=4, n_channels=4, loss_function='dc', opt='sgd', use_penalties=False, device=device) # Populate the model weights proto_path = model_weights_path proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto _, holdout_tensors = shared_tensors, holdout_tensors = split_tensor_dict_for_holdouts( None, model.get_tensor_dict()) tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) print("\nWill be running inference on {} samples.\n".format( model.get_validation_data_size())) if not os.path.exists(output_pardir): os.mkdir(output_pardir) for subject in data.get_val_loader(): first_mode_path = subject['1']['path'][ 0] # using this because this is only one that's always defined subfolder = first_mode_path.split('/')[-2] #prep the path for the output file output_subdir = os.path.join(output_pardir, subfolder) if not os.path.exists(output_subdir): os.mkdir(output_subdir) outpath = os.path.join(output_subdir, subfolder + model_output_tag + '_seg.nii.gz') if is_mask_present(subject): label_path = subject['label']['path'][0] label_file = label_path.split('/')[-1] # copy the label file over to the output subdir copy_label_path = os.path.join(output_subdir, label_file) shutil.copyfile(label_path, copy_label_path) features, labels = subject_to_feature_and_label(subject) output = infer(model, features) output = np.squeeze(output.cpu().numpy()) # crop away the padding we put in output = output[:, :, :, :155] # the label on disk is transposed from what the gandlf loader produces print( "\nWARNING: gandlf loader produces transposed output from what sitk gets from file, so transposing here.\n" ) output = np.transpose(output, [0, 3, 2, 1]) # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map output = new_labels_from_float_output(array=output, class_label_map=class_label_map, binary_classification=False) # convert array to SimpleITK image image = sitk.GetImageFromArray(output) image.CopyInformation(sitk.ReadImage(first_mode_path)) print("\nWriting inference NIfTI image of shape {} to {}\n".format( output.shape, outpath)) sitk.WriteImage(image, outpath)
def main(plan, model_weights_filename, native_model_weights_filepath, data_dir, logging_config_path, logging_default_level, logging_directory, model_device): """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan Args: plan (string) : The filename for the federation (FL) plan YAML file model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!! native_model_weights_filepath : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename) data_dir (string) : The directory path for the parent directory containing the data. Path will be relative to the working directory. logging_config_fname (string) : The log file logging_default_level (string) : The log level """ if model_weights_filename is not None and native_model_weights_filepath is not None: sys.exit("Parameters model_weights_filename and native_model_weights_filepath are mutually exclusive.\nmodel_weights_file was set to {}\native_model_weights_filepath was set to {}".format(model_weights_filename, native_model_weights_filepath)) # FIXME: consistent filesystem (#15) script_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.join(script_dir, 'federations') plan_dir = os.path.join(base_dir, 'plans') logging_directory = os.path.join(script_dir, logging_directory) setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory) flplan = parse_fl_plan(os.path.join(plan_dir, plan)) # check the inference config if 'inference' not in flplan: sys.exit("FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled.") if 'allowed' not in flplan['inference'] or flplan['inference']['allowed'] != True: sys.exit("FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed.") # create the data object data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_dir) # create the model object model = create_model_object(flplan, data, model_device=model_device) # record which tensors were held out from the saved proto _, holdout_tensors = remove_and_save_holdout_tensors(model.get_tensor_dict()) # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline if model_weights_filename is not None: proto_path = os.path.join(base_dir, 'weights', model_weights_filename) proto = load_proto(proto_path) tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline()) # restore any tensors held out from the proto tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) elif native_model_weights_filepath is not None: # FIXME: how do we handle kwargs here? Will they come from the flplan? model.load_native(native_model_weights_filepath) else: sys.exit("One of model_weights_filename or native_model_weights_filepath is required.") # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block inference_kwargs = flplan['inference'].get('kwargs') or {} model.run_inference_and_store_results(**inference_kwargs)