class DataParams(DataBaseParams): skip_invalid_gt: bool = True input_channels: int = 1 downscale_factor: int = field( default=-1, metadata=pai_meta(mode="ignore")) # Set based on model line_height: int = field(default=48, metadata=pai_meta(help="The line height")) ensemble: int = field( default=0, metadata=pai_meta(mode="ignore")) # Set based on model codec: Optional[Codec] = field(default=None, metadata=pai_meta(mode="ignore")) @staticmethod def cls(): from calamari_ocr.ocr.dataset.data import Data return Data def __post_init__(self): from calamari_ocr.ocr.dataset.imageprocessors.center_normalizer import ( CenterNormalizerProcessorParams, ) from calamari_ocr.ocr.dataset.imageprocessors.scale_to_height_processor import ( ScaleToHeightProcessorParams, ) for p in self.post_proc.processors + self.pre_proc.processors: if isinstance(p, ScaleToHeightProcessorParams): p.height = self.line_height elif isinstance(p, CenterNormalizerProcessorParams): p.line_height = self.line_height
class DilatedBlockLayerParams(LayerParams): @classmethod def name_prefix(cls) -> str: return "dilated_block" @classmethod def cls(cls) -> Type["Layer"]: return DilatedBlockLayer def downscale(self, size: IntVec2D) -> IntVec2D: return IntVec2D( (size.x + self.strides.x - 1) // self.strides.x, (size.y + self.strides.y - 1) // self.strides.y, ) def downscale_factor(self, factor: IntVec2D) -> IntVec2D: return IntVec2D(factor.x * self.strides.x, factor.y * self.strides.y) filters: int = 40 kernel_size: IntVec2D = field(default_factory=lambda: IntVec2D(3, 3), metadata=pai_meta(tuple_like=True)) strides: IntVec2D = field(default_factory=lambda: IntVec2D(1, 1), metadata=pai_meta(tuple_like=True)) padding: str = "same" activation: str = "relu" dilated_depth: int = 2
class TransposedConv2DLayerParams(LayerParams): @classmethod def name_prefix(cls) -> str: return "tconv2d" @classmethod def cls(cls) -> Type["Layer"]: return TransposedConv2DLayer def downscale(self, size: IntVec2D) -> IntVec2D: return IntVec2D(size.x * self.strides.x, size.y * self.strides.y) def downscale_factor(self, size: IntVec2D) -> IntVec2D: return IntVec2D( (size.x + self.strides.x - 1) // self.strides.x, (size.y + self.strides.y - 1) // self.strides.y, ) filters: int = 40 kernel_size: IntVec2D = field(default_factory=lambda: IntVec2D(3, 3), metadata=pai_meta(tuple_like=True)) strides: IntVec2D = field(default_factory=lambda: IntVec2D(2, 2), metadata=pai_meta(tuple_like=True)) padding: str = "same" activation: str = "relu"
class OptimizerParams(ABC): """General parameters of a Optimizer""" @abstractmethod def create(self) -> Tuple[Type["tf.keras.optimizers.Optimizer"], Dict[str, Any]]: raise NotImplementedError clip_norm: Optional[float] = field( default=None, metadata=pai_meta(help="float or None. If set, clips gradients to a maximum norm.") ) clip_value: Optional[float] = field( default=None, metadata=pai_meta(help="float or None. If set, clips gradients to a maximum value.") ) global_clip_norm: Optional[float] = field( default=None, metadata=pai_meta( help="float or None. If set, the gradient of all weights is clipped so that " "their global norm is no higher than this value." ), ) def _clip_grad_args(self): return { "clipnorm": self.clip_norm, "clipvalue": self.clip_value, "global_clipnorm": self.global_clip_norm, }
class ListsFileGeneratorParams(DataGeneratorParams): """ Parameters for the ListsFileDataGenerator """ @staticmethod def cls(): from tfaip.scenario.listfile.datagenerator import ( ListsFileDataGenerator, ) # pylint: disable=import-outside-toplevel return ListsFileDataGenerator lists: Optional[List[str]] = field( default_factory=list, metadata=pai_meta(help="Training list files.")) list_ratios: Optional[List[float]] = field( default=None, metadata=pai_meta( help= "Ratios of picking list files. Must be supported by the scenario")) ignore_prefix: Optional[str] = field( default="#", metadata=pai_meta( help= "Ignore entries in the files that start with the corresponding sign. Default is '#' for commenting out samples." ), ) def __post_init__(self): if self.lists: if not self.list_ratios: self.list_ratios = [1.0] * len(self.lists) else: if len(self.list_ratios) != len(self.lists): raise ValueError( f"Length of list_ratios must be equals to number of lists. " f"Got {self.list_ratios}!={self.lists}")
class MaxPool2DLayerParams(LayerParams): @classmethod def name_prefix(cls) -> str: return "maxpool2d" @classmethod def cls(cls) -> Type["Layer"]: return MaxPool2DLayer def downscale(self, size: IntVec2D) -> IntVec2D: strides = self.real_strides() return IntVec2D((size.x + strides.x - 1) // strides.x, (size.y + strides.y - 1) // strides.y) def downscale_factor(self, factor: IntVec2D) -> IntVec2D: strides = self.real_strides() return IntVec2D(factor.x * strides.x, factor.y * strides.y) def real_strides(self) -> IntVec2D: if self.strides is None: return self.pool_size return IntVec2D( self.strides.x if self.strides.x >= 0 else self.pool_size.x, self.strides.y if self.strides.y >= 0 else self.pool_size.y, ) pool_size: IntVec2D = field(default_factory=lambda: IntVec2D(2, 2), metadata=pai_meta(tuple_like=True)) strides: IntVec2D = field(default_factory=lambda: IntVec2D(-1, -1), metadata=pai_meta(tuple_like=True)) padding: str = "same"
class Hdf5(CalamariDataGeneratorParams): files: List[str] = field(default_factory=list, metadata=pai_meta(required=True)) pred_extension: str = field( default=".pred.h5", metadata=pai_meta(help="Default extension of the prediction files"), ) def __len__(self): return len(self.files) def to_prediction(self): self.files = sorted(glob_all(self.files)) pred = deepcopy(self) pred.files = [ split_all_ext(f)[0] + self.pred_extension for f in self.files ] return pred @staticmethod def cls(): return Hdf5Generator def prepare_for_mode(self, mode: PipelineMode): self.files = sorted(glob_all(self.files))
class DeviceConfigParams: """Configuration of the devices (GPUs). By default no gpys are added. Specify which gpus to use either by setting gpus or CUDA_VISIBLE_DEVICES """ gpus: Optional[List[int]] = field( default=None, metadata=pai_meta(help="List of the GPUs to use.")) gpu_auto_tune: bool = field( default=False, metadata=pai_meta(help="Enable auto tuning of the GPUs")) gpu_memory: Optional[int] = field( default=None, metadata=pai_meta( help= "Limit the per GPU memory in MB. By default the memory will grow automatically" ), ) soft_device_placement: bool = field( default=True, metadata=pai_meta(help="Set up soft device placement is enabled")) dist_strategy: DistributionStrategy = field( default=DistributionStrategy.DEFAULT, metadata=pai_meta( help= "Distribution strategy for multi GPU, select 'mirror' or 'central_storage'" ), )
class ExponentialDecayParams(LearningRateParams): """Exponential decay parameters""" @staticmethod def cls(): from tfaip.trainer.scheduler.exponential_decay import ( ExponentialDecaySchedule, ) # pylint: disable=import-outside-toplevel return ExponentialDecaySchedule learning_circle: int = field( default=3, metadata=pai_meta( help= "(type dependent) The number of epochs with a flat constant learning rate" )) lr_decay_rate: float = field( default=0.99, metadata=pai_meta( help="(type dependent) The exponential decay factor")) decay_min_fraction: float = field( default=0.0, metadata=pai_meta( help= "(type dependent) Minimal fraction the learning rate can drop to by exponential decay)" ), )
class CalamariDataGeneratorParams(DataGeneratorParams, ImageLoaderParams, ABC): skip_invalid: bool = True non_existing_as_empty: bool = False n_folds: int = field(default=-1, metadata=pai_meta(mode="ignore")) preload: bool = field( default=True, metadata=pai_meta( help="Instead of preloading all data, load the data on the fly. " "This is slower, but might be required for limited RAM or large dataset" ), ) def __len__(self): raise NotImplementedError @abstractmethod def to_prediction(self): raise NotImplementedError def select(self, indices: List[int]): raise NotImplementedError def prepare_for_mode(self, mode: PipelineMode) -> NoReturn: pass def create(self, mode: PipelineMode) -> "CalamariDataGenerator": params = deepcopy(self) # always copy of params params.prepare_for_mode(mode) gen: CalamariDataGenerator = self.cls()(mode, params) gen.post_init() return gen def image_loader(self) -> ImageLoader: return ImageLoader(self)
class CodecConstructionParams: keep_loaded: bool = field(default=True, metadata=pai_meta( help="Fully include the codec of the loaded model to the new codec")) auto_compute: bool = field(default=True, metadata=pai_meta( help="Compute the codec automatically. See also include.")) include: List[str] = field(default_factory=list, metadata=pai_meta( help="Whitelist of characters that may not be removed on restoring a model. " "For large dataset you can use this to skip the automatic codec computation " "(see auto_compute)")) include_files: List[str] = field(default_factory=list, metadata=pai_meta( help="Whitelist of txt files that may not be removed on restoring a model")) resolved_include_chars: Set[str] = field(default_factory=set, metadata=pai_meta(mode='ignore')) def __post_init__(self): # parse whitelist if len(self.include) == 1: include = set(self.include[0]) else: include = set(self.include) for f in glob_all(self.include_files): with open(f) as txt: include = include.union(txt.read()) self.resolved_include_chars = include
class Abbyy(CalamariDataGeneratorParams): images: List[str] = field(default_factory=list, metadata=pai_meta(required=True)) xml_files: List[str] = field(default_factory=list) gt_extension: str = field( default=".abbyy.xml", metadata=pai_meta( help= "Default extension of the gt files (expected to exist in same dir)" ), ) binary: bool = False pred_extension: str = field( default=".abbyy.pred.xml", metadata=pai_meta(help="Default extension of the prediction files"), ) def __len__(self): return len(self.images) def select(self, indices: List[int]): if self.images: self.images = [self.images[i] for i in indices] if self.xml_files: self.xml_files = [self.xml_files[i] for i in indices] def to_prediction(self): pred = deepcopy(self) pred.xml_files = [ split_all_ext(f)[0] + self.pred_extension for f in self.xml_files ] return pred @staticmethod def cls(): return AbbyyGenerator def prepare_for_mode(self, mode: PipelineMode): self.images = sorted(glob_all(self.images)) self.xml_files = sorted(glob_all(self.xml_files)) if not self.xml_files: self.xml_files = [ split_all_ext(f)[0] + self.gt_extension for f in self.images ] if not self.images: self.images = [None] * len(self.xml_files) if len(self.images) != len(self.xml_files): raise ValueError( f"Different number of image and xml files, {len(self.images)} != {len(self.xml_files)}" ) for img_path, xml_path in zip(self.images, self.xml_files): if img_path and xml_path: img_bn, xml_bn = split_all_ext(img_path)[0], split_all_ext( xml_path)[0] if img_bn != xml_bn: logger.warning( f"Filenames are not matching, got base names \n image: {img_bn}\n xml: {xml_bn}\n." )
class Base: i1: Optional[int] = None i2: Optional[int] = -1 i3: int = 3 sub1: Optional[Sub] = field(default_factory=Sub, metadata=pai_meta(choices=[Sub])) sub2: Optional[Sub] = field(default=None, metadata=pai_meta(choices=[Sub]))
class CalamariDefaultTrainerPipelineParams( TrainerPipelineParams[CalamariDataGeneratorParams, CalamariDataGeneratorParams]): train: CalamariDataGeneratorParams = field( default_factory=FileDataParams, metadata=pai_meta(choices=DATA_GENERATOR_CHOICES, mode='flat')) val: CalamariDataGeneratorParams = field( default_factory=FileDataParams, metadata=pai_meta(choices=DATA_GENERATOR_CHOICES, mode='flat'))
class TextRegularizerProcessorParams(DataProcessorParams): # TODO: groups as enums replacement_groups: List[str] = field( default_factory=lambda: ["extended"], metadata=pai_meta(help="Text regularization to apply.")) replacements: Optional[List[Replacement]] = field( default=None, metadata=pai_meta(mode='ignore')) @staticmethod def cls() -> Type['TextProcessor']: return TextRegularizerProcessor
class EvalArgs: gt: CalamariDataGeneratorParams = field( default_factory=FileDataParams, metadata=pai_meta(help="GT", mode="flat", choices=DATA_GENERATOR_CHOICES), ) pred: Optional[CalamariDataGeneratorParams] = field( default=None, metadata=pai_meta( help="Optional prediction dataset", mode="flat", choices=DATA_GENERATOR_CHOICES, ), ) n_confusions: int = field( default=10, metadata=pai_meta( help="Only print n most common confusions. Defaults to 10, use -1 for all.", mode="flat", ), ) n_worst_lines: int = field( default=0, metadata=pai_meta(help="Print the n worst recognized text lines with its error", mode="flat"), ) xlsx_output: Optional[str] = field( default=None, metadata=pai_meta(help="Optionally write a xlsx file with the evaluation results", mode="flat"), ) non_existing_file_handling_mode: str = field( default="error", metadata=pai_meta( mode="flat", choices=["error", "skip", "empty"], help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. " "'Skip' will simply skip the evaluation of that file (not counting it to errors). " "'Empty' will handle this file as would it be empty (fully checking for errors)." "'Error' will throw an exception if a file is not existing. This is the default behaviour.", ), ) skip_empty_gt: bool = field( default=False, metadata=pai_meta(help="Ignore lines of the gt that are empty.", mode="flat"), ) checkpoint: Optional[str] = field( default=None, metadata=pai_meta( help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)", mode="flat", ), ) evaluator: EvaluatorParams = field( default_factory=EvaluatorParams, metadata=pai_meta( mode="flat", fix_dc=True, ), )
class WarmupDecayParams(LearningRateParams): """Cosine decay with warmup""" @staticmethod def cls(): from tfaip.trainer.scheduler.warmup_decay import ( WarmupDecaySchedule, ) # pylint: disable=import-outside-toplevel return WarmupDecaySchedule warmup_epochs: int = field(default=-1, metadata=pai_meta(help="Number of epochs for linear increase")) warmup_steps: int = field(default=-1, metadata=pai_meta(help="Number of epochs for linear increase"))
class TrainerPipelineParams(TrainerPipelineParamsBase[TDataGeneratorTrain, TDataGeneratorVal], metaclass=TrainerPipelineParamsMeta): train: TDataGeneratorTrain = field(default_factory=DataGeneratorParams, metadata=pai_meta(mode="flat")) val: TDataGeneratorVal = field(default_factory=DataGeneratorParams, metadata=pai_meta(mode="flat")) def train_gen(self) -> TDataGeneratorTrain: return self.train def val_gen(self) -> TDataGeneratorVal: return self.val
class TrainerPipelines: train: DataPipelineParams = field( default_factory=lambda: DataPipelineParams(mode=PipelineMode.TRAINING), metadata=pai_meta(fix_dc=True, mode="flat"), ) val: DataPipelineParams = field( default_factory=lambda: DataPipelineParams(mode=PipelineMode.EVALUATION ), metadata=pai_meta(fix_dc=True, mode="flat"), ) def __post_init__(self): self.train.mode = PipelineMode.TRAINING self.val.mode = PipelineMode.EVALUATION
class PageXML(CalamariDataGeneratorParams): images: List[str] = field(default_factory=list) xml_files: List[str] = field(default_factory=list) gt_extension: str = field( default='.xml', metadata=pai_meta( help= "Default extension of the gt files (expected to exist in same dir)" )) text_index: int = 0 pad: Optional[List[int]] = field( default=None, metadata=pai_meta(help="Additional padding after lines were cut out.")) pred_extension: str = field( default='.pred.xml', metadata=pai_meta(help="Default extension of the prediction files")) skip_commented: bool = field( default=False, metadata=pai_meta(help='Skip lines with "comments" attribute.')) def __len__(self): return len(self.images) def select(self, indices: List[int]): if self.images: self.images = [self.images[i] for i in indices] if self.xml_files: self.xml_files = [self.xml_files[i] for i in indices] def to_prediction(self): pred = deepcopy(self) pred.xml_files = [ split_all_ext(f)[0] + self.pred_extension for f in self.xml_files ] return pred @staticmethod def cls(): return PageXMLReader def prepare_for_mode(self, mode: PipelineMode): self.images = sorted(glob_all(self.images)) self.xml_files = sorted(self.xml_files) if not self.xml_files: self.xml_files = [ split_all_ext(f)[0] + self.gt_extension for f in self.images ] if not self.images: self.xml_files = sorted(glob_all(self.xml_files)) self.images = [None] * len(self.xml_files)
class TrainerParams(AIPTrainerParams[CalamariScenarioParams, CalamariDefaultTrainerPipelineParams]): version: int = SavedCalamariModel.VERSION data_aug_retrain_on_original: bool = field(default=True, metadata=pai_meta( help="When training with augmentations usually the model is retrained in a second run with " "only the non augmented data. This will take longer. Use this flag to disable this " "behavior.")) # Current training progress: 0 standard, 1 retraining on non aug. current_stage: int = field(default=0, metadata=pai_meta(mode='ignore')) progress_bar: bool = True auto_upgrade_checkpoints: bool = field(default=True, metadata=pai_meta( help='Automatically update older checkpoints for warm start.')) codec: CodecConstructionParams = field(default_factory=CodecConstructionParams, metadata=pai_meta( help="Parameters defining how to construct the codec.", mode='flat' # The actual codec is stored in data )) gen: TrainerPipelineParamsBase = field(default_factory=CalamariDefaultTrainerPipelineParams, metadata=pai_meta( help="Parameters that setup the data generators (i.e. the input data).", disable_subclass_check=False, choices=[CalamariDefaultTrainerPipelineParams, CalamariTrainOnlyPipelineParams, CalamariSplitTrainerPipelineParams] )) best_model_prefix: str = field(default="best", metadata=pai_meta( help="The prefix of the best model using early stopping")) network: Optional[str] = field(default=None, metadata=pai_meta( mode='flat', help='Pass a network configuration to construct a simple graph. ' 'Defaults to: --network=cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5' )) def __post_init__(self): self.scenario.default_serve_dir = f'{self.best_model_prefix}.ckpt.h5' self.scenario.trainer_params_filename = f'{self.best_model_prefix}.ckpt.json' self.early_stopping.best_model_name = '' self.gen.train_gen().n_folds = self.scenario.model.ensemble self.gen.train_gen().channels = self.scenario.data.input_channels if self.gen.val_gen() is not None: self.gen.val_gen().channels = self.scenario.data.input_channels self.gen.val_gen().n_folds = self.scenario.model.ensemble if self.network: self.scenario.model.layers = graph_params_from_definition_string(self.network)
class WarmStartParams: """Parameters for warm-starting from a model.""" model: Optional[str] = field( default=None, metadata=pai_meta( help= "Path to the saved model or checkpoint to load the weights from.")) allow_partial: bool = field( default=False, metadata=pai_meta(help="Allow that not all weights can be matched.")) trim_graph_name: bool = field( default=True, metadata=pai_meta( help= "Remove the graph name from the loaded model and the target model. This is useful if the model name " "changed"), ) rename: List[str] = field( default_factory=list, metadata=pai_meta( help= "A list of renaming rules to perform on the loaded weights. Format: FROM->TO FROM->TO ..." ), ) add_suffix: str = field( default="", metadata=pai_meta(help="Add suffix str to all variable names")) rename_targets: List[str] = field( default_factory=list, metadata=pai_meta( help= "A list of renaming rules to perform on the target weights. Format: FROM->TO FROM->TO ..." ), ) exclude: Optional[str] = field( default=None, metadata=pai_meta( help="A regex applied on the loaded weights to ignore from loading." )) include: Optional[str] = field( default=None, metadata=pai_meta( help= "A regex applied on the loaded weights to include from loading.")) auto_remove_numbers_for: List[str] = field( default_factory=lambda: ["lstm_cell"]) def create(self, **kwargs) -> "WarmStarter": from tfaip.trainer.warmstart.warmstarter import WarmStarter return WarmStarter(params=self, **kwargs)
class ICTrainerPipelineParams(TrainerPipelineParamsBase[ICDataGeneratorParams, ICDataGeneratorParams]): dataset_path: str = field(default='') validation_split: float = 0.2 shuffle_files: bool = True # resolved files, a list of file names per class image_files: Dict[str, List[str]] = field(default_factory=dict, metadata=pai_meta(mode='ignore')) def train_gen(self) -> ICDataGeneratorParams: return ICDataGeneratorParams( image_files={k: v[int(self.validation_split * len(v)):] for k, v in self.image_files.items()}) def val_gen(self) -> Optional[ICDataGeneratorParams]: val = ICDataGeneratorParams( image_files={k: v[:int(self.validation_split * len(v))] for k, v in self.image_files.items()}) if val.num_files() == 0: return None # No validation return val def __post_init__(self): self.image_files = {} if os.path.exists(self.dataset_path): for class_name in os.listdir(self.dataset_path): class_path = os.path.join(self.dataset_path, class_name) if not os.path.isdir(class_path): continue self.image_files[class_name] = [os.path.join(class_path, file) for file in os.listdir(class_path)] if self.shuffle_files: shuffle(self.image_files[class_name])
class TutorialTrainerGeneratorParams( TrainerPipelineParamsBase[TutorialDataGeneratorParams, TutorialDataGeneratorParams]): """ Definition of the training data. Since the dataset is loaded from the keras.datasets, training and validation data is jointly loaded (parameter `train_val`) which is why `train_gen` and `val_gen` return the same generator. The decision whether to chose training and validation data is dependent on the `PipelineMode`. Furthermore, the `lav_gen` method is overwritten to perform lav on both the training and the validation set. For this purpose, the `force_train` variable is overwritten, to select the training data even if the PipelineMode is PipelineMode.EVALUATION. """ train_val: TutorialDataGeneratorParams = field( default_factory=TutorialDataGeneratorParams, metadata=pai_meta(mode="flat")) def train_gen(self) -> TutorialDataGeneratorParams: return self.train_val def val_gen(self) -> Optional[TutorialDataGeneratorParams]: return self.train_val def lav_gen(self) -> Iterable[TutorialDataGeneratorParams]: train: TutorialDataGeneratorParams = copy(self.train_val) train.force_train = True return [train, self.train_val]
class ImageLoaderParams: channels: int = field( default=1, metadata=pai_meta( help= 'Number of channels to produce, by default 1=grayscale. Use 3 for colour.' )) to_gray_method: str = field( default='cv', metadata=pai_meta( help='Method to apply to convert color to gray.', choices=['avg', 'cv'], )) def create(self) -> 'ImageLoader': return ImageLoader(self)
class ModelParams(ModelBaseParams): layers: List[LayerParams] = field( default_factory=default_layers, metadata=pai_meta( choices=all_layers(), help="Layers of the graph. See the docs for more information."), ) classes: int = -1 ctc_merge_repeated: bool = True ensemble: int = 0 # For usage with the ensemble-model graph masking_mode: int = False # This parameter is for evaluation only and should not be used in production @staticmethod def cls(): from calamari_ocr.ocr.model.model import Model return Model def graph_cls(self): from calamari_ocr.ocr.model.graph import CalamariGraph return CalamariGraph def __post_init__(self): # setup layer names counts = {} for layer in self.layers: counts[layer.name_prefix()] = counts.get(layer.name_prefix(), -1) + 1 layer.name = f"{layer.name_prefix()}_{counts[layer.name_prefix()]}" def compute_downscale_factor(self) -> IntVec2D: factor = IntVec2D(1, 1) for layer in self.layers: factor = layer.downscale_factor(factor) return factor def compute_max_downscale_factor(self) -> IntVec2D: factor = IntVec2D(1, 1) max_factor = IntVec2D(1, 1) for layer in self.layers: factor = layer.downscale_factor(factor) max_factor.x = max(max_factor.x, factor.x) max_factor.y = max(max_factor.y, factor.y) return max_factor def compute_downscaled(self, size: Union[int, IntVec2D, Tuple[Any, Any]]): if isinstance(size, int): for layer in self.layers: size = layer.downscale(IntVec2D(size, 1)).x elif isinstance(size, IntVec2D): for layer in self.layers: size = layer.downscale(size) elif isinstance(size, tuple): for layer in self.layers: size = layer.downscale(IntVec2D(size[0], size[1])) size = size.x, size.y else: raise NotImplementedError return size
class ICPredictionDataGeneratorParams(DataGeneratorParams): @staticmethod def cls() -> Type["DataGenerator"]: return ICPredictionDataGenerator image_files: List[str] = field(default_factory=list, metadata=pai_meta(required=True))
class CalamariSplitTrainerPipelineParams( TrainerPipelineParams[CalamariDataGeneratorParams, CalamariDataGeneratorParams]): train: CalamariDataGeneratorParams = field( default_factory=FileDataParams, metadata=pai_meta( choices=[FileDataParams, PageXML], enforce_choices=True, mode="flat", ), ) validation_split_ratio: float = field( default=0.2, metadata=pai_meta( help="Use factor of n of the training dataset for validation."), ) val: Optional[CalamariDataGeneratorParams] = field( default=None, metadata=pai_meta(mode="ignore")) def __post_init__(self): if self.val is not None: # Already initialized return if not 0 < self.validation_split_ratio < 1: raise ValueError("validation_split_ratio must be in (0, 1)") # resolve all files so we can split them self.train.prepare_for_mode(PipelineMode.TRAINING) self.val = deepcopy(self.train) samples = len(self.train) n = int(self.validation_split_ratio * samples) if n == 0: raise ValueError( f"Ratio is to small since {self.validation_split_ratio} * {samples} = {n}. " f"Increase the amount of data or the split ratio.") logger.info( f"Splitting training and validation files with ratio {self.validation_split_ratio}: " f"{n}/{samples - n} for validation/training.") indices = list(range(samples)) shuffle(indices) # split train and val img/gt files. Use train settings self.train.select(indices[n:]) self.val.select(indices[:n])
class PrepareSampleProcessorParams(DataProcessorParams): @staticmethod def cls() -> Type["MappingDataProcessor"]: return PrepareSample max_line_width: int = field( default=4096, metadata=pai_meta(help="Max width of a line. Set to -1 or 0 to skip this check.") )
class TextRegularizerProcessorParams(DataProcessorParams): replacements: Optional[List[Replacement]] = field(default=None, metadata=pai_meta(mode="ignore")) rulesets: List[str] = field(default_factory=lambda: ["spaces"]) rulegroups: List[str] = field(default_factory=list) @staticmethod def cls() -> Type["TextProcessor"]: return TextRegularizerProcessor