class FilterKeys(DataProcess): required = State(default=[]) superfluous = State(default=[]) def __init__(self, **kwargs): super().__init__(self, **kwargs) self.required_keys = set(self.required) self.superfluous_keys = set(self.superfluous) if len(self.required_keys) > 0 and len(self.superfluous_keys) > 0: raise ValueError( 'required_keys and superfluous_keys can not be specified at the same time.') def process(self, data): for key in self.required: assert key in data, '%s is required in data' % key superfluous = self.superfluous_keys if len(superfluous) == 0: for key in data.keys(): if key not in self.required_keys: superfluous.add(key) for key in superfluous: del data[key] return data
class ShowSettings(Configurable): data_loader = State() representer = State() visualizer = State() def __init__(self, **kwargs): self.load_all(**kwargs)
class SerializeBox(DataProcess): box_key = State(default='charboxes') format = State(default='NP2') def process(self, data): data[self.box_key] = data['lines'].quads return data
class EvaluationSettings(Configurable): data_loaders = State() visualize = State(default=True) resume = State() def __init__(self, **kwargs): self.load_all(**kwargs)
class RandomSampleDataLoader(Configurable, torch.utils.data.DataLoader): datasets = State() weights = State() batch_size = State(default=256) num_workers = State(default=10) size = State(default=2**31) def __init__(self, **kwargs): self.load_all(**kwargs) cmd = kwargs['cmd'] if 'batch_size' in cmd: self.batch_size = cmd['batch_size'] probs = [] for dataset, weight in zip(self.datasets, self.weights): probs.append(np.full(len(dataset), weight / len(dataset))) # pytorch自带的一种合并子数据集的方式,ConcatDataset类。 dataset = ConcatDataset(self.datasets) probs = np.concatenate(probs) assert (len(dataset) == len(probs)) sampler = RandomSampleSampler(dataset, probs, self.size) torch.utils.data.DataLoader.__init__( self, dataset, batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler, worker_init_fn=default_worker_init_fn, )
class ResizeData(_ResizeImage, DataProcess): key = State(default='image') box_key = State(default='polygons') image_size = State(default=[64, 256]) # height, width def __init__(self, cmd={}, mode=None, key=None, box_key=None, **kwargs): self.load_all(**kwargs) if mode is not None: self.mode = mode if key is not None: self.key = key if box_key is not None: self.box_key = box_key if 'resize_mode' in cmd: self.mode = cmd['resize_mode'] assert self.mode in self.MODES def process(self, data): height, width = data['image'].shape[:2] new_height, new_width = self.get_image_size(data['image']) data[self.key] = self.resize_or_pad(data[self.key]) charboxes = data[self.box_key] data[self.box_key] = charboxes.copy() data[self.box_key][:, :, 0] = data[self.box_key][:, :, 0] * \ new_width / width data[self.box_key][:, :, 1] = data[self.box_key][:, :, 1] * \ new_height / height return data
class RecognitionMetaLoader(MetaLoader): skip_vertical = State(default=False) case_sensitive = State(default=False) simplify = State(default=False) key = State(default='words') scan_meta = True def __init__(self, key=None, cmd={}, **kwargs): super().__init__(cmd=cmd, **kwargs) if key is not None: self.key = key def may_simplify(self, words): garbled = stringQ2B(words) if self.simplify: return HanziConv.toSimplified(garbled) return garbled def parse_meta(self, data_id, meta): word = self.may_simplify(self.get_annotation(meta)[self.key]) vertical = self.get_annotation(meta).get('vertical', False) if self.skip_vertical and vertical: return None if word == '###': return None return dict(data_ids=data_id, gt=word)
class SimpleTextsnakeRepresenter(SimpleDetectionRepresenter): heatmap_thr = State(default=0.5) min_area = State(default=200) def postprocess(self, output): output['heatmask'] = self.threshold_heatmap(output['heatmap'], self.heatmap_thr) def get_polygons(self, output): heatmask = output['heatmask'][0] radius = output['radius'] _, contours, _ = cv2.findContours(heatmask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) polygons = [] for contour in contours: contour = contour[:, 0] mask = np.zeros(heatmask.shape[:2], dtype=np.uint8) for x, y in contour: r = radius[0, y, x] if r > 1: cv2.circle(mask, (int(x), int(y)), int(r), 1, -1) _, conts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if len(conts) == 1: poly = conts[0][:, 0] if Polygon(poly).geom_type == 'Polygon' and Polygon( poly).area > self.min_area: polygons.append(poly) return polygons
class SimpleMSRRepresenter(SimpleDetectionRepresenter): heatmap_thr = State(default=0.5) min_area = State(default=200) max_poly_points = State(default=32) def postprocess(self, output): output['heatmask'] = self.threshold_heatmap(output['heatmap'], self.heatmap_thr) def get_polygons(self, output): heatmask = output['heatmask'][0] offset = output['offset'] _, contours, _ = cv2.findContours(heatmask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) polygons = [] for contour in contours: contour = contour[:, 0] poly = [] stride = len(contour) // self.max_poly_points + 1 for x, y in contour[::stride]: point = offset[:, y, x] + (x, y) poly.append(point) if len(poly) >= 3: if Polygon(poly).area > self.min_area: polygons.append(poly) return polygons
class BuitlinLearningRate(Configurable): lr = State(default=0.001) klass = State(default='StepLR') args = State(default=[]) kwargs = State(default={}) def __init__(self, cmd={}, **kwargs): self.load_all(**kwargs) self.lr = cmd.get('lr', None) or self.lr self.scheduler = None def prepare(self, optimizer): self.scheduler = getattr(lr_scheduler, self.klass)(optimizer, *self.args, **self.kwargs) def get_learning_rate(self, epoch, step=None): if self.scheduler is None: raise Exception( 'learning rate not ready(prepared with optimizer) ') self.scheduler.last_epoch = epoch # return value of gt_lr is a list, # where each element is the corresponding learning rate for a # paramater group. return self.scheduler.get_lr()[0]
class Checkpoint(Configurable): start_epoch = State(default=0) start_iter = State(default=0) resume = State() def __init__(self, **kwargs): self.load_all(**kwargs) cmd = kwargs['cmd'] if 'start_epoch' in cmd: self.start_epoch = cmd['start_epoch'] if 'start_iter' in cmd: self.start_iter = cmd['start_iter'] if 'resume' in cmd: self.resume = cmd['resume'] def restore_model(self, model, device, logger): if self.resume is None: return if not os.path.exists(self.resume): self.logger.warning("Checkpoint not found: " + self.resume) return logger.info("Resuming from " + self.resume) state_dict = torch.load(self.resume, map_location=device) model.load_state_dict(state_dict, strict=False) logger.info("Resumed from " + self.resume) def restore_counter(self): return self.start_epoch, self.start_iter
class OptimizerScheduler(Configurable): optimizer = State() optimizer_args = State(default={}) learning_rate = State(autoload=False) def __init__(self, cmd={}, **kwargs): self.load_all(**kwargs) self.load('learning_rate', cmd=cmd, **kwargs) if 'lr' in cmd: self.optimizer_args['lr'] = cmd['lr'] def create_optimizer(self, parameters): optimizer = getattr(torch.optim, self.optimizer)(parameters, **self.optimizer_args) if hasattr(self.learning_rate, 'prepare'): self.learning_rate.prepare(optimizer) return optimizer # optimizer = getattr(torch.optim, "Adam") # print(optimizer) # # optimizer2 = torch.optim.Adam() # # print(optimizer2) # optimizer3 = getattr(torch.optim, "SGD") # print(optimizer3)
class OcrDataset(data.Dataset, Configurable): r'''Dataset reading from images. Args: Processes: A series of Callable object, which accept as parameter and return the data dict, typically inherrited the `DataProcess`(data/processes/data_process.py) class. ''' data_names = State() filter = State() is_training=State() processes = State(default=[]) def __init__(self, data_names=None, filter=None, cmd={}, **kwargs): self.load_all(**kwargs) self.data_names = data_names or self.data_names self.filter = filter or self.filter self.is_training = False if not 'is_training' in kwargs else self.is_training self.debug = cmd.get('debug', False) # load dataset split = 'train' if self.is_training else 'test' dataset = get_dataset_by_name(self.data_names, filter=self.filter, split=split) dataset.verbose() self.dataset = dataset def __getitem__(self, index, retry=0): ''' item = {'img': img, 'type': 'contour', 'bboxes': bboxes, 'tags': tags, 'path': img_path} ''' if index >= self.dataset.size(): index = index % self.dataset.size() item = self.dataset.getData(index) data = {} image_path = item['path'] img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32') if self.is_training: data['filename'] = image_path data['data_id'] = image_path else: data['filename'] = image_path.split('/')[-1] data['data_id'] = image_path.split('/')[-1] data['image'] = img target = [] num_targets = len(item['bboxes']) for idx in range(num_targets): text = 1234 if item['tags'][idx] else '###' target.append({'poly': item['bboxes'][idx], 'text': text}) data['lines'] = target if self.processes is not None: for data_process in self.processes: data = data_process(data) return data def __len__(self): return self.dataset.size()
class NoriMetaLoader(Configurable): cache = State() force_reload = State(default=False) scan_meta = True scan_data = False post_prosess = None def __init__(self, force_reload=None, cmd={}, **kwargs): self.load_all(cmd=cmd, **kwargs) self.force_reload = cmd.get('force_reload', self.force_reload) if force_reload is not None: self.force_reload = force_reload def load_meta(self, nori_path): if not self.force_reload and self.cache is not None: meta = self.cache.read(nori_path) if meta is not None: return meta meta_info = dict() valid_count = 0 with nori.open(nori_path) as reader: for data_id, data, meta in reader.scan( scan_data=self.scan_data, scan_meta=self.scan_meta): args_tuple = (data_id, ) if self.scan_data: args_tuple = tuple((*args_tuple, data)) if self.scan_meta: args_tuple = tuple((*args_tuple, meta)) meta_instance = self.parse_meta(*args_tuple) if meta_instance is None: continue valid_count += 1 if valid_count % 100000 == 0: print("%d instances processd" % valid_count) for key in meta_instance: the_list = meta_info.get(key, []) the_list.append(meta_instance[key]) meta_info[key] = the_list print(valid_count, 'instances found') if self.post_prosess is not None: meta_info = self.post_prosess(meta_info) if self.cache is not None: self.cache.save(nori_path, meta_info) return meta_info def parse_meta(self, data_id, meta): raise NotImplementedError def get_annotation(self, meta): return meta['extra']
class DecayLearningRate(Configurable): lr = State(default=0.001) epochs = State(default=100) factor = State(default=0.9) def __init__(self, **kwargs): self.load_all(**kwargs) def get_learning_rate(self, epoch, step=None): rate = np.power(1.0 - epoch / float(self.epochs + 1), self.factor) return rate * self.lr
class MultiStepLR(Configurable): lr = State() milestones = State(default=[]) # milestones must be sorted gamma = State(default=0.1) def __init__(self, cmd={}, **kwargs): self.load_all(**kwargs) self.lr = cmd.get('lr', self.lr) def get_learning_rate(self, epoch, step): return self.lr * self.gamma ** bisect_right(self.milestones, epoch)
class MakeCenterMap(DataProcess): max_size = State(default=32) shape = State(default=(64, 256)) sigma_ratio = State(default=16) function_name = State(default='sample_gaussian') points_key = 'points' correlation = 0 # The formulation of guassian is simplified when correlation is 0 def process(self, data): assert self.points_key in data, '%s in data is required' % self.points_key points = data['points'] * self.shape[::-1] # N, 2 assert points.shape[0] >= self.max_size func = getattr(self, self.function_name) data['charmaps'] = func(points, *self.shape) return data def gaussian(self, points, height, width): index_x, index_y = np.meshgrid(np.linspace(0, width, width), np.linspace(0, height, height)) index_x = np.repeat(index_x[np.newaxis], points.shape[0], axis=0) index_y = np.repeat(index_y[np.newaxis], points.shape[0], axis=0) mu_x = points[:, 0][:, np.newaxis, np.newaxis] mu_y = points[:, 1][:, np.newaxis, np.newaxis] mask_is_zero = ((mu_x == 0) + (mu_y == 0)) == 0 result = np.reciprocal(2 * np.pi * width / self.sigma_ratio * height / self.sigma_ratio)\ * np.exp(- 0.5 * (np.square((index_x - mu_x) / width * self.sigma_ratio) + np.square((index_y - mu_y) / height * self.sigma_ratio))) result = result / \ np.maximum(result.max(axis=1, keepdims=True).max( axis=2, keepdims=True), np.finfo(np.float32).eps) result = result * mask_is_zero return result.astype(np.float32) def sample_gaussian(self, points, height, width): points = (points + 0.5).astype(np.int32) canvas = np.zeros((self.max_size, height, width), dtype=np.float32) for index in range(canvas.shape[0]): point = points[index] canvas[index, point[1], point[0]] = 1. if point.sum() > 0: fi.gaussian_filter( canvas[index], (height // self.sigma_ratio, width // self.sigma_ratio), output=canvas[index], mode='mirror') canvas[index] = canvas[index] / canvas[index].max() x_range = min(point[0], width - point[0]) canvas[index, :, :point[0] - x_range] = 0 canvas[index, :, point[0] + x_range:] = 0 y_range = min(point[1], width - point[1]) canvas[index, :point[1] - y_range, :] = 0 canvas[index, point[1] + y_range:, :] = 0 return canvas
class TrainSettings(Configurable): data_loader = State() model_saver = State() checkpoint = State() scheduler = State() epochs = State(default=10) def __init__(self, **kwargs): kwargs['cmd'].update(is_train=True) self.load_all(**kwargs) if 'epochs' in kwargs['cmd']: self.epochs = kwargs['cmd']['epochs']
class Structure(Configurable): builder = State() representer = State() measurer = State() visualizer = State() def __init__(self, **kwargs): self.load_all(**kwargs) @property def model_name(self): return self.builder.model_name
class ValidationSettings(Configurable): data_loaders = State() visualize = State() interval = State(default=100) exempt = State(default=-1) def __init__(self, **kwargs): kwargs['cmd'].update(is_train=False) self.load_all(**kwargs) cmd = kwargs['cmd'] self.visualize = cmd['visualize']
class PiecewiseConstantLearningRate(Configurable): boundaries = State(default=[10000, 20000]) values = State(default=[0.001, 0.0001, 0.00001]) def __init__(self, **kwargs): self.load_all(**kwargs) def get_learning_rate(self, epoch, step): for boundary, value in zip(self.boundaries, self.values[:-1]): if step < boundary: return value return self.values[-1]
class WarmupLR(Configurable): steps = State(default=4000) warmup_lr = State(default=1e-5) origin_lr = State() def __init__(self, cmd={}, **kwargs): self.load_all(**kwargs) def get_learning_rate(self, epoch, step): if epoch == 0 and step < self.steps: return self.warmup_lr return self.origin_lr.get_learning_rate(epoch, step)
class MakeCenterPoints(DataProcess): box_key = State(default='charboxes') size = State(default=32) def process(self, data): shape = data['image'].shape[:2] points = np.zeros((self.size, 2), dtype=np.float32) boxes = np.array(data[self.box_key])[:self.size] size = boxes.shape[0] points[:size] = boxes.mean(axis=1) data['points'] = (points / shape[::-1]).astype(np.float32) return data
class FileMetaCache(MetaCache): storage_dir = State(default=os.path.join(config.db_path, 'meta_cache')) inplace = State(default=not config.will_use_nori) def __init__(self, storage_dir=None, cmd={}, **kwargs): super(FileMetaCache, self).__init__(cmd=cmd, **kwargs) self.debug = cmd.get('debug', False) def ensure_dir(self): if not os.path.exists(self.storage_dir): os.makedirs(self.storage_dir) def storate_path(self, nori_path): if self.inplace: if config.will_use_lmdb: return os.path.join( nori_path, 'meta_cache-%s.pickle' % self.client) else: return os.path.join( os.path.dirname(nori_path), 'meta_cache-%s.pickle' % self.client) return os.path.join(self.storage_dir, self.hash(nori_path) + '.pickle') def hash(self, nori_path: str): return hashlib.md5(nori_path.encode('utf-8')).hexdigest() + '-' + self.client def read(self, nori_path): file_path = self.storate_path(nori_path) if not os.path.exists(file_path): warnings.warn( 'Meta cache not found: ' + file_path) warnings.warn('Now trying to read meta from scratch') return None with open(file_path, 'rb') as reader: try: return pickle.load(reader) except EOFError as e: # recover from broken file if self.debug: raise e return None def save(self, nori_path, meta): self.ensure_dir() with open(self.storate_path(nori_path), 'wb') as writer: pickle.dump(meta, writer) return True
class ResizeImage(_ResizeImage, DataProcess): mode = State(default='keep_ratio') image_size = State(default=[1152, 2048]) # height, width key = State(default='image') def __init__(self, cmd={}, mode=None, **kwargs): self.load_all(**kwargs) if mode is not None: self.mode = mode if 'resize_mode' in cmd: self.mode = cmd['resize_mode'] assert self.mode in self.MODES def process(self, data): data[self.key] = self.resize_or_pad(data[self.key]) return data
class CTCVisualizer2D(Configurable): eager_show = State(default=False, cmd_key='eager_show') def visualize(self, batch, output, interested): return self.visualize_batch(batch, output) def visualize_batch(self, batch, output): visualization = dict() for index, output_dict in enumerate(output): image = batch['image'][index] image = NormalizeImage.restore(image) mask = output_dict['mask'] mask = cv2.resize(Visualize.visualize_weights(mask), image.shape[:2][::-1]) classify = output_dict['classify'] classify = cv2.resize( Visualize.visualize_heatmap(classify, format='CHW'), image.shape[:2][::-1]) canvas = np.concatenate([image, mask, classify], axis=0) key = "【%s-%s】" % (output_dict['label_string'], output_dict['pred_string']) vis_dict = {key: canvas} if self.eager_show: for k, v in vis_dict.items(): # if output_dict['label_string'] != output_dict['pred_string']: webcv2.imshow(k, v) visualization.update(mask=mask, classify=classify, image=image) if self.eager_show: webcv2.waitKey() return visualization
class SimpleEASTRepresenter(SimpleDetectionRepresenter): heatmap_thr = State(default=0.5) def postprocess(self, output): output['heatmask'] = self.threshold_heatmap(output['heatmap'], self.heatmap_thr) def get_polygons(self, output): heatmask = output['heatmask'][0] densebox = output['densebox'] _, contours, _ = cv2.findContours(heatmask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) polygons = [] for contour in contours: points = [] for x, y in contour[:, 0]: quad = densebox[:, y, x].reshape(4, 2) + (x, y) points.extend(quad) quad = cv2.boxPoints(cv2.minAreaRect(np.array(points, np.float32))) polygons.append(quad) return polygons
class EnsembleBuilder(Configurable): '''Ensemble multiple models into one model Input: builders: A dict which consists of several builders. Example: >>> builder: class: EnsembleBuilder builders: ctc: model: CTCModel atten: model: AttentionDecoderModel ''' builders = State(default={}) def __init__(self, cmd={}, **kwargs): resume_paths = dict() for key, value_dict in kwargs['builders'].items(): resume_paths[key] = value_dict.pop('resume') self.resume_paths = resume_paths self.load_all(**kwargs) @property def model_name(self): return 'ensembled-model' def build(self, device, *args, **kwargs): models = OrderedDict() for key, builder in self.builders.items(): models[key] = builder.build(device=device, *args, **kwargs) models[key].load_state_dict(torch.load(self.resume_paths[key], map_location=device), strict=True) return EnsembleModel(models)
class SequenceRecognitionRepresenter(Configurable): charset = State(default=DefaultCharset()) def __init__(self, cmd={}, **kwargs): self.load_all(**kwargs) def label_to_string(self, label): return self.charset.label_to_string(label) def represent(self, batch, pred): images, labels = batch['image'], batch['label'] mask = torch.ones(pred.shape[0], dtype=torch.int).to(pred.device) for i in range(pred.shape[1]): mask = (1 - (pred[:, i] == self.charset.blank).type(torch.int)) * mask pred[:, i] = pred[:, i] * mask + self.charset.blank * (1 - mask) output = [] for i in range(labels.shape[0]): label_str = self.label_to_string(labels[i]) pred_str = self.label_to_string(pred[i]) if False and label_str != pred_str: print('label: %s , pred: %s' % (label_str, pred_str)) img = (np.clip( images[i].cpu().data.numpy().transpose(1, 2, 0) + 0.5, 0, 1) * 255).astype('uint8') webcv.imshow( '【 pred: <%s> , label: <%s> 】' % (pred_str, label_str), np.array(img, dtype=np.uint8)) if webcv.waitKey() == ord('q'): continue output.append({'label_string': label_str, 'pred_string': pred_str}) return output
class SequenceRecognitionVisualizer(Configurable): charset = State(default=DefaultCharset()) def __init__(self, cmd={}, **kwargs): self.eager = cmd.get('eager_show', False) self.load_all(**kwargs) def visualize(self, batch, output, interested): return self.visualize_batch(batch, output) def visualize_batch(self, batch, output): images, labels, lengths = batch['image'], batch['label'], batch[ 'length'] for i in range(images.shape[0]): image = NormalizeImage.restore(images[i]) gt = self.charset.label_to_string(labels[i]) webcv2.imshow(output[i]['pred_string'] + '_' + str(i) + '_' + gt, image) # folder = 'images/dropout/lexicon/' # np.save(folder + output[i]['pred_string'] + '_' + gt + '_' + batch['data_ids'][i], image) webcv2.waitKey() return { 'image': (np.clip( batch['image'][0].cpu().data.numpy().transpose(1, 2, 0) + 0.5, 0, 1) * 255).astype('uint8') }