Exemple #1
0
def main():

    if not os.path.exists(opt.output):
        os.makedirs(opt.output)

    converter = utils.strLabelConverter(opt.alphabet)

    collate = dataset.AlignCollate()
    train_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batchsize, shuffle=True,
                                               num_workers=opt.num_workers, collate_fn=collate)
    test_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter)
    test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=opt.batchsize,
                                              num_workers=opt.num_workers, collate_fn=collate)

    criterion = nn.CTCLoss()

    import models.crnn as crnn

    crnn = crnn.CRNN(opt.imgH, opt.nc, opt.num_classes, opt.nh)
    crnn.apply(utils.weights_init)
    if opt.pretrained != '':
        print('loading pretrained model from %s' % opt.pretrained)
        crnn.load_state_dict(torch.load(opt.pretrained), strict=False)
    print(crnn)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    crnn = crnn.to(device)
    criterion = criterion.to(device)


    # setup optimizer
    optimizer = optim.Adam(crnn.parameters(), lr=opt.lr)

    for epoch in range(opt.num_epochs):

        loss_avg = 0.0
        i = 0
        while i < len(train_loader):

            time0 = time.time()
            # 训练
            train_iter = iter(train_loader)

            cost = trainBatch(crnn, train_iter, criterion, optimizer, device) # 一个批次,一个批次训练
            loss_avg += cost
            i += 1

            if i % opt.interval == 0:
                print('[%d/%d][%d/%d] Loss: %f Time: %f s' %
                      (epoch, opt.num_epochs, i, len(train_loader), loss_avg,
                       time.time() - time0))
                loss_avg = 0.0



        if (epoch + 1) % opt.valinterval == 0:
            val(crnn, test_loader, criterion, converter=converter, device=device, max_iter=100)
Exemple #2
0
def main(cfg):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    alphabet = utils.get_alphabet('./data/char_std_5990.txt')
    num_classes = len(alphabet)
    converter = convert.StringLabelConverter(alphabet)
    transformer = dataset.ResizeNormalize(32, 280)

    model = crnn_ctc.CRNN(
        in_channels=3,
        hidden_size=256,
        output_size=num_classes,
    )
    net = model.to(device)

    if os.path.exists(cfg.model_path):
        print('Loading model from {0}'.format(cfg.model_path))
        model.load_state_dict(torch.load(cfg.model_path))
        print('Done!')

    if os.path.exists(cfg.image_path):
        image = Image.open(cfg.image_path).convert('RGB')
        image = transformer(image)
        image = torch.unsqueeze(image, 0)
        image = image.to(device)

    net.eval()
    output = net(image)
    preds = output.max(2)[1]
    # print(preds)
    preds_len = torch.IntTensor([preds.size(0)] * int(preds.size(1)))
    results = converter.decode(preds, preds_len)
    print('Result: {0}'.format(results))
Exemple #3
0
if __name__ == '__main__':

    # load alphabet
    with open('./data/char_std_5990.txt', encoding="UTF-8") as f:
        data = f.readlines()
        alphabet = [x.rstrip() for x in data]
        alphabet = ''.join(alphabet)

    # define convert bwteen string and label index
    converter = utils.ConvertBetweenStringAndLabel(alphabet)

    # len(alphabet) + SOS_TOKEN + EOS_TOKEN
    num_classes = len(alphabet) + 2

    transformer = dataset.ResizeNormalize(img_width=args.img_width,
                                          img_height=args.img_height)

    # load detect net
    detect_net = CRAFT()  # initialize

    print('Loading weights from checkpoint (' + args.trained_model + ')')
    if args.cuda:
        detect_net.load_state_dict(
            copyStateDict(torch.load(args.trained_model)))
    else:
        detect_net.load_state_dict(
            copyStateDict(torch.load(args.trained_model, map_location='cpu')))

    if args.cuda:
        detect_net = detect_net.cuda()
        detect_net = torch.nn.DataParallel(detect_net)
def main(cfg):
    if not os.path.exists(cfg.model_path):
        os.makedirs(cfg.model_path)

    torch.manual_seed(cfg.seed)
    if cfg.gpu_id is not None and torch.cuda.is_available():
        # os.environ['CUDA_VISIBLE_DEVICE'] = cfg.gpu_id
        device = torch.device('cuda:{0}'.format(cfg.gpu_id))
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
    else:
        device = torch.device('cpu')

    alphabet = utils.get_alphabet(cfg.alpha_path)
    # converter = convert.StringLabelConverter(alphabet)
    num_classes = len(alphabet)

    # prepare train data
    train_dataset = dataset.TextLineDataset(
        cfg.image_root,
        cfg.train_list,
        transform=dataset.ResizeNormalize(
            img_width=cfg.img_width,
            img_height=cfg.img_height,
        ),
    )
    # sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        # sampler=sampler,
        shuffle=True,
    )

    # prepare test data
    valid_dataset = dataset.TextLineDataset(
        cfg.image_root,
        cfg.valid_list,
        transform=dataset.ResizeNormalize(
            img_width=cfg.img_width,
            img_height=cfg.img_height,
        ),
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=True,
    )

    model = crnn_ctc.CRNN(
        in_channels=3,
        hidden_size=cfg.hidden_size,
        output_size=num_classes,
    )
    if not cfg.pretrained:
        model.apply(utils.weights_init)

    # num_gpus = torch.cuda.device_count()
    # if num_gpus > 1:
    #     model = nn.DataParallel(model)
    model = model.to(device)

    train(model, train_loader, valid_loader, device, cfg)
