def test_TextSnake(config_file): import sys sys.path.append('./detection_model/TextSnake_pytorch') import os import time import numpy as np import torch import json import torch.backends.cudnn as cudnn import torch.utils.data as data import torch.nn.functional as func from dataset.total_text import TotalText from network.textnet import TextNet from util.detection import TextDetector from util.augmentation import BaseTransform, EvalTransform from util.config import config as cfg, update_config, print_config from util.misc import to_device, fill_hole from util.option import BaseOptions from util.visualize import visualize_detection from Evaluation.Detval import detval import cv2 from yacs.config import CfgNode as CN def read_config_file(config_file): """ read config information form yaml file """ f = open(config_file) opt = CN.load_cfg(f) return opt opt = read_config_file(config_file) def result2polygon(image, result): """ convert geometric info(center_x, center_y, radii) into contours :param result: (list), each with (n, 3), 3 denotes (x, y, radii) :return: (np.ndarray list), polygon format contours """ conts = [] for instance in result: mask = np.zeros(image.shape[:2], dtype=np.uint8) for disk in instance: for x, y, r in disk: cv2.circle(mask, (int(x), int(y)), int(r), (1), -1) cont, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(cont) > 0: # for item in cont: # conts.append(item) conts.append(cont[0]) conts = [cont[:, 0, :] for cont in conts] return conts def mask2conts(mask): conts, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) conts = [cont[:, 0, :] for cont in conts] return conts def rescale_result(image, polygons, H, W): ori_H, ori_W = image.shape[:2] image = cv2.resize(image, (W, H)) for polygon in polygons: cont = polygon['points'] cont[:, 0] = (cont[:, 0] * W / ori_W).astype(int) cont[:, 1] = (cont[:, 1] * H / ori_H).astype(int) polygon['points'] = cont return image, polygons def rescale_padding_result(image, polygons, ori_h, ori_w): h, w = image.shape[:2] # get no-padding image size resize_h = ori_h if ori_h % 32 == 0 else (ori_h // 32) * 32 resize_w = ori_w if ori_w % 32 == 0 else (ori_w // 32) * 32 ratio = float(h) / resize_h if resize_h > resize_w else float( w) / resize_w resize_h = int(resize_h * ratio) resize_w = int(resize_w * ratio) # crop no-padding image no_padding_image = image[0:resize_h, 0:resize_w, ::-1] no_padding_image = cv2.resize(no_padding_image, (ori_w, ori_h)) # rescale points for polygon in polygons: polygon['points'][:, 0] = (polygon['points'][:, 0] * float(ori_w) / resize_w).astype(np.int32) polygon['points'][:, 1] = (polygon['points'][:, 1] * float(ori_h) / resize_h).astype(np.int32) return no_padding_image, polygons def calc_confidence(contours, score_map): polygons = [] for cnt in contours: drawing = np.zeros(score_map.shape[1:], np.int8) mask = cv2.fillPoly(drawing, [cnt.astype(np.int32)], 1) area = np.sum(np.greater(mask, 0)) if not area > 0: continue confidence = np.sum(mask * score_map[0]) / area polygon = {'points': cnt, 'confidence': confidence} polygons.append(polygon) return polygons def load_model(model, model_path): """ load retrained model Args: model: the name to model model_path: the path to model """ print('Loading from {}'.format(model_path)) state_dict = torch.load(model_path) model.load_state_dict(state_dict['model']) def inference(model, detector, test_loader): """ start inference with the parameters provided earlier """ gt_json_path = os.path.join('/home/shf/fudan_ocr_system/datasets/', opt.dataset, 'train_labels.json') # gt_json_path = '/workspace/mnt/group/ocr/wangxunyan/maskscoring_rcnn/crop_train/crop_result_js.json' with open(gt_json_path, 'r') as f: gt_dict = json.load(f) model.eval() result = dict() for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(test_loader): timer = {'model': 0, 'detect': 0, 'viz': 0, 'restore': 0} start = time.time() img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device( img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map) # inference output = model(img) if opt.multi_scale: size_h, size_w = img.shape[2:4] img_rescale = func.interpolate(img, scale_factor=0.5, mode='nearest') output_rescale = model(img_rescale) output_rescale = func.interpolate(output_rescale, size=(size_h, size_w), mode='nearest') timer['model'] = time.time() - start for idx in range(img.size(0)): start = time.time() print('detect {} / {} images: {}.'.format( i, len(test_loader), meta['image_id'][idx])) tr_pred = output[idx, 0:2].softmax(dim=0).data.cpu().numpy() tcl_pred = output[idx, 2:4].softmax(dim=0).data.cpu().numpy() sin_pred = output[idx, 4].data.cpu().numpy() cos_pred = output[idx, 5].data.cpu().numpy() radii_pred = output[idx, 6].data.cpu().numpy() # tr_pred_mask = 1 / (1 + np.exp(-12*tr_pred[1]+3)) tr_pred_mask = np.where(tr_pred[1] > detector.tr_conf_thresh, 1, tr_pred[1]) # tr_pred_mask = fill_hole(tr_pred_mask) tcl_pred_mask = (tcl_pred * tr_pred_mask)[1] > detector.tcl_conf_thresh if opt.multi_scale: tr_pred_rescale = output_rescale[ idx, 0:2].sigmoid().data.cpu().numpy() tcl_pred_rescale = output_rescale[idx, 2:4].softmax( dim=0).data.cpu().numpy() tr_pred_scale_mask = np.where( tr_pred_rescale[1] + tr_pred[1] > 1, 1, tr_pred_rescale[1] + tr_pred[1]) tr_pred_mask = tr_pred_scale_mask # weighted adding origin_ratio = 0.5 rescale_ratio = 0.5 tcl_pred = (tcl_pred * origin_ratio + tcl_pred_rescale * rescale_ratio).astype( np.float32) tcl_pred_mask = ( tcl_pred * tr_pred_mask)[1] > detector.tcl_conf_thresh batch_result = detector.complete_detect( tr_pred_mask, tcl_pred_mask, sin_pred, cos_pred, radii_pred) # (n_tcl, 3) timer['detect'] = time.time() - start start = time.time() # visualization img_show = img[idx].permute(1, 2, 0).cpu().numpy() img_show = ((img_show * opt.stds + opt.means) * 255).astype( np.uint8) H, W = meta['Height'][idx].item(), meta['Width'][idx].item() # get pred_contours contours = result2polygon(img_show, batch_result) if opt.viz: resize_H = H if H % 32 == 0 else (H // 32) * 32 resize_W = W if W % 32 == 0 else (W // 32) * 32 ratio = float( img_show.shape[0] ) / resize_H if resize_H > resize_W else float( img_show.shape[1]) / resize_W resize_H = int(resize_H * ratio) resize_W = int(resize_W * ratio) gt_info = gt_dict[int(meta['image_id'][idx].lstrip( 'gt_').rstrip('.jpg').split('_')[1])] gt_contours = [] # for gt in gt_info: # if not gt['illegibility']: # gt_cont = np.array(gt['points']) # gt_cont[:, 0] = (gt_cont[:, 0] * float(resize_W) / W).astype(np.int32) # gt_cont[:, 1] = (gt_cont[:, 1] * float(resize_H) / H).astype(np.int32) # gt_contours.append(gt_cont) gt_cont = np.array(gt_info['points']) gt_cont[:, 0] = gt_cont[:, 0] * float(resize_W) / float(W) gt_cont[:, 1] = gt_cont[:, 1] * float(resize_H) / float(H) gt_contours.append(gt_cont.astype(np.int32)) illegal_contours = mask2conts( meta['illegal_mask'][idx].cpu().numpy()) predict_vis = visualize_detection( img_show, tr_pred_mask, tcl_pred_mask.astype(np.uint8), contours.copy()) gt_vis = visualize_detection(img_show, tr_mask[idx].cpu().numpy(), tcl_mask[idx].cpu().numpy(), gt_contours, illegal_contours) im_vis = np.concatenate([predict_vis, gt_vis], axis=0) path = os.path.join(opt.vis_dir, meta['image_id'][idx]) cv2.imwrite(path, im_vis) timer['viz'] = time.time() - start start = time.time() polygons = calc_confidence(contours, tr_pred) img_show, polygons = rescale_padding_result( img_show, polygons, H, W) # filter too small polygon for i, poly in enumerate(polygons): if cv2.contourArea(poly['points']) < 100: polygons[i] = [] polygons = [item for item in polygons if item != []] # convert np.array to list for polygon in polygons: polygon['points'] = polygon['points'].tolist() result[meta['image_id'][idx].replace('.jpg', '').replace( 'gt', 'res')] = polygons timer['restore'] = time.time() - start print( 'Cost time {:.2f}s: model {:.2f}s, detect {:.2f}s, viz {:.2f}s, restore {:.2f}s' .format( timer['model'] + timer['detect'] + timer['viz'] + timer['restore'], timer['model'], timer['detect'], timer['viz'], timer['restore'])) # write to json file with open(os.path.join(opt.output_dir, 'result.json'), 'w') as f: json.dump(result, f) print("Output json file in {}.".format(opt.output_dir)) torch.cuda.set_device(opt.num_device) option = BaseOptions(config_file) args = option.initialize() update_config(opt, args) print_config(opt) data_root = os.path.join(opt.data_root, opt.dataset) testset = TotalText( data_root=data_root, ignore_list=os.path.join(data_root, 'ignore_list.txt'), is_training=False, transform=EvalTransform(size=1280, mean=opt.means, std=opt.stds) # transform=BaseTransform(size=1280, mean=opt.means, std=opt.stds) ) test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=opt.num_workers) # Model model = TextNet(backbone=opt.backbone, output_channel=7) model_path = os.path.join(opt.save_dir, opt.exp_name, \ 'textsnake_{}_{}.pth'.format(model.backbone_name, opt.checkepoch)) load_model(model, model_path) # copy to cuda model = model.to(opt.device) if opt.cuda: cudnn.benchmark = True detector = TextDetector(tcl_conf_thresh=0.3, tr_conf_thresh=1.0) # 0.3 # check vis_dir and output_dir exist if opt.viz: if not os.path.exists(opt.vis_dir): os.mkdir(opt.vis_dir) if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir) print('Start testing TextSnake.') inference(model, detector, test_loader) detval() print('End.')
output_dir = os.path.join(cfg.output_dir, cfg.exp_name) inference(detector, test_loader, output_dir) # compute DetEval print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name)) subprocess.call([ 'python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name, '--tr', '0.7', '--tp', '0.6' ]) subprocess.call([ 'python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name, '--tr', '0.8', '--tp', '0.4' ]) print('End.') if __name__ == "__main__": # parse arguments option = BaseOptions() args = option.initialize() update_config(cfg, args) print_config(cfg) vis_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name)) if not os.path.exists(vis_dir): mkdirs(vis_dir) # main main()
def train_TextSnake(config_file): import sys sys.path.append('./detection_model/TextSnake_pytorch') import os import time import torch import torch.backends.cudnn as cudnn import torch.utils.data as data from torch.optim import lr_scheduler import torchvision.utils as vutils from tensorboardX import SummaryWriter from dataset.total_text import TotalText from network.loss import TextLoss from network.textnet import TextNet from util.augmentation import EvalTransform, NewAugmentation from util.config import config as cfg, update_config, print_config, init_config from util.misc import AverageMeter from util.misc import mkdirs, to_device from util.option import BaseOptions from util.visualize import visualize_network_output from yacs.config import CfgNode as CN global total_iter total_iter = 0 def read_config_file(config_file): """ read config information form yaml file """ f = open(config_file) opt = CN.load_cfg(f) return opt opt = read_config_file(config_file) def adjust_learning_rate(optimizer, i): if 0 <= i * opt.batch_size < 100000: lr = opt.lr elif 100000 <= i * opt.batch_size < 400000: lr = opt.lr * 0.1 else: lr = opt.lr * 0.1 * (0.94**( (i * opt.batch_size - 300000) // 100000)) for param_group in optimizer.param_groups: param_group['lr'] = lr def adjust_requires_grad(model, i): if 0 <= i < 4000: for name, param in model.named_parameters(): if name == 'conv1.0.weight' or name == 'conv1.0.bias' or \ name == 'conv1.1.weight' or name == 'conv1.1.bias': param.requires_grad = False else: for name, param in model.named_parameters(): if name == 'conv1.0.weight' or name == 'conv1.0.bias' or \ name == 'conv1.1.weight' or name == 'conv1.1.bias': param.requires_grad = True def save_model(model, optimizer, scheduler, epoch): save_dir = os.path.join(opt.save_dir, opt.exp_name) if not os.path.exists(save_dir): mkdirs(save_dir) save_path = os.path.join( save_dir, 'textsnake_{}_{}.pth'.format(model.backbone_name, epoch)) print('Saving to {}.'.format(save_path)) state_dict = { 'epoch': epoch, 'model': model.state_dict(), 'optim': optimizer.state_dict() # 'scheduler': scheduler.state_dict() } torch.save(state_dict, save_path) def load_model(save_path): print('Loading from {}.'.format(save_path)) checkpoint = torch.load(save_path) return checkpoint def train(model, train_loader, criterion, scheduler, optimizer, epoch, summary_writer): start = time.time() losses = AverageMeter() batch_time = AverageMeter() data_time = AverageMeter() end = time.time() model.train() global total_iter for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(train_loader): data_time.update(time.time() - end) img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device( img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map) output = model(img) tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \ criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask, total_iter) loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss # backward # scheduler.step() optimizer.zero_grad() loss.backward() optimizer.step() losses.update(loss.item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if opt.viz and i < opt.vis_num: visualize_network_output(output, tr_mask, tcl_mask, prefix='train_{}'.format(i)) if i % opt.display_freq == 0: print( 'Epoch: [ {} ][ {:03d} / {:03d} ] - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f} - {:.2f}s/step' .format(epoch, i, len(train_loader), loss.item(), tr_loss.item(), tcl_loss.item(), sin_loss.item(), cos_loss.item(), radii_loss.item(), batch_time.avg)) # write summary if total_iter % opt.summary_freq == 0: print('Summary in {}'.format( os.path.join(opt.summary_dir, opt.exp_name))) tr_pred = output[:, 0:2].softmax(dim=1)[:, 1:2] tcl_pred = output[:, 2:4].softmax(dim=1)[:, 1:2] summary_writer.add_image('input_image', vutils.make_grid(img, normalize=True), total_iter) summary_writer.add_image( 'tr/tr_pred', vutils.make_grid(tr_pred * 255, normalize=True), total_iter) summary_writer.add_image( 'tr/tr_mask', vutils.make_grid( torch.unsqueeze(tr_mask * train_mask, 1) * 255), total_iter) summary_writer.add_image( 'tcl/tcl_pred', vutils.make_grid(tcl_pred * 255, normalize=True), total_iter) summary_writer.add_image( 'tcl/tcl_mask', vutils.make_grid( torch.unsqueeze(tcl_mask * train_mask, 1) * 255), total_iter) summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], total_iter) summary_writer.add_scalar('model/tr_loss', tr_loss.item(), total_iter) summary_writer.add_scalar('model/tcl_loss', tcl_loss.item(), total_iter) summary_writer.add_scalar('model/sin_loss', sin_loss.item(), total_iter) summary_writer.add_scalar('model/cos_loss', cos_loss.item(), total_iter) summary_writer.add_scalar('model/radii_loss', radii_loss.item(), total_iter) summary_writer.add_scalar('model/loss', loss.item(), total_iter) total_iter += 1 print('Speed: {}s /step, {}s /epoch'.format(batch_time.avg, time.time() - start)) if epoch % opt.save_freq == 0: save_model(model, optimizer, scheduler, epoch) print('Training Loss: {}'.format(losses.avg)) def validation(model, valid_loader, criterion): """ print a series of loss information """ model.eval() losses = AverageMeter() for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(valid_loader): print(meta['image_id']) img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device( img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map) output = model(img) tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \ criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask) loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss losses.update(loss.item()) if opt.viz and i < opt.vis_num: visualize_network_output(output, tr_mask, tcl_mask, prefix='val_{}'.format(i)) if i % opt.display_freq == 0: print( 'Validation: - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f}' .format(loss.item(), tr_loss.item(), tcl_loss.item(), sin_loss.item(), cos_loss.item(), radii_loss.item())) print('Validation Loss: {}'.format(losses.avg)) # parse arguments torch.cuda.set_device(opt.num_device) option = BaseOptions(config_file) args = option.initialize() init_config(opt, config_file) update_config(opt, args) print_config(opt) data_root = os.path.join(opt.data_root, opt.dataset) trainset = TotalText(data_root=data_root, ignore_list=os.path.join(data_root, 'ignore_list.txt'), is_training=True, transform=NewAugmentation(size=opt.input_size, mean=opt.means, std=opt.stds, maxlen=1280, minlen=512)) train_loader = data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) # Model model = TextNet(backbone=opt.backbone, output_channel=7) model = model.to(opt.device) if opt.cuda: cudnn.benchmark = True criterion = TextLoss() optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) # scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94) # if opt.dataset == 'ArT_train': # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1) # elif opt.dataset == 'LSVT_full_train': # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1) # load model if resume if opt.resume is not False: checkpoint = load_model(opt.resume) opt.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optim']) # scheduler.load_state_dict(checkpoint['scheduler']) total_iter = checkpoint['epoch'] * len(train_loader) if not os.path.exists(os.path.join(opt.summary_dir, opt.exp_name)): os.mkdir(os.path.join(opt.summary_dir, opt.exp_name)) summary_writer = SummaryWriter( log_dir=os.path.join(opt.summary_dir, opt.exp_name)) print('Start training TextSnake.') for epoch in range(opt.start_epoch, opt.max_epoch): adjust_learning_rate(optimizer, total_iter) train(model, train_loader, criterion, None, optimizer, epoch, summary_writer) print('End.')