Beispiel #1
0
class FilterKeys(DataProcess):
    required = State(default=[])
    superfluous = State(default=[])

    def __init__(self, **kwargs):
        super().__init__(self, **kwargs)

        self.required_keys = set(self.required)
        self.superfluous_keys = set(self.superfluous)
        if len(self.required_keys) > 0 and len(self.superfluous_keys) > 0:
            raise ValueError(
                'required_keys and superfluous_keys can not be specified at the same time.')

    def process(self, data):
        for key in self.required:
            assert key in data, '%s is required in data' % key

        superfluous = self.superfluous_keys
        if len(superfluous) == 0:
            for key in data.keys():
                if key not in self.required_keys:
                    superfluous.add(key)

        for key in superfluous:
            del data[key]
        return data
Beispiel #2
0
class ShowSettings(Configurable):
    data_loader = State()
    representer = State()
    visualizer = State()

    def __init__(self, **kwargs):
        self.load_all(**kwargs)
Beispiel #3
0
class SerializeBox(DataProcess):
    box_key = State(default='charboxes')
    format = State(default='NP2')

    def process(self, data):
        data[self.box_key] = data['lines'].quads
        return data
Beispiel #4
0
class EvaluationSettings(Configurable):
    data_loaders = State()
    visualize = State(default=True)
    resume = State()

    def __init__(self, **kwargs):
        self.load_all(**kwargs)
Beispiel #5
0
class RandomSampleDataLoader(Configurable, torch.utils.data.DataLoader):
    datasets = State()
    weights = State()
    batch_size = State(default=256)
    num_workers = State(default=10)
    size = State(default=2**31)

    def __init__(self, **kwargs):
        self.load_all(**kwargs)

        cmd = kwargs['cmd']
        if 'batch_size' in cmd:
            self.batch_size = cmd['batch_size']

        probs = []
        for dataset, weight in zip(self.datasets, self.weights):
            probs.append(np.full(len(dataset), weight / len(dataset)))

    # pytorch自带的一种合并子数据集的方式,ConcatDataset类。
        dataset = ConcatDataset(self.datasets)
        probs = np.concatenate(probs)
        assert (len(dataset) == len(probs))

        sampler = RandomSampleSampler(dataset, probs, self.size)

        torch.utils.data.DataLoader.__init__(
            self,
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sampler=sampler,
            worker_init_fn=default_worker_init_fn,
        )
Beispiel #6
0
class ResizeData(_ResizeImage, DataProcess):
    key = State(default='image')
    box_key = State(default='polygons')
    image_size = State(default=[64, 256])  # height, width

    def __init__(self, cmd={}, mode=None, key=None, box_key=None, **kwargs):
        self.load_all(**kwargs)
        if mode is not None:
            self.mode = mode
        if key is not None:
            self.key = key
        if box_key is not None:
            self.box_key = box_key
        if 'resize_mode' in cmd:
            self.mode = cmd['resize_mode']
        assert self.mode in self.MODES

    def process(self, data):
        height, width = data['image'].shape[:2]
        new_height, new_width = self.get_image_size(data['image'])
        data[self.key] = self.resize_or_pad(data[self.key])

        charboxes = data[self.box_key]
        data[self.box_key] = charboxes.copy()
        data[self.box_key][:, :, 0] = data[self.box_key][:, :, 0] * \
            new_width / width
        data[self.box_key][:, :, 1] = data[self.box_key][:, :, 1] * \
            new_height / height
        return data
Beispiel #7
0
class RecognitionMetaLoader(MetaLoader):
    skip_vertical = State(default=False)
    case_sensitive = State(default=False)
    simplify = State(default=False)
    key = State(default='words')
    scan_meta = True

    def __init__(self, key=None, cmd={}, **kwargs):
        super().__init__(cmd=cmd, **kwargs)
        if key is not None:
            self.key = key

    def may_simplify(self, words):
        garbled = stringQ2B(words)
        if self.simplify:
            return HanziConv.toSimplified(garbled)
        return garbled

    def parse_meta(self, data_id, meta):
        word = self.may_simplify(self.get_annotation(meta)[self.key])
        vertical = self.get_annotation(meta).get('vertical', False)
        if self.skip_vertical and vertical:
            return None

        if word == '###':
            return None
        return dict(data_ids=data_id, gt=word)
