def test(model, net): with torch.no_grad(): dataset = MyDataset(opt.video_path, opt.anno_path, opt.val_list, opt.vid_padding, opt.txt_padding, 'test') print('num_test_data:{}'.format(len(dataset.data))) model.eval() loader = dataset2dataloader(dataset, shuffle=False) loss_list = [] wer = [] cer = [] crit = nn.CTCLoss() tic = time.time() for (i_iter, input) in enumerate(loader): vid = input.get('vid').cuda() txt = input.get('txt').cuda() vid_len = input.get('vid_len').cuda() txt_len = input.get('txt_len').cuda() y = net(vid) loss = crit( y.transpose(0, 1).log_softmax(-1), txt, vid_len.view(-1), txt_len.view(-1)).detach().cpu().numpy() loss_list.append(loss) pred_txt = ctc_decode(y) truth_txt = [ MyDataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0)) ] wer.extend(MyDataset.wer(pred_txt, truth_txt)) cer.extend(MyDataset.cer(pred_txt, truth_txt)) if (i_iter % opt.display == 0): v = 1.0 * (time.time() - tic) / (i_iter + 1) eta = v * (len(loader) - i_iter) / 3600.0 print(''.join(101 * '-')) print('{:<50}|{:>50}'.format('predict', 'truth')) print(''.join(101 * '-')) for (predict, truth) in list(zip(pred_txt, truth_txt))[:10]: print('{:<50}|{:>50}'.format(predict, truth)) print(''.join(101 * '-')) print('test_iter={},eta={},wer={},cer={}'.format( i_iter, eta, np.array(wer).mean(), np.array(cer).mean())) print(''.join(101 * '-')) return (np.array(loss_list).mean(), np.array(wer).mean(), np.array(cer).mean())
def train(model, net): dataset = MyDataset(opt.video_path, opt.anno_path, opt.train_list, opt.vid_padding, opt.txt_padding, 'train') loader = dataset2dataloader(dataset) optimizer = optim.Adam(model.parameters(), lr = opt.base_lr, weight_decay = 0., amsgrad = True) print('num_train_data:{}'.format(len(dataset.data))) crit = nn.CTCLoss() tic = time.time() train_wer = [] for epoch in range(opt.max_epoch): for (i_iter, input) in enumerate(loader): model.train() vid = input.get('vid').cuda() txt = input.get('txt').cuda() vid_len = input.get('vid_len').cuda() txt_len = input.get('txt_len').cuda() optimizer.zero_grad() y = net(vid) loss = crit(y.transpose(0, 1).log_softmax(-1), txt, vid_len.view(-1), txt_len.view(-1)) loss.backward() if(opt.is_optimize): optimizer.step() tot_iter = i_iter + epoch*len(loader) pred_txt = ctc_decode(y) truth_txt = [MyDataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0))] train_wer.extend(MyDataset.wer(pred_txt, truth_txt)) if(tot_iter % opt.display == 0): v = 1.0*(time.time()-tic)/(tot_iter+1) eta = (len(loader)-i_iter)*v/3600.0 writer.add_scalar('train loss', loss, tot_iter) writer.add_scalar('train wer', np.array(train_wer).mean(), tot_iter) print(''.join(101*'-')) print('{:<50}|{:>50}'.format('predict', 'truth')) print(''.join(101*'-')) for (predict, truth) in list(zip(pred_txt, truth_txt))[:3]: print('{:<50}|{:>50}'.format(predict, truth)) print(''.join(101*'-')) print('epoch={},tot_iter={},eta={},loss={},train_wer={}'.format(epoch, tot_iter, eta, loss, np.array(train_wer).mean())) print(''.join(101*'-')) if(tot_iter % opt.test_step == 0): (loss, wer, cer) = test(model, net) print('i_iter={},lr={},loss={},wer={},cer={}' .format(tot_iter,show_lr(optimizer),loss,wer,cer)) writer.add_scalar('val loss', loss, tot_iter) writer.add_scalar('wer', wer, tot_iter) writer.add_scalar('cer', cer, tot_iter) savename = '{}_loss_{}_wer_{}_cer_{}.pt'.format(opt.save_prefix, loss, wer, cer) (path, name) = os.path.split(savename) if(not os.path.exists(path)): os.makedirs(path) torch.save(model.state_dict(), savename) if(not opt.is_optimize): exit()
def eval(model, net): with torch.no_grad(): dataset = MyDataset(opt.video_path, opt.anno_path, opt.val_list, opt.vid_padding, opt.txt_padding, 'test') print('num_test_data:{}'.format(len(dataset.data))) model.eval() loader = dataset2dataloader(dataset, shuffle=False) wer = [] cer = [] wla = [] total_sentences = 0.0 correct_sentences = 0.0 for (i_iter, input) in enumerate(loader): vid = input.get('vid').cuda() txt = input.get('txt').cuda() y = net(vid) pred_txt = ctc_decode(y) truth_txt = [ MyDataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0)) ] wer.extend(MyDataset.wer(pred_txt, truth_txt)) cer.extend(MyDataset.cer(pred_txt, truth_txt)) wla.extend(MyDataset.wla(pred_txt, truth_txt)) batch_correct_sentences, batch_total_sentences = MyDataset.sentences( pred_txt, truth_txt) correct_sentences = correct_sentences + batch_correct_sentences total_sentences = total_sentences + batch_total_sentences sla = correct_sentences / total_sentences if (i_iter % opt.display == 0): print(''.join(101 * '-')) print('{:<50}|{:>50}'.format('predict', 'truth')) print(''.join(101 * '-')) for (predict, truth) in list(zip(pred_txt, truth_txt))[:10]: print('{:<50}|{:>50}'.format(predict, truth)) print(''.join(101 * '-')) print('test_iter={}, wer={}, cer={}, wla={} , sla={}'.format( i_iter, np.array(wer).mean(), np.array(cer).mean(), np.array(wla).mean(), sla)) print(''.join(101 * '-')) writer.add_scalar('wer', np.array(wer).mean(), i_iter) writer.add_scalar('cer', np.array(cer).mean(), i_iter) writer.add_scalar('wla', np.array(wla).mean(), i_iter) writer.add_scalar('bla', sla, i_iter) return np.array(wer).mean(), np.array(cer).mean(), np.array( wla).mean(), sla