Exemplo n.º 1
0
    def encode(self, codec: Optional[PytorchCodec] = None) -> None:
        """
        Adds a codec to the dataset and encodes all text lines.

        Has to be run before sampling from the dataset.
        """
        if codec:
            self.codec = codec
        else:
            self.codec = PytorchCodec(''.join(self.alphabet.keys()))
        self.training_set = []  # type: List[Tuple[Union[Image, torch.Tensor], torch.Tensor]]
        for im, gt in zip(self._images, self._gt):
            self.training_set.append((im, self.codec.encode(gt)))
Exemplo n.º 2
0
    def load_model(cls, path: str):
        """
        Deserializes a VGSL model from a CoreML file.

        Args:
            path (str): CoreML file

        Returns:
            A TorchVGSLModel instance.

        Raises:
            KrakenInvalidModelException if the model data is invalid (not a
            string, protobuf file, or without appropriate metadata).
            FileNotFoundError if the path doesn't point to a file.
        """
        try:
            mlmodel = MLModel(path)
        except TypeError as e:
            raise KrakenInvalidModelException(str(e))
        except DecodeError as e:
            raise KrakenInvalidModelException('Failure parsing model protobuf: {}'.format(str(e)))
        if 'vgsl' not in mlmodel.user_defined_metadata:
            raise KrakenInvalidModelException('No VGSL spec in model metadata')
        vgsl_spec = mlmodel.user_defined_metadata['vgsl']
        nn = cls(vgsl_spec)
        for name, layer in nn.nn.named_children():
            layer.deserialize(name, mlmodel.get_spec())

        if 'codec' in mlmodel.user_defined_metadata:
            nn.add_codec(PytorchCodec(json.loads(mlmodel.user_defined_metadata['codec'])))

        nn.user_metadata = {'accuracy': [], 'seg_type': 'bbox', 'one_channel_mode': '1', 'model_type': None, 'hyper_params': {}}  # type: dict[str, str]
        if 'kraken_meta' in mlmodel.user_defined_metadata:
            nn.user_metadata.update(json.loads(mlmodel.user_defined_metadata['kraken_meta']))
        return nn
Exemplo n.º 3
0
    def load_pronn_model(cls, path: str):
        """
        Loads an pronn model to VGSL.
        """
        with open(path, 'rb') as fp:
            net = pyrnn_pb2.pyrnn()
            try:
                net.ParseFromString(fp.read())
            except Exception:
                raise KrakenInvalidModelException('File does not contain valid proto msg')
            if not net.IsInitialized():
                raise KrakenInvalidModelException('Model incomplete')

        # extract codec
        codec = PytorchCodec(net.codec)

        input = net.ninput
        hidden = net.fwdnet.wgi.dim[0]

        # extract weights
        weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop')

        fwd_w = []
        rev_w = []
        for w in weightnames:
            fwd_ar = getattr(net.fwdnet, w)
            rev_ar = getattr(net.revnet, w)
            fwd_w.append(torch.Tensor(fwd_ar.value).view(list(fwd_ar.dim)))
            rev_w.append(torch.Tensor(rev_ar.value).view(list(rev_ar.dim)))

        t = torch.cat(fwd_w[:4])
        weight_ih_l0 = t[:, :input+1]
        weight_hh_l0 = t[:, input+1:]

        t = torch.cat(rev_w[:4])
        weight_ih_l0_rev = t[:, :input+1]
        weight_hh_l0_rev = t[:, input+1:]

        weight_lin = torch.Tensor(net.softmax.w2.value).view(list(net.softmax.w2.dim))

        # build vgsl spec and set weights
        nn = cls('[1,1,0,{} Lbxo{} O1ca{}]'.format(input, hidden, len(net.codec)))

        nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
        nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
        nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev)
        nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev)
        nn.nn.L_0.layer.weight_ip_l0 = torch.nn.Parameter(fwd_w[4])
        nn.nn.L_0.layer.weight_fp_l0 = torch.nn.Parameter(fwd_w[5])
        nn.nn.L_0.layer.weight_op_l0 = torch.nn.Parameter(fwd_w[6])
        nn.nn.L_0.layer.weight_ip_l0_reverse = torch.nn.Parameter(rev_w[4])
        nn.nn.L_0.layer.weight_fp_l0_reverse = torch.nn.Parameter(rev_w[5])
        nn.nn.L_0.layer.weight_op_l0_reverse = torch.nn.Parameter(rev_w[6])

        nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin)

        nn.add_codec(codec)

        return nn
Exemplo n.º 4
0
    def load_model(cls, path: Union[str, pathlib.Path]):
        """
        Deserializes a VGSL model from a CoreML file.

        Args:
            path: CoreML file

        Returns:
            A TorchVGSLModel instance.

        Raises:
            KrakenInvalidModelException if the model data is invalid (not a
            string, protobuf file, or without appropriate metadata).
            FileNotFoundError if the path doesn't point to a file.
        """
        if isinstance(path, pathlib.Path):
            path = path.as_posix()
        try:
            mlmodel = MLModel(path)
        except TypeError as e:
            raise KrakenInvalidModelException(str(e))
        except DecodeError as e:
            raise KrakenInvalidModelException(
                'Failure parsing model protobuf: {}'.format(str(e)))
        if 'vgsl' not in mlmodel.user_defined_metadata:
            raise KrakenInvalidModelException('No VGSL spec in model metadata')
        vgsl_spec = mlmodel.user_defined_metadata['vgsl']
        nn = cls(vgsl_spec)

        def _deserialize_layers(name, layer):
            logger.debug(f'Deserializing layer {name} with type {type(layer)}')
            if type(layer) in (layers.MultiParamParallel,
                               layers.MultiParamSequential):
                for name, l in layer.named_children():
                    _deserialize_layers(name, l)
            else:
                layer.deserialize(name, mlmodel.get_spec())

        _deserialize_layers('', nn.nn)

        if 'codec' in mlmodel.user_defined_metadata:
            nn.add_codec(
                PytorchCodec(json.loads(
                    mlmodel.user_defined_metadata['codec'])))

        nn.user_metadata = {
            'accuracy': [],
            'seg_type': 'bbox',
            'one_channel_mode': '1',
            'model_type': None,
            'hyper_params': {}
        }  # type: dict[str, str]
        if 'kraken_meta' in mlmodel.user_defined_metadata:
            nn.user_metadata.update(
                json.loads(mlmodel.user_defined_metadata['kraken_meta']))
        return nn
Exemplo n.º 5
0
    def encode(self, codec: Optional[PytorchCodec] = None) -> None:
        """
        Adds a codec to the dataset and encodes all text lines.

        Has to be run before sampling from the dataset.
        """
        if codec:
            self.codec = codec
        else:
            self.codec = PytorchCodec(''.join(self.alphabet.keys()))
        self.training_set = []  # type: List[Tuple[Union[Image, torch.Tensor], torch.Tensor]]
        for im, gt in zip(self._images, self._gt):
            self.training_set.append((im, self.codec.encode(gt)))
Exemplo n.º 6
0
    def load_model(cls, path: str):
        """
        Deserializes a VGSL model from a CoreML file.

        Args:
            path (str): CoreML file
        """
        mlmodel = MLModel(path)
        if 'vgsl' not in mlmodel.user_defined_metadata:
            raise ValueError('No VGSL spec in model metadata')
        vgsl_spec = mlmodel.user_defined_metadata['vgsl']
        nn = cls(vgsl_spec)
        for name, layer in nn.nn.named_children():
            layer.deserialize(name, mlmodel.get_spec())

        if 'codec' in mlmodel.user_defined_metadata:
            nn.add_codec(
                PytorchCodec(json.loads(
                    mlmodel.user_defined_metadata['codec'])))
        return nn
