Beispiel #1
0
    def __init__(
        self,
        checkpoint=None,
        text_postproc=None,
        data_preproc=None,
        codec=None,
        network=None,
        batch_size=1,
        processes=1,
        auto_update_checkpoints=True,
        with_gt=False,
        ctc_decoder_params=None,
    ):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        auto_update_checkpoints : bool, optional
            Update old models automatically (this will change the checkpoint files)
        with_gt : bool, optional
            The prediction will also output the ground truth if available else None
        ctc_decoder_params : optional
            Parameters of the ctc decoder
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes
        self.auto_update_checkpoints = auto_update_checkpoints
        self.with_gt = with_gt

        if checkpoint:
            if network:
                raise Exception(
                    "Either a checkpoint or a network can be provided")

            ckpt = Checkpoint(checkpoint,
                              auto_update=self.auto_update_checkpoints)
            self.checkpoint = ckpt.ckpt_path
            checkpoint_params = ckpt.checkpoint
            self.model_params = checkpoint_params.model
            self.codec = codec if codec else Codec(
                self.model_params.codec.charset)

            self.network_params = self.model_params.network
            backend = create_backend_from_checkpoint(
                checkpoint_params=checkpoint_params, processes=processes)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
            self.network = backend.create_net(
                codec=self.codec,
                ctc_decoder_params=ctc_decoder_params,
                checkpoint_to_load=ckpt,
                graph_type="predict",
                batch_size=batch_size)
        elif network:
            self.codec = codec
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception(
                    "A codec is required if preloaded network is used.")
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")

        self.out_to_in_trans = OutputToInputTransformer(
            self.data_preproc, self.network)
Beispiel #2
0
    def __init__(self,
                 checkpoint=None,
                 text_postproc=None,
                 data_preproc=None,
                 codec=None,
                 network=None,
                 batch_size=1,
                 processes=1):
        """ Predicting a dataset based on a trained model

        Parameters
        ----------
        checkpoint : str, optional
            filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network`
        text_postproc : TextProcessor, optional
            text processor to be applied on the predicted sentence for the final output.
            If loaded from a checkpoint the text processor will be loaded from it.
        data_preproc : DataProcessor, optional
            data processor (must be the same as of the trained model) to be applied to the input image.
            If loaded from a checkpoint the text processor will be loaded from it.
        codec : Codec, optional
            Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used,
            or a `network` has been provided instead of a `checkpoint`
        network : ModelInterface, optional
            DNN instance to used. Alternatively you can provide a `checkpoint` to load a network.
        batch_size : int, optional
            Batch size to use for prediction
        processes : int, optional
            The number of processes to use for prediction
        """
        self.network = network
        self.checkpoint = checkpoint
        self.processes = processes

        if checkpoint:
            if network:
                raise Exception(
                    "Either a checkpoint or a network can be provided")

            with open(checkpoint + '.json', 'r') as f:
                checkpoint_params = json_format.Parse(f.read(),
                                                      CheckpointParams())
                self.model_params = checkpoint_params.model

            self.network_params = self.model_params.network
            backend = create_backend_from_proto(self.network_params,
                                                restore=self.checkpoint,
                                                processes=processes)
            self.network = backend.create_net(restore=self.checkpoint,
                                              weights=None,
                                              graph_type="predict",
                                              batch_size=batch_size)
            self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(
                self.model_params.text_postprocessor, "post")
            self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(
                self.model_params.data_preprocessor)
        elif network:
            self.model_params = None
            self.network_params = network.network_proto
            self.text_postproc = text_postproc
            self.data_preproc = data_preproc
            if not codec:
                raise Exception(
                    "A codec is required if preloaded network is used.")
        else:
            raise Exception(
                "Either a checkpoint or a existing backend must be provided")

        self.codec = codec if codec else Codec(self.model_params.codec.charset)
        self.out_to_in_trans = OutputToInputTransformer(
            self.data_preproc, self.network)