def __init__(self): self.config=get_config() self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') self.model = crnn.get_crnn(self.config).to(self.device) checkpoint = torch.load(self.config.checkpoint) if 'state_dict' in checkpoint.keys(): self.model.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint) self.model.eval()
def main(): config, args = parse_arg() model = crnn.get_crnn(config) criterion = torch.nn.CTCLoss() # get device if torch.cuda.is_available(): device = torch.device("cuda:{}".format(config.GPUID)) else: device = torch.device("cpu:0") model = model.to(device) val_dataset = get_dataset(config)(config, is_train=False) val_loader = DataLoader( dataset=val_dataset, batch_size=config.TEST.BATCH_SIZE_PER_GPU, shuffle=config.TEST.SHUFFLE, num_workers=config.WORKERS, pin_memory=config.PIN_MEMORY, ) converter = utils.strLabelConverter(config.DATASET.ALPHABETS) acc = validate(config, val_loader, val_dataset, converter, model, criterion, device)
def main(): # load config config = parse_arg() # create output folder output_dict = utils.create_log_folder(config, phase='train') # cudnn cudnn.benchmark = config.CUDNN.BENCHMARK cudnn.deterministic = config.CUDNN.DETERMINISTIC cudnn.enabled = config.CUDNN.ENABLED # writer dict writer_dict = { 'writer': SummaryWriter(log_dir=output_dict['tb_dir']), 'train_global_steps': 0, 'valid_global_steps': 0, } # construct face related neural networks model = crnn.get_crnn(config) # get device if torch.cuda.is_available(): device = torch.device("cuda:{}".format(config.GPUID)) else: device = torch.device("cpu:0") model = model.to(device) # define loss function criterion = torch.nn.CTCLoss() optimizer = utils.get_optimizer(config, model) last_epoch = config.TRAIN.BEGIN_EPOCH if config.TRAIN.RESUME.IS_RESUME: model_state_file = config.TRAIN.RESUME.FILE if model_state_file == '': print(" => no checkpoint found") checkpoint = torch.load(model_state_file, map_location='cpu') model.load_state_dict(checkpoint['state_dict']) last_epoch = checkpoint['epoch'] if isinstance(config.TRAIN.LR_STEP, list): lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR, last_epoch-1 ) else: lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR, last_epoch - 1 ) train_dataset = get_dataset(config)(config, is_train=True) train_loader = DataLoader( dataset=train_dataset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, shuffle=config.TRAIN.SHUFFLE, num_workers=config.WORKERS, pin_memory=config.PIN_MEMORY, ) val_dataset = get_dataset(config)(config, is_train=False) val_loader = DataLoader( dataset=val_dataset, batch_size=config.TEST.BATCH_SIZE_PER_GPU, shuffle=config.TEST.SHUFFLE, num_workers=config.WORKERS, pin_memory=config.PIN_MEMORY, ) best_acc = 0.5 converter = utils.strLabelConverter(config.DATASET.ALPHABETS) for epoch in range(last_epoch, config.TRAIN.END_EPOCH): function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict) lr_scheduler.step() acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict) is_best = acc > best_acc best_acc = max(acc, best_acc) print("is best:", is_best) print("best acc is:", best_acc) # save checkpoint torch.save( { "state_dict": model.state_dict(), "epoch": epoch + 1, "best_acc": best_acc, }, os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc)) ) writer_dict['writer'].close()
alphabets_char_file = [ char.strip() for char in open(config_crnn.DATASET.CHAR_FILE).readlines()[1:] ] alphabets_char_file = ''.join(alphabets_char_file) config_crnn.DATASET.ALPHABETS = alphabets_char_file else: config_crnn.DATASET.ALPHABETS = alphabets.alphabet # alphabets_char_file = [char.split('\n')[0] for char in open(args.alphabet_path).readlines()[1:]] # alphabets_char_file = "".join(alphabets_char_file) # config_crnn.DATASET.ALPHABETS = alphabets_char_file config_crnn.MODEL.NUM_CLASSES = len(config_crnn.DATASET.ALPHABETS) ocr_rec_model = CRNN_model.get_crnn(config_crnn).to(0) print("loading pretrained rec model from {0}".format( args.ocr_recognition_model_path)) checkpoint = torch.load(args.ocr_recognition_model_path) if "state_dict" in checkpoint.keys(): ocr_rec_model.load_state_dict(checkpoint["state_dict"]) else: ocr_rec_model.load_state_dict(checkpoint) f = open(args.wts_save_path, 'w') f.write("{}\n".format(len(ocr_rec_model.state_dict().keys()))) for k, v in ocr_rec_model.state_dict().items(): print('key: ', k) print('value: ', v.shape) vr = v.reshape(-1).cpu().numpy() f.write("{} {}".format(k, len(vr)))
_, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([preds.size(0)])) sim_pred = converter.decode(preds.data, preds_size.data, raw=False) print('results: {0}'.format(sim_pred)) if __name__ == '__main__': config, args = parse_arg() device = torch.device( 'cuda:0') if torch.cuda.is_available() else torch.device('cpu') model = crnn.get_crnn(config).to(device) print('loading pretrained model from {0}'.format(args.checkpoint)) checkpoint = torch.load(args.checkpoint) if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) started = time.time() img = cv2.imread(args.image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) converter = utils.strLabelConverter(config.DATASET.ALPHABETS) recognition(config, img, model, converter, device)
def main(): # load config config = parse_arg() # create output folder output_dict = utils.create_log_folder(config, phase='train') # cudnn cudnn.benchmark = config.CUDNN.BENCHMARK cudnn.deterministic = config.CUDNN.DETERMINISTIC cudnn.enabled = config.CUDNN.ENABLED # writer dict writer_dict = { 'writer': SummaryWriter(log_dir=output_dict['tb_dir']), 'train_global_steps': 0, 'valid_global_steps': 0, } # construct face related neural networks model = crnn.get_crnn(config) # # checkpoint = torch.load('/data/yolov5/CRNN_Chinese_Characters_Rec/output/OWN/crnn/2020-09-15-22-13/checkpoints/checkpoint_98_acc_1.0983.pth') # if 'state_dict' in checkpoint.keys(): # model.load_state_dict(checkpoint['state_dict']) # else: # model.load_state_dict(checkpoint) # get device if torch.cuda.is_available(): device = torch.device("cuda:{}".format(config.GPUID)) else: device = torch.device("cpu:0") model = model.to(device) # define loss function # criterion = torch.nn.CTCLoss() criterion = CTCLoss() last_epoch = config.TRAIN.BEGIN_EPOCH optimizer = utils.get_optimizer(config, model) if isinstance(config.TRAIN.LR_STEP, list): lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR, last_epoch - 1) else: lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR, last_epoch - 1) if config.TRAIN.FINETUNE.IS_FINETUNE: model_state_file = config.TRAIN.FINETUNE.FINETUNE_CHECKPOINIT if model_state_file == '': print(" => no checkpoint found") checkpoint = torch.load(model_state_file, map_location='cpu') if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] from collections import OrderedDict model_dict = OrderedDict() for k, v in checkpoint.items(): if 'cnn' in k: model_dict[k[4:]] = v model.cnn.load_state_dict(model_dict) if config.TRAIN.FINETUNE.FREEZE: for p in model.cnn.parameters(): p.requires_grad = False elif config.TRAIN.RESUME.IS_RESUME: model_state_file = config.TRAIN.RESUME.FILE if model_state_file == '': print(" => no checkpoint found") checkpoint = torch.load(model_state_file, map_location='cpu') if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict']) last_epoch = checkpoint['epoch'] # optimizer.load_state_dict(checkpoint['optimizer']) # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) else: model.load_state_dict(checkpoint) model_info(model) train_dataset = get_dataset(config)(config, is_train=True) train_loader = DataLoader( dataset=train_dataset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, shuffle=config.TRAIN.SHUFFLE, num_workers=config.WORKERS, pin_memory=config.PIN_MEMORY, ) val_dataset = get_dataset(config)(config, is_train=False) val_loader = DataLoader( dataset=val_dataset, batch_size=config.TEST.BATCH_SIZE_PER_GPU, shuffle=config.TEST.SHUFFLE, num_workers=config.WORKERS, pin_memory=config.PIN_MEMORY, ) best_acc = 0.5 converter = utils.strLabelConverter(config.DATASET.ALPHABETS) for epoch in range(last_epoch, config.TRAIN.END_EPOCH): function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict) lr_scheduler.step() acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict) is_best = acc > best_acc best_acc = max(acc, best_acc) print("is best:", is_best) print("best acc is:", best_acc) # save checkpoint torch.save( { "state_dict": model.state_dict(), "epoch": epoch + 1, # "optimizer": optimizer.state_dict(), # "lr_scheduler": lr_scheduler.state_dict(), "best_acc": best_acc, }, os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))) writer_dict['writer'].close()
preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([preds.size(0)])) sim_pred = converter.decode(preds.data, preds_size.data, raw=False) # 写入结果 with open('/home/lab/zts/datasets/warped_100_images_results.txt', 'a') as f: f.write(image_name + ' ' + sim_pred + '\n') f.close() print('results: {0}'.format(sim_pred)) if __name__ == '__main__': # 参数解析 config, args = parse_args() # 加载模型 device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') model = crnn.get_crnn(config, len(alphabets.alphabet)).to(device) print('loading pre trained model from {0}'.format(args.checkpoint)) checkpoint = torch.load(args.checkpoint) if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) # time started = time.time() # 标签转换 converter = utils.strLabelConverter(config.DATASET.ALPHABETS) # 处理文件夹 image_folder = args.image_folder for image_name in os.listdir(image_folder): image_path = os.path.join(image_folder, image_name) img = cv2.imread(image_path)