Exemplo n.º 7
0
    def load_clstm_model(cls, path: str):
        """
        Loads an CLSTM model to VGSL.
        """
        net = clstm_pb2.NetworkProto()
        with open(path, 'rb') as fp:
            try:
                net.ParseFromString(fp.read())
            except Exception:
                raise KrakenInvalidModelException('File does not contain valid proto msg')
            if not net.IsInitialized():
                raise KrakenInvalidModelException('Model incomplete')

        input = net.ninput
        attrib = {a.key: a.value for a in list(net.attribute)}
        # mainline clstm model
        if len(attrib) > 1:
            mode = 'clstm'
        else:
            mode = 'clstm_compat'

        # extract codec
        codec = PytorchCodec([''] + [chr(x) for x in net.codec[1:]])

        # separate layers
        nets = {}
        nets['softm'] = [n for n in list(net.sub) if n.kind == 'SoftmaxLayer'][0]
        parallel = [n for n in list(net.sub) if n.kind == 'Parallel'][0]
        nets['lstm1'] = [n for n in list(parallel.sub) if n.kind.startswith('NPLSTM')][0]
        rev = [n for n in list(parallel.sub) if n.kind == 'Reversed'][0]
        nets['lstm2'] = rev.sub[0]

        hidden = int(nets['lstm1'].attribute[0].value)

        weights = {}  # type: Dict[str, torch.Tensor]
        for n in nets:
            weights[n] = {}
            for w in list(nets[n].weights):
                weights[n][w.name] = torch.Tensor(w.value).view(list(w.dim))

        if mode == 'clstm_compat':
            weightnames = ('.WGI', '.WGF', '.WCI', '.WGO')
            weightname_softm = '.W'
        else:
            weightnames = ('WGI', 'WGF', 'WCI', 'WGO')
            weightname_softm = 'W1'

        # input hidden and hidden-hidden weights are in one matrix. also
        # CLSTM/ocropy likes 1-augmenting every other tensor so the ih weights
        # are input+1 in one dimension.
        t = torch.cat(list(w for w in [weights['lstm1'][wn] for wn in weightnames]))
        weight_ih_l0 = t[:, :input+1]
        weight_hh_l0 = t[:, input+1:]

        t = torch.cat(list(w for w in [weights['lstm2'][wn] for wn in weightnames]))
        weight_ih_l0_rev = t[:, :input+1]
        weight_hh_l0_rev = t[:, input+1:]

        weight_lin = weights['softm'][weightname_softm]
        if mode == 'clstm_compat':
            weight_lin = torch.cat([torch.zeros(len(weight_lin), 1), weight_lin], 1)

        # build vgsl spec and set weights
        nn = cls('[1,1,0,{} Lbxc{} O1ca{}]'.format(input, hidden, len(net.codec)))
        nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
        nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
        nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev)
        nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev)
        nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin)

        nn.add_codec(codec)

        return nn
Exemplo n.º 8
0
    def load_pyrnn_model(cls, path: str):
        """
        Loads an pyrnn model to VGSL.
        """
        if not PY2:
            raise KrakenInvalidModelException('Loading pickle models is not supported on python 3')

        import cPickle

        def find_global(mname, cname):
            aliases = {
                'lstm.lstm': kraken.lib.lstm,
                'ocrolib.lstm': kraken.lib.lstm,
                'ocrolib.lineest': kraken.lib.lineest,
            }
            if mname in aliases:
                return getattr(aliases[mname], cname)
            return getattr(sys.modules[mname], cname)

        of = io.open
        if path.endswith('.gz'):
            of = gzip.open
        with io.BufferedReader(of(path, 'rb')) as fp:
            unpickler = cPickle.Unpickler(fp)
            unpickler.find_global = find_global
            try:
                net = unpickler.load()
            except Exception as e:
                raise KrakenInvalidModelException(str(e))
            if not isinstance(net, kraken.lib.lstm.SeqRecognizer):
                raise KrakenInvalidModelException('Pickle is %s instead of '
                                                  'SeqRecognizer' %
                                                  type(net).__name__)
        # extract codec
        codec = PytorchCodec({k: [v] for k, v in net.codec.char2code.items()})

        input = net.Ni
        parallel, softmax = net.lstm.nets
        fwdnet, revnet = parallel.nets
        revnet = revnet.net

        hidden = fwdnet.WGI.shape[0]

        # extract weights
        weightnames = ('WGI', 'WGF', 'WCI', 'WGO', 'WIP', 'WFP', 'WOP')

        fwd_w = []
        rev_w = []
        for w in weightnames:
            fwd_w.append(torch.Tensor(getattr(fwdnet, w)))
            rev_w.append(torch.Tensor(getattr(revnet, w)))

        t = torch.cat(fwd_w[:4])
        weight_ih_l0 = t[:, :input+1]
        weight_hh_l0 = t[:, input+1:]

        t = torch.cat(rev_w[:4])
        weight_ih_l0_rev = t[:, :input+1]
        weight_hh_l0_rev = t[:, input+1:]

        weight_lin = torch.Tensor(softmax.W2)

        # build vgsl spec and set weights
        nn = cls('[1,1,0,{} Lbxo{} O1ca{}]'.format(input, hidden, len(net.codec.code2char)))

        nn.nn.L_0.layer.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
        nn.nn.L_0.layer.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
        nn.nn.L_0.layer.weight_ih_l0_reverse = torch.nn.Parameter(weight_ih_l0_rev)
        nn.nn.L_0.layer.weight_hh_l0_reverse = torch.nn.Parameter(weight_hh_l0_rev)
        nn.nn.L_0.layer.weight_ip_l0 = torch.nn.Parameter(fwd_w[4])
        nn.nn.L_0.layer.weight_fp_l0 = torch.nn.Parameter(fwd_w[5])
        nn.nn.L_0.layer.weight_op_l0 = torch.nn.Parameter(fwd_w[6])
        nn.nn.L_0.layer.weight_ip_l0_reverse = torch.nn.Parameter(rev_w[4])
        nn.nn.L_0.layer.weight_fp_l0_reverse = torch.nn.Parameter(rev_w[5])
        nn.nn.L_0.layer.weight_op_l0_reverse = torch.nn.Parameter(rev_w[6])

        nn.nn.O_1.lin.weight = torch.nn.Parameter(weight_lin)

        nn.add_codec(codec)

        return nn
