def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize): os.environ["CUDA_VISIBLE_DEVICES"] = gpu cuda = True if gpu is not '' else False abc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" #print(abc) os.environ['CUDA_VISIBLE_DEVICES'] = gpu input_size = [int(x) for x in input_size.split('x')] transform = Compose( [Rotation(), Resize(size=(input_size[0], input_size[1]))]) if data_path is not None: data = LoadDataset(data_path=data_path, mode="test", transform=transform) seq_proj = [int(x) for x in seq_proj.split('x')] #net = load_model(abc, seq_proj, backend, snapshot, cuda) net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend) #net = nn.DataParallel(net) if snapshot is not None: load_weights(net, torch.load(snapshot)) if cuda: net = net.cuda() #import pdb;pdb.set_trace() net = net.eval() detect(net, data, cuda, visualize)
val_loader = torch.utils.data.DataLoader(testdataset, shuffle=False, batch_size=opt.batch_size, num_workers=int(opt.workers), collate_fn=alignCollate( imgH=imgH, imgW=imgW, keep_ratio=keep_ratio)) alphabet = keys.alphabetChinese print("char num ", len(alphabet)) model = CRNN(32, 1, len(alphabet) + 1, 256, 1) converter = strLabelConverter(''.join(alphabet)) state_dict = torch.load("../SceneOcr/model/ocr-lstm.pth", map_location=lambda storage, loc: storage) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k if "num_batches_tracked" not in k: # name = name.replace('module.', '') # remove `module.` new_state_dict[name] = v model.cuda() model = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) # load params model.load_state_dict(new_state_dict) model.eval() curAcc = val(model, converter, val_loader, max_iter=5)
class Demo(object): def __init__(self, args): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus self.args = args self.alphabet = alphabetChinese nclass = len(self.alphabet) + 1 nc = 1 self.net = CRNN(args.imgH, nc, args.nh, nclass) self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False) self.transformer = resizeNormalize(args.imgH) print('loading pretrained model from %s' % args.model_path) checkpoint = torch.load(args.model_path) if 'model_state_dict' in checkpoint.keys(): checkpoint = checkpoint['model_state_dict'] from collections import OrderedDict model_dict = OrderedDict() for k, v in checkpoint.items(): if 'module' in k: model_dict[k[7:]] = v else: model_dict[k] = v self.net.load_state_dict(model_dict) if args.cuda and torch.cuda.is_available(): print('available gpus is,', torch.cuda.device_count()) self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda() self.net.eval() def predict(self, image): image = self.transformer(image) if torch.cuda.is_available(): image = image.cuda() image = image.view(1, *image.size()) image = Variable(image) preds = self.net(image) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([preds.size(0)])) raw_pred = self.converter.decode(preds.data, preds_size.data, raw=True) sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False) print('%-20s => %-20s' % (raw_pred, sim_pred)) return sim_pred def predict_batch(self, images): N = len(images) n_batch = N // self.args.batch_size n_batch += 1 if N % self.args.batch_size else 0 res = [] for i in range(n_batch): batch = images[i*self.args.batch_size : min((i+1)*self.args.batch_size, N)] maxW = 0 for i in range(len(batch)): batch[i] = self.transformer(batch[i]) imgW = batch[i].shape[2] maxW = max(maxW, imgW) for i in range(len(batch)): if batch[i].shape[2] < maxW: batch[i] = torch.cat((batch[i], torch.zeros((1, self.args.imgH, maxW-batch[i].shape[2]), dtype=batch[i].dtype)), 2) batch_imgs = torch.cat([t.unsqueeze(0) for t in batch], 0) preds = self.net(batch_imgs) preds_size = Variable(torch.IntTensor([preds.size(0)]*len(batch))) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) raw_preds = self.converter.decode(preds.data, preds_size.data, raw=True) sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False) for raw_pred, sim_pred in zip(raw_preds, sim_preds): print('%-20s => %-20s' % (raw_pred, sim_pred)) res.extend(sim_preds) return res def inference(self, image_path, batch_pred=False): if os.path.isdir(image_path): file_list = os.listdir(image_path) image_list = [os.path.join(image_path, i) for i in file_list if i.rsplit('.')[-1].lower() in img_types] else: image_list = [image_path] res = [] images = [] for img_path in image_list: image = Image.open(img_path).convert('L') if not batch_pred: sim_pred = self.predict(image) res.append(sim_pred) else: images.append(image) if batch_pred and images: res = self.predict_batch(images) return res
detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as f: od_graph_def.ParseFromString(f.read()) tf.import_graph_def(od_graph_def, name='') # Load CRNN model alphabet = '0123456789abcdefghijklmnopqrstuvwxyz' crnn = CRNN(32, 1, 37, 256) if torch.cuda.is_available(): crnn = crnn.cuda() crnn.load_state_dict(torch.load(args.recognition_model_path)) converter = utils.strLabelConverter(alphabet) transformer = dataset.resizeNormalize((100, 32)) crnn.eval() # Open a video file or an image file cap = cv2.VideoCapture(args.input if args.input else 0) while cv2.waitKey(1) < 0: has_frame, frame = cap.read() if not has_frame: cv2.waitKey(0) break im_height, im_width, _ = frame.shape tf_frame = np.expand_dims(frame, axis=0) output_dict = run_inference_for_single_image(tf_frame, detection_graph) for i in range(output_dict['num_detections']):
class Trainer(object): def __init__(self): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus if args.chars_file == '': self.alphabet = alphabetChinese else: self.alphabet = utils.load_chars(args.chars_file) nclass = len(self.alphabet) + 1 nc = 1 self.net = CRNN(args.imgH, nc, args.nh, nclass) self.train_dataloader, self.val_dataloader = self.dataloader( self.alphabet) self.criterion = CTCLoss() self.optimizer = self.get_optimizer() self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False) self.best_acc = 0.00001 model_name = '%s' % (args.dataset_name) if not os.path.exists(args.save_prefix): os.mkdir(args.save_prefix) args.save_prefix += model_name if args.pretrained != '': print('loading pretrained model from %s' % args.pretrained) checkpoint = torch.load(args.pretrained) if 'model_state_dict' in checkpoint.keys(): # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) args.start_epoch = checkpoint['epoch'] self.best_acc = checkpoint['best_acc'] checkpoint = checkpoint['model_state_dict'] from collections import OrderedDict model_dict = OrderedDict() for k, v in checkpoint.items(): if 'module' in k: model_dict[k[7:]] = v else: model_dict[k] = v self.net.load_state_dict(model_dict) if not args.cuda and torch.cuda.is_available(): print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) elif args.cuda and torch.cuda.is_available(): print('available gpus is ', torch.cuda.device_count()) self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda() self.criterion = self.criterion.cuda() def dataloader(self, alphabet): # train_transform = transforms.Compose( # [transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), # resizeNormalize(args.imgH)]) # train_dataset = BaseDataset(args.train_dir, alphabet, transform=train_transform) train_dataset = NumDataset(args.train_dir, alphabet, transform=resizeNormalize(args.imgH)) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) if os.path.exists(args.val_dir): # val_dataset = BaseDataset(args.val_dir, alphabet, transform=resizeNormalize(args.imgH)) val_dataset = NumDataset(args.val_dir, alphabet, mode='test', transform=resizeNormalize(args.imgH)) val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) else: val_dataloader = None return train_dataloader, val_dataloader def get_optimizer(self): if args.optimizer == 'sgd': optimizer = optim.SGD( self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, ) elif args.optimizer == 'adam': optimizer = optim.Adam( self.net.parameters(), lr=args.lr, betas=(args.beta1, 0.999), ) else: optimizer = optim.RMSprop( self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, ) return optimizer def train(self): logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.mkdir(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) logger.info('Start training from [Epoch {}]'.format(args.start_epoch + 1)) losses = utils.Averager() train_accuracy = utils.Averager() for epoch in range(args.start_epoch, args.nepoch): self.net.train() btic = time.time() for i, (imgs, labels) in enumerate(self.train_dataloader): batch_size = imgs.size()[0] imgs = imgs.cuda() preds = self.net(imgs).cpu() text, length = self.converter.encode( labels ) # length 一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标 preds_size = torch.IntTensor([preds.size(0)] * batch_size) loss_avg = self.criterion(preds, text, preds_size, length) / batch_size self.optimizer.zero_grad() loss_avg.backward() self.optimizer.step() losses.update(loss_avg.item(), batch_size) _, preds_m = preds.max(2) preds_m = preds_m.transpose(1, 0).contiguous().view(-1) sim_preds = self.converter.decode(preds_m.data, preds_size.data, raw=False) n_correct = 0 for pred, target in zip(sim_preds, labels): if pred == target: n_correct += 1 train_accuracy.update(n_correct, batch_size, MUL_n=False) if args.log_interval and not (i + 1) % args.log_interval: logger.info( '[Epoch {}/{}][Batch {}/{}], Speed: {:.3f} samples/sec, Loss:{:.3f}' .format(epoch + 1, args.nepoch, i + 1, len(self.train_dataloader), batch_size / (time.time() - btic), losses.val())) losses.reset() logger.info( 'Training accuracy: {:.3f}, [#correct:{} / #total:{}]'.format( train_accuracy.val(), train_accuracy.sum, train_accuracy.count)) train_accuracy.reset() if args.val_interval and not (epoch + 1) % args.val_interval: acc = self.validate(logger) if acc > self.best_acc: self.best_acc = acc save_path = '{:s}_best.pth'.format(args.save_prefix) torch.save( { 'epoch': epoch, 'model_state_dict': self.net.state_dict(), # 'optimizer_state_dict': self.optimizer.state_dict(), 'best_acc': self.best_acc, }, save_path) logging.info("best acc is:{:.3f}".format(self.best_acc)) if args.save_interval and not (epoch + 1) % args.save_interval: save_path = '{:s}_{:04d}_{:.3f}.pth'.format( args.save_prefix, epoch + 1, acc) torch.save( { 'epoch': epoch, 'model_state_dict': self.net.state_dict(), # 'optimizer_state_dict': self.optimizer.state_dict(), 'best_acc': self.best_acc, }, save_path) def validate(self, logger): if self.val_dataloader is None: return 0 logger.info('Start validate.') losses = utils.Averager() self.net.eval() n_correct = 0 with torch.no_grad(): for i, (imgs, labels) in enumerate(self.val_dataloader): batch_size = imgs.size()[0] imgs = imgs.cuda() preds = self.net(imgs).cpu() text, length = self.converter.encode( labels ) # length 一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标 preds_size = torch.IntTensor( [preds.size(0)] * batch_size) # timestep * batchsize loss_avg = self.criterion(preds, text, preds_size, length) / batch_size losses.update(loss_avg.item(), batch_size) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False) for pred, target in zip(sim_preds, labels): if pred == target: n_correct += 1 accuracy = n_correct / float(losses.count) logger.info( 'Evaling loss: {:.3f}, accuracy: {:.3f}, [#correct:{} / #total:{}]' .format(losses.val(), accuracy, n_correct, losses.count)) return accuracy