def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) """ evaluation """ model.eval() with torch.no_grad(): if opt.benchmark_all_eval: benchmark_all_eval(model, criterion, converter, opt) else: log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a') AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) log.write(eval_data_log) print(f'{accuracy_by_best_model:0.3f}') log.write(f'{accuracy_by_best_model:0.3f}\n') log.close()
class Mode(): def __init__(self, config, is_training): self.config = config self.is_training = is_training self.net = Model(self.config, is_training=self.is_training) if self.is_training: self.net.train(is_training) else: self.net.eval() self.net.init_weights() if self.is_training: self.optimizer = self._get_optimizer() if len(self.config.parallels) > 0: self.net = nn.DataParallel(self.net) self.net = self.net.cuda() self.yolo_loss = [] for i in range(3): self.yolo_loss.append( YOLOLoss(config.anchors[i], config.image_size, config.num_classes)) #if is_refine: # self.refine_loss = RefineLoss(config.anchors, config.num_classes, (config.image_size, config.image_size)) if config.pretrained_weights: logging.info("Load pretrained weights from {}".format( config.pretrained_weights)) checkpoint = torch.load(config.pretrained_weights) state_dict = checkpoint['state_dict'] self.net.load_state_dict(state_dict) self.epoch = checkpoint["epoch"] + 1 self.global_step = checkpoint['global step'] + 1 else: self.epoch = 0 self.global_step = 0 if config.official_weights: logging.info("Loading official weights from {}".format( config.official_weights)) self.net.load_state_dict(torch.load(config.official_weights)) self.global_step = 20000 def _get_optimizer(self): optimizer = None # Assign different lr for each layer params = None base_params = list(map(id, self.net.backbone.parameters())) logits_params = filter(lambda p: id(p) not in base_params, self.net.parameters()) if not self.config.freeze_backbone: params = [ { "params": self.net.parameters(), "lr": self.config.learning_rate }, ] else: logging.info("freeze backbone's parameters.") for p in self.net.backbone.parameters(): p.requires_grad = False params = [ { "params": logits_params, "lr": self.config.learning_rate }, ] # Initialize optimizer class if self.config.optimizer == "adam": optimizer = optim.Adam(params, weight_decay=self.config.weight_decay) elif self.config.optimizer == "amsgrad": optimizer = optim.Adam(params, weight_decay=self.config.weight_decay, amsgrad=True) elif self.config.optimizer == "rmsprop": optimizer = optim.RMSprop(params, weight_decay=self.config.weight_decay) else: # Default to sgd logging.info("Using SGD optimizer.") optimizer = optim.SGD( params, momentum=self.config.momentum, weight_decay=self.config.weight_decay, nesterov=(self.config.optimizer == "nesterov")) return optimizer def train(self, train_dataloader, val_dataloader): # Optimizer def adjust_learning_rate(optimizer, config, global_step): lr = config.learning_rate if global_step < config.burn_in: lr = lr * (global_step / config.burn_in) * (global_step / config.burn_in) elif global_step < config.decay_step[0]: lr = lr elif global_step < config.decay_step[1]: lr = config.decay_gamma * lr else: lr = config.decay_gamma * config.decay_gamma * lr for param_group in optimizer.param_groups: param_group['lr'] = lr return lr summary = SummaryWriter(self.config.write) logging.info("Start training") while self.global_step < self.config.max_iter: #train step train_dataloader.dataset.random_shuffle() train_dataloader.dataset.update(self.global_step) for step, samples in enumerate(train_dataloader): images, labels = samples['image'], samples['label'] images = images.cuda() image_size = images.size(2) batch_size = images.size(0) start_time = time.time() lr = adjust_learning_rate(self.optimizer, self.config, self.global_step) self.optimizer.zero_grad() outputs = self.net(images) losses_name = [ "total_loss", "x", "y", "w", "h", "conf", "cls", "a" ] losses = [] for _ in range(len(losses_name)): losses.append([]) for i in range(3): _loss_item = self.yolo_loss[i](outputs[i], labels, self.global_step) for j, l in enumerate(_loss_item): losses[j].append(l) losses = [sum(l) for l in losses] loss = losses[0] loss.backward() self.optimizer.step() #memory_usage_psutil() if step >= 0 and step % 10 == 0: _loss = loss.item() time_per_example = float(time.time() - start_time) / batch_size logging.info( "epoch [%.3d] step = %d size = %d loss = %.2f time/example = %.3f lr = %.5f loss_x = %.3f loss_y = %.3f loss_w = %.3f loss_h = %.3f loss_conf = %.3f loss_cls = %.3f loss_a = %.3f" % (self.epoch, step, image_size, _loss, time_per_example, lr, losses[1], losses[2], losses[3], losses[4], losses[5], losses[6], losses[7])) summary.add_scalar("lr", lr, self.global_step) for i, name in enumerate(losses_name): v = _loss if i == 0 else losses[i] summary.add_scalar(name, v, self.global_step) if step > 0 and step % 1000 == 0: checkpoint_path = os.path.join(self.config.save_dir, "model_backup.pth") checkpoint = { 'state_dict': self.net.state_dict(), 'epoch': self.epoch, "global step": self.global_step } torch.save(checkpoint, checkpoint_path) logging.info("Model checkpoint saved to {}".format( checkpoint_path)) self.global_step += 1 checkpoint_path = os.path.join(self.config.save_dir, "model_{}.pth".format(self.epoch)) checkpoint = { 'state_dict': self.net.state_dict(), 'epoch': self.epoch, "global step": self.global_step } torch.save(checkpoint, checkpoint_path) logging.info( "Model checkpoint saved to {}".format(checkpoint_path)) #val every epoch logging.info('Start validating after epoch {}'.format(self.epoch)) val_losses = [] val_num = len(val_dataloader) for step, samples in enumerate(val_dataloader): images, labels = samples['image'], samples['label'] with torch.no_grad(): outputs = self.net(images) losses_name = [ "total_loss", "x", "y", "w", "h", "conf", "cls", "a" ] losses = [] for _ in range(len(losses_name)): losses.append([]) for i in range(3): _loss_item = self.yolo_loss[i](outputs[i], labels) for j, l in enumerate(_loss_item): losses[j].append(l) losses = [sum(l) for l in losses] val_loss = losses[0].item() if step > 0 and step % 10 == 0: logging.info("Having validated [%.3d/%.3d]" % (step, val_num)) val_losses.append(val_loss) val_loss = np.mean(np.asarray(val_losses)) logging.info("val loss = %.2f at epoch [%.3d]" % (val_loss, self.epoch)) self.epoch += 1 #def inference(self, inputs): # with torch.no_grad(): # outputs = self.net(inputs) # output = self.yolo_loss(outputs) # detections = def eval_coco(self, val_dataset): index2category = json.load(open("coco_index2category.json")) logging.info('Start Evaling') coco_result = [] coco_img_ids = set([]) for step, samples in enumerate(val_dataset): images, labels = samples['image'], samples['label'] image_size = images.size(2) image_paths, origin_sizes = samples['image_path'], samples[ 'origin_size'] with torch.no_grad(): outputs = self.net(images) #output = self.yolo_loss(outputs) output_list = [] for i in range(3): output_list.append(self.yolo_loss[i](outputs[i])) output = torch.cat(output_list, 1) batch_detections = non_max_suppression(output, self.config.num_classes, conf_thres=0.001, nms_thres=0.45) for idx, detections in enumerate(batch_detections): image_id = int(os.path.basename(image_paths[idx])[-16:-4]) coco_img_ids.add(image_id) if detections is not None: origin_size = eval(origin_sizes[idx]) detections = detections.cpu().numpy() dim_diff = np.abs(origin_size[0] - origin_size[1]) pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2 pad = ((pad1, pad2), (0, 0), (0, 0)) if origin_size[1] <= origin_size[0] else (( 0, 0), (pad1, pad2), (0, 0)) scale = origin_size[0] if origin_size[1] <= origin_size[ 0] else origin_size[1] for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections: x1 = x1 / self.config.image_size * scale x2 = x2 / self.config.image_size * scale y1 = y1 / self.config.image_size * scale y2 = y2 / self.config.image_size * scale x1 -= pad[1][0] y1 -= pad[0][0] x2 -= pad[1][0] y2 -= pad[0][0] w = x2 - x1 h = y2 - y1 coco_result.append({ "image_id": image_id, "category_id": index2category[str(int(cls_pred.item()))], "bbox": (float(x1), float(y1), float(w), float(h)), "score": float(conf), }) logging.info("Now have finished [%.3d/%.3d]" % (step, len(val_dataset))) save_path = "coco_results.json" with open(save_path, "w") as f: json.dump(coco_result, f, sort_keys=True, indent=4, separators=(',', ':')) logging.info('Save result in {}'.format(save_path)) logging.info('Using COCO APi to evaluate') cocoGt = COCO(self.config.annotation) cocoDt = cocoGt.loadRes(save_path) cocoEval = COCOeval(cocoGt, cocoDt, "bbox") cocoEval.params.imgIds = list(coco_img_ids) cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() def eval_voc(self, val_dataset, classes, iou_thresh=0.5): logging.info('Start Evaling') results = {} def voc_ap(rec, prec, use_07_metric=False): """ ap = voc_ap(rec, prec, [use_07_metric]) Compute VOC AP given precision and recall. If use_07_metric is true, uses the VOC 07 11 point method (default:False). """ _rec = np.arange(0., 1.1, 0.1) _prec = [] if use_07_metric: # 11 point metric ap = 0. for t in np.arange(0., 1.1, 0.1): if np.sum(rec >= t) == 0: p = 0 else: p = np.max(prec[rec >= t]) _prec.append(p) ap = ap + p / 11. else: # correct AP calculation # first append sentinel values at the end mrec = np.concatenate(([0.], rec, [1.])) mpre = np.concatenate(([0.], prec, [0.])) # compute the precision envelope for i in range(mpre.size - 1, 0, -1): mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) # to calculate area under PR curve, look for points # where X axis (recall) changes value i = np.where(mrec[1:] != mrec[:-1])[0] # and sum (\Delta recall) * prec ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) return ap def caculate_ap(correct, conf, pred_cls, total, classes): correct, conf, pred_cls = np.array(correct), np.array( conf), np.array(pred_cls) index = np.argsort(-conf) correct, conf, pred_cls = correct[index], conf[index], pred_cls[ index] ap = [] AP = {} for i, c in enumerate(classes): k = pred_cls == i n_gt = total[c] n_p = sum(k) if n_gt == 0 and n_p == 0: continue elif n_p == 0 or n_gt == 0: ap.append(0) AP[c] = 0 else: fpc = np.cumsum(1 - correct[k]) tpc = np.cumsum(correct[k]) rec = tpc / n_gt prec = tpc / (tpc + fpc) _ap = voc_ap(rec, prec) ap.append(_ap) AP[c] = _ap mAP = np.array(ap).mean() return mAP, AP def parse_rec(imagename, classes): filename = imagename.replace('jpg', 'xml') tree = ET.parse(filename) objects = [] for obj in tree.findall('object'): difficult = obj.find('difficult').text cls = obj.find('name').text if cls not in classes or int(difficult) == 1: continue cls_id = classes.index(cls) xmlbox = obj.find('bndbox') obj = [ float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text), cls_id ] objects.append(obj) return np.asarray(objects) total = {} for cls in classes: total[cls] = 0 correct = [] conf_list = [] pred_list = [] for step, samples in enumerate(val_dataset): images, labels = samples['image'], samples['label'] image_paths, origin_sizes = samples['image_path'], samples[ 'origin_size'] logging.info("Now have finished [%.3d/%.3d]" % (step, len(val_dataset))) with torch.no_grad(): outputs = self.net(images) output_list = [] for i in range(3): output_list.append(self.yolo_loss[i](outputs[i])) output = torch.cat(output_list, 1) batch_detections = non_max_suppression(output, self.config.num_classes, conf_thres=0.001, nms_thres=0.4) for idx, detections in enumerate(batch_detections): image_path = image_paths[idx] label = labels[idx] for t in range(label.size(0)): if label[t, :].sum() == 0: label = label[:t, :] break label_cls = np.array(label[:, 0]) for cls_id in label_cls: total[classes[int(cls_id)]] += 1 if detections is None: if label.size(0) != 0: label_cls = np.unique(label_cls) for cls_id in label_cls: correct.append(0) conf_list.append(1) pred_list.append(int(cls_id)) continue if label.size(0) == 0: for *pred_box, conf, cls_conf, cls_pred in detections: correct.append(0) conf_list.append(conf) pred_list.append(int(cls_pred)) else: detections = detections[np.argsort(-detections[:, 4])] detected = [] for *pred_box, conf, cls_conf, cls_pred in detections: pred_box = torch.FloatTensor(pred_box).view(1, -1) pred_box[:, 2:] = pred_box[:, 2:] - pred_box[:, :2] pred_box[:, :2] = pred_box[:, :2] + pred_box[:, 2:] / 2 pred_box = pred_box / self.config.image_size ious = bbox_iou(pred_box, label[:, 1:]) best_i = np.argmax(ious) if ious[best_i] > iou_thresh and int(cls_pred) == int( label[best_i, 0]) and best_i not in detected: correct.append(1) detected.append(best_i) else: correct.append(0) pred_list.append(int(cls_pred)) conf_list.append(float(conf)) results['correct'] = correct results['conf'] = conf_list results['pred_cls'] = pred_list results['total'] = total with open('results.json', 'w') as f: json.dump(results, f) logging.info('Having saved to results.json') logging.info('Begin calculating....') with open('results.json', 'r') as result_file: results = json.load(result_file) mAP, AP_class = caculate_ap(correct=results['correct'], conf=results['conf'], pred_cls=results['pred_cls'], total=results['total'], classes=classes) logging.info('mAP(IoU=0.5):{:.1f}'.format(mAP * 100)) def inference(self, image, classes, colors): image_origin = image image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (self.config.image_size, self.config.image_size), interpolation=cv2.INTER_LINEAR) image = np.expand_dims(image, 0) image = image.astype(np.float32) image /= 255 image = np.transpose(image, (0, 3, 1, 2)) image = image.astype(np.float32) image = torch.from_numpy(image) start_time = time.time() if torch.cuda.is_available(): image = image.cuda() with torch.no_grad(): outputs = self.net(image) output_list = [] for i in range(3): output_list.append(self.yolo_loss[i](outputs[i])) output = torch.cat(output_list, 1) batch_detections = non_max_suppression(output, self.config.num_classes, conf_thres=0.5, nms_thres=0.4) spand_time = float(time.time() - start_time) detection = batch_detections[0] if detection is not None: origin_size = image_origin.shape[:2] detection = detection.cpu().numpy() for x1, y1, x2, y2, conf, cls_conf, cls_pred in detection: x1 = int(x1 / self.config.image_size * origin_size[1]) x2 = int(x2 / self.config.image_size * origin_size[1]) y1 = int(y1 / self.config.image_size * origin_size[0]) y2 = int(y2 / self.config.image_size * origin_size[0]) color = colors[int(cls_pred)] image_origin = cv2.rectangle(image_origin, (x1, y1), (x2, y2), color, 3) image_origin = cv2.rectangle(image_origin, (x1, y1), (x2, y1 + 20), color, thickness=-1) caption = "{}:{:.2f}".format(classes[int(cls_pred)], cls_conf) image_origin = cv2.putText(image_origin, caption, (x1, y1 + 15), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) return image_origin, spand_time
def train(opt): """ Dataset Preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset( root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ Model Configuration """ if 'CTC' in opt.Prediction: if opt.baiduCTC: converter = CTCLabelConverterForBaiduWarpctc(opt.character) else: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).to(device) model.train() if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) print("Model:") print(model) """ Setup Loss """ if 'CTC' in opt.Prediction: if opt.baiduCTC: # need to install warpctc. see our guideline. from warpctc_pytorch import CTCLoss criterion = CTCLoss() else: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ Final Options """ # print(opt) with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ Start Training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass start_time = time.time() best_accuracy = -1 best_norm_ED = -1 iteration = start_iter while True: # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text) preds_size = torch.IntTensor([preds.size(1)] * batch_size) if opt.baiduCTC: preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) / batch_size else: preds = preds.log_softmax(2).permute(1, 0, 2) cost = criterion(preds, text, preds_size, length) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if ( iteration + 1 ) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') sys.exit() iteration += 1
def demo(opt): """ Model Configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader( demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=0, # In Linux use int(opt.workers), in Windows 0 collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) log = open(f'./log_demo_result.txt', 'a') dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') log.close()