Exemplo n.º 9
0
class GroundTruthDataset(Dataset):
    """
    Dataset for training a line recognition model.

    All data is cached in memory.
    """
    def __init__(self,
                 split: Callable[[str], str] = lambda x: path.splitext(x)[0],
                 suffix: str = '.gt.txt',
                 normalization: Optional[str] = None,
                 whitespace_normalization: bool = True,
                 reorder: bool = True,
                 im_transforms: Callable[[Any],
                                         torch.Tensor] = transforms.Compose(
                                             []),
                 preload: bool = True,
                 augmentation: bool = False) -> None:
        """
        Reads a list of image-text pairs and creates a ground truth set.

        Args:
            split (func): Function for generating the base name without
                          extensions from paths
            suffix (str): Suffix to attach to image base name for text
                          retrieval
            mode (str): Image color space. Either RGB (color) or L
                        (grayscale/bw). Only L is compatible with vertical
                        scaling/dewarping.
            scale (int, tuple): Target height or (width, height) of dewarped
                                line images. Vertical-only scaling is through
                                CenterLineNormalizer, resizing with Lanczos
                                interpolation. Set to 0 to disable.
            normalization (str): Unicode normalization for gt
            whitespace_normalization (str): Normalizes unicode whitespace and
                                            strips whitespace.
            reorder (bool): Whether to rearrange code points in "display"/LTR
                            order
            im_transforms (func): Function taking an PIL.Image and returning a
                                  tensor suitable for forward passes.
            preload (bool): Enables preloading and preprocessing of image files.
        """
        self.suffix = suffix
        self.split = lambda x: split(x) + self.suffix
        self._images = []  # type:  Union[List[Image], List[torch.Tensor]]
        self._gt = []  # type:  List[str]
        self.alphabet = Counter()  # type: Counter
        self.text_transforms = []  # type: List[Callable[[str], str]]
        # split image transforms into two. one part giving the final PIL image
        # before conversion to a tensor and the actual tensor conversion part.
        self.head_transforms = transforms.Compose(im_transforms.transforms[:2])
        self.tail_transforms = transforms.Compose(im_transforms.transforms[2:])
        self.aug = None

        self.preload = preload
        self.seg_type = 'bbox'
        # built text transformations
        if normalization:
            self.text_transforms.append(
                lambda x: unicodedata.normalize(cast(str, normalization), x))
        if whitespace_normalization:
            self.text_transforms.append(
                lambda x: regex.sub('\s', ' ', x).strip())
        if reorder:
            self.text_transforms.append(bd.get_display)
        if augmentation:
            from albumentations import (
                Compose,
                ToFloat,
                FromFloat,
                Flip,
                OneOf,
                MotionBlur,
                MedianBlur,
                Blur,
                ShiftScaleRotate,
                OpticalDistortion,
                ElasticTransform,
                RandomBrightnessContrast,
            )

            self.aug = Compose([
                ToFloat(),
                OneOf([
                    MotionBlur(p=0.2),
                    MedianBlur(blur_limit=3, p=0.1),
                    Blur(blur_limit=3, p=0.1),
                ],
                      p=0.2),
                ShiftScaleRotate(shift_limit=0.0625,
                                 scale_limit=0.2,
                                 rotate_limit=45,
                                 p=0.2),
                OneOf([
                    OpticalDistortion(p=0.3),
                    ElasticTransform(p=0.1),
                ],
                      p=0.2),
            ],
                               p=0.5)

        self.im_mode = '1'

    def add(self, image: Union[str, Image.Image], *args, **kwargs) -> None:
        """
        Adds a line-image-text pair to the dataset.

        Args:
            image (str): Input image path
        """
        with open(self.split(image), 'r', encoding='utf-8') as fp:
            gt = fp.read().strip('\n\r')
            for func in self.text_transforms:
                gt = func(gt)
            if not gt:
                raise KrakenInputException(f'Text line is empty ({fp.name})')
        if self.preload:
            try:
                im = Image.open(image)
                im = self.head_transforms(im)
                if not is_bitonal(im):
                    self.im_mode = im.mode
                im = self.tail_transforms(im)
            except ValueError:
                raise KrakenInputException(
                    f'Image transforms failed on {image}')
            self._images.append(im)
        else:
            self._images.append(image)
        self._gt.append(gt)
        self.alphabet.update(gt)

    def add_loaded(self, image: Image.Image, gt: str) -> None:
        """
        Adds an already loaded line-image-text pair to the dataset.

        Args:
            image (PIL.Image.Image): Line image
            gt (str): Text contained in the line image
        """
        if self.preload:
            try:
                im = self.head_transforms(im)
                if not is_bitonal(im):
                    self.im_mode = im.mode
                im = self.tail_transforms(im)
            except ValueError:
                raise KrakenInputException(
                    f'Image transforms failed on {image}')
            self._images.append(im)
        else:
            self._images.append(image)
        for func in self.text_transforms:
            gt = func(gt)
        self._gt.append(gt)
        self.alphabet.update(gt)

    def encode(self, codec: Optional[PytorchCodec] = None) -> None:
        """
        Adds a codec to the dataset and encodes all text lines.

        Has to be run before sampling from the dataset.
        """
        if codec:
            self.codec = codec
        else:
            self.codec = PytorchCodec(''.join(self.alphabet.keys()))
        self.training_set = [
        ]  # type: List[Tuple[Union[Image, torch.Tensor], torch.Tensor]]
        for im, gt in zip(self._images, self._gt):
            self.training_set.append((im, self.codec.encode(gt)))

    def no_encode(self) -> None:
        """
        Creates an unencoded dataset.
        """
        self.training_set = [
        ]  # type: List[Tuple[Union[Image, torch.Tensor], str]]
        for im, gt in zip(self._images, self._gt):
            self.training_set.append((im, gt))

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.preload:
            x, y = self.training_set[index]
            if self.aug:
                im = x.permute((1, 2, 0)).numpy()
                o = self.aug(image=im)
                im = torch.tensor(o['image'].transpose(2, 0, 1))
                return {'image': im, 'target': y}
            return {'image': x, 'target': y}
        else:
            item = self.training_set[index]
            try:
                logger.debug(f'Attempting to load {item[0]}')
                im = item[0]
                if not isinstance(im, Image.Image):
                    im = Image.open(im)
                im = self.head_transforms(im)
                if not is_bitonal(im):
                    self.im_mode = im.mode
                im = self.tail_transforms(im)
                if self.aug:
                    im = im.permute((1, 2, 0)).numpy()
                    o = self.aug(image=im)
                    im = torch.tensor(o['image'].transpose(2, 0, 1))
                return {'image': im, 'target': item[1]}
            except Exception:
                idx = np.random.randint(0, len(self.training_set))
                logger.debug(traceback.format_exc())
                logger.info(f'Failed. Replacing with sample {idx}')
                return self[np.random.randint(0, len(self.training_set))]

    def __len__(self) -> int:
        return len(self.training_set)
Exemplo n.º 10
0
class PolygonGTDataset(Dataset):
    """
    Dataset for training a line recognition model from polygonal/baseline data.
    """
    def __init__(self,
                 normalization: Optional[str] = None,
                 whitespace_normalization: bool = True,
                 reorder: bool = True,
                 im_transforms: Callable[[Any],
                                         torch.Tensor] = transforms.Compose(
                                             []),
                 preload: bool = True,
                 augmentation: bool = False) -> None:
        self._images = []  # type:  Union[List[Image], List[torch.Tensor]]
        self._gt = []  # type:  List[str]
        self.alphabet = Counter()  # type: Counter
        self.text_transforms = []  # type: List[Callable[[str], str]]
        # split image transforms into two. one part giving the final PIL image
        # before conversion to a tensor and the actual tensor conversion part.
        self.head_transforms = transforms.Compose(im_transforms.transforms[:2])
        self.tail_transforms = transforms.Compose(im_transforms.transforms[2:])
        self.transforms = im_transforms
        self.preload = preload
        self.aug = None

        self.seg_type = 'baselines'
        # built text transformations
        if normalization:
            self.text_transforms.append(
                lambda x: unicodedata.normalize(cast(str, normalization), x))
        if whitespace_normalization:
            self.text_transforms.append(
                lambda x: regex.sub('\s', ' ', x).strip())
        if reorder:
            self.text_transforms.append(bd.get_display)
        if augmentation:
            from albumentations import (
                Compose,
                ToFloat,
                FromFloat,
                Flip,
                OneOf,
                MotionBlur,
                MedianBlur,
                Blur,
                ShiftScaleRotate,
                OpticalDistortion,
                ElasticTransform,
                RandomBrightnessContrast,
            )

            self.aug = Compose([
                ToFloat(),
                OneOf([
                    MotionBlur(p=0.2),
                    MedianBlur(blur_limit=3, p=0.1),
                    Blur(blur_limit=3, p=0.1),
                ],
                      p=0.2),
                ShiftScaleRotate(
                    shift_limit=0.0625, scale_limit=0.2, rotate_limit=3,
                    p=0.2),
                OneOf([
                    OpticalDistortion(p=0.3),
                    ElasticTransform(p=0.1),
                ],
                      p=0.2),
            ],
                               p=0.5)

        self.im_mode = '1'

    def add(self, image: Union[str, Image.Image], text: str,
            baseline: List[Tuple[int, int]], boundary: List[Tuple[int, int]],
            *args, **kwargs):
        """
        Adds a line to the dataset.

        Args:
            im (path): Path to the whole page image
            text (str): Transcription of the line.
            baseline (list): A list of coordinates [[x0, y0], ..., [xn, yn]].
            boundary (list): A polygon mask for the line.
        """
        for func in self.text_transforms:
            text = func(text)
        if not text:
            raise KrakenInputException(
                'Text line is empty after transformations')
        if not baseline:
            raise KrakenInputException('No baseline given for line')
        if not boundary:
            raise KrakenInputException('No boundary given for line')
        if self.preload:
            if not isinstance(image, Image.Image):
                im = Image.open(image)
            try:
                im, _ = next(
                    extract_polygons(
                        im, {
                            'type': 'baselines',
                            'lines': [{
                                'baseline': baseline,
                                'boundary': boundary
                            }]
                        }))
            except IndexError:
                raise KrakenInputException(
                    'Patch extraction failed for baseline')
            try:
                im = self.head_transforms(im)
                if not is_bitonal(im):
                    self.im_mode = im.mode
                im = self.tail_transforms(im)
            except ValueError:
                raise KrakenInputException(
                    f'Image transforms failed on {image}')
            self._images.append(im)
        else:
            self._images.append((image, baseline, boundary))
        self._gt.append(text)
        self.alphabet.update(text)

    def encode(self, codec: Optional[PytorchCodec] = None) -> None:
        """
        Adds a codec to the dataset and encodes all text lines.

        Has to be run before sampling from the dataset.
        """
        if codec:
            self.codec = codec
        else:
            self.codec = PytorchCodec(''.join(self.alphabet.keys()))
        self.training_set = [
        ]  # type: List[Tuple[Union[Image, torch.Tensor], torch.Tensor]]
        for im, gt in zip(self._images, self._gt):
            self.training_set.append((im, self.codec.encode(gt)))

    def no_encode(self) -> None:
        """
        Creates an unencoded dataset.
        """
        self.training_set = [
        ]  # type: List[Tuple[Union[Image, torch.Tensor], str]]
        for im, gt in zip(self._images, self._gt):
            self.training_set.append((im, gt))

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.preload:
            x, y = self.training_set[index]
            if self.aug:
                x = x.permute((1, 2, 0)).numpy()
                o = self.aug(image=x)
                x = torch.tensor(o['image'].transpose(2, 0, 1))
            return {'image': x, 'target': y}
        else:
            item = self.training_set[index]
            try:
                logger.debug(f'Attempting to load {item[0]}')
                im = item[0][0]
                if not isinstance(im, Image.Image):
                    im = Image.open(im)
                im, _ = next(
                    extract_polygons(
                        im, {
                            'type':
                            'baselines',
                            'lines': [{
                                'baseline': item[0][1],
                                'boundary': item[0][2]
                            }]
                        }))
                im = self.head_transforms(im)
                if not is_bitonal(im):
                    self.im_mode = im.mode
                im = self.tail_transforms(im)
                if self.aug:
                    im = im.permute((1, 2, 0)).numpy()
                    o = self.aug(image=im)
                    im = torch.tensor(o['image'].transpose(2, 0, 1))
                return {'image': im, 'target': item[1]}
            except Exception:
                idx = np.random.randint(0, len(self.training_set))
                logger.debug(traceback.format_exc())
                logger.info(f'Failed. Replacing with sample {idx}')
                return self[np.random.randint(0, len(self.training_set))]

    def __len__(self) -> int:
        return len(self.training_set)
