Пример #1
0
class DataParams(DataBaseParams):
    skip_invalid_gt: bool = True
    input_channels: int = 1
    downscale_factor: int = field(
        default=-1, metadata=pai_meta(mode="ignore"))  # Set based on model
    line_height: int = field(default=48,
                             metadata=pai_meta(help="The line height"))
    ensemble: int = field(
        default=0, metadata=pai_meta(mode="ignore"))  # Set based on model
    codec: Optional[Codec] = field(default=None,
                                   metadata=pai_meta(mode="ignore"))

    @staticmethod
    def cls():
        from calamari_ocr.ocr.dataset.data import Data

        return Data

    def __post_init__(self):
        from calamari_ocr.ocr.dataset.imageprocessors.center_normalizer import (
            CenterNormalizerProcessorParams, )
        from calamari_ocr.ocr.dataset.imageprocessors.scale_to_height_processor import (
            ScaleToHeightProcessorParams, )

        for p in self.post_proc.processors + self.pre_proc.processors:
            if isinstance(p, ScaleToHeightProcessorParams):
                p.height = self.line_height
            elif isinstance(p, CenterNormalizerProcessorParams):
                p.line_height = self.line_height
Пример #2
0
class DilatedBlockLayerParams(LayerParams):
    @classmethod
    def name_prefix(cls) -> str:
        return "dilated_block"

    @classmethod
    def cls(cls) -> Type["Layer"]:
        return DilatedBlockLayer

    def downscale(self, size: IntVec2D) -> IntVec2D:
        return IntVec2D(
            (size.x + self.strides.x - 1) // self.strides.x,
            (size.y + self.strides.y - 1) // self.strides.y,
        )

    def downscale_factor(self, factor: IntVec2D) -> IntVec2D:
        return IntVec2D(factor.x * self.strides.x, factor.y * self.strides.y)

    filters: int = 40
    kernel_size: IntVec2D = field(default_factory=lambda: IntVec2D(3, 3), metadata=pai_meta(tuple_like=True))
    strides: IntVec2D = field(default_factory=lambda: IntVec2D(1, 1), metadata=pai_meta(tuple_like=True))

    padding: str = "same"
    activation: str = "relu"
    dilated_depth: int = 2
Пример #3
0
class TransposedConv2DLayerParams(LayerParams):
    @classmethod
    def name_prefix(cls) -> str:
        return "tconv2d"

    @classmethod
    def cls(cls) -> Type["Layer"]:
        return TransposedConv2DLayer

    def downscale(self, size: IntVec2D) -> IntVec2D:
        return IntVec2D(size.x * self.strides.x, size.y * self.strides.y)

    def downscale_factor(self, size: IntVec2D) -> IntVec2D:
        return IntVec2D(
            (size.x + self.strides.x - 1) // self.strides.x,
            (size.y + self.strides.y - 1) // self.strides.y,
        )

    filters: int = 40
    kernel_size: IntVec2D = field(default_factory=lambda: IntVec2D(3, 3),
                                  metadata=pai_meta(tuple_like=True))
    strides: IntVec2D = field(default_factory=lambda: IntVec2D(2, 2),
                              metadata=pai_meta(tuple_like=True))

    padding: str = "same"
    activation: str = "relu"
Пример #4
0
class OptimizerParams(ABC):
    """General parameters of a Optimizer"""

    @abstractmethod
    def create(self) -> Tuple[Type["tf.keras.optimizers.Optimizer"], Dict[str, Any]]:
        raise NotImplementedError

    clip_norm: Optional[float] = field(
        default=None, metadata=pai_meta(help="float or None. If set, clips gradients to a maximum norm.")
    )
    clip_value: Optional[float] = field(
        default=None, metadata=pai_meta(help="float or None. If set, clips gradients to a maximum value.")
    )
    global_clip_norm: Optional[float] = field(
        default=None,
        metadata=pai_meta(
            help="float or None. If set, the gradient of all weights is clipped so that "
            "their global norm is no higher than this value."
        ),
    )

    def _clip_grad_args(self):
        return {
            "clipnorm": self.clip_norm,
            "clipvalue": self.clip_value,
            "global_clipnorm": self.global_clip_norm,
        }
