def main(): if not os.path.exists(opt.output): os.makedirs(opt.output) converter = utils.strLabelConverter(opt.alphabet) collate = dataset.AlignCollate() train_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batchsize, shuffle=True, num_workers=opt.num_workers, collate_fn=collate) test_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter) test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=opt.batchsize, num_workers=opt.num_workers, collate_fn=collate) criterion = nn.CTCLoss() import models.crnn as crnn crnn = crnn.CRNN(opt.imgH, opt.nc, opt.num_classes, opt.nh) crnn.apply(utils.weights_init) if opt.pretrained != '': print('loading pretrained model from %s' % opt.pretrained) crnn.load_state_dict(torch.load(opt.pretrained), strict=False) print(crnn) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") crnn = crnn.to(device) criterion = criterion.to(device) # setup optimizer optimizer = optim.Adam(crnn.parameters(), lr=opt.lr) for epoch in range(opt.num_epochs): loss_avg = 0.0 i = 0 while i < len(train_loader): time0 = time.time() # 训练 train_iter = iter(train_loader) cost = trainBatch(crnn, train_iter, criterion, optimizer, device) # 一个批次,一个批次训练 loss_avg += cost i += 1 if i % opt.interval == 0: print('[%d/%d][%d/%d] Loss: %f Time: %f s' % (epoch, opt.num_epochs, i, len(train_loader), loss_avg, time.time() - time0)) loss_avg = 0.0 if (epoch + 1) % opt.valinterval == 0: val(crnn, test_loader, criterion, converter=converter, device=device, max_iter=100)
def main(): if not os.path.exists(cfg.model): os.makedirs(cfg.model) # create train dataset train_dataset = dataset.TextLineDataset(text_line_file=cfg.train_list, transform=None) sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.batch_size, shuffle=False, sampler=sampler, num_workers=int(cfg.num_workers), collate_fn=dataset.AlignCollate(img_height=cfg.img_height, img_width=cfg.img_width)) # create test dataset test_dataset = dataset.TextLineDataset(text_line_file=cfg.eval_list, transform=dataset.ResizeNormalize( img_width=cfg.img_width, img_height=cfg.img_height)) test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=1, num_workers=int(cfg.num_workers)) # create crnn/seq2seq/attention network encoder = crnn.Encoder(channel_size=3, hidden_size=cfg.hidden_size) # for prediction of an indefinite long sequence decoder = crnn.Decoder(hidden_size=cfg.hidden_size, output_size=num_classes, dropout_p=0.1, max_length=cfg.max_width) print(encoder) print(decoder) encoder.apply(utils.weights_init) decoder.apply(utils.weights_init) if cfg.encoder: print('loading pretrained encoder model from %s' % cfg.encoder) encoder.load_state_dict(torch.load(cfg.encoder)) if cfg.decoder: print('loading pretrained encoder model from %s' % cfg.decoder) decoder.load_state_dict(torch.load(cfg.decoder)) # create input tensor image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width) text = torch.LongTensor(cfg.batch_size) criterion = torch.nn.NLLLoss() assert torch.cuda.is_available( ), "Please run \'train.py\' script on nvidia cuda devices." encoder.cuda() decoder.cuda() image = image.cuda() text = text.cuda() criterion = criterion.cuda() # train crnn train(image, text, encoder, decoder, criterion, train_loader, teach_forcing_prob=cfg.teaching_forcing_prob) # do evaluation after training evaluate(image, text, encoder, decoder, test_loader, max_eval_iter=100)
def main(): if not os.path.exists(cfg.model): os.makedirs(cfg.model) # path to images path_to_images = 'data/sample/images_processed/' # create train dataset train_dataset = dataset.TextLineDataset(text_line_file=cfg.train_list, transform=None, target_transform=get_formula, path_to_images=path_to_images) sampler = dataset.RandomSequentialSampler(train_dataset, cfg.batch_size) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.batch_size, shuffle=False, sampler=sampler, num_workers=int(cfg.num_workers), collate_fn=dataset.AlignCollate(img_height=cfg.img_height, img_width=cfg.img_width)) # create test dataset test_dataset = dataset.TextLineDataset(text_line_file=cfg.eval_list, transform=dataset.ResizeNormalize( img_width=cfg.img_width, img_height=cfg.img_height), target_transform=get_formula, path_to_images=path_to_images) test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=1, num_workers=int(cfg.num_workers)) # create input tensor image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width) text = torch.LongTensor(cfg.batch_size) # # create crnn/seq2seq/attention network # encoder = crnn.Encoder(channel_size=3, hidden_size=cfg.hidden_size) # # # max length for the decoder # max_width = cfg.max_width # max_width = encoder.get_max_lenght_for_Decoder(image) # # # for prediction of an indefinite long sequence # decoder = crnn.Decoder(hidden_size=cfg.hidden_size, output_size=num_classes, dropout_p=0.1, max_length=max_width) # print(encoder) # print(decoder) # encoder.apply(utils.weights_init) # decoder.apply(utils.weights_init) # # # if cfg.encoder: # print('loading pretrained encoder model from %s' % cfg.encoder) # encoder.load_state_dict(torch.load(cfg.encoder)) # if cfg.decoder: # print('loading pretrained encoder model from %s' % cfg.decoder) # decoder.load_state_dict(torch.load(cfg.decoder)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') INPUT_DIM = 512 OUTPUT_DIM = num_classes ENC_EMB_DIM = 256 DEC_EMB_DIM = 256 ENC_HID_DIM = 512 DEC_HID_DIM = 512 ENC_DROPOUT = 0.5 DEC_DROPOUT = 0.5 cnn = S2S.CNN(channel_size=3) attn = S2S.Attention(ENC_HID_DIM, DEC_HID_DIM) enc = S2S.Encoder(INPUT_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT) dec = S2S.Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn) model = S2S.Seq2Seq(cnn, enc, dec, device).to(device) # model.apply(S2S.init_weights) # model.apply(utils.weights_init) print(model) print( f'The model has {S2S.count_parameters(model):,} trainable parameters\n' ) # criterion = torch.nn.NLLLoss(ignore_index=utils.PAD_TOKEN) criterion = torch.nn.CrossEntropyLoss(ignore_index=utils.PAD_TOKEN) # assert torch.cuda.is_available(), "Please run \'train.py\' script on nvidia cuda devices." if torch.cuda.is_available(): # encoder.cuda() # decoder.cuda() image = image.cuda() text = text.cuda() criterion = criterion.cuda() # # test # evaluate(image, text, model, criterion, test_loader, max_eval_iter=100) # train crnn train(image, text, model, criterion, train_loader, teach_forcing_prob=cfg.teaching_forcing_prob) # do evaluation after training evaluate(image, text, model, criterion, test_loader, max_eval_iter=100)