Exemplo n.º 11
0
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag,
          min_delta, device, optimizer, lrate, momentum, weight_decay,
          schedule, partition, normalization, normalize_whitespace, codec,
          resize, reorder, training_files, evaluation_files, preload, threads,
          ground_truth):
    """
    Trains a model from image-text pairs.
    """
    if not load and append:
        raise click.BadOptionUsage(
            'append', 'append option requires loading an existing model')

    if resize != 'fail' and not load:
        raise click.BadOptionUsage(
            'resize', 'resize option requires loading an existing model')

    import re
    import torch
    import shutil
    import numpy as np

    from torch.utils.data import DataLoader

    from kraken.lib import models, vgsl, train
    from kraken.lib.util import make_printable
    from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle
    from kraken.lib.codec import PytorchCodec
    from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms

    logger.info('Building ground truth set from {} line images'.format(
        len(ground_truth) + len(training_files)))

    completed_epochs = 0
    # load model if given. if a new model has to be created we need to do that
    # after data set initialization, otherwise to output size is still unknown.
    nn = None
    #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output']

    if load:
        logger.info('Loading existing model from {} '.format(load))
        message('Loading existing model from {}'.format(load), nl=False)
        nn = vgsl.TorchVGSLModel.load_model(load)
        #if nn.user_metadata and load_hyper_parameters:
        #    for param in hyper_fields:
        #        if param in nn.user_metadata:
        #            logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
        #            message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
        #            locals()[param] = nn.user_metadata[param]
        message('\u2713', fg='green', nl=False)

    # preparse input sizes from vgsl string to seed ground truth data set
    # sizes and dimension ordering.
    if not nn:
        spec = spec.strip()
        if spec[0] != '[' or spec[-1] != ']':
            raise click.BadOptionUsage(
                'spec', 'VGSL spec {} not bracketed'.format(spec))
        blocks = spec[1:-1].split(' ')
        m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
        if not m:
            raise click.BadOptionUsage(
                'spec', 'Invalid input spec {}'.format(blocks[0]))
        batch, height, width, channels = [int(x) for x in m.groups()]
    else:
        batch, channels, height, width = nn.input
    try:
        transforms = generate_input_transforms(batch, height, width, channels,
                                               pad)
    except KrakenInputException as e:
        raise click.BadOptionUsage('spec', str(e))

    # disable automatic partition when given evaluation set explicitly
    if evaluation_files:
        partition = 1
    ground_truth = list(ground_truth)

    # merge training_files into ground_truth list
    if training_files:
        ground_truth.extend(training_files)

    if len(ground_truth) == 0:
        raise click.UsageError(
            'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.'
        )

    np.random.shuffle(ground_truth)

    if len(ground_truth) > 2500 and not preload:
        logger.info(
            'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter'
        )
        preload = False
    # implicit preloading enabled for small data sets
    if preload is None:
        preload = True

    tr_im = ground_truth[:int(len(ground_truth) * partition)]
    if evaluation_files:
        logger.debug('Using {} lines from explicit eval set'.format(
            len(evaluation_files)))
        te_im = evaluation_files
    else:
        te_im = ground_truth[int(len(ground_truth) * partition):]
        logger.debug('Taking {} lines from training for evaluation'.format(
            len(te_im)))

    # set multiprocessing tensor sharing strategy
    if 'file_system' in torch.multiprocessing.get_all_sharing_strategies():
        logger.debug(
            'Setting multiprocessing tensor sharing strategy to file_system')
        torch.multiprocessing.set_sharing_strategy('file_system')

    gt_set = GroundTruthDataset(normalization=normalization,
                                whitespace_normalization=normalize_whitespace,
                                reorder=reorder,
                                im_transforms=transforms,
                                preload=preload)
    with log.progressbar(tr_im, label='Building training set') as bar:
        for im in bar:
            logger.debug('Adding line {} to training set'.format(im))
            try:
                gt_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    val_set = GroundTruthDataset(normalization=normalization,
                                 whitespace_normalization=normalize_whitespace,
                                 reorder=reorder,
                                 im_transforms=transforms,
                                 preload=preload)
    with log.progressbar(te_im, label='Building validation set') as bar:
        for im in bar:
            logger.debug('Adding line {} to validation set'.format(im))
            try:
                val_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    logger.info(
        'Training set {} lines, validation set {} lines, alphabet {} symbols'.
        format(len(gt_set._images), len(val_set._images),
               len(gt_set.alphabet)))
    alpha_diff = set(gt_set.alphabet).symmetric_difference(
        set(val_set.alphabet))
    if alpha_diff:
        logger.warn('alphabet mismatch {}'.format(alpha_diff))
    logger.info('grapheme\tcount')
    for k, v in sorted(gt_set.alphabet.items(),
                       key=lambda x: x[1],
                       reverse=True):
        char = make_printable(k)
        if char == k:
            char = '\t' + char
        logger.info(u'{}\t{}'.format(char, v))

    logger.debug('Encoding training set')

    # use model codec when given
    if append:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)
        gt_set.encode(codec)
        message('Slicing and dicing model ', nl=False)
        # now we can create a new model
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        logger.info('Appending {} to existing model {} after {}'.format(
            spec, nn.spec, append))
        nn.append(append, spec)
        nn.add_codec(gt_set.codec)
        message('\u2713', fg='green')
        logger.info('Assembled model spec: {}'.format(nn.spec))
    elif load:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)

        # prefer explicitly given codec over network codec if mode is 'both'
        codec = codec if (codec and resize == 'both') else nn.codec

        try:
            gt_set.encode(codec)
        except KrakenEncodeException as e:
            message('Network codec not compatible with training set')
            alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys()))
            if resize == 'fail':
                logger.error(
                    'Training data and model codec alphabets mismatch: {}'.
                    format(alpha_diff))
                ctx.exit(code=1)
            elif resize == 'add':
                message('Adding missing labels to network ', nl=False)
                logger.info(
                    'Resizing codec to include {} new code points'.format(
                        len(alpha_diff)))
                codec.c2l.update({
                    k: [v]
                    for v, k in enumerate(alpha_diff,
                                          start=codec.max_label() + 1)
                })
                nn.add_codec(PytorchCodec(codec.c2l))
                logger.info(
                    'Resizing last layer in network to {} outputs'.format(
                        codec.max_label() + 1))
                nn.resize_output(codec.max_label() + 1)
                gt_set.encode(nn.codec)
                message('\u2713', fg='green')
            elif resize == 'both':
                message('Fitting network exactly to training set ', nl=False)
                logger.info(
                    'Resizing network or given codec to {} code sequences'.
                    format(len(gt_set.alphabet)))
                gt_set.encode(None)
                ncodec, del_labels = codec.merge(gt_set.codec)
                logger.info(
                    'Deleting {} output classes from network ({} retained)'.
                    format(len(del_labels),
                           len(codec) - len(del_labels)))
                gt_set.encode(ncodec)
                nn.resize_output(ncodec.max_label() + 1, del_labels)
                message('\u2713', fg='green')
            else:
                raise click.BadOptionUsage(
                    'resize', 'Invalid resize value {}'.format(resize))
    else:
        gt_set.encode(codec)
        logger.info('Creating new model {} with {} outputs'.format(
            spec,
            gt_set.codec.max_label() + 1))
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        nn = vgsl.TorchVGSLModel(spec)
        # initialize weights
        message('Initializing model ', nl=False)
        nn.init_weights()
        nn.add_codec(gt_set.codec)
        # initialize codec
        message('\u2713', fg='green')

    # half the number of data loading processes if device isn't cuda and we haven't enabled preloading
    if device == 'cpu' and not preload:
        loader_threads = threads // 2
    else:
        loader_threads = threads
    train_loader = DataLoader(gt_set,
                              batch_size=1,
                              shuffle=True,
                              num_workers=loader_threads,
                              pin_memory=True)
    threads -= loader_threads

    # don't encode validation set as the alphabets may not match causing encoding failures
    val_set.training_set = list(zip(val_set._images, val_set._gt))

    logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(
        optimizer, lrate, momentum))

    # set mode to trainindg
    nn.train()

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    nn.set_num_threads(threads)

    logger.debug('Moving model to device {}'.format(device))
    optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0)

    if 'accuracy' not in nn.user_metadata:
        nn.user_metadata['accuracy'] = []

    tr_it = TrainScheduler(optim)
    if schedule == '1cycle':
        add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum,
                   momentum - 0.10, weight_decay)
    else:
        # constant learning rate scheduler
        tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay,
                        train.annealing_const)

    if quit == 'early':
        st_it = EarlyStopping(min_delta, lag)
    elif quit == 'dumb':
        st_it = EpochStopping(epochs - completed_epochs)
    else:
        raise click.BadOptionUsage(
            'quit', 'Invalid training interruption scheme {}'.format(quit))

    #for param in hyper_fields:
    #    logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param]))
    #    nn.user_metadata[param] = locals()[param]

    trainer = train.KrakenTrainer(model=nn,
                                  optimizer=optim,
                                  device=device,
                                  filename_prefix=output,
                                  event_frequency=freq,
                                  train_set=train_loader,
                                  val_set=val_set,
                                  stopper=st_it)

    trainer.add_lr_scheduler(tr_it)

    with log.progressbar(label='stage {}/{}'.format(
            1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'),
                         length=trainer.event_it,
                         show_pos=True) as bar:

        def _draw_progressbar():
            bar.update(1)

        def _print_eval(epoch, accuracy, chars, error):
            message('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            # reset progress bar
            bar.label = 'stage {}/{}'.format(
                epoch + 1,
                trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞')
            bar.pos = 0
            bar.finished = False

        trainer.run(_print_eval, _draw_progressbar)

    if quit == 'early':
        message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
                format(output, trainer.stopper.best_epoch,
                       trainer.stopper.best_loss))
        logger.info(
            'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
            format(output, trainer.stopper.best_epoch,
                   trainer.stopper.best_loss))
        shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch),
                    '{}_best.mlmodel'.format(output))
