示例#1
0
def main():

    if not os.path.exists(opt.output):
        os.makedirs(opt.output)

    converter = utils.strLabelConverter(opt.alphabet)

    collate = dataset.AlignCollate()
    train_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batchsize, shuffle=True,
                                               num_workers=opt.num_workers, collate_fn=collate)
    test_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter)
    test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=opt.batchsize,
                                              num_workers=opt.num_workers, collate_fn=collate)

    criterion = nn.CTCLoss()

    import models.crnn as crnn

    crnn = crnn.CRNN(opt.imgH, opt.nc, opt.num_classes, opt.nh)
    crnn.apply(utils.weights_init)
    if opt.pretrained != '':
        print('loading pretrained model from %s' % opt.pretrained)
        crnn.load_state_dict(torch.load(opt.pretrained), strict=False)
    print(crnn)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    crnn = crnn.to(device)
    criterion = criterion.to(device)


    # setup optimizer
    optimizer = optim.Adam(crnn.parameters(), lr=opt.lr)

    for epoch in range(opt.num_epochs):

        loss_avg = 0.0
        i = 0
        while i < len(train_loader):

            time0 = time.time()
            # 训练
            train_iter = iter(train_loader)

            cost = trainBatch(crnn, train_iter, criterion, optimizer, device) # 一个批次,一个批次训练
            loss_avg += cost
            i += 1

            if i % opt.interval == 0:
                print('[%d/%d][%d/%d] Loss: %f Time: %f s' %
                      (epoch, opt.num_epochs, i, len(train_loader), loss_avg,
                       time.time() - time0))
                loss_avg = 0.0



        if (epoch + 1) % opt.valinterval == 0:
            val(crnn, test_loader, criterion, converter=converter, device=device, max_iter=100)
def main():
    if not os.path.exists(cfg.model):
        os.makedirs(cfg.model)

    # create train dataset
    train_dataset = dataset.TextLineDataset(text_line_file=cfg.train_list,
                                            transform=None)
    sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=int(cfg.num_workers),
        collate_fn=dataset.AlignCollate(img_height=cfg.img_height,
                                        img_width=cfg.img_width))

    # create test dataset
    test_dataset = dataset.TextLineDataset(text_line_file=cfg.eval_list,
                                           transform=dataset.ResizeNormalize(
                                               img_width=cfg.img_width,
                                               img_height=cfg.img_height))
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              shuffle=False,
                                              batch_size=1,
                                              num_workers=int(cfg.num_workers))

    # create crnn/seq2seq/attention network
    encoder = crnn.Encoder(channel_size=3, hidden_size=cfg.hidden_size)
    # for prediction of an indefinite long sequence
    decoder = crnn.Decoder(hidden_size=cfg.hidden_size,
                           output_size=num_classes,
                           dropout_p=0.1,
                           max_length=cfg.max_width)
    print(encoder)
    print(decoder)
    encoder.apply(utils.weights_init)
    decoder.apply(utils.weights_init)
    if cfg.encoder:
        print('loading pretrained encoder model from %s' % cfg.encoder)
        encoder.load_state_dict(torch.load(cfg.encoder))
    if cfg.decoder:
        print('loading pretrained encoder model from %s' % cfg.decoder)
        decoder.load_state_dict(torch.load(cfg.decoder))

    # create input tensor
    image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width)
    text = torch.LongTensor(cfg.batch_size)

    criterion = torch.nn.NLLLoss()

    assert torch.cuda.is_available(
    ), "Please run \'train.py\' script on nvidia cuda devices."
    encoder.cuda()
    decoder.cuda()
    image = image.cuda()
    text = text.cuda()
    criterion = criterion.cuda()

    # train crnn
    train(image,
          text,
          encoder,
          decoder,
          criterion,
          train_loader,
          teach_forcing_prob=cfg.teaching_forcing_prob)

    # do evaluation after training
    evaluate(image, text, encoder, decoder, test_loader, max_eval_iter=100)
示例#3
0
def main():
    if not os.path.exists(cfg.model):
        os.makedirs(cfg.model)

    # path to images
    path_to_images = 'data/sample/images_processed/'

    # create train dataset
    train_dataset = dataset.TextLineDataset(text_line_file=cfg.train_list,
                                            transform=None,
                                            target_transform=get_formula,
                                            path_to_images=path_to_images)
    sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=int(cfg.num_workers),
        collate_fn=dataset.AlignCollate(img_height=cfg.img_height,
                                        img_width=cfg.img_width))

    # create test dataset
    test_dataset = dataset.TextLineDataset(text_line_file=cfg.eval_list,
                                           transform=dataset.ResizeNormalize(
                                               img_width=cfg.img_width,
                                               img_height=cfg.img_height),
                                           target_transform=get_formula,
                                           path_to_images=path_to_images)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              shuffle=False,
                                              batch_size=1,
                                              num_workers=int(cfg.num_workers))

    # create input tensor
    image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width)
    text = torch.LongTensor(cfg.batch_size)

    # # create crnn/seq2seq/attention network
    # encoder = crnn.Encoder(channel_size=3, hidden_size=cfg.hidden_size)
    #
    # # max length for the decoder
    # max_width = cfg.max_width
    # max_width = encoder.get_max_lenght_for_Decoder(image)
    #
    # # for prediction of an indefinite long sequence
    # decoder = crnn.Decoder(hidden_size=cfg.hidden_size, output_size=num_classes, dropout_p=0.1, max_length=max_width)
    # print(encoder)
    # print(decoder)
    # encoder.apply(utils.weights_init)
    # decoder.apply(utils.weights_init)
    #
    #
    # if cfg.encoder:
    #     print('loading pretrained encoder model from %s' % cfg.encoder)
    #     encoder.load_state_dict(torch.load(cfg.encoder))
    # if cfg.decoder:
    #     print('loading pretrained encoder model from %s' % cfg.decoder)
    #     decoder.load_state_dict(torch.load(cfg.decoder))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    INPUT_DIM = 512
    OUTPUT_DIM = num_classes
    ENC_EMB_DIM = 256
    DEC_EMB_DIM = 256
    ENC_HID_DIM = 512
    DEC_HID_DIM = 512
    ENC_DROPOUT = 0.5
    DEC_DROPOUT = 0.5

    cnn = S2S.CNN(channel_size=3)
    attn = S2S.Attention(ENC_HID_DIM, DEC_HID_DIM)
    enc = S2S.Encoder(INPUT_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
    dec = S2S.Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM,
                      DEC_DROPOUT, attn)

    model = S2S.Seq2Seq(cnn, enc, dec, device).to(device)

    # model.apply(S2S.init_weights)
    # model.apply(utils.weights_init)
    print(model)

    print(
        f'The model has {S2S.count_parameters(model):,} trainable parameters\n'
    )

    # criterion = torch.nn.NLLLoss(ignore_index=utils.PAD_TOKEN)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=utils.PAD_TOKEN)

    # assert torch.cuda.is_available(), "Please run \'train.py\' script on nvidia cuda devices."
    if torch.cuda.is_available():
        # encoder.cuda()
        # decoder.cuda()
        image = image.cuda()
        text = text.cuda()
        criterion = criterion.cuda()

    # #     test
    # evaluate(image, text, model, criterion, test_loader, max_eval_iter=100)

    # train crnn
    train(image,
          text,
          model,
          criterion,
          train_loader,
          teach_forcing_prob=cfg.teaching_forcing_prob)

    # do evaluation after training
    evaluate(image, text, model, criterion, test_loader, max_eval_iter=100)