def main(plan,
         model_weights_filename,
         native_model_weights_filepath,
         populate_weights_at_init,
         model_file_argument_name,
         data_dir,
         logging_config_path,
         logging_default_level,
         logging_directory,
         model_device,
         inference_patient=None):
    """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan

    Args:
        plan (string)                           : The filename for the federation (FL) plan YAML file
        model_weights_filename (string)         : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!!
        native_model_weights_filepath (string)  : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename)
        populate_weights_at_init (boolean)      : Whether or not the model populates its own weights at instantiation
        model_file_argument_name (string)       : Name of argument to be passed to model __init__ providing model file location info
        data_dir (string)                       : The directory path for the parent directory containing the data. Path will be relative to the working directory.
        logging_config_fname (string)           : The log file
        logging_default_level (string)          : The log level
        inference_patient (string)              : Subdirectory of single patient to run inference on (exclusively)

    """

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path,
                  default_level=logging_default_level,
                  logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # check the inference config
    if 'inference' not in flplan:
        sys.exit(
            "FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled."
        )

    if 'allowed' not in flplan[
            'inference'] or flplan['inference']['allowed'] != True:
        sys.exit(
            "FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed."
        )

    # create the data object
    data = create_data_object_with_explicit_data_path(
        flplan=flplan, data_path=data_dir, inference_patient=inference_patient)

    # TODO: Find a good way to detect and communicate mishandling of model_file_argument_name
    #       Ie, capture exception of model not gettinng its required kwarg for this purpose, also
    #       how to tell if the model is using its random initialization rather than weights from file?
    if populate_weights_at_init:
        # Supplementing the flplan base model kwargs to include model weights file info.
        if model_weights_filename is not None:
            model_file_argument_name = model_file_argument_name or 'model_weights_filename'
            flplan['model_object_init']['init_kwargs'].update(
                {model_file_argument_name: model_weights_filename})
        # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser)
        else:
            model_file_argument_name = model_file_argument_name or 'native_model_weights_filepath'
            flplan['model_object_init']['init_kwargs'].update(
                {model_file_argument_name: native_model_weights_filepath})

    # create the base model object
    model = create_model_object(flplan, data, model_device=model_device)

    # the model may have an 'infer_volume' method instead of 'infer_batch'
    if not hasattr(model, 'infer_batch'):
        if hasattr(model, 'infer_volume'):
            model = InferenceOnlyModelWrapper(data=data, base_model=model)
        elif not hasattr(model, 'run_inference_and_store_results'):
            sys.exit(
                "If model object does not have a 'run_inference_and_store_results' method, it must have either an 'infer_batch' or 'infer_volume' method."
            )

    if not populate_weights_at_init:
        # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline
        if model_weights_filename is not None:
            proto_path = os.path.join(base_dir, 'weights',
                                      model_weights_filename)
            proto = load_proto(proto_path)
            tensor_dict_from_proto = deconstruct_proto(proto,
                                                       NoCompressionPipeline())

            # restore any tensors held out from the proto
            _, holdout_tensors = remove_and_save_holdout_tensors(
                model.get_tensor_dict())
            tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}

            model.set_tensor_dict(tensor_dict, with_opt_vars=False)

        # model_weights_filename and native_model_weights_filepath are mutually exlusive and required (see argument parser)
        else:
            # FIXME: how do we handle kwargs here? Will they come from the flplan?
            model.load_native(native_model_weights_filepath)

    # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block
    inference_kwargs = flplan['inference'].get('kwargs') or {}
    model.run_inference_and_store_results(**inference_kwargs)
    def UploadLocalModelUpdate(self, message):
        """Parses the collaborator reply message to get the collaborator model update

        Args:
            message: Message from the collaborator

        Returns:
            The reply to the message (usually just the acknowledgement to the collaborator)

        """

        self.mutex.acquire(blocking=True)
        try:
            t = time.time()
            self.validate_header(message)

            self.logger.info("Receive model update from %s " %
                             message.header.sender)

            # Get the model parameters from the model proto and additional model info
            model_tensors = deconstruct_proto(
                model_proto=message.model,
                compression_pipeline=self.compression_pipeline)
            is_delta = message.model.header.is_delta
            delta_from_version = message.model.header.delta_from_version

            # if collaborator out of sync, we need to log and ignore
            if self.collaborator_out_of_sync(message.model.header):
                self.logger.info(
                    "Model version mismatch in UploadLocalModelUpdate from {}. Aggregator version: {} Collaborator version: {}. Ignoring update"
                    .format(message.header.sender, self.model.header.version,
                            message.model.header.version))
                return LocalModelUpdateAck(
                    header=self.create_reply_header(message))

            # ensure we haven't received an update from this collaborator already
            check_not_in(message.header.sender,
                         self.per_col_round_stats["loss_results"], self.logger)
            check_not_in(
                message.header.sender,
                self.per_col_round_stats["collaborator_training_sizes"],
                self.logger)

            # dump the local update, if necessary
            if self.save_all_models_path is not None:
                self.save_local_update(message.header.sender, message.model)

            # if this is our very first update for the round, we take these model tensors as-is
            # FIXME: move to model deltas, add with original to reconstruct
            # FIXME: this really only works with a trusted collaborator. Sanity check this against self.model
            if self.model_update_in_progress is None:
                self.model_update_in_progress = {
                    "tensor_dict": model_tensors,
                    "is_delta": is_delta,
                    "delta_from_version": delta_from_version
                }

            # otherwise, we compute the streaming weighted average
            else:
                # get the current update size total
                total_update_size = np.sum(
                    list(
                        self.per_col_round_stats["collaborator_training_sizes"]
                        .values()))

                # compute the weights for the global vs local tensors for our streaming average
                weight_g = total_update_size / (message.data_size +
                                                total_update_size)
                weight_l = message.data_size / (message.data_size +
                                                total_update_size)

                # The model parameters are represented in float32 and will be transmitted in byte stream.
                weight_g = weight_g.astype(np.float32)
                weight_l = weight_l.astype(np.float32)

                # FIXME: right now we're really using names just to sanity check consistent ordering

                # check that the models include the same number of tensors, and that whether or not
                # it is a delta and from what version is the same
                check_equal(len(self.model_update_in_progress["tensor_dict"]),
                            len(model_tensors), self.logger)
                check_equal(self.model_update_in_progress["is_delta"],
                            is_delta, self.logger)
                check_equal(
                    self.model_update_in_progress["delta_from_version"],
                    delta_from_version, self.logger)

                # aggregate all the model tensors in the tensor_dict
                # (weighted average of local update l and global tensor g for all l, g)
                for name, l in model_tensors.items():
                    g = self.model_update_in_progress["tensor_dict"][name]
                    # check that g and l have the same shape
                    check_equal(g.shape, l.shape, self.logger)

                    # now store a weighted average into the update in progress
                    self.model_update_in_progress["tensor_dict"][
                        name] = np.average([g, l],
                                           weights=[weight_g, weight_l],
                                           axis=0)

            # store the loss results and training update size
            self.per_col_round_stats["loss_results"][
                message.header.sender] = message.loss
            self.per_col_round_stats["collaborator_training_sizes"][
                message.header.sender] = message.data_size

            # return LocalModelUpdateAck
            self.logger.debug("Complete model update from %s " %
                              message.header.sender)
            reply = LocalModelUpdateAck(
                header=self.create_reply_header(message))

            self.end_of_round_check()

            self.logger.debug(
                'aggregator handled UploadLocalModelUpdate in time {}'.format(
                    time.time() - t))
        finally:
            self.mutex.release()

        return reply