def main():
    if not os.path.exists(cfg.model):
        os.makedirs(cfg.model)

    # create train dataset
    train_dataset = dataset.TextLineDataset(text_line_file=cfg.train_list,
                                            transform=None)
    sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=int(cfg.num_workers),
        collate_fn=dataset.AlignCollate(img_height=cfg.img_height,
                                        img_width=cfg.img_width))

    # create test dataset
    test_dataset = dataset.TextLineDataset(text_line_file=cfg.eval_list,
                                           transform=dataset.ResizeNormalize(
                                               img_width=cfg.img_width,
                                               img_height=cfg.img_height))
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              shuffle=False,
                                              batch_size=1,
                                              num_workers=int(cfg.num_workers))

    # create crnn/seq2seq/attention network
    encoder = crnn.Encoder(channel_size=3, hidden_size=cfg.hidden_size)
    # for prediction of an indefinite long sequence
    decoder = crnn.Decoder(hidden_size=cfg.hidden_size,
                           output_size=num_classes,
                           dropout_p=0.1,
                           max_length=cfg.max_width)
    print(encoder)
    print(decoder)
    encoder.apply(utils.weights_init)
    decoder.apply(utils.weights_init)
    if cfg.encoder:
        print('loading pretrained encoder model from %s' % cfg.encoder)
        encoder.load_state_dict(torch.load(cfg.encoder))
    if cfg.decoder:
        print('loading pretrained encoder model from %s' % cfg.decoder)
        decoder.load_state_dict(torch.load(cfg.decoder))

    # create input tensor
    image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width)
    text = torch.LongTensor(cfg.batch_size)

    criterion = torch.nn.NLLLoss()

    assert torch.cuda.is_available(
    ), "Please run \'train.py\' script on nvidia cuda devices."
    encoder.cuda()
    decoder.cuda()
    image = image.cuda()
    text = text.cuda()
    criterion = criterion.cuda()

    # train crnn
    train(image,
          text,
          encoder,
          decoder,
          criterion,
          train_loader,
          teach_forcing_prob=cfg.teaching_forcing_prob)

    # do evaluation after training
    evaluate(image, text, encoder, decoder, test_loader, max_eval_iter=100)
Exemple #6
0
def main():
    if not os.path.exists(cfg.model):
        os.makedirs(cfg.model)

    # path to images
    path_to_images = 'data/sample/images_processed/'

    # create train dataset
    train_dataset = dataset.TextLineDataset(text_line_file=cfg.train_list,
                                            transform=None,
                                            target_transform=get_formula,
                                            path_to_images=path_to_images)
    sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=int(cfg.num_workers),
        collate_fn=dataset.AlignCollate(img_height=cfg.img_height,
                                        img_width=cfg.img_width))

    # create test dataset
    test_dataset = dataset.TextLineDataset(text_line_file=cfg.eval_list,
                                           transform=dataset.ResizeNormalize(
                                               img_width=cfg.img_width,
                                               img_height=cfg.img_height),
                                           target_transform=get_formula,
                                           path_to_images=path_to_images)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              shuffle=False,
                                              batch_size=1,
                                              num_workers=int(cfg.num_workers))

    # create input tensor
    image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width)
    text = torch.LongTensor(cfg.batch_size)

    # # create crnn/seq2seq/attention network
    # encoder = crnn.Encoder(channel_size=3, hidden_size=cfg.hidden_size)
    #
    # # max length for the decoder
    # max_width = cfg.max_width
    # max_width = encoder.get_max_lenght_for_Decoder(image)
    #
    # # for prediction of an indefinite long sequence
    # decoder = crnn.Decoder(hidden_size=cfg.hidden_size, output_size=num_classes, dropout_p=0.1, max_length=max_width)
    # print(encoder)
    # print(decoder)
    # encoder.apply(utils.weights_init)
    # decoder.apply(utils.weights_init)
    #
    #
    # if cfg.encoder:
    #     print('loading pretrained encoder model from %s' % cfg.encoder)
    #     encoder.load_state_dict(torch.load(cfg.encoder))
    # if cfg.decoder:
    #     print('loading pretrained encoder model from %s' % cfg.decoder)
    #     decoder.load_state_dict(torch.load(cfg.decoder))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    INPUT_DIM = 512
    OUTPUT_DIM = num_classes
    ENC_EMB_DIM = 256
    DEC_EMB_DIM = 256
    ENC_HID_DIM = 512
    DEC_HID_DIM = 512
    ENC_DROPOUT = 0.5
    DEC_DROPOUT = 0.5

    cnn = S2S.CNN(channel_size=3)
    attn = S2S.Attention(ENC_HID_DIM, DEC_HID_DIM)
    enc = S2S.Encoder(INPUT_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
    dec = S2S.Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM,
                      DEC_DROPOUT, attn)

    model = S2S.Seq2Seq(cnn, enc, dec, device).to(device)

    # model.apply(S2S.init_weights)
    # model.apply(utils.weights_init)
    print(model)

    print(
        f'The model has {S2S.count_parameters(model):,} trainable parameters\n'
    )

    # criterion = torch.nn.NLLLoss(ignore_index=utils.PAD_TOKEN)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=utils.PAD_TOKEN)

    # assert torch.cuda.is_available(), "Please run \'train.py\' script on nvidia cuda devices."
    if torch.cuda.is_available():
        # encoder.cuda()
        # decoder.cuda()
        image = image.cuda()
        text = text.cuda()
        criterion = criterion.cuda()

    # #     test
    # evaluate(image, text, model, criterion, test_loader, max_eval_iter=100)

    # train crnn
    train(image,
          text,
          model,
          criterion,
          train_loader,
          teach_forcing_prob=cfg.teaching_forcing_prob)

    # do evaluation after training
    evaluate(image, text, model, criterion, test_loader, max_eval_iter=100)