Beispiel #8
0
class SimpleTextsnakeRepresenter(SimpleDetectionRepresenter):
    heatmap_thr = State(default=0.5)
    min_area = State(default=200)

    def postprocess(self, output):
        output['heatmask'] = self.threshold_heatmap(output['heatmap'],
                                                    self.heatmap_thr)

    def get_polygons(self, output):
        heatmask = output['heatmask'][0]
        radius = output['radius']

        _, contours, _ = cv2.findContours(heatmask.astype(np.uint8),
                                          cv2.RETR_EXTERNAL,
                                          cv2.CHAIN_APPROX_NONE)

        polygons = []
        for contour in contours:
            contour = contour[:, 0]
            mask = np.zeros(heatmask.shape[:2], dtype=np.uint8)
            for x, y in contour:
                r = radius[0, y, x]
                if r > 1:
                    cv2.circle(mask, (int(x), int(y)), int(r), 1, -1)

            _, conts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
                                           cv2.CHAIN_APPROX_NONE)
            if len(conts) == 1:
                poly = conts[0][:, 0]
                if Polygon(poly).geom_type == 'Polygon' and Polygon(
                        poly).area > self.min_area:
                    polygons.append(poly)

        return polygons
Beispiel #9
0
class SimpleMSRRepresenter(SimpleDetectionRepresenter):
    heatmap_thr = State(default=0.5)
    min_area = State(default=200)
    max_poly_points = State(default=32)

    def postprocess(self, output):
        output['heatmask'] = self.threshold_heatmap(output['heatmap'],
                                                    self.heatmap_thr)

    def get_polygons(self, output):
        heatmask = output['heatmask'][0]
        offset = output['offset']
        _, contours, _ = cv2.findContours(heatmask.astype(np.uint8),
                                          cv2.RETR_EXTERNAL,
                                          cv2.CHAIN_APPROX_NONE)

        polygons = []
        for contour in contours:
            contour = contour[:, 0]
            poly = []
            stride = len(contour) // self.max_poly_points + 1
            for x, y in contour[::stride]:
                point = offset[:, y, x] + (x, y)
                poly.append(point)
            if len(poly) >= 3:
                if Polygon(poly).area > self.min_area:
                    polygons.append(poly)

        return polygons
Beispiel #10
0
class BuitlinLearningRate(Configurable):
    lr = State(default=0.001)
    klass = State(default='StepLR')
    args = State(default=[])
    kwargs = State(default={})

    def __init__(self, cmd={}, **kwargs):
        self.load_all(**kwargs)
        self.lr = cmd.get('lr', None) or self.lr
        self.scheduler = None

    def prepare(self, optimizer):
        self.scheduler = getattr(lr_scheduler,
                                 self.klass)(optimizer, *self.args,
                                             **self.kwargs)

    def get_learning_rate(self, epoch, step=None):
        if self.scheduler is None:
            raise Exception(
                'learning rate not ready(prepared with optimizer) ')
        self.scheduler.last_epoch = epoch
        # return value of gt_lr is a list,
        # where each element is the corresponding learning rate for a
        # paramater group.
        return self.scheduler.get_lr()[0]
Beispiel #11
0
class Checkpoint(Configurable):
    start_epoch = State(default=0)
    start_iter = State(default=0)
    resume = State()

    def __init__(self, **kwargs):
        self.load_all(**kwargs)

        cmd = kwargs['cmd']
        if 'start_epoch' in cmd:
            self.start_epoch = cmd['start_epoch']
        if 'start_iter' in cmd:
            self.start_iter = cmd['start_iter']
        if 'resume' in cmd:
            self.resume = cmd['resume']

    def restore_model(self, model, device, logger):
        if self.resume is None:
            return

        if not os.path.exists(self.resume):
            self.logger.warning("Checkpoint not found: " + self.resume)
            return

        logger.info("Resuming from " + self.resume)
        state_dict = torch.load(self.resume, map_location=device)
        model.load_state_dict(state_dict, strict=False)
        logger.info("Resumed from " + self.resume)

    def restore_counter(self):
        return self.start_epoch, self.start_iter
Beispiel #12
0
class OptimizerScheduler(Configurable):
    optimizer = State()
    optimizer_args = State(default={})
    learning_rate = State(autoload=False)

    def __init__(self, cmd={}, **kwargs):
        self.load_all(**kwargs)
        self.load('learning_rate', cmd=cmd, **kwargs)
        if 'lr' in cmd:
            self.optimizer_args['lr'] = cmd['lr']

    def create_optimizer(self, parameters):
        optimizer = getattr(torch.optim, self.optimizer)(parameters,
                                                         **self.optimizer_args)
        if hasattr(self.learning_rate, 'prepare'):
            self.learning_rate.prepare(optimizer)
        return optimizer


