def main():
    conf_file = "conf/train.yml"
    with open(conf_file, 'r') as f:
        args = edict(yaml.load(f))

    train_root = args.train_root
    test_root = args.test_root
    batch_size = args.batch_size
    max_len = args.max_len
    img_h = args.img_h
    img_w = args.img_w
    n_hidden = args.n_hidden
    n_iter = args.n_iter
    lr = args.lr
    cuda = args.cuda
    val_interval = args.val_interval
    save_interval = args.save_interval
    model_dir = args.model_dir
    debug_level = args.debug_level
    experiment = args.experiment
    n_channel = args.n_channel
    n_class = args.n_class
    beta = args.beta

    image = torch.FloatTensor(batch_size, n_channel, img_h, img_h)
    text = torch.IntTensor(batch_size * max_len)
    length = torch.IntTensor(batch_size)

    logging.getLogger().setLevel(debug_level)
    '''
        50 - critical
        40 - error
        30 - warining
        20 - info
        10 - debug
    '''
    crnn = CRNN(img_h, n_channel, n_class, n_hidden).cuda()
    crnn.apply(weights_init)

    criterion = CTCLoss().cuda()

    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
    # optimizer = optim.Adam(crnn.parameters(), lr=lr,
    #                    betas=(beta, 0.999))

    trainset = train_set(train_root, batch_size, img_h, img_w, n_class)
    valset = train_set(test_root, batch_size, img_h, img_w, n_class)

    cur_iter = 0
    for ITER in range(n_iter):
        for train_img, train_label, train_lengths, batch_label \
                in iter(trainset):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            if train_img is None:
                break
            cur_iter += 1
            loadData(image, train_img)
            loadData(text, train_label)
            loadData(length, train_lengths)
            preds = crnn(train_img.cuda())
            # preds = F.softmax(preds, dim=2)
            # print(preds.shape)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))
            # print(batch_label, text, length, len(text), len(length), length.sum(),
            #     preds.shape, preds_size.shape)
            cost = criterion(preds, text, preds_size, length)\
                    / batch_size
            crnn.zero_grad()
            cost.backward()
            optimizer.step()
            print("training-iter {} cost {}".format(
                ITER,
                cost.cpu().detach().numpy()[0]))
            if cur_iter % val_interval == 0:
                val(crnn, valset, criterion, n_class)
            if cur_iter % save_interval == 0:
                model_file = os.path.join(model_dir,
                                          "crnn_iter{}.pth".format(ITER))
                print("saving in file {}".format(model_file))
                with open(model_file, 'wb') as f:
                    torch.save(crnn, f)
    def test_train(self):
        '''
        parameters of train
        '''
        # test_root = "data/ocr_dataset_val"
        # train_root = "data/ocr_dataset"
        train_root = "data/ocr_dataset_train_400_10/"
        test_root = "data/ocr_dataset_train_50_10_val/"
        batch_size = 20
        max_len = 15
        img_h, img_w = 32, 150
        n_hidden = 512
        n_iter = 400
        lr = 0.00005
        cuda = True
        val_interval = 250
        save_interval = 1000
        model_dir = "models"
        debug_level = 20
        experiment = "experiment"
        n_channel = 3
        n_class = 11
        beta = 0.5

        image = torch.FloatTensor(batch_size, n_channel, img_h, img_h)
        text = torch.IntTensor(batch_size * max_len)
        length = torch.IntTensor(batch_size)

        logging.getLogger().setLevel(debug_level)
        '''
            50 - critical
            40 - error
            30 - warining
            20 - info
            10 - debug
        '''
        crnn = CRNN(img_h, n_channel, n_class, n_hidden).cuda()
        crnn.apply(weights_init)

        criterion = CTCLoss().cuda()

        optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
        # optimizer = optim.Adam(crnn.parameters(), lr=lr,
        #                    betas=(beta, 0.999))

        trainset = train_set(train_root, batch_size, img_h, img_w, n_class)
        valset = train_set(test_root, batch_size, img_h, img_w, n_class)

        cur_iter = 0
        for ITER in range(n_iter):
            for train_img, train_label, train_lengths, batch_label in iter(
                    trainset):
                for p in crnn.parameters():
                    p.requires_grad = True
                crnn.train()

                if train_img is None:
                    break
                cur_iter += 1
                loadData(image, train_img)
                loadData(text, train_label)
                loadData(length, train_lengths)
                preds = crnn(train_img.cuda())
                # preds = F.softmax(preds, dim=2)
                # print(preds.shape)
                preds_size = Variable(
                    torch.IntTensor([preds.size(0)] * batch_size))
                # print(batch_label, text, length, len(text), len(length), length.sum(),
                #     preds.shape, preds_size.shape)
                cost = criterion(preds, text, preds_size, length) / batch_size
                crnn.zero_grad()
                cost.backward()
                optimizer.step()
                print("training-iter {} cost {}".format(
                    ITER,
                    cost.cpu().detach().numpy()[0]))
                if cur_iter % val_interval == 0:
                    val(crnn, valset, criterion, n_class)
                if cur_iter % save_interval == 0:
                    model_file = os.path.join(model_dir,
                                              "crnn_iter{}.pth".format(ITER))
                    print("saving in file {}".format(model_file))
                    with open(model_file, 'wb') as f:
                        torch.save(crnn, f)
Exemple #3
0
            for j, k in zip(predicted_label, label_tmp):
                print(k.lower())
                print(j)

    accuarcy = n_correct / float(10 * option.batch_size)
    print('loss: %.4f accuracy: %.4f' % (total_loss / 10, accuarcy))
    crnn.train()

    return accuarcy


for i in range(option.nepoch):
    for j, (input, label) in enumerate(trainset_dataloader):
        if j == len(trainset_dataloader) - 1:
            continue
        crnn.zero_grad()
        label, length = converter.encode(label)
        input = input.cuda()
        predicted_label = crnn(input)
        predicted_length = [predicted_label.size(0)] * option.batch_size
        label = torch.tensor(label, dtype=torch.long)
        label = label.cuda()
        predicted_length = torch.tensor(predicted_length, dtype=torch.long)
        length = torch.tensor(length, dtype=torch.long)
        loss = loss_function(predicted_label, label, predicted_length, length)
        loss.backward()
        optimizer.step()

        total_loss += loss
        if j % print_every == 0:
            print('[%d / %d] [%d / %d] loss: %.4f' %
        if accuracy > best_acc:
            best_acc == accuracy
            # SAVE MODEL
            print("SAVING MODEL")
            torch.save(model.state_dict(), "trained_models/best_model.pt")

    test_accuracy.append(accuracy)

    running_loss = 0.0
    for i, (data, true_labels) in enumerate(training_dataloader):

        data = data.type(torch.FloatTensor)
        true_labels = true_labels.type(torch.LongTensor)

        # set all gradients to zero
        model.zero_grad()

        # Here we get the data from all layers, and the corresponding timesteps
        output_conv, output_lstm1, output_lstm2, predictions = model.out(data)
        loss = loss_function(predictions, true_labels)

        # Optimization part
        loss.backward()

        # Gradient Clipping to avoid exploding gradients
        #nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        running_loss += loss.item()
        total_loss += loss.item()
        if i % 20 == 0:  # print every 20 mini-batches