def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoints', nargs='+', type=str, required=True) parser.add_argument('--dry_run', action='store_true') args = parser.parse_args() for ckpt in tqdm(glob_all(args.checkpoints)): ckpt = os.path.splitext(ckpt)[0] Checkpoint(ckpt, dry_run=args.dry_run)
def __init__(self, checkpoint=None, text_postproc=None, data_preproc=None, codec=None, network=None, batch_size=1, processes=1, auto_update_checkpoints=True, with_gt=False, ): """ Predicting a dataset based on a trained model Parameters ---------- checkpoint : str, optional filepath of the checkpoint of the network to load, alternatively you can directly use a loaded `network` text_postproc : TextProcessor, optional text processor to be applied on the predicted sentence for the final output. If loaded from a checkpoint the text processor will be loaded from it. data_preproc : DataProcessor, optional data processor (must be the same as of the trained model) to be applied to the input image. If loaded from a checkpoint the text processor will be loaded from it. codec : Codec, optional Codec of the deep net to use for decoding. This parameter is only required if a custom codec is used, or a `network` has been provided instead of a `checkpoint` network : ModelInterface, optional DNN instance to used. Alternatively you can provide a `checkpoint` to load a network. batch_size : int, optional Batch size to use for prediction processes : int, optional The number of processes to use for prediction auto_update_checkpoints : bool, optional Update old models automatically (this will change the checkpoint files) with_gt : bool, optional The prediction will also output the ground truth if available else None """ self.network = network self.checkpoint = checkpoint self.processes = processes self.auto_update_checkpoints = auto_update_checkpoints self.with_gt = with_gt if checkpoint: if network: raise Exception("Either a checkpoint or a network can be provided") ckpt = Checkpoint(checkpoint, auto_update=self.auto_update_checkpoints) self.checkpoint = ckpt.ckpt_path checkpoint_params = ckpt.checkpoint self.model_params = checkpoint_params.model self.codec = codec if codec else Codec(self.model_params.codec.charset) self.network_params = self.model_params.network backend = create_backend_from_proto(self.network_params, restore=self.checkpoint, processes=processes) self.text_postproc = text_postproc if text_postproc else text_processor_from_proto(self.model_params.text_postprocessor, "post") self.data_preproc = data_preproc if data_preproc else data_processor_from_proto(self.model_params.data_preprocessor) self.network = backend.create_net( dataset=None, codec=self.codec, restore=self.checkpoint, weights=None, graph_type="predict", batch_size=batch_size) elif network: self.codec = codec self.model_params = None self.network_params = network.network_proto self.text_postproc = text_postproc self.data_preproc = data_preproc if not codec: raise Exception("A codec is required if preloaded network is used.") else: raise Exception("Either a checkpoint or a existing backend must be provided") self.out_to_in_trans = OutputToInputTransformer(self.data_preproc, self.network)
def train(self, auto_compute_codec=False, progress_bar=False): """ Launch the training Parameters ---------- auto_compute_codec : bool Compute the codec automatically based on the provided ground truth. Else provide a codec using a whitelist (faster). progress_bar : bool Show or hide any progress bar """ checkpoint_params = self.checkpoint_params train_start_time = time.time() + self.checkpoint_params.total_time # load training dataset if self.preload_training: self.dataset.preload(processes=checkpoint_params.processes, progress_bar=progress_bar) # load validation dataset if self.validation_dataset and self.preload_validation: self.validation_dataset.preload( processes=checkpoint_params.processes, progress_bar=progress_bar) # compute the codec if self.codec: codec = self.codec else: if len(self.codec_whitelist) == 0 or auto_compute_codec: codec = Codec.from_input_dataset( [self.dataset, self.validation_dataset], whitelist=self.codec_whitelist, progress_bar=progress_bar) else: codec = Codec.from_texts([], whitelist=self.codec_whitelist) # create backend network_params = checkpoint_params.model.network network_params.features = checkpoint_params.model.line_height network_params.classes = len(codec) if self.weights: # if we load the weights, take care of codec changes as-well ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints) restore_checkpoint_params = ckpt.checkpoint restore_model_params = restore_checkpoint_params.model # checks if checkpoint_params.model.line_height != network_params.features: raise Exception( "The model to restore has a line height of {} but a line height of {} is requested" .format(network_params.features, checkpoint_params.model.line_height)) # create codec of the same type restore_codec = codec.__class__(restore_model_params.codec.charset) # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one codec_changes = restore_codec.align(codec) codec = restore_codec print("Codec changes: {} deletions, {} appends".format( len(codec_changes[0]), len(codec_changes[1]))) # The actual weight/bias matrix will be changed after loading the old weights if all([c == 0 for c in codec_changes]): codec_changes = None # No codec changes else: codec_changes = None # store the new codec checkpoint_params.model.codec.charset[:] = codec.charset print("CODEC: {}".format(codec.charset)) backend = create_backend_from_proto( network_params, weights=self.weights, ) train_net = backend.create_net(self.dataset, codec, restore=None, weights=self.weights, graph_type="train", batch_size=checkpoint_params.batch_size) test_net = backend.create_net(self.validation_dataset, codec, restore=None, weights=self.weights, graph_type="test", batch_size=checkpoint_params.batch_size) if codec_changes: # only required on one net, since the other shares the same variables train_net.realign_model_labels(*codec_changes) train_net.prepare() test_net.prepare() if checkpoint_params.current_stage == 0: self._run_train(train_net, test_net, codec, train_start_time, progress_bar) if checkpoint_params.data_aug_retrain_on_original and self.dataset.data_augmenter and self.dataset.data_augmentation_amount > 0: print("Starting training on original data only") if checkpoint_params.current_stage == 0: checkpoint_params.current_stage = 1 checkpoint_params.iter = 0 checkpoint_params.early_stopping_best_at_iter = 0 checkpoint_params.early_stopping_best_cur_nbest = 0 checkpoint_params.early_stopping_best_accuracy = 0 self.dataset.generate_only_non_augmented = True # this is the important line! train_net.prepare() test_net.prepare() self._run_train(train_net, test_net, codec, train_start_time, progress_bar) train_net.prepare() # reset the state test_net.prepare() # to prevent blocking of tensorflow on shutdown
def train(self, progress_bar=False): """ Launch the training Parameters ---------- progress_bar : bool Show or hide any progress bar """ checkpoint_params = self.checkpoint_params train_start_time = time.time() + self.checkpoint_params.total_time self.dataset.load_samples(processes=1, progress_bar=progress_bar) datas, txts = self.dataset.train_samples( skip_empty=checkpoint_params.skip_invalid_gt) if len(datas) == 0: raise Exception( "Empty dataset is not allowed. Check if the data is at the correct location" ) if self.validation_dataset: self.validation_dataset.load_samples(processes=1, progress_bar=progress_bar) validation_datas, validation_txts = self.validation_dataset.train_samples( skip_empty=checkpoint_params.skip_invalid_gt) if len(validation_datas) == 0: raise Exception( "Validation dataset is empty. Provide valid validation data for early stopping." ) else: validation_datas, validation_txts = [], [] # preprocessing steps texts = self.txt_preproc.apply(txts, processes=checkpoint_params.processes, progress_bar=progress_bar) datas, params = [ list(a) for a in zip( *self.data_preproc.apply(datas, processes=checkpoint_params.processes, progress_bar=progress_bar)) ] validation_txts = self.txt_preproc.apply( validation_txts, processes=checkpoint_params.processes, progress_bar=progress_bar) validation_data_params = self.data_preproc.apply( validation_datas, processes=checkpoint_params.processes, progress_bar=progress_bar) # compute the codec codec = self.codec if self.codec else Codec.from_texts( texts, whitelist=self.codec_whitelist) # store original data in case data augmentation is used with a second step original_texts = texts original_datas = datas # data augmentation on preprocessed data if self.data_augmenter: datas, texts = self.data_augmenter.augment_datas( datas, texts, n_augmentations=self.n_augmentations, processes=checkpoint_params.processes, progress_bar=progress_bar) # TODO: validation data augmentation # validation_datas, validation_txts = self.data_augmenter.augment_datas(validation_datas, validation_txts, n_augmentations=0, # processes=checkpoint_params.processes, progress_bar=progress_bar) # create backend network_params = checkpoint_params.model.network network_params.features = checkpoint_params.model.line_height network_params.classes = len(codec) if self.weights: # if we load the weights, take care of codec changes as-well ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints) restore_checkpoint_params = ckpt.checkpoint restore_model_params = restore_checkpoint_params.model # checks if checkpoint_params.model.line_height != network_params.features: raise Exception( "The model to restore has a line height of {} but a line height of {} is requested" .format(network_params.features, checkpoint_params.model.line_height)) # create codec of the same type restore_codec = codec.__class__(restore_model_params.codec.charset) # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one codec_changes = restore_codec.align(codec) codec = restore_codec print("Codec changes: {} deletions, {} appends".format( len(codec_changes[0]), len(codec_changes[1]))) # The actual weight/bias matrix will be changed after loading the old weights if all([c == 0 for c in codec_changes]): codec_changes = None # No codec changes else: codec_changes = None # store the new codec checkpoint_params.model.codec.charset[:] = codec.charset print("CODEC: {}".format(codec.charset)) # compute the labels with (new/current) codec labels = [codec.encode(txt) for txt in texts] backend = create_backend_from_proto( network_params, weights=self.weights, ) train_net = backend.create_net(restore=None, weights=self.weights, graph_type="train", batch_size=checkpoint_params.batch_size) test_net = backend.create_net(restore=None, weights=self.weights, graph_type="test", batch_size=checkpoint_params.batch_size) train_net.set_data(datas, labels) test_net.set_data(validation_datas, validation_txts) if codec_changes: # only required on one net, since the other shares the same variables train_net.realign_model_labels(*codec_changes) train_net.prepare() test_net.prepare() if checkpoint_params.current_stage == 0: self._run_train(train_net, test_net, codec, validation_data_params, train_start_time, progress_bar) if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations > 0: print("Starting training on original data only") if checkpoint_params.current_stage == 0: checkpoint_params.current_stage = 1 checkpoint_params.iter = 0 checkpoint_params.early_stopping_best_at_iter = 0 checkpoint_params.early_stopping_best_cur_nbest = 0 checkpoint_params.early_stopping_best_accuracy = 0 train_net.set_data(original_datas, [codec.encode(txt) for txt in original_texts]) test_net.set_data(validation_datas, validation_txts) train_net.prepare() test_net.prepare() self._run_train(train_net, test_net, codec, validation_data_params, train_start_time, progress_bar)
def train(self, auto_compute_codec=False, progress_bar=False, training_callback=ConsoleTrainingCallback()): """ Launch the training Parameters ---------- auto_compute_codec : bool Compute the codec automatically based on the provided ground truth. Else provide a codec using a whitelist (faster). progress_bar : bool Show or hide any progress bar training_callback : TrainingCallback Callback for the training process (e.g., for displaying the current cer, loss in the console) """ with ExitStackWithPop() as exit_stack: checkpoint_params = self.checkpoint_params train_start_time = time.time() + self.checkpoint_params.total_time exit_stack.enter_context(self.dataset) if self.validation_dataset: exit_stack.enter_context(self.validation_dataset) # load training dataset if self.preload_training: new_dataset = self.dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar) exit_stack.pop(self.dataset) self.dataset = new_dataset exit_stack.enter_context(self.dataset) # load validation dataset if self.validation_dataset and self.preload_validation: new_dataset = self.validation_dataset.to_raw_input_dataset(processes=checkpoint_params.processes, progress_bar=progress_bar) exit_stack.pop(self.validation_dataset) self.validation_dataset = new_dataset exit_stack.enter_context(self.validation_dataset) # compute the codec if self.codec: codec = self.codec else: if len(self.codec_whitelist) == 0 or auto_compute_codec: codec = Codec.from_input_dataset([self.dataset, self.validation_dataset], whitelist=self.codec_whitelist, progress_bar=progress_bar) else: codec = Codec.from_texts([], whitelist=self.codec_whitelist) # create backend network_params = checkpoint_params.model.network network_params.features = checkpoint_params.model.line_height if self.weights: # if we load the weights, take care of codec changes as-well ckpt = Checkpoint(self.weights + '.json', auto_update=self.auto_update_checkpoints) restore_checkpoint_params = ckpt.checkpoint restore_model_params = restore_checkpoint_params.model # checks if checkpoint_params.model.line_height != network_params.features: raise Exception("The model to restore has a line height of {} but a line height of {} is requested".format( network_params.features, checkpoint_params.model.line_height )) # create codec of the same type restore_codec = codec.__class__(restore_model_params.codec.charset) # the codec changes as tuple (deletions/insertions), and the new codec is the changed old one codec_changes = restore_codec.align(codec, shrink=not self.keep_loaded_codec) codec = restore_codec print("Codec changes: {} deletions, {} appends".format(len(codec_changes[0]), len(codec_changes[1]))) # The actual weight/bias matrix will be changed after loading the old weights if all([c == 0 for c in codec_changes]): codec_changes = None # No codec changes else: codec_changes = None # store the new codec network_params.classes = len(codec) checkpoint_params.model.codec.charset[:] = codec.charset print("CODEC: {}".format(codec.charset)) backend = create_backend_from_checkpoint( checkpoint_params=checkpoint_params, processes=checkpoint_params.processes, ) train_net = backend.create_net(codec, graph_type="train", checkpoint_to_load=Checkpoint(self.weights) if self.weights else None, batch_size=checkpoint_params.batch_size, codec_changes=codec_changes) if checkpoint_params.current_stage == 0: self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback) if checkpoint_params.data_aug_retrain_on_original and self.data_augmenter and self.n_augmentations != 0: print("Starting training on original data only") if checkpoint_params.current_stage == 0: checkpoint_params.current_stage = 1 checkpoint_params.iter = 0 checkpoint_params.early_stopping_best_at_iter = 0 checkpoint_params.early_stopping_best_cur_nbest = 0 checkpoint_params.early_stopping_best_accuracy = 0 self.dataset.generate_only_non_augmented = True # this is the important line! self._run_train(train_net, train_start_time, progress_bar, self.dataset, self.validation_dataset, training_callback)