def __init__(self,
                 aggregator_uuid,
                 federation_uuid,
                 collaborator_common_names,
                 init_model_fpath,
                 latest_model_fpath,
                 best_model_fpath,
                 rounds_to_train=256,
                 minimum_reporting=-1,
                 straggler_cutoff_time=np.inf,
                 single_col_cert_common_name=None,
                 compression_pipeline=None,
                 end_of_round_metadata=None,
                 init_metadata_fname=None,
                 latest_metadata_fname=None,
                 send_metadata_to_clients=False,
                 **kwargs):
        self.logger = logging.getLogger(__name__)
        self.uuid = aggregator_uuid
        self.federation_uuid = federation_uuid
        #FIXME: Should we do anything to insure the intial model is compressed?
        self.model = load_proto(init_model_fpath)
        self.latest_model_fpath = latest_model_fpath
        self.best_model_fpath = best_model_fpath
        self.collaborator_common_names = collaborator_common_names
        self.round_num = 1
        self.rounds_to_train = rounds_to_train
        self.quit_job_sent_to = []
        self.minimum_reporting = minimum_reporting
        self.straggler_cutoff_time = straggler_cutoff_time
        self.round_start_time = None
        self.single_col_cert_common_name = single_col_cert_common_name

        if self.single_col_cert_common_name is not None:
            self.log_big_warning()
        else:
            self.single_col_cert_common_name = ''  # FIXME: '' instead of None is just for protobuf compatibility. Cleaner solution?

        self.model_update_in_progress = None

        self.init_per_col_round_stats()
        self.best_model_score = None
        self.aggregated_model_is_global_best = True
        self.mutex = Lock()

        self.compression_pipeline = compression_pipeline or NoCompressionPipeline(
        )

        self.end_of_round_metadata = end_of_round_metadata
        self.init_metadata_fname = init_metadata_fname
        self.latest_metadata_fname = latest_metadata_fname
        self.send_metadata_to_clients = send_metadata_to_clients

        if self.init_metadata_fname is not None:
            self.metadata = load_yaml(init_metadata_fname)
        else:
            self.metadata = {}
        self.metadata['aggregator_uuid'] = aggregator_uuid
        self.metadata['federation_uuid'] = federation_uuid
        self.metadata_for_round = {}
示例#2
0
def load_model(directory):
    extra_model_info = load_proto(os.path.join(directory,
                                               'ExtraModelInfo.pbuf'),
                                  proto_type=ExtraModelInfo)

    tensor_dict_from_proto = {}
    for t in extra_model_info.tensor_names:
        t_hash = hash_string(t)
        tensor_proto = load_proto(os.path.join(directory,
                                               '{}.pbuf'.format(t_hash)),
                                  proto_type=TensorProto)
        if t != tensor_proto.name:
            raise RuntimeError(
                "Loaded the wrong tensor! Meant to load: {} did load: {} read file: {}"
                .format(t, t.name, t_hash))
        tensor_dict_from_proto[t] = tensor_proto_to_numpy_array(tensor_proto)

    return tensor_dict_from_proto
