예제 #1
0
    def apply(self, sample: Sample) -> Sample:
        inputs = sample.inputs
        outputs = sample.outputs
        assert(inputs['img_len'].shape == (1,))
        assert(inputs['meta'].shape == (1,))
        inputs = inputs.copy()
        outputs = outputs.copy()
        inputs['img_len'] = inputs['img_len'][0]
        inputs['meta'] = inputs['meta'][0]

        def reshape_outputs(suffix):
            out_len = 'out_len' + suffix
            if out_len in outputs and outputs[out_len].shape == (1,):
                outputs[out_len] = outputs[out_len][0]

            for name in {'logits', 'softmax', 'blank_last_logits', 'blank_last_softmax'}:
                name += suffix
                if name in outputs:
                    outputs[name] = outputs[name][:outputs[out_len]]

        reshape_outputs('')
        for i in range(self.params.ensemble_):
            reshape_outputs(f"_{i}")

        return sample.new_inputs(inputs).new_outputs(outputs)
예제 #2
0
 def apply(self, sample: Sample) -> Sample:
     # data augmentation
     if not self.params.data_aug_params.no_augs() \
             and sample.inputs is not None \
             and self.data_augmenter \
             and np.random.rand() <= self.params.data_aug_params.to_rel():
         line, text = self.augment(sample.inputs, sample.targets,
                                   sample.meta)
         return sample.new_inputs(line).new_targets(text)
     return sample
예제 #3
0
파일: reshape.py 프로젝트: znsoftm/calamari
    def apply(self, sample: Sample) -> Sample:
        inputs = sample.inputs
        outputs = sample.outputs
        assert (inputs['img_len'].shape == (1, ))
        assert (inputs['meta'].shape == (1, ))
        inputs = inputs.copy()
        outputs = outputs.copy()
        inputs['img_len'] = inputs['img_len'][0]
        inputs['meta'] = inputs['meta'][0]
        if 'out_len' in outputs and outputs['out_len'].shape == (1, ):
            outputs['out_len'] = outputs['out_len'][0]

        for name in {
                'logits', 'softmax', 'blank_last_logits', 'blank_last_softmax'
        }:
            if name in outputs:
                outputs[name] = outputs[name][:outputs['out_len']]
        return sample.new_inputs(inputs).new_outputs(outputs)
예제 #4
0
    def apply(self, sample: Sample) -> Sample:
        codec = self.params.codec
        # final preparation
        text = np.array(
            codec.encode(sample.targets) if sample.targets else np.zeros(
                (0, ), dtype='int32'))
        line = sample.inputs

        # gray or binary input, add missing axis
        if len(line.shape) == 2:
            line = np.expand_dims(line, axis=-1)

        if line.shape[-1] != self.params.input_channels:
            raise ValueError(
                f"Expected {self.params.input_channels} channels but got {line.shape[-1]}. Shape of input {line.shape}"
            )

        if self.mode in {PipelineMode.Training, PipelineMode.Evaluation
                         } and not self.is_valid_line(
                             text,
                             len(line) // self.params.downscale_factor_):
            # skip longer outputs than inputs (also in evaluation due to loss computation)
            logger.warning(
                f"Skipping line with longer outputs than inputs (id={sample.meta['id']})"
            )
            return sample.new_invalid()

        if self.mode in {PipelineMode.Training, PipelineMode.Evaluation
                         } and len(text) == 0:
            logger.warning(
                f"Skipping empty line with empty GT (id={sample.meta['id']})")
            return sample.new_invalid()

        return sample.new_inputs({
            'img': line.astype(np.uint8),
            'img_len': [len(line)],
            'meta': [json.dumps(sample.meta)]
        }).new_targets({
            'gt': text,
            'gt_len': [len(text)],
            'fold_id': [sample.meta.get('fold_id', -1)]
        })
예제 #5
0
 def apply(self, sample: Sample) -> Sample:
     return sample.new_inputs(self._apply_single(sample.inputs,
                                                 sample.meta))