Beispiel #1
0
 def __init__(self):
     self.config=get_config()
     self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
     self.model = crnn.get_crnn(self.config).to(self.device)
     checkpoint = torch.load(self.config.checkpoint)
     if 'state_dict' in checkpoint.keys():
         self.model.load_state_dict(checkpoint['state_dict'])
     else:
         self.model.load_state_dict(checkpoint)
     self.model.eval()
Beispiel #2
0
def main():
    config, args = parse_arg()
    model = crnn.get_crnn(config)
    criterion = torch.nn.CTCLoss()
    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")
    model = model.to(device)
    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    acc = validate(config, val_loader, val_dataset, converter, model,
                   criterion, device)
Beispiel #3
0
def main():

    # load config
    config = parse_arg()

    # create output folder
    output_dict = utils.create_log_folder(config, phase='train')

    # cudnn
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # writer dict
    writer_dict = {
        'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # construct face related neural networks
    model = crnn.get_crnn(config)

    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")

    model = model.to(device)

    # define loss function
    criterion = torch.nn.CTCLoss()

    optimizer = utils.get_optimizer(config, model)

    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME.IS_RESUME:
        model_state_file = config.TRAIN.RESUME.FILE
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
        last_epoch = checkpoint['epoch']

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch - 1
        )

    train_dataset = get_dataset(config)(config, is_train=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    best_acc = 0.5
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):

        function.train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch, writer_dict, output_dict)
        lr_scheduler.step()

        acc = function.validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)

        is_best = acc > best_acc
        best_acc = max(acc, best_acc)

        print("is best:", is_best)
        print("best acc is:", best_acc)
        # save checkpoint
        torch.save(
            {
                "state_dict": model.state_dict(),
                "epoch": epoch + 1,
                "best_acc": best_acc,
            },  os.path.join(output_dict['chs_dir'], "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc))
        )

    writer_dict['writer'].close()
            alphabets_char_file = [
                char.strip()
                for char in open(config_crnn.DATASET.CHAR_FILE).readlines()[1:]
            ]
            alphabets_char_file = ''.join(alphabets_char_file)
            config_crnn.DATASET.ALPHABETS = alphabets_char_file
        else:
            config_crnn.DATASET.ALPHABETS = alphabets.alphabet

        # alphabets_char_file = [char.split('\n')[0] for char in open(args.alphabet_path).readlines()[1:]]
        # alphabets_char_file = "".join(alphabets_char_file)
        # config_crnn.DATASET.ALPHABETS = alphabets_char_file

        config_crnn.MODEL.NUM_CLASSES = len(config_crnn.DATASET.ALPHABETS)

    ocr_rec_model = CRNN_model.get_crnn(config_crnn).to(0)
    print("loading pretrained rec model from {0}".format(
        args.ocr_recognition_model_path))
    checkpoint = torch.load(args.ocr_recognition_model_path)
    if "state_dict" in checkpoint.keys():
        ocr_rec_model.load_state_dict(checkpoint["state_dict"])
    else:
        ocr_rec_model.load_state_dict(checkpoint)

    f = open(args.wts_save_path, 'w')
    f.write("{}\n".format(len(ocr_rec_model.state_dict().keys())))
    for k, v in ocr_rec_model.state_dict().items():
        print('key: ', k)
        print('value: ', v.shape)
        vr = v.reshape(-1).cpu().numpy()
        f.write("{} {}".format(k, len(vr)))
Beispiel #5
0
    _, preds = preds.max(2)
    preds = preds.transpose(1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)

    print('results: {0}'.format(sim_pred))


if __name__ == '__main__':

    config, args = parse_arg()
    device = torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    model = crnn.get_crnn(config).to(device)
    print('loading pretrained model from {0}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint)
    if 'state_dict' in checkpoint.keys():
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)

    started = time.time()

    img = cv2.imread(args.image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)

    recognition(config, img, model, converter, device)
def main():

    # load config
    config = parse_arg()

    # create output folder
    output_dict = utils.create_log_folder(config, phase='train')

    # cudnn
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # writer dict
    writer_dict = {
        'writer': SummaryWriter(log_dir=output_dict['tb_dir']),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # construct face related neural networks
    model = crnn.get_crnn(config)
    #
    # checkpoint = torch.load('/data/yolov5/CRNN_Chinese_Characters_Rec/output/OWN/crnn/2020-09-15-22-13/checkpoints/checkpoint_98_acc_1.0983.pth')
    # if 'state_dict' in checkpoint.keys():
    #     model.load_state_dict(checkpoint['state_dict'])
    # else:
    #     model.load_state_dict(checkpoint)
    # get device
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(config.GPUID))
    else:
        device = torch.device("cpu:0")

    model = model.to(device)

    # define loss function
    # criterion = torch.nn.CTCLoss()
    criterion = CTCLoss()

    last_epoch = config.TRAIN.BEGIN_EPOCH
    optimizer = utils.get_optimizer(config, model)
    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    if config.TRAIN.FINETUNE.IS_FINETUNE:
        model_state_file = config.TRAIN.FINETUNE.FINETUNE_CHECKPOINIT
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']

        from collections import OrderedDict
        model_dict = OrderedDict()
        for k, v in checkpoint.items():
            if 'cnn' in k:
                model_dict[k[4:]] = v
        model.cnn.load_state_dict(model_dict)
        if config.TRAIN.FINETUNE.FREEZE:
            for p in model.cnn.parameters():
                p.requires_grad = False

    elif config.TRAIN.RESUME.IS_RESUME:
        model_state_file = config.TRAIN.RESUME.FILE
        if model_state_file == '':
            print(" => no checkpoint found")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        if 'state_dict' in checkpoint.keys():
            model.load_state_dict(checkpoint['state_dict'])
            last_epoch = checkpoint['epoch']
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        else:
            model.load_state_dict(checkpoint)

    model_info(model)
    train_dataset = get_dataset(config)(config, is_train=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    best_acc = 0.5
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):

        function.train(config, train_loader, train_dataset, converter, model,
                       criterion, optimizer, device, epoch, writer_dict,
                       output_dict)
        lr_scheduler.step()

        acc = function.validate(config, val_loader, val_dataset, converter,
                                model, criterion, device, epoch, writer_dict,
                                output_dict)

        is_best = acc > best_acc
        best_acc = max(acc, best_acc)

        print("is best:", is_best)
        print("best acc is:", best_acc)
        # save checkpoint
        torch.save(
            {
                "state_dict": model.state_dict(),
                "epoch": epoch + 1,
                # "optimizer": optimizer.state_dict(),
                # "lr_scheduler": lr_scheduler.state_dict(),
                "best_acc": best_acc,
            },
            os.path.join(output_dict['chs_dir'],
                         "checkpoint_{}_acc_{:.4f}.pth".format(epoch, acc)))

    writer_dict['writer'].close()
Beispiel #7
0
    preds = preds.transpose(1, 0).contiguous().view(-1)
    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
    # 写入结果
    with open('/home/lab/zts/datasets/warped_100_images_results.txt', 'a') as f:
        f.write(image_name + ' ' + sim_pred + '\n')
    f.close()
    print('results: {0}'.format(sim_pred))


if __name__ == '__main__':
    # 参数解析
    config, args = parse_args()
    # 加载模型
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    model = crnn.get_crnn(config, len(alphabets.alphabet)).to(device)
    print('loading pre trained model from {0}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint)
    if 'state_dict' in checkpoint.keys():
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)
    # time
    started = time.time()
    # 标签转换
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
    # 处理文件夹
    image_folder = args.image_folder
    for image_name in os.listdir(image_folder):
        image_path = os.path.join(image_folder, image_name)
        img = cv2.imread(image_path)