def main(plan,
         model_weights_filename,
         native_model_weights_filepath,
         populate_weights_at_init,
         model_file_argument_name,
         data_dir,
         logging_config_path,
         logging_default_level,
         logging_directory,
         model_device,
         inference_patient=None):
    """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan

    Args:
        plan (string)                           : The filename for the federation (FL) plan YAML file
        model_weights_filename (string)         : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!!
        native_model_weights_filepath (string)  : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename)
        populate_weights_at_init (boolean)      : Whether or not the model populates its own weights at instantiation
        model_file_argument_name (string)       : Name of argument to be passed to model __init__ providing model file location info
        data_dir (string)                       : The directory path for the parent directory containing the data. Path will be relative to the working directory.
        logging_config_fname (string)           : The log file
        logging_default_level (string)          : The log level
        inference_patient (string)              : Subdirectory of single patient to run inference on (exclusively)

    """

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path,
                  default_level=logging_default_level,
                  logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # check the inference config
    if 'inference' not in flplan:
        sys.exit(
            "FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled."
        )

    if 'allowed' not in flplan[
            'inference'] or flplan['inference']['allowed'] != True:
        sys.exit(
            "FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed."
        )

    # create the data object
    data = create_data_object_with_explicit_data_path(
        flplan=flplan, data_path=data_dir, inference_patient=inference_patient)

    # TODO: Find a good way to detect and communicate mishandling of model_file_argument_name
    #       Ie, capture exception of model not gettinng its required kwarg for this purpose, also
    #       how to tell if the model is using its random initialization rather than weights from file?
    if populate_weights_at_init:
        # Supplementing the flplan base model kwargs to include model weights file info.
        if model_weights_filename is not None:
            model_file_argument_name = model_file_argument_name or 'model_weights_filename'
            flplan['model_object_init']['init_kwargs'].update(
                {model_file_argument_name: model_weights_filename})
        # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser)
        else:
            model_file_argument_name = model_file_argument_name or 'native_model_weights_filepath'
            flplan['model_object_init']['init_kwargs'].update(
                {model_file_argument_name: native_model_weights_filepath})

    # create the base model object
    model = create_model_object(flplan, data, model_device=model_device)

    # the model may have an 'infer_volume' method instead of 'infer_batch'
    if not hasattr(model, 'infer_batch'):
        if hasattr(model, 'infer_volume'):
            model = InferenceOnlyModelWrapper(data=data, base_model=model)
        elif not hasattr(model, 'run_inference_and_store_results'):
            sys.exit(
                "If model object does not have a 'run_inference_and_store_results' method, it must have either an 'infer_batch' or 'infer_volume' method."
            )

    if not populate_weights_at_init:
        # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline
        if model_weights_filename is not None:
            proto_path = os.path.join(base_dir, 'weights',
                                      model_weights_filename)
            proto = load_proto(proto_path)
            tensor_dict_from_proto = deconstruct_proto(proto,
                                                       NoCompressionPipeline())

            # restore any tensors held out from the proto
            _, holdout_tensors = remove_and_save_holdout_tensors(
                model.get_tensor_dict())
            tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}

            model.set_tensor_dict(tensor_dict, with_opt_vars=False)

        # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser)
        else:
            # FIXME: how do we handle kwargs here? Will they come from the flplan?
            model.load_native(native_model_weights_filepath)

    # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block
    inference_kwargs = flplan['inference'].get('kwargs') or {}
    model.run_inference_and_store_results(**inference_kwargs)
    def __init__(self,
                 aggregator_uuid,
                 federation_uuid,
                 collaborator_common_names,
                 init_model_fpath,
                 latest_model_fpath,
                 best_model_fpath,
                 rounds_to_train=256,
                 minimum_reporting=-1,
                 straggler_cutoff_time=np.inf,
                 single_col_cert_common_name=None,
                 compression_pipeline=None,
                 end_of_round_metadata=None,
                 init_metadata_fname=None,
                 latest_metadata_fname=None,
                 send_metadata_to_clients=False,
                 save_all_models_path=None,
                 runtime_aggregator_config_dir=None,
                 runtime_configurable_params=None,
                 **kwargs):
        self.logger = logging.getLogger(__name__)
        self.uuid = aggregator_uuid
        self.federation_uuid = federation_uuid

        self.latest_model_fpath = latest_model_fpath
        self.best_model_fpath = best_model_fpath
        self.collaborator_common_names = collaborator_common_names
        self.rounds_to_train = rounds_to_train
        self.quit_job_sent_to = []
        self.minimum_reporting = minimum_reporting
        self.straggler_cutoff_time = straggler_cutoff_time
        self.round_start_time = None
        self.single_col_cert_common_name = single_col_cert_common_name

        self.save_all_models_path = save_all_models_path
        self.runtime_aggregator_config_dir = runtime_aggregator_config_dir
        self.runtime_configurable_params = runtime_configurable_params

        if self.runtime_aggregator_config_dir is not None:
            self.update_config_from_filesystem()

        if self.single_col_cert_common_name is not None:
            self.log_big_warning()
        else:
            self.single_col_cert_common_name = ''  # FIXME: '' instead of None is just for protobuf compatibility. Cleaner solution?

        # FIXME: Should we do anything to insure the intial model is compressed?
        self.model = load_proto(init_model_fpath)
        self.logger.info(
            "Loaded initial model from {}".format(init_model_fpath))
        self.logger.info("Initial model version is {}".format(
            self.model.header.version))

        self.round_num = self.model.header.version + 1

        self._GRACEFULLY_QUIT = False
        self._do_quit = False

        self.model_update_in_progress = None

        self.init_per_col_round_stats()
        self.best_model_score = None
        self.aggregated_model_is_global_best = True
        self.mutex = Lock()

        self.compression_pipeline = compression_pipeline or NoCompressionPipeline(
        )

        self.end_of_round_metadata = end_of_round_metadata
        self.init_metadata_fname = init_metadata_fname
        self.latest_metadata_fname = latest_metadata_fname
        self.send_metadata_to_clients = send_metadata_to_clients

        if self.init_metadata_fname is not None:
            self.metadata = load_yaml(init_metadata_fname)
        else:
            self.metadata = {}
        self.metadata['aggregator_uuid'] = aggregator_uuid
        self.metadata['federation_uuid'] = federation_uuid
        self.metadata_for_round = {}
