def load_model_from_checkpoint(checkpoint_file_name, use_gpu=False): """Load a pretrained CRNN model.""" model = CRNN(line_size, 1, len(vocab), 256) checkpoint = torch.load(checkpoint_file_name, map_location='cpu' if not use_gpu else None) model.load_state_dict(checkpoint['state_dict']) model.float() model.eval() model = model.cuda() if use_gpu else model.cpu() return model
def ocr(orig_img, lines, checkpoint_file_name, use_gpu=False): """OCR on segmented lines.""" model = CRNN(line_size, 1, len(vocab), 256) checkpoint = torch.load(checkpoint_file_name, map_location='cpu' if not use_gpu else None) model.load_state_dict(checkpoint['state_dict']) model.float() model.eval() model = model.cuda() if use_gpu else model.cpu() torch.set_grad_enabled(False) result = [] for line in lines: (x1, y1), (x2, y2) = line line_img = image_resize(np.array(np.rot90(orig_img[y1:y2, x1:x2])), height=line_size) inputs = torch.from_numpy(line_img / 255).float().unsqueeze(0).unsqueeze(0) outputs = model(inputs) prediction = outputs.softmax(2).max(2)[1] def to_text(tensor, max_length=None, remove_repetitions=False): sentence = '' sequence = tensor.cpu().detach().numpy() for i in range(len(sequence)): if max_length is not None and i >= max_length: continue char = idx2char[sequence[i]] if char != 'B': # ignore blank if remove_repetitions and i != 0 and char == idx2char[ sequence[i - 1]]: pass else: sentence = sentence + char return sentence predicted_text = to_text(prediction[:, 0], remove_repetitions=True) result.append((line_img, predicted_text)) return result