Exemplo n.º 1
0
def main():
    eval_batch_size = config["eval_batch_size"]
    cpu_workers = config["cpu_workers"]
    reload_checkpoint = config["reload_checkpoint"]

    img_height = config["img_height"]
    img_width = config["img_width"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    test_dataset = Synth90kDataset(
        root_dir=config["data_dir"],
        mode="test",
        img_height=img_height,
        img_width=img_width,
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=cpu_workers,
        collate_fn=synth90k_collate_fn,
    )

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(
        1,
        img_height,
        img_width,
        num_class,
        map_to_seq_hidden=config["map_to_seq_hidden"],
        rnn_hidden=config["rnn_hidden"],
        leaky_relu=config["leaky_relu"],
    )
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    criterion = CTCLoss(reduction="sum")
    criterion.to(device)

    evaluation = evaluate(
        crnn,
        test_loader,
        criterion,
        decode_method=config["decode_method"],
        beam_size=config["beam_size"],
    )
    print("test_evaluation: loss={loss}, acc={acc}".format(**evaluation))
Exemplo n.º 2
0
def main():
    eval_batch_size = config['eval_batch_size']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']

    img_height = config['img_height']
    img_width = config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    test_dataset = Synth90kDataset(root_dir=config['data_dir'],
                                   mode='test',
                                   img_height=img_height,
                                   img_width=img_width)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=eval_batch_size,
                             shuffle=False,
                             num_workers=cpu_workers,
                             collate_fn=synth90k_collate_fn)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1,
                img_height,
                img_width,
                num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    evaluation = evaluate(crnn,
                          test_loader,
                          criterion,
                          decode_method=config['decode_method'],
                          beam_size=config['beam_size'])
    print('test_evaluation: loss={loss}, acc={acc}'.format(**evaluation))
Exemplo n.º 3
0
def main():
    arguments = docopt(__doc__)

    images = arguments['IMAGE']
    reload_checkpoint = arguments['-m']
    batch_size = int(arguments['-s'])
    decode_method = arguments['-d']
    beam_size = int(arguments['-b'])

    img_height = config['img_height']
    img_width = config['img_width']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    predict_dataset = Synth90kDataset(paths=images,
                                      img_height=img_height,
                                      img_width=img_width)

    predict_loader = DataLoader(dataset=predict_dataset,
                                batch_size=batch_size,
                                shuffle=False)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1,
                img_height,
                img_width,
                num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    preds = predict(crnn,
                    predict_loader,
                    Synth90kDataset.LABEL2CHAR,
                    decode_method=decode_method,
                    beam_size=beam_size)

    show_result(images, preds)
Exemplo n.º 4
0
def main():
    epochs = config['epochs']
    train_batch_size = config['train_batch_size']
    eval_batch_size = config['eval_batch_size']
    lr = config['lr']
    show_interval = config['show_interval']
    valid_interval = config['valid_interval']
    save_interval = config['save_interval']
    cpu_workers = config['cpu_workers']
    reload_checkpoint = config['reload_checkpoint']
    valid_max_iter = config['valid_max_iter']

    img_width = config['img_width']
    img_height = config['img_height']
    data_dir = config['data_dir']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    train_dataset = Synth90kDataset(root_dir=data_dir,
                                    mode='train',
                                    img_height=img_height,
                                    img_width=img_width)
    valid_dataset = Synth90kDataset(root_dir=data_dir,
                                    mode='dev',
                                    img_height=img_height,
                                    img_width=img_width)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=train_batch_size,
                              shuffle=True,
                              num_workers=cpu_workers,
                              collate_fn=synth90k_collate_fn)
    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=eval_batch_size,
                              shuffle=True,
                              num_workers=cpu_workers,
                              collate_fn=synth90k_collate_fn)

    num_class = len(Synth90kDataset.LABEL2CHAR) + 1
    crnn = CRNN(1,
                img_height,
                img_width,
                num_class,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])
    if reload_checkpoint:
        crnn.load_state_dict(torch.load(reload_checkpoint,
                                        map_location=device))
    crnn.to(device)

    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction='sum')
    criterion.to(device)

    assert save_interval % valid_interval == 0
    i = 1
    for epoch in range(1, epochs + 1):
        print(f'epoch: {epoch}')
        tot_train_loss = 0.
        tot_train_count = 0
        for train_data in train_loader:
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)

            tot_train_loss += loss
            tot_train_count += train_size
            if i % show_interval == 0:
                print('train_batch_loss[', i, ']: ', loss / train_size)

            if i % valid_interval == 0:
                evaluation = evaluate(crnn,
                                      valid_loader,
                                      criterion,
                                      decode_method=config['decode_method'],
                                      beam_size=config['beam_size'])
                print('valid_evaluation: loss={loss}, acc={acc}'.format(
                    **evaluation))

                if i % save_interval == 0:
                    prefix = 'crnn'
                    loss = evaluation['loss']
                    save_model_path = os.path.join(
                        config['checkpoints_dir'],
                        f'{prefix}_{i:06}_loss{loss}.pt')
                    torch.save(crnn.state_dict(), save_model_path)
                    print('save model at ', save_model_path)

            i += 1

        print('train_loss: ', tot_train_loss / tot_train_count)
