def main(args): print(args) if args.push_to_hub: login_to_hub() if not isinstance(args.workers, int): args.workers = min(16, mp.cpu_count()) torch.backends.cudnn.benchmark = True vocab = VOCABS[args.vocab] fonts = args.font.split(",") # Load val data generator st = time.time() if isinstance(args.val_path, str): with open(os.path.join(args.val_path, "labels.json"), "rb") as f: val_hash = hashlib.sha256(f.read()).hexdigest() val_set = RecognitionDataset( img_folder=os.path.join(args.val_path, "images"), labels_path=os.path.join(args.val_path, "labels.json"), img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) else: val_hash = None # Load synthetic data generator val_set = WordGenerator( vocab=vocab, min_chars=args.min_chars, max_chars=args.max_chars, num_samples=args.val_samples * len(vocab), font_family=fonts, img_transforms=Compose([ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Ensure we have a 90% split of white-background images T.RandomApply(T.ColorInversion(), 0.9), ]), ) val_loader = DataLoader( val_set, batch_size=args.batch_size, drop_last=False, num_workers=args.workers, sampler=SequentialSampler(val_set), pin_memory=torch.cuda.is_available(), collate_fn=val_set.collate_fn, ) print( f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{len(val_loader)} batches)") batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)) # Load doctr model model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=vocab) # Resume weights if isinstance(args.resume, str): print(f"Resuming {args.resume}") checkpoint = torch.load(args.resume, map_location="cpu") model.load_state_dict(checkpoint) # GPU if isinstance(args.device, int): if not torch.cuda.is_available(): raise AssertionError( "PyTorch cannot access your GPU. Please investigate!") if args.device >= torch.cuda.device_count(): raise ValueError("Invalid device index") # Silent default switch to GPU if available elif torch.cuda.is_available(): args.device = 0 else: logging.warning("No accessible GPU, targe device set to CPU.") if torch.cuda.is_available(): torch.cuda.set_device(args.device) model = model.cuda() # Metrics val_metric = TextMatch() if args.test_only: print("Running evaluation") val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) print( f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})" ) return st = time.time() if isinstance(args.train_path, str): # Load train data generator base_path = Path(args.train_path) parts = ([base_path] if base_path.joinpath("labels.json").is_file() else [base_path.joinpath(sub) for sub in os.listdir(base_path)]) with open(parts[0].joinpath("labels.json"), "rb") as f: train_hash = hashlib.sha256(f.read()).hexdigest() train_set = RecognitionDataset( parts[0].joinpath("images"), parts[0].joinpath("labels.json"), img_transforms=Compose([ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Augmentations T.RandomApply(T.ColorInversion(), 0.1), ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), ]), ) if len(parts) > 1: for subfolder in parts[1:]: train_set.merge_dataset( RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json"))) else: train_hash = None # Load synthetic data generator train_set = WordGenerator( vocab=vocab, min_chars=args.min_chars, max_chars=args.max_chars, num_samples=args.train_samples * len(vocab), font_family=fonts, img_transforms=Compose([ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Ensure we have a 90% split of white-background images T.RandomApply(T.ColorInversion(), 0.9), ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), ]), ) train_loader = DataLoader( train_set, batch_size=args.batch_size, drop_last=True, num_workers=args.workers, sampler=RandomSampler(train_set), pin_memory=torch.cuda.is_available(), collate_fn=train_set.collate_fn, ) print( f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{len(train_loader)} batches)") if args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target) return # Optimizer optimizer = torch.optim.Adam( [p for p in model.parameters() if p.requires_grad], args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay, ) # LR Finder if args.find_lr: lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) plot_recorder(lrs, losses) return # Scheduler if args.sched == "cosine": scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) elif args.sched == "onecycle": scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name # W&B if args.wb: run = wandb.init( name=exp_name, project="text-recognition", config={ "learning_rate": args.lr, "epochs": args.epochs, "weight_decay": args.weight_decay, "batch_size": args.batch_size, "architecture": args.arch, "input_size": args.input_size, "optimizer": "adam", "framework": "pytorch", "scheduler": args.sched, "vocab": args.vocab, "train_hash": train_hash, "val_hash": val_hash, "pretrained": args.pretrained, }, ) # Create loss queue min_loss = np.inf # Training loop mb = master_bar(range(args.epochs)) for epoch in mb: fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp) # Validation loop at the end of each epoch val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) if val_loss < min_loss: print( f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state..." ) torch.save(model.state_dict(), f"./{exp_name}.pt") min_loss = val_loss mb.write( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") # W&B if args.wb: wandb.log({ "val_loss": val_loss, "exact_match": exact_match, "partial_match": partial_match, }) if args.wb: run.finish() if args.push_to_hub: push_to_hf_hub(model, exp_name, task="recognition", run_config=args)
def main(args): if not args.rotation: args.eval_straight = True predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size, assume_straight_pages=not args.rotation) if args.img_folder and args.label_file: testset = datasets.OCRDataset( img_folder=args.img_folder, label_file=args.label_file, ) sets = [testset] else: train_set = datasets.__dict__[args.dataset]( train=True, download=True, use_polygons=not args.eval_straight) val_set = datasets.__dict__[args.dataset]( train=False, download=True, use_polygons=not args.eval_straight) sets = [train_set, val_set] reco_metric = TextMatch() if args.mask_shape: det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape)) e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape)) else: det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight) e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight) sample_idx = 0 extraction_fn = extract_crops if args.eval_straight else extract_rcrops for dataset in sets: for page, target in tqdm(dataset): # GT gt_boxes = target['boxes'] gt_labels = target['labels'] if args.img_folder and args.label_file: x, y, w, h = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2], gt_boxes[:, 3] xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1) xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1) gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1) # Forward if is_tf_available(): out = predictor(page[None, ...]) crops = extraction_fn(page, gt_boxes) reco_out = predictor.reco_predictor(crops) else: with torch.no_grad(): out = predictor(page[None, ...]) # We directly crop on PyTorch tensors, which are in channels_first crops = extraction_fn(page, gt_boxes, channels_last=False) reco_out = predictor.reco_predictor(crops) if len(reco_out): reco_words, _ = zip(*reco_out) else: reco_words = [] # Unpack preds pred_boxes = [] pred_labels = [] for page in out.pages: height, width = page.dimensions for block in page.blocks: for line in block.lines: for word in line.words: if not args.rotation: (a, b), (c, d) = word.geometry else: [x1, y1], [x2, y2], [x3, y3], [x4, y4], = word.geometry if gt_boxes.dtype == int: if not args.rotation: pred_boxes.append([ int(a * width), int(b * height), int(c * width), int(d * height) ]) else: if args.eval_straight: pred_boxes.append([ int(width * min(x1, x2, x3, x4)), int(height * min(y1, y2, y3, y4)), int(width * max(x1, x2, x3, x4)), int(height * max(y1, y2, y3, y4)), ]) else: pred_boxes.append([ [ int(x1 * width), int(y1 * height) ], [ int(x2 * width), int(y2 * height) ], [ int(x3 * width), int(y3 * height) ], [ int(x4 * width), int(y4 * height) ], ]) else: if not args.rotation: pred_boxes.append([a, b, c, d]) else: if args.eval_straight: pred_boxes.append([ min(x1, x2, x3, x4), min(y1, y2, y3, y4), max(x1, x2, x3, x4), max(y1, y2, y3, y4), ]) else: pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) pred_labels.append(word.value) # Update the metric det_metric.update(gt_boxes, np.asarray(pred_boxes)) reco_metric.update(gt_labels, reco_words) e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels, pred_labels) # Loop break sample_idx += 1 if isinstance(args.samples, int) and args.samples == sample_idx: break if isinstance(args.samples, int) and args.samples == sample_idx: break # Unpack aggregated metrics print(f"Model Evaluation (model= {args.detection} + {args.recognition}, " f"dataset={'OCRDataset' if args.img_folder else args.dataset})") recall, precision, mean_iou = det_metric.summary() print( f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}" ) acc = reco_metric.summary() print( f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})" ) recall, precision, mean_iou = e2e_metric.summary() print( f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), " f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}" )
def main(args): print(args) if args.push_to_hub: login_to_hub() if not isinstance(args.workers, int): args.workers = min(16, mp.cpu_count()) vocab = VOCABS[args.vocab] fonts = args.font.split(",") # AMP if args.amp: mixed_precision.set_global_policy("mixed_float16") st = time.time() if isinstance(args.val_path, str): with open(os.path.join(args.val_path, "labels.json"), "rb") as f: val_hash = hashlib.sha256(f.read()).hexdigest() # Load val data generator val_set = RecognitionDataset( img_folder=os.path.join(args.val_path, "images"), labels_path=os.path.join(args.val_path, "labels.json"), img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) else: val_hash = None # Load synthetic data generator val_set = WordGenerator( vocab=vocab, min_chars=args.min_chars, max_chars=args.max_chars, num_samples=args.val_samples * len(vocab), font_family=fonts, img_transforms=T.Compose([ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Ensure we have a 90% split of white-background images T.RandomApply(T.ColorInversion(), 0.9), ]), ) val_loader = DataLoader( val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers, ) print( f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{val_loader.num_batches} batches)") # Load doctr model model = recognition.__dict__[args.arch]( pretrained=args.pretrained, input_shape=(args.input_size, 4 * args.input_size, 3), vocab=vocab, ) # Resume weights if isinstance(args.resume, str): model.load_weights(args.resume) # Metrics val_metric = TextMatch() batch_transforms = T.Compose([ T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)), ]) if args.test_only: print("Running evaluation") val_loss, exact_match, partial_match = evaluate( model, val_loader, batch_transforms, val_metric) print( f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})" ) return st = time.time() if isinstance(args.train_path, str): # Load train data generator base_path = Path(args.train_path) parts = ([base_path] if base_path.joinpath("labels.json").is_file() else [base_path.joinpath(sub) for sub in os.listdir(base_path)]) with open(parts[0].joinpath("labels.json"), "rb") as f: train_hash = hashlib.sha256(f.read()).hexdigest() train_set = RecognitionDataset( parts[0].joinpath("images"), parts[0].joinpath("labels.json"), img_transforms=T.Compose([ T.RandomApply(T.ColorInversion(), 0.1), T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Augmentations T.RandomJpegQuality(60), T.RandomSaturation(0.3), T.RandomContrast(0.3), T.RandomBrightness(0.3), ]), ) if len(parts) > 1: for subfolder in parts[1:]: train_set.merge_dataset( RecognitionDataset(subfolder.joinpath("images"), subfolder.joinpath("labels.json"))) else: train_hash = None # Load synthetic data generator train_set = WordGenerator( vocab=vocab, min_chars=args.min_chars, max_chars=args.max_chars, num_samples=args.train_samples * len(vocab), font_family=fonts, img_transforms=T.Compose([ T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Ensure we have a 90% split of white-background images T.RandomApply(T.ColorInversion(), 0.9), T.RandomJpegQuality(60), T.RandomSaturation(0.3), T.RandomContrast(0.3), T.RandomBrightness(0.3), ]), ) train_loader = DataLoader( train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers, ) print( f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{train_loader.num_batches} batches)") if args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target) return # Optimizer scheduler = tf.keras.optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (25e4), # final lr as a fraction of initial lr staircase=False, ) optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) if args.amp: optimizer = mixed_precision.LossScaleOptimizer(optimizer) # LR Finder if args.find_lr: lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) plot_recorder(lrs, losses) return # Tensorboard to monitor training current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name # W&B if args.wb: run = wandb.init( name=exp_name, project="text-recognition", config={ "learning_rate": args.lr, "epochs": args.epochs, "weight_decay": 0.0, "batch_size": args.batch_size, "architecture": args.arch, "input_size": args.input_size, "optimizer": "adam", "framework": "tensorflow", "scheduler": "exp_decay", "vocab": args.vocab, "train_hash": train_hash, "val_hash": val_hash, "pretrained": args.pretrained, }, ) min_loss = np.inf # Training loop mb = master_bar(range(args.epochs)) for epoch in mb: fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp) # Validation loop at the end of each epoch val_loss, exact_match, partial_match = evaluate( model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: print( f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state..." ) model.save_weights(f"./{exp_name}/weights") min_loss = val_loss mb.write( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") # W&B if args.wb: wandb.log({ "val_loss": val_loss, "exact_match": exact_match, "partial_match": partial_match, }) if args.wb: run.finish() if args.push_to_hub: push_to_hf_hub(model, exp_name, task="recognition", run_config=args)
def main(args): print(args) if not isinstance(args.workers, int): args.workers = min(16, mp.cpu_count()) # AMP if args.amp: mixed_precision.set_global_policy("mixed_float16") # Load doctr model model = recognition.__dict__[args.arch]( pretrained=True if args.resume is None else False, input_shape=(args.input_size, 4 * args.input_size, 3), vocab=VOCABS[args.vocab], ) # Resume weights if isinstance(args.resume, str): model.load_weights(args.resume) st = time.time() ds = datasets.__dict__[args.dataset]( train=True, download=True, recognition_task=True, use_polygons=args.regular, img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) _ds = datasets.__dict__[args.dataset]( train=False, download=True, recognition_task=True, use_polygons=args.regular, img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) ds.data.extend([(np_img, target) for np_img, target in _ds.data]) test_loader = DataLoader( ds, batch_size=args.batch_size, drop_last=False, num_workers=args.workers, shuffle=False, ) print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in " f"{len(test_loader)} batches)") mean, std = model.cfg["mean"], model.cfg["std"] batch_transforms = T.Normalize(mean=mean, std=std) # Metrics val_metric = TextMatch() print("Running evaluation") val_loss, exact_match, partial_match = evaluate(model, test_loader, batch_transforms, val_metric) print( f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})" )
def main(args): print(args) torch.backends.cudnn.benchmark = True if not isinstance(args.workers, int): args.workers = min(16, mp.cpu_count()) # Load doctr model model = recognition.__dict__[args.arch]( pretrained=True if args.resume is None else False, input_shape=(3, args.input_size, 4 * args.input_size), vocab=VOCABS[args.vocab], ).eval() # Resume weights if isinstance(args.resume, str): print(f"Resuming {args.resume}") checkpoint = torch.load(args.resume, map_location="cpu") model.load_state_dict(checkpoint) st = time.time() ds = datasets.__dict__[args.dataset]( train=True, download=True, recognition_task=True, use_polygons=args.regular, img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) _ds = datasets.__dict__[args.dataset]( train=False, download=True, recognition_task=True, use_polygons=args.regular, img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), ) ds.data.extend([(np_img, target) for np_img, target in _ds.data]) test_loader = DataLoader( ds, batch_size=args.batch_size, drop_last=False, num_workers=args.workers, sampler=SequentialSampler(ds), pin_memory=torch.cuda.is_available(), collate_fn=ds.collate_fn, ) print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in " f"{len(test_loader)} batches)") mean, std = model.cfg["mean"], model.cfg["std"] batch_transforms = Normalize(mean=mean, std=std) # Metrics val_metric = TextMatch() # GPU if isinstance(args.device, int): if not torch.cuda.is_available(): raise AssertionError( "PyTorch cannot access your GPU. Please investigate!") if args.device >= torch.cuda.device_count(): raise ValueError("Invalid device index") # Silent default switch to GPU if available elif torch.cuda.is_available(): args.device = 0 else: print("No accessible GPU, targe device set to CPU.") if torch.cuda.is_available(): torch.cuda.set_device(args.device) model = model.cuda() print("Running evaluation") val_loss, exact_match, partial_match = evaluate(model, test_loader, batch_transforms, val_metric, amp=args.amp) print( f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})" )