Exemple #3
0
    def do_download_model_job(self):
        """Download model operation

        Asks the aggregator for the latest model to download and downloads it.

        """

        # time the download
        download_start = time.time()

        # sanity check on version is implicit in send
        # FIXME: this needs to be a more robust response. The aggregator should actually have sent an error code, rather than an unhandled exception
        # an exception can happen in cases where we simply need to retry
        for i in range(self.num_retries):
            try:
                reply = self.channel.DownloadModel(
                    ModelDownloadRequest(header=self.create_message_header(),
                                         model_header=self.model_header))
                break
            except Exception as e:
                self.logger.exception(repr(e))
                # if final retry, raise exception
                if i + 1 == self.num_retries:
                    raise e
                else:
                    self.logger.warning(
                        "Retrying download of model. Try {} of {}".format(
                            i + 1, self.num_retries))

        received_model_proto = reply.model
        received_model_version = received_model_proto.header.version

        # handling possability that the recieved model is delta
        received_model_is_delta = received_model_proto.header.is_delta
        received_model_delta_from_version = received_model_proto.header.delta_from_version

        self.logger.info("{} took {} seconds to download the model".format(
            self, round(time.time() - download_start, 3)))

        self.validate_header(reply)
        self.logger.info(
            "{} - Completed the model downloading job.".format(self))

        check_type(reply, GlobalModelUpdate, self.logger)

        # check if our version has been reverted, possibly due to an aggregator reset
        version_reverted = self.model_header.version > received_model_proto.header.version

        # set our model header
        self.model_header = received_model_proto.header

        # compute the aggregated tensors dict from the model proto
        agg_tensor_dict = deconstruct_proto(
            model_proto=received_model_proto,
            compression_pipeline=self.compression_pipeline)

        # TODO: If updating of base is not done every round, we will no longer be able to use the base to get
        #       the current global values of the shared tensors.
        if self.send_model_deltas:
            self.update_base_for_deltas(
                tensor_dict=agg_tensor_dict,
                delta_from_version=received_model_delta_from_version,
                version=received_model_version,
                is_delta=received_model_is_delta)
            # base_for_deltas can provide the global shared tensor values here
            agg_tensor_dict = self.base_for_deltas["tensor_dict"]

        # restore any tensors held out from aggregation
        tensor_dict = {**agg_tensor_dict, **self.holdout_tensors}

        if self.opt_treatment == OptTreatment.CONTINUE_GLOBAL:
            with_opt_vars = True
        else:
            with_opt_vars = False

        # Ensuring proper initialization regardless of model state. Initial global models
        # do not contain optimizer state, and so cannot be used to reset the optimizer params.
        # Additionally, in any mode other than continue global, if we received an older model, we need to
        # reset our optimizer parameters
        if reply.model.header.version == 0 or \
           (self.opt_treatment != OptTreatment.CONTINUE_GLOBAL and version_reverted):
            with_opt_vars = False
            self.logger.info("Resetting optimizer vars")
            self.wrapped_model.reset_opt_vars()

        self.wrapped_model.set_tensor_dict(tensor_dict,
                                           with_opt_vars=with_opt_vars)
        self.logger.debug("Loaded the model.")

        # if we are supposed to save the best model and the model is the best, we save it
        if reply.is_global_best and self.save_best_native_path:
            self.wrapped_model.save_native(self.save_best_native_path,
                                           self.save_best_native_kwargs)
            self.logger.info("Saving best model to {}".format(
                self.save_best_native_path))

        # if we should save metadata and have meta in protobuf
        if self.save_metadata_path is not None and reply.metadata_yaml is not None:
            with open(self.save_metadata_path, 'w') as f:
                f.write(reply.metadata_yaml)
                self.logger.info("Wrote metadata to {}".format(
                    self.save_metadata_path))

        # FIXME: for the CONTINUE_LOCAL treatment, we need to store the status in case of a crash.
        if self.opt_treatment == OptTreatment.RESET:
            try:
                self.wrapped_model.reset_opt_vars()
            except:
                self.logger.exception(
                    "Failed to reset the optimization variables.")
                raise
            else:
                self.logger.debug("Reset the optimization variables.")
