def __init__( self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None, batch_size=1, processes=1, auto_update_checkpoints=True, with_gt=False, ctc_decoder_params=None, ): """ Predicting a dataset based on a trained model Parameters ---------- checkpoint : str, optional filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network` text_postproc : TextProcessor, optional text processor to be applied on the predicted sentence for the final output. If loaded from a checkpoint the text processor will be loaded from it. data_preproc : DataProcessor, optional data processor (must be the same as of the trained model) to be applied to the input image. If loaded from a checkpoint the text processor will be loaded from it. codec : Codec, optional Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used, or a `network` has been provided instead of a `checkpoint` network : ModelInterface, optional DNN instance to used. Alternatively you can provide a `checkpoint` to load a network. batch_size : int, optional Batch size to use for prediction processes : int, optional The number of processes to use for prediction auto_update_checkpoints : bool, optional Update old models automatically (this will change the checkpoint files) with_gt : bool, optional The prediction will also output the ground truth if available else None ctc_decoder_params : optional Parameters of the ctc decoder """ self.network = network self.checkpoint = checkpoint self.processes = processes self.auto_update_checkpoints = auto_update_checkpoints self.with_gt = with_gt if checkpoint: if network: raise Exception( "Either a checkpoint or a network can be provided") ckpt = Checkpoint(checkpoint, auto_update=self.auto_update_checkpoints) self.checkpoint = ckpt.ckpt_path checkpoint_params = ckpt.checkpoint self.model_params = checkpoint_params.model self.codec = codec if codec else Codec( self.model_params.codec.charset) self.network_params = self.model_params.network backend = create_backend_from_checkpoint( checkpoint_params=checkpoint_params, processes=processes) self.text_postproc = text_postproc if text_postproc else text_processor_from_proto( self.model_params.text_postprocessor, "post") self.data_preproc = data_preproc if data_preproc else data_processor_from_proto( self.model_params.data_preprocessor) self.network = backend.create_net( codec=self.codec, ctc_decoder_params=ctc_decoder_params, checkpoint_to_load=ckpt, graph_type="predict", batch_size=batch_size) elif network: self.codec = codec self.model_params = None self.network_params = network.network_proto self.text_postproc = text_postproc self.data_preproc = data_preproc if not codec: raise Exception( "A codec is required if preloaded network is used.") else: raise Exception( "Either a checkpoint or a existing backend must be provided") self.out_to_in_trans = OutputToInputTransformer( self.data_preproc, self.network)
def __init__(self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None, batch_size=1, processes=1): """ Predicting a dataset based on a trained model Parameters ---------- checkpoint : str, optional filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network` text_postproc : TextProcessor, optional text processor to be applied on the predicted sentence for the final output. If loaded from a checkpoint the text processor will be loaded from it. data_preproc : DataProcessor, optional data processor (must be the same as of the trained model) to be applied to the input image. If loaded from a checkpoint the text processor will be loaded from it. codec : Codec, optional Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used, or a `network` has been provided instead of a `checkpoint` network : ModelInterface, optional DNN instance to used. Alternatively you can provide a `checkpoint` to load a network. batch_size : int, optional Batch size to use for prediction processes : int, optional The number of processes to use for prediction """ self.network = network self.checkpoint = checkpoint self.processes = processes if checkpoint: if network: raise Exception( "Either a checkpoint or a network can be provided") with open(checkpoint + '.json', 'r') as f: checkpoint_params = json_format.Parse(f.read(), CheckpointParams()) self.model_params = checkpoint_params.model self.network_params = self.model_params.network backend = create_backend_from_proto(self.network_params, restore=self.checkpoint, processes=processes) self.network = backend.create_net(restore=self.checkpoint, weights=None, graph_type="predict", batch_size=batch_size) self.text_postproc = text_postproc if text_postproc else text_processor_from_proto( self.model_params.text_postprocessor, "post") self.data_preproc = data_preproc if data_preproc else data_processor_from_proto( self.model_params.data_preprocessor) elif network: self.model_params = None self.network_params = network.network_proto self.text_postproc = text_postproc self.data_preproc = data_preproc if not codec: raise Exception( "A codec is required if preloaded network is used.") else: raise Exception( "Either a checkpoint or a existing backend must be provided") self.codec = codec if codec else Codec(self.model_params.codec.charset) self.out_to_in_trans = OutputToInputTransformer( self.data_preproc, self.network)