Exemplo n.º 12
0
    def recognition_train_gen(
            cls,
            hyper_params: Dict = default_specs.RECOGNITION_HYPER_PARAMS,
            progress_callback: Callable[[str, int], Callable[
                [None], None]] = lambda string, length: lambda: None,
            message: Callable[[str], None] = lambda *args, **kwargs: None,
            output: str = 'model',
            spec: str = default_specs.RECOGNITION_SPEC,
            append: Optional[int] = None,
            load: Optional[str] = None,
            device: str = 'cpu',
            reorder: bool = True,
            training_data: Sequence[Dict] = None,
            evaluation_data: Sequence[Dict] = None,
            preload: Optional[bool] = None,
            threads: int = 1,
            load_hyper_parameters: bool = False,
            repolygonize: bool = False,
            force_binarization: bool = False,
            format_type: str = 'path',
            codec: Optional[Dict] = None,
            resize: str = 'fail',
            augment: bool = False):
        """
        This is an ugly constructor that takes all the arguments from the command
        line driver, finagles the datasets, models, and hyperparameters correctly
        and returns a KrakenTrainer object.

        Setup parameters (load, training_data, evaluation_data, ....) are named,
        model hyperparameters (everything in
        kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS) are in in the
        `hyper_params` argument.

        Args:
            hyper_params (dict): Hyperparameter dictionary containing all fields
                                 from
                                 kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS
            progress_callback (Callable): Callback for progress reports on various
                                          computationally expensive processes. A
                                          human readable string and the process
                                          length is supplied. The callback has to
                                          return another function which will be
                                          executed after each step.
            message (Callable): Messaging printing method for above log but below
                                warning level output, i.e. infos that should
                                generally be shown to users.
            **kwargs: Setup parameters, i.e. CLI parameters of the train() command.

        Returns:
            A KrakenTrainer object.
        """
        # load model if given. if a new model has to be created we need to do that
        # after data set initialization, otherwise to output size is still unknown.
        nn = None

        if load:
            logger.info(f'Loading existing model from {load} ')
            message(f'Loading existing model from {load} ', nl=False)
            nn = vgsl.TorchVGSLModel.load_model(load)
            if load_hyper_parameters:
                hyper_params.update(nn.hyper_params)
                nn.hyper_params = hyper_params
            message('\u2713', fg='green', nl=False)

        DatasetClass = GroundTruthDataset
        valid_norm = True
        if format_type and format_type != 'path':
            logger.info(
                f'Parsing {len(training_data)} XML files for training data')
            if repolygonize:
                message('Repolygonizing data')
            training_data = preparse_xml_data(training_data, format_type,
                                              repolygonize)
            evaluation_data = preparse_xml_data(evaluation_data, format_type,
                                                repolygonize)
            DatasetClass = PolygonGTDataset
            valid_norm = False
        elif format_type == 'path':
            if force_binarization:
                logger.warning(
                    'Forced binarization enabled in `path` mode. Will be ignored.'
                )
                force_binarization = False
            if repolygonize:
                logger.warning(
                    'Repolygonization enabled in `path` mode. Will be ignored.'
                )
            training_data = [{'image': im} for im in training_data]
            if evaluation_data:
                evaluation_data = [{'image': im} for im in evaluation_data]
            valid_norm = True
        # format_type is None. Determine training type from length of training data entry
        else:
            if len(training_data[0]) >= 4:
                DatasetClass = PolygonGTDataset
                valid_norm = False
            else:
                if force_binarization:
                    logger.warning(
                        'Forced binarization enabled with box lines. Will be ignored.'
                    )
                    force_binarization = False
                if repolygonize:
                    logger.warning(
                        'Repolygonization enabled with box lines. Will be ignored.'
                    )

        # preparse input sizes from vgsl string to seed ground truth data set
        # sizes and dimension ordering.
        if not nn:
            spec = spec.strip()
            if spec[0] != '[' or spec[-1] != ']':
                raise click.BadOptionUsage(
                    'spec', 'VGSL spec {} not bracketed'.format(spec))
            blocks = spec[1:-1].split(' ')
            m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
            if not m:
                raise click.BadOptionUsage('spec',
                                           f'Invalid input spec {blocks[0]}')
            batch, height, width, channels = [int(x) for x in m.groups()]
        else:
            batch, channels, height, width = nn.input
        try:
            transforms = generate_input_transforms(batch, height, width,
                                                   channels,
                                                   hyper_params['pad'],
                                                   valid_norm,
                                                   force_binarization)
        except KrakenInputException as e:
            raise click.BadOptionUsage('spec', str(e))

        if len(training_data) > 2500 and not preload:
            logger.info(
                'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter'
            )
            preload = False
        # implicit preloading enabled for small data sets
        if preload is None:
            preload = True

        # set multiprocessing tensor sharing strategy
        if 'file_system' in torch.multiprocessing.get_all_sharing_strategies():
            logger.debug(
                'Setting multiprocessing tensor sharing strategy to file_system'
            )
            torch.multiprocessing.set_sharing_strategy('file_system')

        gt_set = DatasetClass(
            normalization=hyper_params['normalization'],
            whitespace_normalization=hyper_params['normalize_whitespace'],
            reorder=reorder,
            im_transforms=transforms,
            preload=preload,
            augmentation=hyper_params['augment'])
        bar = progress_callback('Building training set', len(training_data))
        for im in training_data:
            logger.debug(f'Adding line {im} to training set')
            try:
                gt_set.add(**im)
                bar()
            except FileNotFoundError as e:
                logger.warning(f'{e.strerror}: {e.filename}. Skipping.')
            except KrakenInputException as e:
                logger.warning(str(e))

        val_set = DatasetClass(
            normalization=hyper_params['normalization'],
            whitespace_normalization=hyper_params['normalize_whitespace'],
            reorder=reorder,
            im_transforms=transforms,
            preload=preload)
        bar = progress_callback('Building validation set',
                                len(evaluation_data))
        for im in evaluation_data:
            logger.debug(f'Adding line {im} to validation set')
            try:
                val_set.add(**im)
                bar()
            except FileNotFoundError as e:
                logger.warning(f'{e.strerror}: {e.filename}. Skipping.')
            except KrakenInputException as e:
                logger.warning(str(e))

        if len(gt_set._images) == 0:
            logger.error(
                'No valid training data was provided to the train command. Please add valid XML or line data.'
            )
            return None

        logger.info(
            f'Training set {len(gt_set._images)} lines, validation set {len(val_set._images)} lines, alphabet {len(gt_set.alphabet)} symbols'
        )
        alpha_diff_only_train = set(gt_set.alphabet).difference(
            set(val_set.alphabet))
        alpha_diff_only_val = set(val_set.alphabet).difference(
            set(gt_set.alphabet))
        if alpha_diff_only_train:
            logger.warning(
                f'alphabet mismatch: chars in training set only: {alpha_diff_only_train} (not included in accuracy test during training)'
            )
        if alpha_diff_only_val:
            logger.warning(
                f'alphabet mismatch: chars in validation set only: {alpha_diff_only_val} (not trained)'
            )
        logger.info('grapheme\tcount')
        for k, v in sorted(gt_set.alphabet.items(),
                           key=lambda x: x[1],
                           reverse=True):
            char = make_printable(k)
            if char == k:
                char = '\t' + char
            logger.info(f'{char}\t{v}')

        logger.debug('Encoding training set')

        # use model codec when given
        if append:
            # is already loaded
            nn = cast(vgsl.TorchVGSLModel, nn)
            gt_set.encode(codec)
            message('Slicing and dicing model ', nl=False)
            # now we can create a new model
            spec = '[{} O1c{}]'.format(spec[1:-1],
                                       gt_set.codec.max_label() + 1)
            logger.info(
                f'Appending {spec} to existing model {nn.spec} after {append}')
            nn.append(append, spec)
            nn.add_codec(gt_set.codec)
            message('\u2713', fg='green')
            logger.info(f'Assembled model spec: {nn.spec}')
        elif load:
            # is already loaded
            nn = cast(vgsl.TorchVGSLModel, nn)

            # prefer explicitly given codec over network codec if mode is 'both'
            codec = codec if (codec and resize == 'both') else nn.codec

            try:
                gt_set.encode(codec)
            except KrakenEncodeException:
                message('Network codec not compatible with training set')
                alpha_diff = set(gt_set.alphabet).difference(
                    set(codec.c2l.keys()))
                if resize == 'fail':
                    logger.error(
                        f'Training data and model codec alphabets mismatch: {alpha_diff}'
                    )
                    return None
                elif resize == 'add':
                    message('Adding missing labels to network ', nl=False)
                    logger.info(
                        f'Resizing codec to include {len(alpha_diff)} new code points'
                    )
                    codec.c2l.update({
                        k: [v]
                        for v, k in enumerate(alpha_diff,
                                              start=codec.max_label() + 1)
                    })
                    nn.add_codec(PytorchCodec(codec.c2l))
                    logger.info(
                        f'Resizing last layer in network to {codec.max_label()+1} outputs'
                    )
                    nn.resize_output(codec.max_label() + 1)
                    gt_set.encode(nn.codec)
                    message('\u2713', fg='green')
                elif resize == 'both':
                    message('Fitting network exactly to training set ',
                            nl=False)
                    logger.info(
                        f'Resizing network or given codec to {gt_set.alphabet} code sequences'
                    )
                    gt_set.encode(None)
                    ncodec, del_labels = codec.merge(gt_set.codec)
                    logger.info(
                        f'Deleting {len(del_labels)} output classes from network ({len(codec)-len(del_labels)} retained)'
                    )
                    gt_set.encode(ncodec)
                    nn.resize_output(ncodec.max_label() + 1, del_labels)
                    message('\u2713', fg='green')
                else:
                    logger.error(f'invalid resize parameter value {resize}')
                    return None
        else:
            gt_set.encode(codec)
            logger.info(
                f'Creating new model {spec} with {gt_set.codec.max_label()+1} outputs'
            )
            spec = '[{} O1c{}]'.format(spec[1:-1],
                                       gt_set.codec.max_label() + 1)
            nn = vgsl.TorchVGSLModel(spec)
            # initialize weights
            message('Initializing model ', nl=False)
            nn.init_weights()
            nn.add_codec(gt_set.codec)
            # initialize codec
            message('\u2713', fg='green')

        if nn.one_channel_mode and gt_set.im_mode != nn.one_channel_mode:
            logger.warning(
                f'Neural network has been trained on mode {nn.one_channel_mode} images, training set contains mode {gt_set.im_mode} data. Consider setting `force_binarization`'
            )

        if format_type != 'path' and nn.seg_type == 'bbox':
            logger.warning(
                'Neural network has been trained on bounding box image information but training set is polygonal.'
            )

        # half the number of data loading processes if device isn't cuda and we haven't enabled preloading

        if device == 'cpu' and not preload:
            loader_threads = threads // 2
        else:
            loader_threads = threads
        train_loader = InfiniteDataLoader(
            gt_set,
            batch_size=hyper_params['batch_size'],
            shuffle=True,
            num_workers=loader_threads,
            pin_memory=True,
            collate_fn=collate_sequences)
        threads = max(threads - loader_threads, 1)

        # don't encode validation set as the alphabets may not match causing encoding failures
        val_set.no_encode()
        val_loader = DataLoader(val_set,
                                batch_size=hyper_params['batch_size'],
                                num_workers=loader_threads,
                                pin_memory=True,
                                collate_fn=collate_sequences)

        logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(
            hyper_params['optimizer'], hyper_params['lrate'],
            hyper_params['momentum']))

        # set model type metadata field
        nn.model_type = 'recognition'

        # set mode to trainindg
        nn.train()

        # set number of OpenMP threads
        logger.debug(f'Set OpenMP threads to {threads}')
        nn.set_num_threads(threads)

        optim = getattr(torch.optim,
                        hyper_params['optimizer'])(nn.nn.parameters(), lr=0)

        if 'seg_type' not in nn.user_metadata:
            nn.user_metadata[
                'seg_type'] = 'baselines' if format_type != 'path' else 'bbox'

        tr_it = TrainScheduler(optim)
        if hyper_params['schedule'] == '1cycle':
            add_1cycle(tr_it, int(len(gt_set) * hyper_params['epochs']),
                       hyper_params['lrate'], hyper_params['momentum'],
                       hyper_params['momentum'] - 0.10,
                       hyper_params['weight_decay'])
        elif hyper_params['schedule'] == 'exponential':
            add_exponential_decay(tr_it,
                                  int(len(gt_set) * hyper_params['epochs']),
                                  len(gt_set), hyper_params['lrate'], 0.95,
                                  hyper_params['momentum'],
                                  hyper_params['weight_decay'])
        else:
            # constant learning rate scheduler
            tr_it.add_phase(1, 2 * (hyper_params['lrate'], ),
                            2 * (hyper_params['momentum'], ),
                            hyper_params['weight_decay'], annealing_const)

        if hyper_params['quit'] == 'early':
            st_it = EarlyStopping(hyper_params['min_delta'],
                                  hyper_params['lag'])
        elif hyper_params['quit'] == 'dumb':
            st_it = EpochStopping(hyper_params['epochs'] -
                                  hyper_params['completed_epochs'])
        else:
            logger.error(f'Invalid training interruption scheme {quit}')
            return None

        trainer = cls(model=nn,
                      optimizer=optim,
                      device=device,
                      filename_prefix=output,
                      event_frequency=hyper_params['freq'],
                      train_set=train_loader,
                      val_set=val_loader,
                      stopper=st_it)

        trainer.add_lr_scheduler(tr_it)

        return trainer