Exemplo n.º 5
0
def main():
    epochs = config["epochs"]
    train_batch_size = config["train_batch_size"]
    eval_batch_size = config["eval_batch_size"]
    lr = config["lr"]
    show_interval = config["show_interval"]
    valid_interval = config["valid_interval"]
    save_interval = config["save_interval"]
    cpu_workers = config["cpu_workers"]
    reload_checkpoint = config["reload_checkpoint"]
    valid_max_iter = config["valid_max_iter"]

    img_width = config["img_width"]
    img_height = config["img_height"]
    data_dir = config["data_dir"]

    torch.cuda.empty_cache()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    #     # Extracts metadata related to the file path, label, and image width + height.
    #    wbsin_dir = Path().cwd() / "data" / "processed" / "cropped_wbsin"
    #    wbsin_meta_df = extract_jpg_meta(img_dir=wbsin_dir, img_type="wbsin")
    #     # Saves the extracted metadata.
    # interim_path = Path.cwd() / "data" / "interim"
    #    interim_path.mkdir(parents=True, exist_ok=True)
    #    wbsin_meta_df.to_csv(interim_path / "wbsin_meta.csv", index=False)

    X_transforms = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((160, 1440)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )),
    ])

    wbsin_dataset = WbsinImageDataset(
        meta_file=(Path.cwd() / "data" / "processed" /
                   "processed_wbsin_meta.csv"),
        transform=X_transforms,
    )

    train_size = int(0.8 * len(wbsin_dataset))
    test_size = len(wbsin_dataset) - train_size

    train_dataset, test_dataset = torch.utils.data.random_split(
        wbsin_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )

    # Save the test_dataset

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=wbsin_collate_fn,
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=True,
        num_workers=cpu_workers,
        collate_fn=wbsin_collate_fn,
    )

    num_class = len(WbsinImageDataset.LABEL2CHAR) + 1
    crnn = CRNN(
        1,
        img_height,
        img_width,
        num_class,
        map_to_seq_hidden=config["map_to_seq_hidden"],
        rnn_hidden=config["rnn_hidden"],
        leaky_relu=config["leaky_relu"],
    )
    if reload_checkpoint:
        crnn.load_state_dict(torch.load(reload_checkpoint,
                                        map_location=device))
    crnn.to(device)

    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
    criterion = CTCLoss(reduction="sum")
    criterion.to(device)

    assert save_interval % valid_interval == 0
    i = 1
    for epoch in range(1, epochs + 1):
        print(f"epoch: {epoch}")
        tot_train_loss = 0.0
        tot_train_count = 0
        for train_data in train_dataloader:
            loss = train_batch(crnn, train_data, optimizer, criterion, device)
            train_size = train_data[0].size(0)

            tot_train_loss += loss
            tot_train_count += train_size
            if i % show_interval == 0:
                print("train_batch_loss[", i, "]: ", loss / train_size)

            if i % valid_interval == 0:
                evaluation = evaluate(
                    crnn,
                    test_dataloader,
                    criterion,
                    decode_method=config["decode_method"],
                    beam_size=config["beam_size"],
                )
                print(
                    "valid_evaluation: loss={loss}, acc={acc}, char_acc={char_acc}"
                    .format(**evaluation))

                if i % save_interval == 0:
                    prefix = "crnn"
                    loss = evaluation["loss"]
                    save_model_path = os.path.join(
                        config["checkpoints_dir"],
                        f"{prefix}_{i:06}_loss{loss}.pt")
                    torch.save(crnn.state_dict(), save_model_path)
                    print("save model at ", save_model_path)
            i += 1

        print("train_loss: ", tot_train_loss / tot_train_count)
Exemplo n.º 6
0
def main():
    arguments = docopt(__doc__)

    images = arguments["IMAGE"]
    images = [image for image in (Path.cwd() / images[0]).glob("*.jpg")]
    reload_checkpoint = arguments["-m"]
    batch_size = int(arguments["-s"])
    decode_method = arguments["-d"]
    beam_size = int(arguments["-b"])

    img_height = config["img_height"]
    img_width = config["img_width"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")

    X_transforms = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((160, 1440)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )),
    ])
    wbsin_dataset = WbsinImageDataset(
        meta_file=(Path.cwd() / "data" / "interim" / "cropped_wbsin_meta.csv"),
        transform=X_transforms)

    train_size = int(0.8 * len(wbsin_dataset))
    test_size = len(wbsin_dataset) - train_size

    train_dataset, test_dataset = torch.utils.data.random_split(
        wbsin_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=wbsin_collate_fn,
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=wbsin_collate_fn,
    )

    num_class = len(WbsinImageDataset.LABEL2CHAR) + 1
    crnn = CRNN(
        1,
        img_height,
        img_width,
        num_class,
        map_to_seq_hidden=config["map_to_seq_hidden"],
        rnn_hidden=config["rnn_hidden"],
        leaky_relu=config["leaky_relu"],
    )
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
    crnn.to(device)

    preds = predict(
        crnn,
        train_dataloader,
        WbsinImageDataset.LABEL2CHAR,
        decode_method=decode_method,
        beam_size=beam_size,
    )

    show_result(images, preds)
