def main(argv=None):
    if not os.path.exists(Constants.MODEL_DIR):
        os.makedirs(Constants.MODEL_DIR)
    if not os.path.exists(Constants.TENSORBOARD_DIR):
        os.makedirs(Constants.TENSORBOARD_DIR)

    with open(Constants.CHARLIST_FILE, "r") as fp:
        charList = fp.readlines()
    for ind in range(len(charList)):
        charList[ind] = charList[ind].strip('\n')
    model = crnn(len(charList))

    starter_learning_rate = 0.1
    optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate=starter_learning_rate)
    epoch = 1
    summary_writer = tf.contrib.summary.create_file_writer(Constants.TENSORBOARD_DIR)

    with open(Constants.TRAIN_TFRECORD, 'r', encoding='utf-8') as imgf:
        Train_images = imgf.readlines()
    for key, value in json.loads(Train_images[0].strip('\r\n')).items():
        print(key, ':', value)
    Train_images.pop(0)

    with open(Constants.VAL_TFRECORDS, 'r', encoding='utf-8') as imgf:
        Val_images = imgf.readlines()
    for key, value in json.loads(Val_images[0].strip('\r\n')).items():
        print(key, ':', value)
    Val_images.pop(0)

    with summary_writer.as_default(), tf.contrib.summary.always_record_summaries():
        while True:
            print("Epoch", epoch, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
            loss = train_on_batch(model, charList, optimizer, epoch, Train_images)
            validate_on_batch(model, charList, epoch, Val_images)
            epoch += 1
def test_with_builtin_data(args, hparams):
    transform = Compose([
        resize(size=(hparams.resize_width, hparams.resize_height)),
        # gaussian_noise(gauss_mean, gauss_std),
        ToTensor(),
        Normalize(hparams.normalize_mean, hparams.normalize_std)
    ])

    if args.dataset not in ['synthesized', 'coco']:
        raise ValueError('Dataset not supported.')

    if args.dataset == 'coco':
        testset = coco_test(root_dir='cropped_COCO',
                            annotation='desc.json',
                            transform=transform)
    if args.dataset == 'synthesized':
        testset = synthetic_train(height=hparams.syn_height,
                                  width=hparams.syn_width,
                                  num_instances=hparams.syn_num_test,
                                  transform=transform)

    net = crnn(hid_dim=hparams.hidden_dim, chardict=testset.chardict)
    if cuda:
        net = net.cuda()

    ## restore from checkpoint
    ckpt = torch.load(args.checkpoint,
                      map_location=lambda storage, loc: storage)
    net.load_state_dict(ckpt['model_params'])
    net.eval()

    for idx in range(testset.__len__()):
        sample = testset.__getitem__(idx)
        img, true_label = sample['image'], sample['text']
        batch_img = torch.unsqueeze(img, 0)
        if cuda:
            batch_img = batch_img.cuda()
        pred = net(batch_img).cpu()
        pred = torch.argmax(pred, dim=2)
        pred = pred.permute(1, 0)

        pred_label = net.seq_to_text(pred[0].tolist())

        cv2.imshow("true:{}, pred:{}".format(true_label, pred_label),
                   inv_transform(img))
        key = cv2.waitKey(0)
        cv2.destroyAllWindows()
        if key == 27:
            break

    print('exiting..')
    return
Beispiel #3
0
def crnn_predict(crnn, img, transformer, decoder='bestPath', normalise=False):
    """
    Params
    ------
    crnn: torch.nn
        Neural network architecture
    transformer: torchvision.transform
        Image transformer
    decoder: string, 'bestPath' or 'beamSearch'
        CTC decoder method.
    
    Returns
    ------
    out: a list of tuples (predicted alphanumeric sequence, confidence level)
    """

    classes = string.ascii_uppercase + string.digits
    image = img.copy()

    image = transformer(image).to(device)
    image = image.view(1, *image.size())

    # forward pass (convert to numpy array)
    preds_np = crnn(image).data.cpu().numpy().squeeze()

    # move first column to last (so that we can use CTCDecoder as it is)
    preds_np = np.hstack([preds_np[:, 1:], preds_np[:, [0]]])

    preds_sm = softmax(preds_np, axis=1)
    #     preds_sm = np.divide(preds_sm, prior)

    # normalise is only suitable for best path
    if normalise == True:
        preds_sm = np.divide(preds_sm, prior)

    if decoder == 'bestPath':
        output = utils.ctcBestPath(preds_sm, classes)

    elif decoder == 'beamSearch':
        output = utils.ctcBeamSearch(preds_sm, classes, None)
    else:
        raise Exception("Invalid decoder method. \
                        Choose either 'bestPath' or 'beamSearch'")

    return output
Beispiel #4
0
print("Num of training samples: {}".format(len(train_dl)))
if args.val_annotation_paths:
    val_dl = OCRDataLoader(args.val_annotation_paths, args.val_parse_funcs,
                           args.image_width, args.table_path, args.batch_size)
    print("Num of val samples: {}".format(len(val_dl)))
else:
    val_dl = lambda: None

localtime = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
print("Start at {}".format(localtime))
os.makedirs("saved_models/{}".format(localtime))
saved_model_path = (
    "saved_models/{}/".format(localtime) +
    "{epoch:03d}_{word_accuracy:.4f}_{val_word_accuracy:.4f}.h5")

model = crnn(train_dl.num_classes)
model.compile(optimizer=keras.optimizers.Adam(0.0001),
              loss=CTCLoss(),
              metrics=[WordAccuracy()])

model.summary()

callbacks = [
    keras.callbacks.ModelCheckpoint(saved_model_path),
    keras.callbacks.TensorBoard(log_dir="logs/{}".format(localtime),
                                histogram_freq=1,
                                profile_batch=0)
]

model.fit(train_dl(),
          epochs=args.epochs,
import numpy as np
from model import crnn
from data.dataset import get_train_test_split
from tensorflow.python.keras.callbacks import TensorBoard
from time import time

import sys

model = crnn()

tensorboard = TensorBoard(log_dir="logs/fold_{}".format(sys.argv[1]))

train_input, train_label, test_input, test_label = get_train_test_split(test_fold = list(map(int,sys.argv[1].split(','))), using_CRNN = True)


model.fit(
    [train_input],
    [train_label],
    validation_data=[[test_input], [test_label]],
    epochs=100,
    batch_size=200,
    callbacks=[tensorboard]
    )
Beispiel #6
0
    if os.path.isdir(args_images):
        imgs_path = os.listdir(args_images)
        img_paths = [
            os.path.join(args_images, img_path) for img_path in imgs_path
        ]
        imgs = list(map(read_image, img_paths))
        imgs = tf.stack(imgs)
    else:
        img_paths = [args_images]
        img = read_image(args_images)
        imgs = tf.expand_dims(img, 0)
with open(args_table_path, "r") as f:
    inv_table = [char.strip() for char in f]
print('len inv_table', len(inv_table))
# model = keras.models.load_model(args.model, compile=False)
model = crnn(len(inv_table))
decoder = Decoder(inv_table)

checkpoint_dir = r'E:\tsl_file\python_project\all_models\crnn_checkpoints'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint,
                                     directory=checkpoint_dir,
                                     checkpoint_name='ckpt',
                                     max_to_keep=5)

status = checkpoint.restore(manager.latest_checkpoint)
print('status', status)

# y_pred = model.predict(imgs)
y_pred = model(imgs)
def train(hparams, args):

    trainset, testset = get_dataset(hparams, args)
    trainloader = DataLoader(trainset,
                             batch_size=hparams.batch_size,
                             shuffle=True,
                             num_workers=1,
                             collate_fn=train_data_collate)
    testloader = DataLoader(testset,
                            batch_size=hparams.batch_size,
                            shuffle=True,
                            num_workers=1,
                            collate_fn=test_data_collate)

    # for debug use
    # while True:
    #     sample = trainset.__getitem__(np.random.choice(range(trainset.__len__())))
    #     cv2.imshow(sample['text'], inv_transform(sample['image']))
    #     key = cv2.waitKey(0)
    #     cv2.destroyAllWindows()
    #     if key == 27:
    #         break

    net = crnn(hid_dim=hparams.hidden_dim, chardict=trainset.chardict)
    net.train()
    if cuda:
        net = net.cuda()

    optimizer = optim.Adam(net.parameters(),
                           lr=hparams.base_lr,
                           weight_decay=0.0001)
    loss_function = CTCLoss()

    ## restore from checkpoint
    if args.restore_from_checkpoint != '':
        ckpt = torch.load(args.restore_from_checkpoint,
                          map_location=lambda storage, loc: storage)
        net.load_state_dict(ckpt['model_params'])
        optimizer.load_state_dict(ckpt['optim_params'])
        epoch = ckpt['epoch']
    else:
        epoch = 0

    while epoch < hparams.NUM_epochs:
        iterator = tqdm(trainloader)
        mean_loss = []
        for iter, batch in enumerate(iterator):
            optimizer.zero_grad()
            imgs = Variable(batch['img'])
            labels = Variable(batch['seq'])
            label_lens = Variable(batch['seq_len'].int())
            if cuda:
                imgs = imgs.cuda()

            preds = net(imgs).cpu()

            pred_lens = Variable(
                torch.Tensor([preds.size(0)] * preds.size(1)).int())
            loss = loss_function(preds, labels, pred_lens,
                                 label_lens) / hparams.batch_size
            loss.backward()
            nn.utils.clip_grad_norm(net.parameters(), 10.0)

            mean_loss.append(loss.data[0])
            optimizer.step()

            ## set description
            description = 'epoch:{}, iteration:{}, current loss:{}, mean loss:{}'.format(
                epoch, iter, loss.data[0], np.mean(mean_loss))
            iterator.set_description(description)

        epoch += 1
        if epoch % hparams.snapshot_interval == 0 or epoch == hparams.NUM_epochs:
            ckpt_path = join(ckpt_dir, 'crnn_ckpt_epoch{}'.format(epoch))
            torch.save(
                {
                    'epoch': epoch,
                    'loss': loss.data[0],
                    'model_params': net.state_dict(),
                    'optim_params': optimizer.state_dict()
                }, ckpt_path)

        if epoch % hparams.eval_interval == 0 or epoch == hparams.NUM_epochs:
            net.eval()
            count = 0
            avg_editdist = 0
            for test_batch in testloader:
                #test_batch = next(iter(testloader))
                test_batch_size = test_batch['img'].size(0)
                with torch.no_grad():
                    imgs = test_batch['img']
                    if cuda:
                        imgs = imgs.cuda()

                    preds = net(imgs).cpu()
                preds = torch.argmax(preds, dim=2)
                preds = preds.permute(1, 0)

                for i in range(test_batch_size):
                    pred_label = net.seq_to_text(preds[i].tolist())
                    true_label = test_batch['text'][i]
                    avg_editdist += distance(true_label, pred_label)
                    if count == 0:
                        print('true: {}, pred: {}'.format(
                            true_label, pred_label))
                count += 1
            avg_editdist = float(avg_editdist) / testset.__len__()
            print('epoch: {}, average edit distance: {}'.format(
                epoch, avg_editdist))
            net.train()