# optimizer = getattr(torch.optim, "Adam")
# print(optimizer)
# # optimizer2 = torch.optim.Adam()
# # print(optimizer2)
# optimizer3 = getattr(torch.optim, "SGD")
# print(optimizer3)
Beispiel #13
0
class OcrDataset(data.Dataset, Configurable):
    r'''Dataset reading from images.
    Args:
        Processes: A series of Callable object, which accept as parameter and return the data dict,
            typically inherrited the `DataProcess`(data/processes/data_process.py) class.
    '''
    data_names = State()
    filter = State()
    is_training=State()
    processes = State(default=[])

    def __init__(self, data_names=None, filter=None, cmd={}, **kwargs):
        self.load_all(**kwargs)
        self.data_names = data_names or self.data_names
        self.filter = filter or self.filter
        self.is_training = False if not 'is_training' in kwargs else self.is_training
        self.debug = cmd.get('debug', False)

        # load dataset
        split = 'train' if self.is_training else 'test'
        dataset = get_dataset_by_name(self.data_names, filter=self.filter, split=split)
        dataset.verbose()
        self.dataset = dataset

    def __getitem__(self, index, retry=0):
        '''
        item = {'img': img, 'type': 'contour', 'bboxes': bboxes, 'tags': tags,
                'path': img_path}
        '''
        if index >= self.dataset.size():
            index = index % self.dataset.size()

        item = self.dataset.getData(index)

        data = {}
        image_path = item['path']
        img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32')
        if self.is_training:
            data['filename'] = image_path
            data['data_id'] = image_path
        else:
            data['filename'] = image_path.split('/')[-1]
            data['data_id'] = image_path.split('/')[-1]
        data['image'] = img
        target = []
        num_targets = len(item['bboxes'])
        for idx in range(num_targets):
            text = 1234 if item['tags'][idx] else '###'
            target.append({'poly': item['bboxes'][idx], 'text': text})
        data['lines'] = target
        if self.processes is not None:
            for data_process in self.processes:
                data = data_process(data)
        return data

    def __len__(self):
        return self.dataset.size()
Beispiel #14
0
class NoriMetaLoader(Configurable):
    cache = State()
    force_reload = State(default=False)

    scan_meta = True
    scan_data = False
    post_prosess = None

    def __init__(self, force_reload=None, cmd={}, **kwargs):
        self.load_all(cmd=cmd, **kwargs)
        self.force_reload = cmd.get('force_reload', self.force_reload)
        if force_reload is not None:
            self.force_reload = force_reload

    def load_meta(self, nori_path):
        if not self.force_reload and self.cache is not None:
            meta = self.cache.read(nori_path)
            if meta is not None:
                return meta

        meta_info = dict()
        valid_count = 0
        with nori.open(nori_path) as reader:
            for data_id, data, meta in reader.scan(
                    scan_data=self.scan_data, scan_meta=self.scan_meta):
                args_tuple = (data_id, )
                if self.scan_data:
                    args_tuple = tuple((*args_tuple, data))
                if self.scan_meta:
                    args_tuple = tuple((*args_tuple, meta))
                meta_instance = self.parse_meta(*args_tuple)
                if meta_instance is None:
                    continue
                valid_count += 1
                if valid_count % 100000 == 0:
                    print("%d instances processd" % valid_count)
                for key in meta_instance:
                    the_list = meta_info.get(key, [])
                    the_list.append(meta_instance[key])
                    meta_info[key] = the_list

        print(valid_count, 'instances found')
        if self.post_prosess is not None:
            meta_info = self.post_prosess(meta_info)

        if self.cache is not None:
            self.cache.save(nori_path, meta_info)

        return meta_info

    def parse_meta(self, data_id, meta):
        raise NotImplementedError

    def get_annotation(self, meta):
        return meta['extra']
Beispiel #15
0
class DecayLearningRate(Configurable):
    lr = State(default=0.001)
    epochs = State(default=100)
    factor = State(default=0.9)

    def __init__(self, **kwargs):
        self.load_all(**kwargs)

    def get_learning_rate(self, epoch, step=None):
        rate = np.power(1.0 - epoch / float(self.epochs + 1), self.factor)
        return rate * self.lr