Exemplo n.º 13
0
class GroundTruthDataset(Dataset):
    """
    Dataset for ground truth used during training.

    All data is cached in memory.
    """
    def __init__(self, split: Callable[[str], str] = lambda x: os.path.splitext(x)[0],
                 suffix: str = '.gt.txt',
                 normalization: Optional[str] = None,
                 whitespace_normalization: bool = True,
                 reorder: bool = True,
                 im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]),
                 preload: bool = True) -> None:
        """
        Reads a list of image-text pairs and creates a ground truth set.

        Args:
            split (func): Function for generating the base name without
                          extensions from paths
            suffix (str): Suffix to attach to image base name for text
                          retrieval
            mode (str): Image color space. Either RGB (color) or L
                        (grayscale/bw). Only L is compatible with vertical
                        scaling/dewarping.
            scale (int, tuple): Target height or (width, height) of dewarped
                                line images. Vertical-only scaling is through
                                CenterLineNormalizer, resizing with Lanczos
                                interpolation. Set to 0 to disable.
            normalization (str): Unicode normalization for gt
            whitespace_normalization (str): Normalizes unicode whitespace and
                                            strips whitespace.
            reorder (bool): Whether to rearrange code points in "display"/LTR
                            order
            im_transforms (func): Function taking an PIL.Image and returning a
                                  tensor suitable for forward passes.
            preload (bool): Enables preloading and preprocessing of image files.
        """
        self.suffix = suffix
        self.split = lambda x: split(x) + self.suffix
        self._images = []  # type:  Union[List[Image], List[torch.Tensor]]
        self._gt = []  # type:  List[str]
        self.alphabet = Counter()  # type: Counter
        self.text_transforms = []  # type: List[Callable[[str], str]]
        self.transforms = im_transforms
        self.preload = preload
        # built text transformations
        if normalization:
            self.text_transforms.append(lambda x: unicodedata.normalize(cast(str, normalization), x))
        if whitespace_normalization:
            self.text_transforms.append(lambda x: regex.sub('\s', ' ', x).strip())
        if reorder:
            self.text_transforms.append(bd.get_display)

    def add(self, image: str) -> None:
        """
        Adds a line-image-text pair to the dataset.

        Args:
            image (str): Input image path
        """
        with open(self.split(image), 'r', encoding='utf-8') as fp:
            gt = fp.read().strip('\n\r')
            for func in self.text_transforms:
                gt = func(gt)
            if not gt:
                raise KrakenInputException('Text line is empty ({})'.format(fp.name))
        if self.preload:
            im = Image.open(image)
            try:
                im = self.transforms(im)
            except ValueError as e:
                raise KrakenInputException('Image transforms failed on {}'.format(image))
            self._images.append(im)
        else:
            self._images.append(image)
        self._gt.append(gt)
        self.alphabet.update(gt)

    def add_loaded(self, image: Image.Image, gt: str) -> None:
        """
        Adds an already loaded  line-image-text pair to the dataset.

        Args:
            image (PIL.Image.Image): Line image
            gt (str): Text contained in the line image
        """
        if self.preload:
            try:
                im = self.transforms(image)
            except ValueError as e:
                raise KrakenInputException('Image transforms failed on {}'.format(image))
            self._images.append(im)
        else:
            self._images.append(image)
        for func in self.text_transforms:
            gt = func(gt)
        self._gt.append(gt)
        self.alphabet.update(gt)

    def encode(self, codec: Optional[PytorchCodec] = None) -> None:
        """
        Adds a codec to the dataset and encodes all text lines.

        Has to be run before sampling from the dataset.
        """
        if codec:
            self.codec = codec
        else:
            self.codec = PytorchCodec(''.join(self.alphabet.keys()))
        self.training_set = []  # type: List[Tuple[Union[Image, torch.Tensor], torch.Tensor]]
        for im, gt in zip(self._images, self._gt):
            self.training_set.append((im, self.codec.encode(gt)))

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.preload:
            return self.training_set[index]
        else:
            item = self.training_set[index]
            try:
                logger.debug('Attempting to load {}'.format(item[0]))
                im = item[0]
                if not isinstance(im, Image.Image):
                    im = Image.open(im)
                return (self.transforms(im), item[1])
            except Exception:
                idx = np.random.randint(0, len(self.training_set))
                logger.debug('Failed. Replacing with sample {}'.format(idx))
                return self[np.random.randint(0, len(self.training_set))]

    def __len__(self) -> int:
        return len(self.training_set)