Пример #5
0
class ListsFileGeneratorParams(DataGeneratorParams):
    """
    Parameters for the ListsFileDataGenerator
    """
    @staticmethod
    def cls():
        from tfaip.scenario.listfile.datagenerator import (
            ListsFileDataGenerator, )  # pylint: disable=import-outside-toplevel

        return ListsFileDataGenerator

    lists: Optional[List[str]] = field(
        default_factory=list, metadata=pai_meta(help="Training list files."))
    list_ratios: Optional[List[float]] = field(
        default=None,
        metadata=pai_meta(
            help=
            "Ratios of picking list files. Must be supported by the scenario"))
    ignore_prefix: Optional[str] = field(
        default="#",
        metadata=pai_meta(
            help=
            "Ignore entries in the files that start with the corresponding sign. Default is '#' for commenting out samples."
        ),
    )

    def __post_init__(self):
        if self.lists:
            if not self.list_ratios:
                self.list_ratios = [1.0] * len(self.lists)
            else:
                if len(self.list_ratios) != len(self.lists):
                    raise ValueError(
                        f"Length of list_ratios must be equals to number of lists. "
                        f"Got {self.list_ratios}!={self.lists}")
Пример #6
0
class MaxPool2DLayerParams(LayerParams):
    @classmethod
    def name_prefix(cls) -> str:
        return "maxpool2d"

    @classmethod
    def cls(cls) -> Type["Layer"]:
        return MaxPool2DLayer

    def downscale(self, size: IntVec2D) -> IntVec2D:
        strides = self.real_strides()
        return IntVec2D((size.x + strides.x - 1) // strides.x,
                        (size.y + strides.y - 1) // strides.y)

    def downscale_factor(self, factor: IntVec2D) -> IntVec2D:
        strides = self.real_strides()
        return IntVec2D(factor.x * strides.x, factor.y * strides.y)

    def real_strides(self) -> IntVec2D:
        if self.strides is None:
            return self.pool_size
        return IntVec2D(
            self.strides.x if self.strides.x >= 0 else self.pool_size.x,
            self.strides.y if self.strides.y >= 0 else self.pool_size.y,
        )

    pool_size: IntVec2D = field(default_factory=lambda: IntVec2D(2, 2),
                                metadata=pai_meta(tuple_like=True))
    strides: IntVec2D = field(default_factory=lambda: IntVec2D(-1, -1),
                              metadata=pai_meta(tuple_like=True))

    padding: str = "same"
Пример #7
0
class Hdf5(CalamariDataGeneratorParams):
    files: List[str] = field(default_factory=list,
                             metadata=pai_meta(required=True))
    pred_extension: str = field(
        default=".pred.h5",
        metadata=pai_meta(help="Default extension of the prediction files"),
    )

    def __len__(self):
        return len(self.files)

    def to_prediction(self):
        self.files = sorted(glob_all(self.files))
        pred = deepcopy(self)
        pred.files = [
            split_all_ext(f)[0] + self.pred_extension for f in self.files
        ]
        return pred

    @staticmethod
    def cls():
        return Hdf5Generator

    def prepare_for_mode(self, mode: PipelineMode):
        self.files = sorted(glob_all(self.files))
Пример #8
0
class DeviceConfigParams:
    """Configuration of the devices (GPUs).

    By default no gpys are added.
    Specify which gpus to use either by setting gpus or CUDA_VISIBLE_DEVICES
    """

    gpus: Optional[List[int]] = field(
        default=None, metadata=pai_meta(help="List of the GPUs to use."))
    gpu_auto_tune: bool = field(
        default=False,
        metadata=pai_meta(help="Enable auto tuning of the GPUs"))
    gpu_memory: Optional[int] = field(
        default=None,
        metadata=pai_meta(
            help=
            "Limit the per GPU memory in MB. By default the memory will grow automatically"
        ),
    )
    soft_device_placement: bool = field(
        default=True,
        metadata=pai_meta(help="Set up soft device placement is enabled"))
    dist_strategy: DistributionStrategy = field(
        default=DistributionStrategy.DEFAULT,
        metadata=pai_meta(
            help=
            "Distribution strategy for multi GPU, select 'mirror' or 'central_storage'"
        ),
    )
class ExponentialDecayParams(LearningRateParams):
    """Exponential decay parameters"""
    @staticmethod
    def cls():
        from tfaip.trainer.scheduler.exponential_decay import (
            ExponentialDecaySchedule, )  # pylint: disable=import-outside-toplevel

        return ExponentialDecaySchedule

    learning_circle: int = field(
        default=3,
        metadata=pai_meta(
            help=
            "(type dependent) The number of epochs with a flat constant learning rate"
        ))
    lr_decay_rate: float = field(
        default=0.99,
        metadata=pai_meta(
            help="(type dependent) The exponential decay factor"))
    decay_min_fraction: float = field(
        default=0.0,
        metadata=pai_meta(
            help=
            "(type dependent) Minimal fraction the learning rate can drop to by exponential decay)"
        ),
    )
Пример #10
0
class CalamariDataGeneratorParams(DataGeneratorParams, ImageLoaderParams, ABC):
    skip_invalid: bool = True
    non_existing_as_empty: bool = False
    n_folds: int = field(default=-1, metadata=pai_meta(mode="ignore"))
    preload: bool = field(
        default=True,
        metadata=pai_meta(
            help="Instead of preloading all data, load the data on the fly. "
            "This is slower, but might be required for limited RAM or large dataset"
        ),
    )

    def __len__(self):
        raise NotImplementedError

    @abstractmethod
    def to_prediction(self):
        raise NotImplementedError

    def select(self, indices: List[int]):
        raise NotImplementedError

    def prepare_for_mode(self, mode: PipelineMode) -> NoReturn:
        pass

    def create(self, mode: PipelineMode) -> "CalamariDataGenerator":
        params = deepcopy(self)  # always copy of params
        params.prepare_for_mode(mode)
        gen: CalamariDataGenerator = self.cls()(mode, params)
        gen.post_init()
        return gen

    def image_loader(self) -> ImageLoader:
        return ImageLoader(self)
Пример #11
0
class CodecConstructionParams:
    keep_loaded: bool = field(default=True, metadata=pai_meta(
        help="Fully include the codec of the loaded model to the new codec"))
    auto_compute: bool = field(default=True, metadata=pai_meta(
        help="Compute the codec automatically. See also include."))
    include: List[str] = field(default_factory=list, metadata=pai_meta(
        help="Whitelist of characters that may not be removed on restoring a model. "
             "For large dataset you can use this to skip the automatic codec computation "
             "(see auto_compute)"))
    include_files: List[str] = field(default_factory=list, metadata=pai_meta(
        help="Whitelist of txt files that may not be removed on restoring a model"))

    resolved_include_chars: Set[str] = field(default_factory=set, metadata=pai_meta(mode='ignore'))

    def __post_init__(self):
        # parse whitelist
        if len(self.include) == 1:
            include = set(self.include[0])
        else:
            include = set(self.include)

        for f in glob_all(self.include_files):
            with open(f) as txt:
                include = include.union(txt.read())

        self.resolved_include_chars = include
Пример #12
0
class Abbyy(CalamariDataGeneratorParams):
    images: List[str] = field(default_factory=list,
                              metadata=pai_meta(required=True))
    xml_files: List[str] = field(default_factory=list)
    gt_extension: str = field(
        default=".abbyy.xml",
        metadata=pai_meta(
            help=
            "Default extension of the gt files (expected to exist in same dir)"
        ),
    )
    binary: bool = False
    pred_extension: str = field(
        default=".abbyy.pred.xml",
        metadata=pai_meta(help="Default extension of the prediction files"),
    )

    def __len__(self):
        return len(self.images)

    def select(self, indices: List[int]):
        if self.images:
            self.images = [self.images[i] for i in indices]
        if self.xml_files:
            self.xml_files = [self.xml_files[i] for i in indices]

    def to_prediction(self):
        pred = deepcopy(self)
        pred.xml_files = [
            split_all_ext(f)[0] + self.pred_extension for f in self.xml_files
        ]
        return pred

    @staticmethod
    def cls():
        return AbbyyGenerator

    def prepare_for_mode(self, mode: PipelineMode):
        self.images = sorted(glob_all(self.images))
        self.xml_files = sorted(glob_all(self.xml_files))
        if not self.xml_files:
            self.xml_files = [
                split_all_ext(f)[0] + self.gt_extension for f in self.images
            ]
        if not self.images:
            self.images = [None] * len(self.xml_files)

        if len(self.images) != len(self.xml_files):
            raise ValueError(
                f"Different number of image and xml files, {len(self.images)} != {len(self.xml_files)}"
            )
        for img_path, xml_path in zip(self.images, self.xml_files):
            if img_path and xml_path:
                img_bn, xml_bn = split_all_ext(img_path)[0], split_all_ext(
                    xml_path)[0]
                if img_bn != xml_bn:
                    logger.warning(
                        f"Filenames are not matching, got base names \n  image: {img_bn}\n  xml:   {xml_bn}\n."
                    )
Пример #13
0
class Base:
    i1: Optional[int] = None
    i2: Optional[int] = -1
    i3: int = 3

    sub1: Optional[Sub] = field(default_factory=Sub,
                                metadata=pai_meta(choices=[Sub]))
    sub2: Optional[Sub] = field(default=None, metadata=pai_meta(choices=[Sub]))
Пример #14
0
class CalamariDefaultTrainerPipelineParams(
        TrainerPipelineParams[CalamariDataGeneratorParams,
                              CalamariDataGeneratorParams]):
    train: CalamariDataGeneratorParams = field(
        default_factory=FileDataParams,
        metadata=pai_meta(choices=DATA_GENERATOR_CHOICES, mode='flat'))
    val: CalamariDataGeneratorParams = field(
        default_factory=FileDataParams,
        metadata=pai_meta(choices=DATA_GENERATOR_CHOICES, mode='flat'))
Пример #15
0
class TextRegularizerProcessorParams(DataProcessorParams):
    # TODO: groups as enums
    replacement_groups: List[str] = field(
        default_factory=lambda: ["extended"],
        metadata=pai_meta(help="Text regularization to apply."))
    replacements: Optional[List[Replacement]] = field(
        default=None, metadata=pai_meta(mode='ignore'))

    @staticmethod
    def cls() -> Type['TextProcessor']:
        return TextRegularizerProcessor
Пример #16
0
class EvalArgs:
    gt: CalamariDataGeneratorParams = field(
        default_factory=FileDataParams,
        metadata=pai_meta(help="GT", mode="flat", choices=DATA_GENERATOR_CHOICES),
    )
    pred: Optional[CalamariDataGeneratorParams] = field(
        default=None,
        metadata=pai_meta(
            help="Optional prediction dataset",
            mode="flat",
            choices=DATA_GENERATOR_CHOICES,
        ),
    )
    n_confusions: int = field(
        default=10,
        metadata=pai_meta(
            help="Only print n most common confusions. Defaults to 10, use -1 for all.",
            mode="flat",
        ),
    )
    n_worst_lines: int = field(
        default=0,
        metadata=pai_meta(help="Print the n worst recognized text lines with its error", mode="flat"),
    )
    xlsx_output: Optional[str] = field(
        default=None,
        metadata=pai_meta(help="Optionally write a xlsx file with the evaluation results", mode="flat"),
    )
    non_existing_file_handling_mode: str = field(
        default="error",
        metadata=pai_meta(
            mode="flat",
            choices=["error", "skip", "empty"],
            help="How to handle non existing .pred.txt files. Possible modes: skip, empty, error. "
            "'Skip' will simply skip the evaluation of that file (not counting it to errors). "
            "'Empty' will handle this file as would it be empty (fully checking for errors)."
            "'Error' will throw an exception if a file is not existing. This is the default behaviour.",
        ),
    )
    skip_empty_gt: bool = field(
        default=False,
        metadata=pai_meta(help="Ignore lines of the gt that are empty.", mode="flat"),
    )
    checkpoint: Optional[str] = field(
        default=None,
        metadata=pai_meta(
            help="Specify an optional checkpoint to parse the text preprocessor (for the gt txt files)",
            mode="flat",
        ),
    )
    evaluator: EvaluatorParams = field(
        default_factory=EvaluatorParams,
        metadata=pai_meta(
            mode="flat",
            fix_dc=True,
        ),
    )
Пример #17
0
class WarmupDecayParams(LearningRateParams):
    """Cosine decay with warmup"""

    @staticmethod
    def cls():
        from tfaip.trainer.scheduler.warmup_decay import (
            WarmupDecaySchedule,
        )  # pylint: disable=import-outside-toplevel

        return WarmupDecaySchedule

    warmup_epochs: int = field(default=-1, metadata=pai_meta(help="Number of epochs for linear increase"))
    warmup_steps: int = field(default=-1, metadata=pai_meta(help="Number of epochs for linear increase"))
Пример #18
0
class TrainerPipelineParams(TrainerPipelineParamsBase[TDataGeneratorTrain,
                                                      TDataGeneratorVal],
                            metaclass=TrainerPipelineParamsMeta):
    train: TDataGeneratorTrain = field(default_factory=DataGeneratorParams,
                                       metadata=pai_meta(mode="flat"))
    val: TDataGeneratorVal = field(default_factory=DataGeneratorParams,
                                   metadata=pai_meta(mode="flat"))

    def train_gen(self) -> TDataGeneratorTrain:
        return self.train

    def val_gen(self) -> TDataGeneratorVal:
        return self.val
Пример #19
0
class TrainerPipelines:
    train: DataPipelineParams = field(
        default_factory=lambda: DataPipelineParams(mode=PipelineMode.TRAINING),
        metadata=pai_meta(fix_dc=True, mode="flat"),
    )
    val: DataPipelineParams = field(
        default_factory=lambda: DataPipelineParams(mode=PipelineMode.EVALUATION
                                                   ),
        metadata=pai_meta(fix_dc=True, mode="flat"),
    )

    def __post_init__(self):
        self.train.mode = PipelineMode.TRAINING
        self.val.mode = PipelineMode.EVALUATION
Пример #20
0
class PageXML(CalamariDataGeneratorParams):
    images: List[str] = field(default_factory=list)
    xml_files: List[str] = field(default_factory=list)
    gt_extension: str = field(
        default='.xml',
        metadata=pai_meta(
            help=
            "Default extension of the gt files (expected to exist in same dir)"
        ))
    text_index: int = 0
    pad: Optional[List[int]] = field(
        default=None,
        metadata=pai_meta(help="Additional padding after lines were cut out."))
    pred_extension: str = field(
        default='.pred.xml',
        metadata=pai_meta(help="Default extension of the prediction files"))
    skip_commented: bool = field(
        default=False,
        metadata=pai_meta(help='Skip lines with "comments" attribute.'))

    def __len__(self):
        return len(self.images)

    def select(self, indices: List[int]):
        if self.images:
            self.images = [self.images[i] for i in indices]
        if self.xml_files:
            self.xml_files = [self.xml_files[i] for i in indices]

    def to_prediction(self):
        pred = deepcopy(self)
        pred.xml_files = [
            split_all_ext(f)[0] + self.pred_extension for f in self.xml_files
        ]
        return pred

    @staticmethod
    def cls():
        return PageXMLReader

    def prepare_for_mode(self, mode: PipelineMode):
        self.images = sorted(glob_all(self.images))
        self.xml_files = sorted(self.xml_files)
        if not self.xml_files:
            self.xml_files = [
                split_all_ext(f)[0] + self.gt_extension for f in self.images
            ]
        if not self.images:
            self.xml_files = sorted(glob_all(self.xml_files))
            self.images = [None] * len(self.xml_files)
Пример #21
0
class TrainerParams(AIPTrainerParams[CalamariScenarioParams, CalamariDefaultTrainerPipelineParams]):
    version: int = SavedCalamariModel.VERSION

    data_aug_retrain_on_original: bool = field(default=True, metadata=pai_meta(
        help="When training with augmentations usually the model is retrained in a second run with "
             "only the non augmented data. This will take longer. Use this flag to disable this "
             "behavior."))

    # Current training progress: 0 standard, 1 retraining on non aug.
    current_stage: int = field(default=0, metadata=pai_meta(mode='ignore'))

    progress_bar: bool = True

    auto_upgrade_checkpoints: bool = field(default=True, metadata=pai_meta(
        help='Automatically update older checkpoints for warm start.'))

    codec: CodecConstructionParams = field(default_factory=CodecConstructionParams, metadata=pai_meta(
        help="Parameters defining how to construct the codec.", mode='flat'  # The actual codec is stored in data
    ))

    gen: TrainerPipelineParamsBase = field(default_factory=CalamariDefaultTrainerPipelineParams, metadata=pai_meta(
        help="Parameters that setup the data generators (i.e. the input data).",
        disable_subclass_check=False,
        choices=[CalamariDefaultTrainerPipelineParams, CalamariTrainOnlyPipelineParams,
                 CalamariSplitTrainerPipelineParams]
    ))

    best_model_prefix: str = field(default="best", metadata=pai_meta(
        help="The prefix of the best model using early stopping"))

    network: Optional[str] = field(default=None, metadata=pai_meta(
        mode='flat',
        help='Pass a network configuration to construct a simple graph. '
             'Defaults to: --network=cnn=40:3x3,pool=2x2,cnn=60:3x3,pool=2x2,lstm=200,dropout=0.5'
    ))

    def __post_init__(self):
        self.scenario.default_serve_dir = f'{self.best_model_prefix}.ckpt.h5'
        self.scenario.trainer_params_filename = f'{self.best_model_prefix}.ckpt.json'
        self.early_stopping.best_model_name = ''

        self.gen.train_gen().n_folds = self.scenario.model.ensemble
        self.gen.train_gen().channels = self.scenario.data.input_channels
        if self.gen.val_gen() is not None:
            self.gen.val_gen().channels = self.scenario.data.input_channels
            self.gen.val_gen().n_folds = self.scenario.model.ensemble

        if self.network:
            self.scenario.model.layers = graph_params_from_definition_string(self.network)
Пример #22
0
class WarmStartParams:
    """Parameters for warm-starting from a model."""

    model: Optional[str] = field(
        default=None,
        metadata=pai_meta(
            help=
            "Path to the saved model or checkpoint to load the weights from."))

    allow_partial: bool = field(
        default=False,
        metadata=pai_meta(help="Allow that not all weights can be matched."))
    trim_graph_name: bool = field(
        default=True,
        metadata=pai_meta(
            help=
            "Remove the graph name from the loaded model and the target model. This is useful if the model name "
            "changed"),
    )

    rename: List[str] = field(
        default_factory=list,
        metadata=pai_meta(
            help=
            "A list of renaming rules to perform on the loaded weights. Format: FROM->TO FROM->TO ..."
        ),
    )

    add_suffix: str = field(
        default="",
        metadata=pai_meta(help="Add suffix str to all variable names"))

    rename_targets: List[str] = field(
        default_factory=list,
        metadata=pai_meta(
            help=
            "A list of renaming rules to perform on the target weights. Format: FROM->TO FROM->TO ..."
        ),
    )

    exclude: Optional[str] = field(
        default=None,
        metadata=pai_meta(
            help="A regex applied on the loaded weights to ignore from loading."
        ))
    include: Optional[str] = field(
        default=None,
        metadata=pai_meta(
            help=
            "A regex applied on the loaded weights to include from loading."))

    auto_remove_numbers_for: List[str] = field(
        default_factory=lambda: ["lstm_cell"])

    def create(self, **kwargs) -> "WarmStarter":
        from tfaip.trainer.warmstart.warmstarter import WarmStarter

        return WarmStarter(params=self, **kwargs)
class ICTrainerPipelineParams(TrainerPipelineParamsBase[ICDataGeneratorParams, ICDataGeneratorParams]):
    dataset_path: str = field(default='')
    validation_split: float = 0.2
    shuffle_files: bool = True

    # resolved files, a list of file names per class
    image_files: Dict[str, List[str]] = field(default_factory=dict, metadata=pai_meta(mode='ignore'))

    def train_gen(self) -> ICDataGeneratorParams:
        return ICDataGeneratorParams(
            image_files={k: v[int(self.validation_split * len(v)):] for k, v in self.image_files.items()})

    def val_gen(self) -> Optional[ICDataGeneratorParams]:
        val = ICDataGeneratorParams(
            image_files={k: v[:int(self.validation_split * len(v))] for k, v in self.image_files.items()})
        if val.num_files() == 0:
            return None  # No validation
        return val

    def __post_init__(self):
        self.image_files = {}
        if os.path.exists(self.dataset_path):
            for class_name in os.listdir(self.dataset_path):
                class_path = os.path.join(self.dataset_path, class_name)
                if not os.path.isdir(class_path):
                    continue
                self.image_files[class_name] = [os.path.join(class_path, file) for file in os.listdir(class_path)]
                if self.shuffle_files:
                    shuffle(self.image_files[class_name])
Пример #24
0
class TutorialTrainerGeneratorParams(
        TrainerPipelineParamsBase[TutorialDataGeneratorParams,
                                  TutorialDataGeneratorParams]):
    """
    Definition of the training data. Since the dataset is loaded from the keras.datasets, training and validation data
    is jointly loaded (parameter `train_val`) which is why `train_gen` and `val_gen` return the same generator.
    The decision whether to chose training and validation data is dependent on the `PipelineMode`.

    Furthermore, the `lav_gen` method is overwritten to perform lav on both the training and the validation set.
    For this purpose, the `force_train` variable is overwritten, to select the training data even if the PipelineMode is
    PipelineMode.EVALUATION.
    """

    train_val: TutorialDataGeneratorParams = field(
        default_factory=TutorialDataGeneratorParams,
        metadata=pai_meta(mode="flat"))

    def train_gen(self) -> TutorialDataGeneratorParams:
        return self.train_val

    def val_gen(self) -> Optional[TutorialDataGeneratorParams]:
        return self.train_val

    def lav_gen(self) -> Iterable[TutorialDataGeneratorParams]:
        train: TutorialDataGeneratorParams = copy(self.train_val)
        train.force_train = True
        return [train, self.train_val]
Пример #25
0
class ImageLoaderParams:
    channels: int = field(
        default=1,
        metadata=pai_meta(
            help=
            'Number of channels to produce, by default 1=grayscale. Use 3 for colour.'
        ))
    to_gray_method: str = field(
        default='cv',
        metadata=pai_meta(
            help='Method to apply to convert color to gray.',
            choices=['avg', 'cv'],
        ))

    def create(self) -> 'ImageLoader':
        return ImageLoader(self)
Пример #26
0
class ModelParams(ModelBaseParams):
    layers: List[LayerParams] = field(
        default_factory=default_layers,
        metadata=pai_meta(
            choices=all_layers(),
            help="Layers of the graph. See the docs for more information."),
    )
    classes: int = -1
    ctc_merge_repeated: bool = True
    ensemble: int = 0  # For usage with the ensemble-model graph
    masking_mode: int = False  # This parameter is for evaluation only and should not be used in production

    @staticmethod
    def cls():
        from calamari_ocr.ocr.model.model import Model

        return Model

    def graph_cls(self):
        from calamari_ocr.ocr.model.graph import CalamariGraph

        return CalamariGraph

    def __post_init__(self):
        # setup layer names
        counts = {}
        for layer in self.layers:
            counts[layer.name_prefix()] = counts.get(layer.name_prefix(),
                                                     -1) + 1
            layer.name = f"{layer.name_prefix()}_{counts[layer.name_prefix()]}"

    def compute_downscale_factor(self) -> IntVec2D:
        factor = IntVec2D(1, 1)
        for layer in self.layers:
            factor = layer.downscale_factor(factor)
        return factor

    def compute_max_downscale_factor(self) -> IntVec2D:
        factor = IntVec2D(1, 1)
        max_factor = IntVec2D(1, 1)
        for layer in self.layers:
            factor = layer.downscale_factor(factor)
            max_factor.x = max(max_factor.x, factor.x)
            max_factor.y = max(max_factor.y, factor.y)
        return max_factor

    def compute_downscaled(self, size: Union[int, IntVec2D, Tuple[Any, Any]]):
        if isinstance(size, int):
            for layer in self.layers:
                size = layer.downscale(IntVec2D(size, 1)).x
        elif isinstance(size, IntVec2D):
            for layer in self.layers:
                size = layer.downscale(size)
        elif isinstance(size, tuple):
            for layer in self.layers:
                size = layer.downscale(IntVec2D(size[0], size[1]))
                size = size.x, size.y
        else:
            raise NotImplementedError
        return size
Пример #27
0
class ICPredictionDataGeneratorParams(DataGeneratorParams):
    @staticmethod
    def cls() -> Type["DataGenerator"]:
        return ICPredictionDataGenerator

    image_files: List[str] = field(default_factory=list,
                                   metadata=pai_meta(required=True))
Пример #28
0
class CalamariSplitTrainerPipelineParams(
        TrainerPipelineParams[CalamariDataGeneratorParams,
                              CalamariDataGeneratorParams]):
    train: CalamariDataGeneratorParams = field(
        default_factory=FileDataParams,
        metadata=pai_meta(
            choices=[FileDataParams, PageXML],
            enforce_choices=True,
            mode="flat",
        ),
    )
    validation_split_ratio: float = field(
        default=0.2,
        metadata=pai_meta(
            help="Use factor of n of the training dataset for validation."),
    )

    val: Optional[CalamariDataGeneratorParams] = field(
        default=None, metadata=pai_meta(mode="ignore"))

    def __post_init__(self):
        if self.val is not None:
            # Already initialized
            return

        if not 0 < self.validation_split_ratio < 1:
            raise ValueError("validation_split_ratio must be in (0, 1)")

        # resolve all files so we can split them
        self.train.prepare_for_mode(PipelineMode.TRAINING)
        self.val = deepcopy(self.train)
        samples = len(self.train)
        n = int(self.validation_split_ratio * samples)
        if n == 0:
            raise ValueError(
                f"Ratio is to small since {self.validation_split_ratio} * {samples} = {n}. "
                f"Increase the amount of data or the split ratio.")
        logger.info(
            f"Splitting training and validation files with ratio {self.validation_split_ratio}: "
            f"{n}/{samples - n} for validation/training.")
        indices = list(range(samples))
        shuffle(indices)

        # split train and val img/gt files. Use train settings
        self.train.select(indices[n:])
        self.val.select(indices[:n])
Пример #29
0
class PrepareSampleProcessorParams(DataProcessorParams):
    @staticmethod
    def cls() -> Type["MappingDataProcessor"]:
        return PrepareSample

    max_line_width: int = field(
        default=4096, metadata=pai_meta(help="Max width of a line. Set to -1 or 0 to skip this check.")
    )
Пример #30
0
class TextRegularizerProcessorParams(DataProcessorParams):
    replacements: Optional[List[Replacement]] = field(default=None, metadata=pai_meta(mode="ignore"))
    rulesets: List[str] = field(default_factory=lambda: ["spaces"])
    rulegroups: List[str] = field(default_factory=list)

    @staticmethod
    def cls() -> Type["TextProcessor"]:
        return TextRegularizerProcessor