Beispiel #16
0
class MultiStepLR(Configurable):
    lr = State()
    milestones = State(default=[])  # milestones must be sorted
    gamma = State(default=0.1)

    def __init__(self, cmd={}, **kwargs):
        self.load_all(**kwargs)
        self.lr = cmd.get('lr', self.lr)

    def get_learning_rate(self, epoch, step):
        return self.lr * self.gamma ** bisect_right(self.milestones, epoch)
Beispiel #17
0
class MakeCenterMap(DataProcess):
    max_size = State(default=32)
    shape = State(default=(64, 256))
    sigma_ratio = State(default=16)
    function_name = State(default='sample_gaussian')
    points_key = 'points'
    correlation = 0  # The formulation of guassian is simplified when correlation is 0

    def process(self, data):
        assert self.points_key in data, '%s in data is required' % self.points_key
        points = data['points'] * self.shape[::-1]  # N, 2
        assert points.shape[0] >= self.max_size
        func = getattr(self, self.function_name)
        data['charmaps'] = func(points, *self.shape)
        return data

    def gaussian(self, points, height, width):
        index_x, index_y = np.meshgrid(np.linspace(0, width, width),
                                       np.linspace(0, height, height))
        index_x = np.repeat(index_x[np.newaxis], points.shape[0], axis=0)
        index_y = np.repeat(index_y[np.newaxis], points.shape[0], axis=0)
        mu_x = points[:, 0][:, np.newaxis, np.newaxis]
        mu_y = points[:, 1][:, np.newaxis, np.newaxis]
        mask_is_zero = ((mu_x == 0) + (mu_y == 0)) == 0
        result = np.reciprocal(2 * np.pi * width / self.sigma_ratio * height / self.sigma_ratio)\
            * np.exp(- 0.5 * (np.square((index_x - mu_x) / width * self.sigma_ratio) +
                              np.square((index_y - mu_y) / height * self.sigma_ratio)))

        result = result / \
            np.maximum(result.max(axis=1, keepdims=True).max(
                axis=2, keepdims=True), np.finfo(np.float32).eps)
        result = result * mask_is_zero
        return result.astype(np.float32)

    def sample_gaussian(self, points, height, width):
        points = (points + 0.5).astype(np.int32)
        canvas = np.zeros((self.max_size, height, width), dtype=np.float32)
        for index in range(canvas.shape[0]):
            point = points[index]
            canvas[index, point[1], point[0]] = 1.
            if point.sum() > 0:
                fi.gaussian_filter(
                    canvas[index],
                    (height // self.sigma_ratio, width // self.sigma_ratio),
                    output=canvas[index],
                    mode='mirror')
                canvas[index] = canvas[index] / canvas[index].max()
                x_range = min(point[0], width - point[0])
                canvas[index, :, :point[0] - x_range] = 0
                canvas[index, :, point[0] + x_range:] = 0
                y_range = min(point[1], width - point[1])
                canvas[index, :point[1] - y_range, :] = 0
                canvas[index, point[1] + y_range:, :] = 0
        return canvas
Beispiel #18
0
class TrainSettings(Configurable):
    data_loader = State()
    model_saver = State()
    checkpoint = State()
    scheduler = State()
    epochs = State(default=10)

    def __init__(self, **kwargs):
        kwargs['cmd'].update(is_train=True)
        self.load_all(**kwargs)
        if 'epochs' in kwargs['cmd']:
            self.epochs = kwargs['cmd']['epochs']
Beispiel #19
0
class Structure(Configurable):
    builder = State()
    representer = State()
    measurer = State()
    visualizer = State()

    def __init__(self, **kwargs):
        self.load_all(**kwargs)

    @property
    def model_name(self):
        return self.builder.model_name
Beispiel #20
0
class ValidationSettings(Configurable):
    data_loaders = State()
    visualize = State()
    interval = State(default=100)
    exempt = State(default=-1)

    def __init__(self, **kwargs):
        kwargs['cmd'].update(is_train=False)
        self.load_all(**kwargs)

        cmd = kwargs['cmd']
        self.visualize = cmd['visualize']
Beispiel #21
0
class PiecewiseConstantLearningRate(Configurable):
    boundaries = State(default=[10000, 20000])
    values = State(default=[0.001, 0.0001, 0.00001])

    def __init__(self, **kwargs):
        self.load_all(**kwargs)

    def get_learning_rate(self, epoch, step):
        for boundary, value in zip(self.boundaries, self.values[:-1]):
            if step < boundary:
                return value
        return self.values[-1]
Beispiel #22
0
class WarmupLR(Configurable):
    steps = State(default=4000)
    warmup_lr = State(default=1e-5)
    origin_lr = State()

    def __init__(self, cmd={}, **kwargs):
        self.load_all(**kwargs)

    def get_learning_rate(self, epoch, step):
        if epoch == 0 and step < self.steps:
            return self.warmup_lr
        return self.origin_lr.get_learning_rate(epoch, step)
Beispiel #23
0
class MakeCenterPoints(DataProcess):
    box_key = State(default='charboxes')
    size = State(default=32)

    def process(self, data):
        shape = data['image'].shape[:2]
        points = np.zeros((self.size, 2), dtype=np.float32)
        boxes = np.array(data[self.box_key])[:self.size]

        size = boxes.shape[0]
        points[:size] = boxes.mean(axis=1)
        data['points'] = (points / shape[::-1]).astype(np.float32)
        return data
Beispiel #24
0
class FileMetaCache(MetaCache):
    storage_dir = State(default=os.path.join(config.db_path, 'meta_cache'))
    inplace = State(default=not config.will_use_nori)

    def __init__(self, storage_dir=None, cmd={}, **kwargs):
        super(FileMetaCache, self).__init__(cmd=cmd, **kwargs)

        self.debug = cmd.get('debug', False)

    def ensure_dir(self):
        if not os.path.exists(self.storage_dir):
            os.makedirs(self.storage_dir)

    def storate_path(self, nori_path):
        if self.inplace:
            if config.will_use_lmdb:
                return os.path.join(
                    nori_path,
                    'meta_cache-%s.pickle' % self.client)
            else:
                return os.path.join(
                    os.path.dirname(nori_path),
                    'meta_cache-%s.pickle' % self.client)
        return os.path.join(self.storage_dir, self.hash(nori_path) + '.pickle')

    def hash(self, nori_path: str):
        return hashlib.md5(nori_path.encode('utf-8')).hexdigest() + '-' + self.client

    def read(self, nori_path):
        file_path = self.storate_path(nori_path)
        if not os.path.exists(file_path):
            warnings.warn(
                'Meta cache not found: ' + file_path)
            warnings.warn('Now trying to read meta from scratch')
            return None
        with open(file_path, 'rb') as reader:
            try:
                return pickle.load(reader)
            except EOFError as e:  # recover from broken file
                if self.debug:
                    raise e
                return None

    def save(self, nori_path, meta):
        self.ensure_dir()

        with open(self.storate_path(nori_path), 'wb') as writer:
            pickle.dump(meta, writer)
        return True
Beispiel #25
0
class ResizeImage(_ResizeImage, DataProcess):
    mode = State(default='keep_ratio')
    image_size = State(default=[1152, 2048])  # height, width
    key = State(default='image')

    def __init__(self, cmd={}, mode=None, **kwargs):
        self.load_all(**kwargs)
        if mode is not None:
            self.mode = mode
        if 'resize_mode' in cmd:
            self.mode = cmd['resize_mode']
        assert self.mode in self.MODES

    def process(self, data):
        data[self.key] = self.resize_or_pad(data[self.key])
        return data
Beispiel #26
0
class CTCVisualizer2D(Configurable):
    eager_show = State(default=False, cmd_key='eager_show')

    def visualize(self, batch, output, interested):
        return self.visualize_batch(batch, output)

    def visualize_batch(self, batch, output):
        visualization = dict()
        for index, output_dict in enumerate(output):
            image = batch['image'][index]
            image = NormalizeImage.restore(image)

            mask = output_dict['mask']
            mask = cv2.resize(Visualize.visualize_weights(mask),
                              image.shape[:2][::-1])

            classify = output_dict['classify']
            classify = cv2.resize(
                Visualize.visualize_heatmap(classify, format='CHW'),
                image.shape[:2][::-1])

            canvas = np.concatenate([image, mask, classify], axis=0)
            key = "【%s-%s】" % (output_dict['label_string'],
                               output_dict['pred_string'])
            vis_dict = {key: canvas}

            if self.eager_show:
                for k, v in vis_dict.items():
                    # if output_dict['label_string'] != output_dict['pred_string']:
                    webcv2.imshow(k, v)
            visualization.update(mask=mask, classify=classify, image=image)
        if self.eager_show:
            webcv2.waitKey()
        return visualization
Beispiel #27
0
class SimpleEASTRepresenter(SimpleDetectionRepresenter):
    heatmap_thr = State(default=0.5)

    def postprocess(self, output):
        output['heatmask'] = self.threshold_heatmap(output['heatmap'],
                                                    self.heatmap_thr)

    def get_polygons(self, output):
        heatmask = output['heatmask'][0]
        densebox = output['densebox']

        _, contours, _ = cv2.findContours(heatmask.astype(np.uint8),
                                          cv2.RETR_EXTERNAL,
                                          cv2.CHAIN_APPROX_NONE)

        polygons = []
        for contour in contours:
            points = []
            for x, y in contour[:, 0]:
                quad = densebox[:, y, x].reshape(4, 2) + (x, y)
                points.extend(quad)
            quad = cv2.boxPoints(cv2.minAreaRect(np.array(points, np.float32)))
            polygons.append(quad)

        return polygons
Beispiel #28
0
class EnsembleBuilder(Configurable):
    '''Ensemble multiple models into one model
    Input:
        builders: A dict which consists of several builders.
    Example:
        >>> builder:
                class: EnsembleBuilder
                builders:
                    ctc:
                        model: CTCModel
                    atten:
                        model: AttentionDecoderModel
    '''
    builders = State(default={})

    def __init__(self, cmd={}, **kwargs):
        resume_paths = dict()
        for key, value_dict in kwargs['builders'].items():
            resume_paths[key] = value_dict.pop('resume')
        self.resume_paths = resume_paths
        self.load_all(**kwargs)

    @property
    def model_name(self):
        return 'ensembled-model'

    def build(self, device, *args, **kwargs):
        models = OrderedDict()
        for key, builder in self.builders.items():
            models[key] = builder.build(device=device, *args, **kwargs)
            models[key].load_state_dict(torch.load(self.resume_paths[key],
                                                   map_location=device),
                                        strict=True)
        return EnsembleModel(models)
class SequenceRecognitionRepresenter(Configurable):
    charset = State(default=DefaultCharset())

    def __init__(self, cmd={}, **kwargs):
        self.load_all(**kwargs)

    def label_to_string(self, label):
        return self.charset.label_to_string(label)

    def represent(self, batch, pred):
        images, labels = batch['image'], batch['label']
        mask = torch.ones(pred.shape[0], dtype=torch.int).to(pred.device)

        for i in range(pred.shape[1]):
            mask = (1 -
                    (pred[:, i] == self.charset.blank).type(torch.int)) * mask
            pred[:, i] = pred[:, i] * mask + self.charset.blank * (1 - mask)

        output = []
        for i in range(labels.shape[0]):
            label_str = self.label_to_string(labels[i])
            pred_str = self.label_to_string(pred[i])
            if False and label_str != pred_str:
                print('label: %s , pred: %s' % (label_str, pred_str))
                img = (np.clip(
                    images[i].cpu().data.numpy().transpose(1, 2, 0) + 0.5, 0,
                    1) * 255).astype('uint8')
                webcv.imshow(
                    '【 pred: <%s> , label: <%s> 】' % (pred_str, label_str),
                    np.array(img, dtype=np.uint8))
                if webcv.waitKey() == ord('q'):
                    continue
            output.append({'label_string': label_str, 'pred_string': pred_str})

        return output
Beispiel #30
0
class SequenceRecognitionVisualizer(Configurable):
    charset = State(default=DefaultCharset())

    def __init__(self, cmd={}, **kwargs):
        self.eager = cmd.get('eager_show', False)
        self.load_all(**kwargs)

    def visualize(self, batch, output, interested):
        return self.visualize_batch(batch, output)

    def visualize_batch(self, batch, output):
        images, labels, lengths = batch['image'], batch['label'], batch[
            'length']
        for i in range(images.shape[0]):
            image = NormalizeImage.restore(images[i])
            gt = self.charset.label_to_string(labels[i])
            webcv2.imshow(output[i]['pred_string'] + '_' + str(i) + '_' + gt,
                          image)
            # folder = 'images/dropout/lexicon/'
            # np.save(folder + output[i]['pred_string'] + '_' + gt + '_' + batch['data_ids'][i], image)
        webcv2.waitKey()
        return {
            'image': (np.clip(
                batch['image'][0].cpu().data.numpy().transpose(1, 2, 0) + 0.5,
                0, 1) * 255).astype('uint8')
        }