Exemplo n.º 14
0
def train(ctx, pad, output, spec, append, load, savefreq, report, quit, epochs,
          lag, min_delta, device, optimizer, lrate, momentum, weight_decay,
          schedule, partition, normalization, codec, resize, reorder,
          training_files, evaluation_files, preload, threads, ground_truth):
    """
    Trains a model from image-text pairs.
    """
    if not load and append:
        raise click.BadOptionUsage(
            'append', 'append option requires loading an existing model')

    if resize != 'fail' and not load:
        raise click.BadOptionUsage(
            'resize', 'resize option requires loading an existing model')

    import re
    import torch
    import shutil
    import numpy as np

    from torch.utils.data import DataLoader

    from kraken.lib import models, vgsl, train
    from kraken.lib.util import make_printable
    from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle
    from kraken.lib.codec import PytorchCodec
    from kraken.lib.dataset import GroundTruthDataset, compute_error, generate_input_transforms

    logger.info('Building ground truth set from {} line images'.format(
        len(ground_truth) + len(training_files)))

    # load model if given. if a new model has to be created we need to do that
    # after data set initialization, otherwise to output size is still unknown.
    nn = None
    if load:
        logger.info('Loading existing model from {} '.format(load))
        message('Loading model {}'.format(load), nl=False)
        nn = vgsl.TorchVGSLModel.load_model(load)
        message('\u2713', fg='green', nl=False)

    # preparse input sizes from vgsl string to seed ground truth data set
    # sizes and dimension ordering.
    if not nn:
        spec = spec.strip()
        if spec[0] != '[' or spec[-1] != ']':
            raise click.BadOptionUsage(
                'spec', 'VGSL spec {} not bracketed'.format(spec))
        blocks = spec[1:-1].split(' ')
        m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
        if not m:
            raise click.BadOptionUsage(
                'spec', 'Invalid input spec {}'.format(blocks[0]))
        batch, height, width, channels = [int(x) for x in m.groups()]
    else:
        batch, channels, height, width = nn.input
    try:
        transforms = generate_input_transforms(batch, height, width, channels,
                                               pad)
    except KrakenInputException as e:
        raise click.BadOptionUsage('spec', str(e))

    # disable automatic partition when given evaluation set explicitly
    if evaluation_files:
        partition = 1
    ground_truth = list(ground_truth)

    # merge training_files into ground_truth list
    if training_files:
        ground_truth.extend(training_files)

    if len(ground_truth) == 0:
        raise click.UsageError(
            'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.'
        )

    np.random.shuffle(ground_truth)

    if len(ground_truth) > 2500 and not preload:
        logger.info(
            'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter'
        )
        preload = False
    # implicit preloading enabled for small data sets
    if preload is None:
        preload = True

    tr_im = ground_truth[:int(len(ground_truth) * partition)]
    if evaluation_files:
        logger.debug('Using {} lines from explicit eval set'.format(
            len(evaluation_files)))
        te_im = evaluation_files
    else:
        te_im = ground_truth[int(len(ground_truth) * partition):]
        logger.debug('Taking {} lines from training for evaluation'.format(
            len(te_im)))

    gt_set = GroundTruthDataset(normalization=normalization,
                                reorder=reorder,
                                im_transforms=transforms,
                                preload=preload)
    with log.progressbar(tr_im, label='Building training set') as bar:
        for im in bar:
            logger.debug('Adding line {} to training set'.format(im))
            try:
                gt_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    val_set = GroundTruthDataset(normalization=normalization,
                                 reorder=reorder,
                                 im_transforms=transforms,
                                 preload=preload)
    with log.progressbar(te_im, label='Building validation set') as bar:
        for im in bar:
            logger.debug('Adding line {} to validation set'.format(im))
            try:
                val_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    logger.info(
        'Training set {} lines, validation set {} lines, alphabet {} symbols'.
        format(len(gt_set._images), len(val_set._images),
               len(gt_set.alphabet)))
    alpha_diff = set(gt_set.alphabet).symmetric_difference(
        set(val_set.alphabet))
    if alpha_diff:
        logger.warn('alphabet mismatch {}'.format(alpha_diff))
    logger.info('grapheme\tcount')
    for k, v in sorted(gt_set.alphabet.items(),
                       key=lambda x: x[1],
                       reverse=True):
        char = make_printable(k)
        if char == k:
            char = '\t' + char
        logger.info(u'{}\t{}'.format(char, v))

    logger.debug('Encoding training set')

    # use model codec when given
    if append:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)
        gt_set.encode(codec)
        message('Slicing and dicing model ', nl=False)
        # now we can create a new model
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        logger.info('Appending {} to existing model {} after {}'.format(
            spec, nn.spec, append))
        nn.append(append, spec)
        nn.add_codec(gt_set.codec)
        message('\u2713', fg='green')
        logger.info('Assembled model spec: {}'.format(nn.spec))
    elif load:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)

        # prefer explicitly given codec over network codec if mode is 'both'
        codec = codec if (codec and resize == 'both') else nn.codec

        try:
            gt_set.encode(codec)
        except KrakenEncodeException as e:
            message('Network codec not compatible with training set')
            alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys()))
            if resize == 'fail':
                logger.error(
                    'Training data and model codec alphabets mismatch: {}'.
                    format(alpha_diff))
                ctx.exit(code=1)
            elif resize == 'add':
                message('Adding missing labels to network ', nl=False)
                logger.info(
                    'Resizing codec to include {} new code points'.format(
                        len(alpha_diff)))
                codec.c2l.update({
                    k: [v]
                    for v, k in enumerate(alpha_diff,
                                          start=codec.max_label() + 1)
                })
                nn.add_codec(PytorchCodec(codec.c2l))
                logger.info(
                    'Resizing last layer in network to {} outputs'.format(
                        codec.max_label() + 1))
                nn.resize_output(codec.max_label() + 1)
                message('\u2713', fg='green')
            elif resize == 'both':
                message('Fitting network exactly to training set ', nl=False)
                logger.info(
                    'Resizing network or given codec to {} code sequences'.
                    format(len(gt_set.alphabet)))
                gt_set.encode(None)
                ncodec, del_labels = codec.merge(gt_set.codec)
                logger.info(
                    'Deleting {} output classes from network ({} retained)'.
                    format(len(del_labels),
                           len(codec) - len(del_labels)))
                gt_set.encode(ncodec)
                nn.resize_output(ncodec.max_label() + 1, del_labels)
                message('\u2713', fg='green')
            else:
                raise click.BadOptionUsage(
                    'resize', 'Invalid resize value {}'.format(resize))
    else:
        gt_set.encode(codec)
        logger.info('Creating new model {} with {} outputs'.format(
            spec,
            gt_set.codec.max_label() + 1))
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        nn = vgsl.TorchVGSLModel(spec)
        # initialize weights
        message('Initializing model ', nl=False)
        nn.init_weights()
        nn.add_codec(gt_set.codec)
        # initialize codec
        message('\u2713', fg='green')

    train_loader = DataLoader(gt_set,
                              batch_size=1,
                              shuffle=True,
                              pin_memory=True)

    # don't encode validation set as the alphabets may not match causing encoding failures
    val_set.training_set = list(zip(val_set._images, val_set._gt))

    logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(
        optimizer, lrate, momentum))

    # set mode to trainindg
    nn.train()

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    nn.set_num_threads(threads)

    logger.debug('Moving model to device {}'.format(device))
    rec = models.TorchSeqRecognizer(nn, train=True, device=device)
    optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0)

    tr_it = TrainScheduler(optim)
    if schedule == '1cycle':
        add_1cycle(tr_it, epochs * len(gt_set), lrate, momentum,
                   momentum - 0.10, weight_decay)
    else:
        # constant learning rate scheduler
        tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay,
                        train.annealing_const)

    st_it = cast(TrainStopper, None)  # type: TrainStopper
    if quit == 'early':
        st_it = EarlyStopping(train_loader, min_delta, lag)
    elif quit == 'dumb':
        st_it = EpochStopping(train_loader, epochs)
    else:
        raise click.BadOptionUsage(
            'quit', 'Invalid training interruption scheme {}'.format(quit))

    for epoch, loader in enumerate(st_it):
        with log.progressbar(label='epoch {}/{}'.format(
                epoch, epochs - 1 if epochs > 0 else '∞'),
                             length=len(loader),
                             show_pos=True) as bar:
            acc_loss = torch.tensor(0.0).to(device, non_blocking=True)
            for trial, (input, target) in enumerate(loader):
                tr_it.step()
                input = input.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)
                input = input.requires_grad_()
                o = nn.nn(input)
                # height should be 1 by now
                if o.size(2) != 1:
                    raise KrakenInputException(
                        'Expected dimension 3 to be 1, actual {}'.format(
                            o.size(2)))
                o = o.squeeze(2)
                optim.zero_grad()
                # NCW -> WNC
                loss = nn.criterion(
                    o.permute(2, 0, 1),  # type: ignore
                    target,
                    (o.size(2), ),
                    (target.size(1), ))
                logger.info('trial {}'.format(trial))
                if not torch.isinf(loss):
                    loss.backward()
                    optim.step()
                else:
                    logger.debug('infinite loss in trial {}'.format(trial))
                bar.update(1)
        if not epoch % savefreq:
            logger.info('Saving to {}_{}'.format(output, epoch))
            try:
                nn.save_model('{}_{}.mlmodel'.format(output, epoch))
            except Exception as e:
                logger.error('Saving model failed: {}'.format(str(e)))
        if not epoch % report:
            logger.debug('Starting evaluation run')
            nn.eval()
            chars, error = compute_error(rec, list(val_set))
            nn.train()
            accuracy = (chars - error) / chars
            logger.info('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            message('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            st_it.update(accuracy)
    if quit == 'early':
        message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
                format(output, st_it.best_epoch, st_it.best_loss))
        logger.info(
            'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
            format(output, st_it.best_epoch, st_it.best_loss))
        shutil.copy('{}_{}.mlmodel'.format(output, st_it.best_epoch),
                    '{}_best.mlmodel'.format(output))