示例#5
0
def main(data_csv_path, gandlf_config_path, model_weights_path, output_pardir,
         model_output_tag, device):

    # we will use the GANDLFData val loader to serve up the samples to perform inference on
    # will copy the data into the training loader (but not used)

    # These get passed to data constructor
    data_path = {
        'train': data_csv_path,
        'val': data_csv_path,
        'model_params_filepath': gandlf_config_path
    }
    divisibility_factor = 16

    # construct the data object
    data = GANDLFData(data_path=data_path,
                      divisibility_factor=divisibility_factor)

    # construct the model object (requires cpu since we're passing [padded] whole brains)
    model = Model(data=data,
                  base_filters=30,
                  min_learning_rate=0.000001,
                  max_learning_rate=0.001,
                  learning_rate_cycles_per_epoch=0.5,
                  n_classes=4,
                  n_channels=4,
                  loss_function='dc',
                  opt='sgd',
                  use_penalties=False,
                  device=device)

    # Populate the model weights

    proto_path = model_weights_path
    proto = load_proto(proto_path)
    tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline())
    # restore any tensors held out from the proto
    _, holdout_tensors = shared_tensors, holdout_tensors = split_tensor_dict_for_holdouts(
        None, model.get_tensor_dict())
    tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}
    model.set_tensor_dict(tensor_dict, with_opt_vars=False)

    print("\nWill be running inference on {} samples.\n".format(
        model.get_validation_data_size()))

    if not os.path.exists(output_pardir):
        os.mkdir(output_pardir)

    for subject in data.get_val_loader():
        first_mode_path = subject['1']['path'][
            0]  # using this because this is only one that's always defined
        subfolder = first_mode_path.split('/')[-2]

        #prep the path for the output file
        output_subdir = os.path.join(output_pardir, subfolder)
        if not os.path.exists(output_subdir):
            os.mkdir(output_subdir)
        outpath = os.path.join(output_subdir,
                               subfolder + model_output_tag + '_seg.nii.gz')

        if is_mask_present(subject):
            label_path = subject['label']['path'][0]
            label_file = label_path.split('/')[-1]
            # copy the label file over to the output subdir
            copy_label_path = os.path.join(output_subdir, label_file)
            shutil.copyfile(label_path, copy_label_path)

        features, labels = subject_to_feature_and_label(subject)

        output = infer(model, features)

        output = np.squeeze(output.cpu().numpy())

        # crop away the padding we put in
        output = output[:, :, :, :155]

        # the label on disk is transposed from what the gandlf loader produces
        print(
            "\nWARNING: gandlf loader produces transposed output from what sitk gets from file, so transposing here.\n"
        )
        output = np.transpose(output, [0, 3, 2, 1])

        # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map
        output = new_labels_from_float_output(array=output,
                                              class_label_map=class_label_map,
                                              binary_classification=False)

        # convert array to SimpleITK image
        image = sitk.GetImageFromArray(output)

        image.CopyInformation(sitk.ReadImage(first_mode_path))

        print("\nWriting inference NIfTI image of shape {} to {}\n".format(
            output.shape, outpath))
        sitk.WriteImage(image, outpath)
def main(plan, model_weights_filename, native_model_weights_filepath, data_dir, logging_config_path, logging_default_level, logging_directory, model_device):
    """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan

    Args:
        plan (string)                   : The filename for the federation (FL) plan YAML file
        model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!!
        native_model_weights_filepath   : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename)
        data_dir (string)               : The directory path for the parent directory containing the data. Path will be relative to the working directory.
        logging_config_fname (string)   : The log file
        logging_default_level (string)  : The log level

    """

    if model_weights_filename is not None and native_model_weights_filepath is not None:
        sys.exit("Parameters model_weights_filename and native_model_weights_filepath are mutually exclusive.\nmodel_weights_file was set to {}\native_model_weights_filepath was set to {}".format(model_weights_filename, native_model_weights_filepath))

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # check the inference config
    if 'inference' not in flplan:
        sys.exit("FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled.")
    
    if 'allowed' not in flplan['inference'] or flplan['inference']['allowed'] != True:        
        sys.exit("FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed.")

    # create the data object
    data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_dir)

    # create the model object
    model = create_model_object(flplan, data, model_device=model_device)

    # record which tensors were held out from the saved proto
    _, holdout_tensors = remove_and_save_holdout_tensors(model.get_tensor_dict())

    # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline
    if model_weights_filename is not None:        
        proto_path = os.path.join(base_dir, 'weights', model_weights_filename)
        proto = load_proto(proto_path)
        tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline())

        # restore any tensors held out from the proto
        tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}

        model.set_tensor_dict(tensor_dict, with_opt_vars=False)
    elif native_model_weights_filepath is not None:
        # FIXME: how do we handle kwargs here? Will they come from the flplan?
        model.load_native(native_model_weights_filepath)
    else:
        sys.exit("One of model_weights_filename or native_model_weights_filepath is required.")

    # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block
    inference_kwargs = flplan['inference'].get('kwargs') or {}
    model.run_inference_and_store_results(**inference_kwargs)