Exemple #4
0
def main(data_csv_path, gandlf_config_path, model_weights_path, output_pardir,
         model_output_tag, device):

    # we will use the GANDLFData val loader to serve up the samples to perform inference on
    # will copy the data into the training loader (but not used)

    # These get passed to data constructor
    data_path = {
        'train': data_csv_path,
        'val': data_csv_path,
        'model_params_filepath': gandlf_config_path
    }
    divisibility_factor = 16

    # construct the data object
    data = GANDLFData(data_path=data_path,
                      divisibility_factor=divisibility_factor)

    # construct the model object (requires cpu since we're passing [padded] whole brains)
    model = Model(data=data,
                  base_filters=30,
                  min_learning_rate=0.000001,
                  max_learning_rate=0.001,
                  learning_rate_cycles_per_epoch=0.5,
                  n_classes=4,
                  n_channels=4,
                  loss_function='dc',
                  opt='sgd',
                  use_penalties=False,
                  device=device)

    # Populate the model weights

    proto_path = model_weights_path
    proto = load_proto(proto_path)
    tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline())
    # restore any tensors held out from the proto
    _, holdout_tensors = shared_tensors, holdout_tensors = split_tensor_dict_for_holdouts(
        None, model.get_tensor_dict())
    tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}
    model.set_tensor_dict(tensor_dict, with_opt_vars=False)

    print("\nWill be running inference on {} samples.\n".format(
        model.get_validation_data_size()))

    if not os.path.exists(output_pardir):
        os.mkdir(output_pardir)

    for subject in data.get_val_loader():
        first_mode_path = subject['1']['path'][
            0]  # using this because this is only one that's always defined
        subfolder = first_mode_path.split('/')[-2]

        #prep the path for the output file
        output_subdir = os.path.join(output_pardir, subfolder)
        if not os.path.exists(output_subdir):
            os.mkdir(output_subdir)
        outpath = os.path.join(output_subdir,
                               subfolder + model_output_tag + '_seg.nii.gz')

        if is_mask_present(subject):
            label_path = subject['label']['path'][0]
            label_file = label_path.split('/')[-1]
            # copy the label file over to the output subdir
            copy_label_path = os.path.join(output_subdir, label_file)
            shutil.copyfile(label_path, copy_label_path)

        features, labels = subject_to_feature_and_label(subject)

        output = infer(model, features)

        output = np.squeeze(output.cpu().numpy())

        # crop away the padding we put in
        output = output[:, :, :, :155]

        # the label on disk is transposed from what the gandlf loader produces
        print(
            "\nWARNING: gandlf loader produces transposed output from what sitk gets from file, so transposing here.\n"
        )
        output = np.transpose(output, [0, 3, 2, 1])

        # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map
        output = new_labels_from_float_output(array=output,
                                              class_label_map=class_label_map,
                                              binary_classification=False)

        # convert array to SimpleITK image
        image = sitk.GetImageFromArray(output)

        image.CopyInformation(sitk.ReadImage(first_mode_path))

        print("\nWriting inference NIfTI image of shape {} to {}\n".format(
            output.shape, outpath))
        sitk.WriteImage(image, outpath)