Exemplo n.º 7
0
    net.train()
    return running_loss / i


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((RESIZE_H, RESIZE_W)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    net = CRNN()
    net = net.to(device)
    criterion = nn.CTCLoss()
    optimizer = optim.Adam(net.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=5000,
                                             gamma=0.5)

    if args.load_model:
        load_path = os.path.join(os.getcwd(), args.load_model)
        net.load_state_dict(torch.load(load_path))

    save_dir = os.path.join(os.getcwd(), args.save_dir)
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)
Exemplo n.º 8
0
best_recall = 0
best_f1 = 0

if not args.exp_name:
    exp_name = args.cnn_type + "_" + str(
        args.only_cnn) + "_" + ("" if args.only_cnn else str(args.hidden_size))
else:
    exp_name = args.exp_name
model_dir = 'model_saves/' + exp_name

best_epoch = 0
while epoch <= epochs:
    torch.manual_seed(SEED + epoch)
    torch.cuda.empty_cache()
    model.train()
    model.to(device)
    train_loss = 0
    train_loss_aux = 0
    total_train_correct_pred = 0
    total_train_correct_pred_aux = 0
    for batch_num, (audio, label) in enumerate(
            train_loader
    ):  #enumerate(tqdm(train_loader, total=int(freqs_copy.sum()/train_bs)+1)):
        audio, label = audio.double().to(device), label.to(device)
        if args.cnn_type == 'vgg' or args.cnn_type == 'inceptionv3_s' or args.cnn_type == 'inceptionv3_m':
            train_correct_pred = 0
            pred = model(audio)
            loss = criterion(pred, label)
            pred_labels = torch.argmax(pred, dim=1)
            train_correct_pred = torch.sum(pred_labels == label).item()
            total_train_correct_pred += train_correct_pred
Exemplo n.º 9
0
def main():

    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default='../../data/',
                        help='path to the data')
    parser.add_argument('--epochs',
                        '-e',
                        dest='epochs',
                        type=int,
                        help='number of train epochs',
                        default=2)
    parser.add_argument('--batch_size',
                        '-b',
                        dest='batch_size',
                        type=int,
                        help='batch size',
                        default=16)
    parser.add_argument('--load',
                        '-l',
                        dest='load',
                        type=str,
                        help='pretrained weights',
                        default=None)
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.8,
                        type=float,
                        help='train/val split')
    parser.add_argument('--augs',
                        '-a',
                        dest='augs',
                        type=float,
                        help='degree of geometric augs',
                        default=0)

    args = parser.parse_args()
    OCR_MODEL_PATH = '../pretrained/ocr.pt'

    all_marks = load_json(os.path.join(args.data_path, 'train.json'))
    test_start = int(args.val_split * len(all_marks))
    train_marks = all_marks[:test_start]
    val_marks = all_marks[test_start:]

    w, h = (320, 64)
    train_transforms = transforms.Compose([
        #Rotate(max_angle=args.augs * 7.5, p=0.5),  # 5 -> 7.5
        #Pad(max_size=args.augs / 10, p=0.1),
        Resize(size=(w, h)),
        transforms.ToTensor()
    ])
    val_transforms = transforms.Compose(
        [Resize(size=(w, h)), transforms.ToTensor()])
    alphabet = abc

    train_dataset = OCRDataset(marks=train_marks,
                               img_folder=args.data_path,
                               alphabet=alphabet,
                               transforms=train_transforms)
    val_dataset = OCRDataset(marks=val_marks,
                             img_folder=args.data_path,
                             alphabet=alphabet,
                             transforms=val_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  drop_last=True,
                                  num_workers=0,
                                  collate_fn=collate_fn_ocr,
                                  timeout=0,
                                  shuffle=True)

    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                drop_last=False,
                                num_workers=0,
                                collate_fn=collate_fn_ocr,
                                timeout=0)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    model = CRNN(alphabet)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=3e-4,
                                 amsgrad=True,
                                 weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=10,
                                                           factor=0.5,
                                                           verbose=True)
    criterion = F.ctc_loss

    try:
        train(model, criterion, optimizer, scheduler, train_dataloader,
              val_dataloader, OCR_MODEL_PATH, args.epochs, device)
    except KeyboardInterrupt:
        torch.save(model.state_dict(), OCR_MODEL_PATH + 'INTERRUPTED_')
        #logger.info('Saved interrupt')
        sys.exit(0)