def main(plan, model_weights_filename, native_model_weights_filepath, data_dir, logging_config_path, logging_default_level, logging_directory, model_device):
    """Runs the inference according to the flplan, data-dir and weights file. Output format is determined by the data object in the flplan

    Args:
        plan (string)                   : The filename for the federation (FL) plan YAML file
        model_weights_filename (string) : A .pbuf filename in the common weights directory (mutually exclusive with native_model_weights_filepath). NOTE: these must be uncompressed weights!!
        native_model_weights_filepath   : A framework-specific filepath. Path will be relative to the working directory. (mutually exclusive with model_weights_filename)
        data_dir (string)               : The directory path for the parent directory containing the data. Path will be relative to the working directory.
        logging_config_fname (string)   : The log file
        logging_default_level (string)  : The log level

    """

    if model_weights_filename is not None and native_model_weights_filepath is not None:
        sys.exit("Parameters model_weights_filename and native_model_weights_filepath are mutually exclusive.\nmodel_weights_file was set to {}\native_model_weights_filepath was set to {}".format(model_weights_filename, native_model_weights_filepath))

    # FIXME: consistent filesystem (#15)
    script_dir = os.path.dirname(os.path.realpath(__file__))
    base_dir = os.path.join(script_dir, 'federations')
    plan_dir = os.path.join(base_dir, 'plans')
    logging_directory = os.path.join(script_dir, logging_directory)

    setup_logging(path=logging_config_path, default_level=logging_default_level, logging_directory=logging_directory)

    flplan = parse_fl_plan(os.path.join(plan_dir, plan))

    # check the inference config
    if 'inference' not in flplan:
        sys.exit("FL Plan does not contain a top-level 'inference' entry. By default, inference is disabled.")
    
    if 'allowed' not in flplan['inference'] or flplan['inference']['allowed'] != True:        
        sys.exit("FL Plan must contain a {'inference: {'allowed': True}} entry in order for inference to be allowed.")

    # create the data object
    data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_dir)

    # create the model object
    model = create_model_object(flplan, data, model_device=model_device)

    # record which tensors were held out from the saved proto
    _, holdout_tensors = remove_and_save_holdout_tensors(model.get_tensor_dict())

    # if pbuf weights, we need to run deconstruct proto with a NoCompression pipeline
    if model_weights_filename is not None:        
        proto_path = os.path.join(base_dir, 'weights', model_weights_filename)
        proto = load_proto(proto_path)
        tensor_dict_from_proto = deconstruct_proto(proto, NoCompressionPipeline())

        # restore any tensors held out from the proto
        tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}

        model.set_tensor_dict(tensor_dict, with_opt_vars=False)
    elif native_model_weights_filepath is not None:
        # FIXME: how do we handle kwargs here? Will they come from the flplan?
        model.load_native(native_model_weights_filepath)
    else:
        sys.exit("One of model_weights_filename or native_model_weights_filepath is required.")

    # finally, call the model object's run_inference_and_store_results with the kwargs from the inference block
    inference_kwargs = flplan['inference'].get('kwargs') or {}
    model.run_inference_and_store_results(**inference_kwargs)