def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path, orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir, lexicon_path, seq_proj, backend, snapshot, input_height, base_lr, elastic_alpha, elastic_sigma, step_size, max_iter, batch_size, output_dir, test_iter, show_iter, test_init, use_gpu, use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio, vat_epsilon, vat_ip, vat_xi, vat_sign, do_comp, comp_ratio, do_remove_augs, aug_to_remove, do_beam_search, dropout_conv, dropout_rnn, dropout_output, do_ema, do_gray, do_test_vat, do_test_entropy, do_test_vat_cnn, do_test_vat_rnn, do_test_rand, ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio, rnn_hidden_size, do_test_pseudo, test_pseudo_ratio, test_pseudo_thresh, do_lr_step, do_test_ensemble, test_ensemble_ratio, test_ensemble_thresh): num_nets = 4 train_data_path = os.path.join(base_data_dir, train_data_path) train_base_dir = os.path.join(base_data_dir, train_base_dir) synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path) synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir) orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path) orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir) lexicon_path = os.path.join(base_data_dir, lexicon_path) all_parameters = locals() cuda = use_gpu #print(train_base_dir) if output_dir is not None: os.makedirs(output_dir, exist_ok=True) tb_writer = TbSummary(output_dir) output_dir = os.path.join(output_dir, 'model') os.makedirs(output_dir, exist_ok=True) with open(lexicon_path, 'rb') as f: lexicon = pkl.load(f) #print(sorted(lexicon.items(), key=operator.itemgetter(1))) with open(os.path.join(output_dir, 'params.txt'), 'w') as f: f.writelines(str(all_parameters)) print(all_parameters) print('new vat') sin_magnitude = 4 rotate_max_angle = 2 train_fonts = [ 'Qomolangma-Betsu', 'Shangshung Sgoba-KhraChen', 'Shangshung Sgoba-KhraChung', 'Qomolangma-Drutsa' ] all_args = locals() print('doing all transforms :)') rand_trans = [ ElasticAndSine(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma, sin_magnitude=sin_magnitude), Rotation(angle=rotate_max_angle, fill_value=255), ColorGradGausNoise() ] if do_gray: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), ToGray(), Normalize() ] else: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), Normalize() ] transform_random = Compose(rand_trans) if do_gray: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), ToGray(), Normalize()]) else: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), Normalize()]) if use_no_font_repeat_data: print('create dataset') train_data = TextDatasetRandomFont(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) print('finished creating dataset') else: print('train data path:\n{}'.format(train_data_path)) print('train_base_dir:\n{}'.format(train_base_dir)) train_data = TextDataset(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) synth_eval_data = TextDataset(data_path=synth_eval_data_path, lexicon=lexicon, base_path=synth_eval_base_dir, transform=transform_random, fonts=train_fonts) orig_eval_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) if do_test_ensemble: orig_vat_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) #else: # train_data = TestDataset(transform=transform, abc=abc).set_mode("train") # synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") # orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") seq_proj = [int(x) for x in seq_proj.split('x')] nets = [] optimizers = [] lr_schedulers = [] for neti in range(num_nets): nets.append( load_model(lexicon=train_data.get_lexicon(), seq_proj=seq_proj, backend=backend, snapshot=snapshot, cuda=cuda, do_beam_search=do_beam_search, dropout_conv=dropout_conv, dropout_rnn=dropout_rnn, dropout_output=dropout_output, do_ema=do_ema, ada_after_rnn=ada_after_rnn, ada_before_rnn=ada_before_rnn, rnn_hidden_size=rnn_hidden_size, gpu=neti)) optimizers.append( optim.Adam(nets[neti].parameters(), lr=base_lr, weight_decay=0.0001)) lr_schedulers.append( StepLR(optimizers[neti], step_size=step_size, max_iter=max_iter)) loss_function = CTCLoss() synth_avg_ed_best = float("inf") orig_avg_ed_best = float("inf") epoch_count = 0 if do_test_ensemble: collate_vat = lambda x: text_collate(x, do_mask=True) vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) loss_domain = torch.nn.NLLLoss() while True: collate = lambda x: text_collate( x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn)) data_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate) if do_comp: data_loader_comp = DataLoader(train_data_comp, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_comp) iter_comp = iter(data_loader_comp) loss_mean_ctc = [] loss_mean_total = [] loss_mean_test_ensemble = [] num_labels_used_total = 0 iterator = tqdm(data_loader) nll_loss = torch.nn.NLLLoss() iter_count = 0 for iter_num, sample in enumerate(iterator): total_iter = (epoch_count * len(data_loader)) + iter_num if ((total_iter > 1) and total_iter % test_iter == 0) or (test_init and total_iter == 0): # epoch_count != 0 and print("Test phase") for net in nets: net = net.eval() if do_ema: net.start_test() synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test( nets, synth_eval_data, synth_eval_data.get_lexicon(), cuda, batch_size=batch_size, visualize=False, tb_writer=tb_writer, n_iter=total_iter, initial_title='val_synth', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=False) orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test( nets, orig_eval_data, orig_eval_data.get_lexicon(), cuda, batch_size=batch_size, visualize=False, tb_writer=tb_writer, n_iter=total_iter, initial_title='test_orig', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=do_beam_search) for net in nets: net = net.train() #save periodic if output_dir is not None and total_iter // 30000: periodic_save = os.path.join(output_dir, 'periodic_save') os.makedirs(periodic_save, exist_ok=True) old_save = glob.glob(os.path.join(periodic_save, '*')) for neti, net in enumerate(nets): torch.save( net.state_dict(), os.path.join( output_dir, "crnn_{}_".format(neti) + backend + "_" + str(total_iter))) if orig_avg_no_stop_ed < orig_avg_ed_best: orig_avg_ed_best = orig_avg_no_stop_ed if output_dir is not None: for neti, net in enumerate(nets): torch.save( net.state_dict(), os.path.join( output_dir, "crnn_{}_".format(neti) + backend + "_iter_{}".format(total_iter))) if synth_avg_no_stop_ed < synth_avg_ed_best: synth_avg_ed_best = synth_avg_no_stop_ed if do_ema: for net in nets: net.end_test() print( "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(synth_avg_ed_best, synth_avg_ed, synth_avg_no_stop_ed, synth_acc)) print( "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed, orig_acc)) tb_writer.get_writer().add_scalars( 'data/test', { 'synth_ed_total': synth_avg_ed, 'synth_ed_no_stop': synth_avg_no_stop_ed, 'synth_avg_loss': synth_avg_loss, 'orig_ed_total': orig_avg_ed, 'orig_ed_no_stop': orig_avg_no_stop_ed, 'orig_avg_loss': orig_avg_loss }, total_iter) if len(loss_mean_ctc) > 0: train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)} train_dict = { **train_dict, **{ 'mean_test_ensemble_loss': np.mean(loss_mean_test_ensemble) } } train_dict = { **train_dict, **{ 'num_labels_used': num_labels_used_total } } num_labels_used_total = 0 print(train_dict) tb_writer.get_writer().add_scalars('data/train', train_dict, total_iter) ''' # for multi-gpu support if sample["img"].size(0) % len(gpu.split(',')) != 0: continue ''' for optimizer in optimizers: optimizer.zero_grad() imgs = Variable(sample["img"]) #print("images sizes are:") #print(sample["img"].shape) if do_vat or ada_after_rnn or ada_before_rnn: mask = sample['mask'] labels_flatten = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) #print("image sequence length is:") #print(sample["im_seq_len"]) #print("label sequence length is:") #print(sample["seq_len"].view(1,-1)) img_seq_lens = sample["im_seq_len"] if do_test_ensemble: if cur_vat >= vat_len: vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) vat_batch = next(vat_iter) cur_vat += 1 vat_mask = vat_batch['mask'] vat_imgs = Variable(vat_batch["img"]) vat_img_seq_lens = vat_batch["im_seq_len"] all_net_classes = [] all_net_preds = [] def run_net_get_classes(neti_net_pair, cur_vat_imgs, cur_vat_mask, cur_vat_img_seq_lens, cuda): neti, net = neti_net_pair if cuda: cur_vat_imgs = cur_vat_imgs.cuda(neti) cur_vat_mask = cur_vat_mask.cuda(neti) vat_pred = net.vat_forward(cur_vat_imgs, cur_vat_img_seq_lens) vat_pred = vat_pred * cur_vat_mask vat_pred = F.softmax(vat_pred, dim=2).view(-1, vat_pred.size()[-1]) all_net_preds.append(vat_pred) np_vat_preds = vat_pred.cpu().data.numpy() classes_by_index = np.argmax(np_vat_preds, axis=1) return classes_by_index for neti, net in enumerate(nets): if cuda: vat_imgs = vat_imgs.cuda(neti) vat_mask = vat_mask.cuda(neti) vat_pred = net.vat_forward(vat_imgs, vat_img_seq_lens) vat_pred = vat_pred * vat_mask vat_pred = F.softmax(vat_pred, dim=2).view(-1, vat_pred.size()[-1]) all_net_preds.append(vat_pred) np_vat_preds = vat_pred.cpu().data.numpy() classes_by_index = np.argmax(np_vat_preds, axis=1) all_net_classes.append(classes_by_index) all_net_classes = np.stack(all_net_classes) all_net_classes, all_nets_count = stats.mode(all_net_classes, axis=0) all_net_classes = all_net_classes.reshape(-1) all_nets_count = all_nets_count.reshape(-1) ens_indices = np.argwhere( all_nets_count > test_ensemble_thresh) ens_indices = ens_indices.reshape(-1) ens_classes = all_net_classes[ all_nets_count > test_ensemble_thresh] net_ens_losses = [] num_labels_used = len(ens_indices) for neti, net in enumerate(nets): indices = Variable( torch.from_numpy(ens_indices).cuda(neti)) labels = Variable(torch.from_numpy(ens_classes).cuda(neti)) net_preds_to_ens = all_net_preds[neti][indices] loss = nll_loss(net_preds_to_ens, labels) net_ens_losses.append(loss.cpu()) nets_total_losses = [] nets_ctc_losses = [] loss_is_inf = False for neti, net in enumerate(nets): if cuda: imgs = imgs.cuda(neti) preds = net(imgs, img_seq_lens) loss_ctc = loss_function( preds, labels_flatten, Variable(torch.IntTensor(np.array(img_seq_lens))), label_lens) / batch_size if loss_ctc.data[0] in [float("inf"), -float("inf")]: print("warnning: loss should not be inf.") loss_is_inf = True break total_loss = loss_ctc if do_test_ensemble: total_loss = total_loss + test_ensemble_ratio * net_ens_losses[ neti] net_ens_losses[neti] = net_ens_losses[neti].data[0] total_loss.backward() nets_total_losses.append(total_loss.data[0]) nets_ctc_losses.append(loss_ctc.data[0]) nn.utils.clip_grad_norm(net.parameters(), 10.0) if loss_is_inf: continue if -400 < loss_ctc.data[0] < 400: loss_mean_ctc.append(np.mean(nets_ctc_losses)) if -400 < total_loss.data[0] < 400: loss_mean_total.append(np.mean(nets_total_losses)) status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format( epoch_count, lr_schedulers[0].last_iter, lr_schedulers[0].get_lr(), np.mean(nets_total_losses), np.mean(nets_ctc_losses)) if do_test_ensemble: ens_loss = np.mean(net_ens_losses) if ens_loss != 0: loss_mean_test_ensemble.append(ens_loss) status += "; loss_ens: {0:.3f}".format(ens_loss) status += "; num_ens_used {}".format(num_labels_used) else: loss_mean_test_ensemble.append(0) status += "; loss_ens: {}".format(0) iterator.set_description(status) for optimizer in optimizers: optimizer.step() if do_lr_step: for lr_scheduler in lr_schedulers: lr_scheduler.step() iter_count += 1 if output_dir is not None: for neti, net in enumerate(nets): torch.save( net.state_dict(), os.path.join(output_dir, "crnn_{}_".format(neti) + backend + "_last")) epoch_count += 1 return
def test_attn(net, data, abc, cuda, visualize, batch_size=1, tb_writer=None, n_iter=0, initial_title="", is_trian=True, output_path=None): collate = lambda x: text_collate(x, do_mask=True) net.eval() data_loader = DataLoader(data, batch_size=1, num_workers=2, shuffle=False, collate_fn=collate) stop_characters = ['-', '.', '༎', '༑', '།', '་'] count = 0 tp = 0 avg_ed = 0 avg_no_stop_ed = 0 avg_loss = 0 min_ed = 1000 iterator = tqdm(data_loader) all_pred_text = all_label_text = all_im_pathes = [] test_letter_statistics = Statistics() with torch.no_grad(): for i, sample in enumerate(iterator): if is_trian and (i > 1000): break imgs = Variable(sample["img"]) mask = sample["mask"] padded_labels = sample["padded_seq"] if cuda: imgs = imgs.cuda() mask = mask.cuda() padded_labels = padded_labels.cuda() img_seq_lens = sample["im_seq_len"] # Forward propagation decoder_outputs, decoder_hidden, other = net( imgs, img_seq_lens, mask, None, teacher_forcing_ratio=0) # Get loss loss = NLLLoss() loss.reset() zero_labels = torch.zeros_like(padded_labels[:, 1]) max_label_size = padded_labels.size(1) for step, step_output in enumerate(decoder_outputs): batch_size = padded_labels.size(0) if (step + 1) < max_label_size: loss.eval_batch( step_output.contiguous().view(batch_size, -1), padded_labels[:, step + 1]) else: loss.eval_batch( step_output.contiguous().view(batch_size, -1), zero_labels) # Backward propagation total_loss = loss.get_loss().data[0] avg_loss += total_loss labels_flatten = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) preds_text = net.predict(other) padded_labels = (sample["padded_seq"].numpy()).tolist() lens = sample["seq_len"].numpy().tolist() label_text = net.padded_seq_to_txt(padded_labels, lens) if output_path is not None: all_pred_text = all_pred_text + [ pd + '\n' for pd in preds_text ] all_label_text = all_label_text + [ lb + '\n' for lb in label_text ] all_im_pathes.append( sample["im_path"] + '\n') #[imp +'\n' for imp in sample["im_path"]] if i == 0: if tb_writer is not None: tb_writer.show_images( sample["img"], label_text=[lb + '\n' for lb in label_text], pred_text=[pd + '\n' for pd in preds_text], n_iter=n_iter, initial_title=initial_title) pos = 0 key = '' for i in range(len(label_text)): cur_out_no_stops = ''.join(c for c in label_text[i] if not c in stop_characters) cur_gts_no_stops = ''.join(c for c in preds_text[i] if not c in stop_characters) cur_ed = editdistance.eval(preds_text[i], label_text[i]) / max( len(preds_text[i]), len(label_text[i])) errors, matches, bp = my_edit_distance_backpointer( cur_out_no_stops, cur_gts_no_stops) test_letter_statistics.add_data(bp) my_no_stop_ed = errors / max(len(cur_out_no_stops), len(cur_gts_no_stops)) cur_no_stop_ed = editdistance.eval( cur_out_no_stops, cur_gts_no_stops) / max( len(cur_out_no_stops), len(cur_gts_no_stops)) if my_no_stop_ed != cur_no_stop_ed: print('old ed: {} , vs. new ed: {}\n'.format( my_no_stop_ed, cur_no_stop_ed)) avg_no_stop_ed += cur_no_stop_ed avg_ed += cur_ed if cur_ed < min_ed: min_ed = cur_ed count += 1 if visualize: status = "pred: {}; gt: {}".format(preds_text[i], label_text[i]) iterator.set_description(status) img = imgs[i].permute(1, 2, 0).cpu().data.numpy().astype( np.uint8) cv2.imshow("img", img) key = chr(cv2.waitKey() & 255) if key == 'q': break if key == 'q': break if not visualize: iterator.set_description( "acc: {0:.4f}; avg_ed: {0:.4f}".format( float(tp) / float(count), float(avg_ed) / float(count))) with open( output_path + '_{}_{}_statistics.pkl'.format(initial_title, n_iter), 'wb') as sf: pkl.dump(test_letter_statistics.total_actions_hists, sf) if output_path is not None: os.makedirs(output_path, exist_ok=True) print('writing output') with open( output_path + '_{}_{}_pred.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines(all_pred_text) with open( output_path + '_{}_{}_label.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines(all_label_text) with open(output_path + '_{}_{}_im.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines(all_im_pathes) stop_characters = ['-', '.', '༎', '༑', '།', '་'] all_pred_text = [ ''.join(c for c in line if not c in stop_characters) for line in all_pred_text ] with open( output_path + '_{}_{}_pred_no_stopchars.txt'.format(initial_title, n_iter), 'w') as rf: rf.writelines(all_pred_text) all_label_text = [ ''.join(c for c in line if not c in stop_characters) for line in all_label_text ] with open( output_path + '_{}_{}_label_no_stopchars.txt'.format(initial_title, n_iter), 'w') as rf: rf.writelines(all_label_text) acc = float(tp) / float(count) avg_ed = float(avg_ed) / float(count) avg_no_stop_ed = float(avg_no_stop_ed) / float(count) avg_loss = float(avg_loss) / float(count) return acc, avg_ed, avg_no_stop_ed, avg_loss
def test(net, data, abc, cuda, visualize, batch_size=1, tb_writer=None, n_iter=0, initial_title="", loss_function=None, is_trian=True, output_path=None, do_beam_search=False, do_results=False, word_lexicon=None): collate = lambda x: text_collate(x, do_mask=False) data_loader = DataLoader(data, batch_size=1, num_workers=2, shuffle=False, collate_fn=collate) stop_characters = ['-', '.', '༎', '༑', '།', '་'] garbage = '-' count = 0 tp = 0 avg_ed = 0 avg_no_stop_ed = 0 avg_accuracy = 0 avg_loss = 0 min_ed = 1000 iterator = tqdm(data_loader) all_pred_text = all_label_text = all_im_pathes = [] test_letter_statistics = Statistics() im_by_error = {} for i, sample in enumerate(iterator): if is_trian and (i > 500): break imgs = Variable(sample["img"]) if cuda: imgs = imgs.cuda() img_seq_lens = sample["im_seq_len"] out, orig_seq = net(imgs, img_seq_lens, decode=True, do_beam_search=do_beam_search) if loss_function is not None: labels_flatten = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) loss = loss_function( orig_seq, labels_flatten, Variable(torch.IntTensor(np.array(img_seq_lens))), label_lens) / batch_size avg_loss += loss.data[0] gt = (sample["seq"].numpy()).tolist() lens = sample["seq_len"].numpy().tolist() labels_flatten = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) if output_path is not None: preds_text = net.decode(orig_seq, data.get_lexicon()) all_pred_text = all_pred_text + [ ''.join(c for c in pd if c != garbage) + '\n' for pd in preds_text ] label_text = net.decode_flatten(labels_flatten, label_lens, data.get_lexicon()) all_label_text = all_label_text + [lb + '\n' for lb in label_text] all_im_pathes.append( sample["im_path"] + '\n') #[imp +'\n' for imp in sample["im_path"]] if i == 0: if tb_writer is not None: print_data_visuals(net, tb_writer, data.get_lexicon(), sample["img"], labels_flatten, label_lens, orig_seq, n_iter, initial_title) pos = 0 key = '' for i in range(len(out)): gts = ''.join(abc[c] for c in gt[pos:pos + lens[i]]) pos += lens[i] if gts == out[i]: tp += 1 else: cur_out = ''.join(c for c in out[i] if c != garbage) cur_gts = ''.join(c for c in gts if c != garbage) cur_out_no_stops = ''.join(c for c in out[i] if not c in stop_characters) cur_gts_no_stops = ''.join(c for c in gts if not c in stop_characters) cur_ed = editdistance.eval(cur_out, cur_gts) / len(cur_gts) if word_lexicon is not None: closest_word = get_close_matches(cur_out, word_lexicon, n=1, cutoff=0.2) else: closest_word = cur_out if len(closest_word) > 0 and closest_word[0] == cur_gts: avg_accuracy += 1 errors, matches, bp = my_edit_distance_backpointer( cur_out_no_stops, cur_gts_no_stops) test_letter_statistics.add_data(bp) #my_no_stop_ed = errors / max(len(cur_out_no_stops), len(cur_gts_no_stops)) #cur_no_stop_ed = editdistance.eval(cur_out_no_stops, cur_gts_no_stops) / max(len(cur_out_no_stops), len(cur_gts_no_stops)) if do_results: im_by_error[sample["im_path"]] = cur_ed my_no_stop_ed = errors / len(cur_gts_no_stops) cur_no_stop_ed = editdistance.eval( cur_out_no_stops, cur_gts_no_stops) / len(cur_gts_no_stops) if my_no_stop_ed != cur_no_stop_ed: print('old ed: {} , vs. new ed: {}\n'.format( my_no_stop_ed, cur_no_stop_ed)) avg_no_stop_ed += cur_no_stop_ed avg_ed += cur_ed if cur_ed < min_ed: min_ed = cur_ed count += 1 if visualize: status = "pred: {}; gt: {}".format(out[i], gts) iterator.set_description(status) img = imgs[i].permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) cv2.imshow("img", img) key = chr(cv2.waitKey() & 255) if key == 'q': break #if not visualize: # iterator.set_description("acc: {0:.4f}; avg_ed: {0:.4f}".format( # float(tp) / float(count), float(avg_ed) / float(count))) #with open(output_path + '_{}_{}_statistics.pkl'.format(initial_title,n_iter), 'wb') as sf: # pkl.dump(test_letter_statistics.total_actions_hists, sf) if do_results and output_path is not None: print('printing results! :)') sorted_im_by_error = sorted(im_by_error.items(), key=operator.itemgetter(1)) sorted_im = [key for (key, value) in sorted_im_by_error] all_im_pathes_no_new_line = [ im.replace('\n', '') for im in all_im_pathes ] printed_res_best = "" printed_res_worst = "" for im in sorted_im[:20]: im_id = all_im_pathes_no_new_line.index(im) pred = all_pred_text[im_id] label = all_label_text[im_id] printed_res_best += im + '\n' + label + pred for im in list(reversed(sorted_im))[:20]: im_id = all_im_pathes_no_new_line.index(im) pred = all_pred_text[im_id] label = all_label_text[im_id] printed_res_worst += im + '\n' + label + pred with open( output_path + '_{}_{}_sorted_images_by_errors.txt'.format( initial_title, n_iter), 'w') as fp: fp.writelines([ key + ',' + str(value) + '\n' for (key, value) in sorted_im_by_error ]) with open( output_path + '_{}_{}_res_on_best.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines([printed_res_best]) with open( output_path + '_{}_{}_res_on_worst.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines([printed_res_worst]) os.makedirs(output_path, exist_ok=True) with open( output_path + '_{}_{}_pred.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines(all_pred_text) with open( output_path + '_{}_{}_label.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines(all_label_text) with open(output_path + '_{}_{}_im.txt'.format(initial_title, n_iter), 'w') as fp: fp.writelines(all_im_pathes) stop_characters = ['-', '.', '༎', '༑', '།', '་'] all_pred_text = [ ''.join(c for c in line if not c in stop_characters) for line in all_pred_text ] with open( output_path + '_{}_{}_pred_no_stopchars.txt'.format(initial_title, n_iter), 'w') as rf: rf.writelines(all_pred_text) all_label_text = [ ''.join(c for c in line if not c in stop_characters) for line in all_label_text ] with open( output_path + '_{}_{}_label_no_stopchars.txt'.format(initial_title, n_iter), 'w') as rf: rf.writelines(all_label_text) acc = float(avg_accuracy) / float(count) avg_ed = float(avg_ed) / float(count) avg_no_stop_ed = float(avg_no_stop_ed) / float(count) if loss_function is not None: avg_loss = float(avg_loss) / float(count) return acc, avg_ed, avg_no_stop_ed, avg_loss return acc, avg_ed, avg_no_stop_ed
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path, orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir, lexicon_path, seq_proj, backend, snapshot, input_height, base_lr, elastic_alpha, elastic_sigma, step_size, max_iter, batch_size, output_dir, test_iter, show_iter, test_init, use_gpu, use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio, vat_epsilon, vat_ip, vat_xi, vat_sign, do_remove_augs, aug_to_remove, do_beam_search, dropout_conv, dropout_rnn, dropout_output, do_ema, do_gray, do_test_vat, do_test_entropy, do_test_vat_cnn, do_test_vat_rnn, ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio, rnn_hidden_size, do_lr_step, dataset_name): if not do_lr_step and not do_ada_lr: raise NotImplementedError( 'learning rate should be either step or ada.') train_data_path = os.path.join(base_data_dir, train_data_path) train_base_dir = os.path.join(base_data_dir, train_base_dir) synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path) synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir) orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path) orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir) lexicon_path = os.path.join(base_data_dir, lexicon_path) all_parameters = locals() cuda = use_gpu #print(train_base_dir) if output_dir is not None: os.makedirs(output_dir, exist_ok=True) tb_writer = TbSummary(output_dir) output_dir = os.path.join(output_dir, 'model') os.makedirs(output_dir, exist_ok=True) with open(lexicon_path, 'rb') as f: lexicon = pkl.load(f) #print(sorted(lexicon.items(), key=operator.itemgetter(1))) with open(os.path.join(output_dir, 'params.txt'), 'w') as f: f.writelines(str(all_parameters)) print(all_parameters) print('new vat') sin_magnitude = 4 rotate_max_angle = 2 dataset_info = SynthDataInfo(None, None, None, dataset_name.lower()) train_fonts = dataset_info.font_names all_args = locals() allowed_removals = [ 'elastic', 'sine', 'sine_rotate', 'rotation', 'color_aug', 'color_gaus', 'color_sine' ] if do_remove_augs and aug_to_remove not in allowed_removals: raise Exception('augmentation removal value is not allowed.') if do_remove_augs: rand_trans = [] if aug_to_remove == 'elastic': print('doing sine transform :)') rand_trans.append(OnlySine(sin_magnitude=sin_magnitude)) elif aug_to_remove in ['sine', 'sine_rotate']: print('doing elastic transform :)') rand_trans.append( OnlyElastic(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma)) if aug_to_remove not in ['elastic', 'sine', 'sine_rotate']: print('doing elastic transform :)') print('doing sine transform :)') rand_trans.append( ElasticAndSine(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma, sin_magnitude=sin_magnitude)) if aug_to_remove not in ['rotation', 'sine_rotate']: print('doing rotation transform :)') rand_trans.append(Rotation(angle=rotate_max_angle, fill_value=255)) if aug_to_remove not in ['color_aug', 'color_gaus', 'color_sine']: print('doing color_aug transform :)') rand_trans.append(ColorGradGausNoise()) elif aug_to_remove == 'color_gaus': print('doing color_sine transform :)') rand_trans.append(ColorGrad()) elif aug_to_remove == 'color_sine': print('doing color_gaus transform :)') rand_trans.append(ColorGausNoise()) else: print('doing all transforms :)') rand_trans = [ ElasticAndSine(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma, sin_magnitude=sin_magnitude), Rotation(angle=rotate_max_angle, fill_value=255), ColorGradGausNoise() ] if do_gray: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), ToGray(), Normalize() ] else: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), Normalize() ] transform_random = Compose(rand_trans) if do_gray: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), ToGray(), Normalize()]) else: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), Normalize()]) if use_no_font_repeat_data: print('creating dataset') train_data = TextDatasetRandomFont(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) print('finished creating dataset') else: print('train data path:\n{}'.format(train_data_path)) print('train_base_dir:\n{}'.format(train_base_dir)) train_data = TextDataset(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) synth_eval_data = TextDataset(data_path=synth_eval_data_path, lexicon=lexicon, base_path=synth_eval_base_dir, transform=transform_random, fonts=train_fonts) orig_eval_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) if do_test_vat or do_test_vat_rnn or do_test_vat_cnn: orig_vat_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) if ada_after_rnn or ada_before_rnn: orig_ada_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) #else: # train_data = TestDataset(transform=transform, abc=abc).set_mode("train") # synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") # orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") seq_proj = [int(x) for x in seq_proj.split('x')] net = load_model(lexicon=train_data.get_lexicon(), seq_proj=seq_proj, backend=backend, snapshot=snapshot, cuda=cuda, do_beam_search=do_beam_search, dropout_conv=dropout_conv, dropout_rnn=dropout_rnn, dropout_output=dropout_output, do_ema=do_ema, ada_after_rnn=ada_after_rnn, ada_before_rnn=ada_before_rnn, rnn_hidden_size=rnn_hidden_size) optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001) if do_ada_lr: print('using ada lr') lr_scheduler = DannLR(optimizer, max_iter=max_iter) elif do_lr_step: print('using step lr') lr_scheduler = StepLR(optimizer, step_size=step_size, max_iter=max_iter) loss_function = CTCLoss() synth_avg_ed_best = float("inf") orig_avg_ed_best = float("inf") epoch_count = 0 if do_test_vat or do_test_vat_rnn or do_test_vat_cnn: collate_vat = lambda x: text_collate(x, do_mask=True) vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) if ada_after_rnn or ada_before_rnn: collate_ada = lambda x: text_collate(x, do_mask=True) ada_load = DataLoader(orig_ada_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_ada) ada_len = len(ada_load) cur_ada = 0 ada_iter = iter(ada_load) loss_domain = torch.nn.NLLLoss() while True: collate = lambda x: text_collate( x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn)) data_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate) loss_mean_ctc = [] loss_mean_vat = [] loss_mean_at = [] loss_mean_comp = [] loss_mean_total = [] loss_mean_test_vat = [] loss_mean_test_pseudo = [] loss_mean_test_rand = [] loss_mean_ada_rnn_s = [] loss_mean_ada_rnn_t = [] loss_mean_ada_cnn_s = [] loss_mean_ada_cnn_t = [] iterator = tqdm(data_loader) iter_count = 0 for iter_num, sample in enumerate(iterator): total_iter = (epoch_count * len(data_loader)) + iter_num if ((total_iter > 1) and total_iter % test_iter == 0) or (test_init and total_iter == 0): # epoch_count != 0 and print("Test phase") net = net.eval() if do_ema: net.start_test() synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test( net, synth_eval_data, synth_eval_data.get_lexicon(), cuda, visualize=False, dataset_info=dataset_info, batch_size=batch_size, tb_writer=tb_writer, n_iter=total_iter, initial_title='val_synth', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=False) orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test( net, orig_eval_data, orig_eval_data.get_lexicon(), cuda, visualize=False, dataset_info=dataset_info, batch_size=batch_size, tb_writer=tb_writer, n_iter=total_iter, initial_title='test_orig', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=do_beam_search) net = net.train() #save periodic if output_dir is not None and total_iter // 30000: periodic_save = os.path.join(output_dir, 'periodic_save') os.makedirs(periodic_save, exist_ok=True) old_save = glob.glob(os.path.join(periodic_save, '*')) torch.save( net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_" + str(total_iter))) if orig_avg_no_stop_ed < orig_avg_ed_best: orig_avg_ed_best = orig_avg_no_stop_ed if output_dir is not None: torch.save( net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_best")) if synth_avg_no_stop_ed < synth_avg_ed_best: synth_avg_ed_best = synth_avg_no_stop_ed if do_ema: net.end_test() print( "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(synth_avg_ed_best, synth_avg_ed, synth_avg_no_stop_ed, synth_acc)) print( "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed, orig_acc)) tb_writer.get_writer().add_scalars( 'data/test', { 'synth_ed_total': synth_avg_ed, 'synth_ed_no_stop': synth_avg_no_stop_ed, 'synth_avg_loss': synth_avg_loss, 'orig_ed_total': orig_avg_ed, 'orig_ed_no_stop': orig_avg_no_stop_ed, 'orig_avg_loss': orig_avg_loss }, total_iter) if len(loss_mean_ctc) > 0: train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)} if do_vat: train_dict = { **train_dict, **{ 'mean_vat_loss': np.mean(loss_mean_vat) } } if do_at: train_dict = { **train_dict, **{ 'mean_at_loss': np.mean(loss_mean_at) } } if do_test_vat: train_dict = { **train_dict, **{ 'mean_test_vat_loss': np.mean(loss_mean_test_vat) } } if do_test_vat_rnn and do_test_vat_cnn: train_dict = { **train_dict, **{ 'mean_test_vat_crnn_loss': np.mean(loss_mean_test_vat) } } elif do_test_vat_rnn: train_dict = { **train_dict, **{ 'mean_test_vat_rnn_loss': np.mean(loss_mean_test_vat) } } elif do_test_vat_cnn: train_dict = { **train_dict, **{ 'mean_test_vat_cnn_loss': np.mean(loss_mean_test_vat) } } if ada_after_rnn: train_dict = { **train_dict, **{ 'mean_ada_rnn_s_loss': np.mean(loss_mean_ada_rnn_s), 'mean_ada_rnn_t_loss': np.mean(loss_mean_ada_rnn_t) } } if ada_before_rnn: train_dict = { **train_dict, **{ 'mean_ada_cnn_s_loss': np.mean(loss_mean_ada_cnn_s), 'mean_ada_cnn_t_loss': np.mean(loss_mean_ada_cnn_t) } } print(train_dict) tb_writer.get_writer().add_scalars('data/train', train_dict, total_iter) ''' # for multi-gpu support if sample["img"].size(0) % len(gpu.split(',')) != 0: continue ''' optimizer.zero_grad() imgs = Variable(sample["img"]) #print("images sizes are:") #print(sample["img"].shape) if do_vat or ada_after_rnn or ada_before_rnn: mask = sample['mask'] labels_flatten = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) #print("image sequence length is:") #print(sample["im_seq_len"]) #print("label sequence length is:") #print(sample["seq_len"].view(1,-1)) img_seq_lens = sample["im_seq_len"] if cuda: imgs = imgs.cuda() if do_vat or ada_after_rnn or ada_before_rnn: mask = mask.cuda() if do_ada_lr: ada_p = float(iter_count) / max_iter lr_scheduler.update(ada_p) if ada_before_rnn or ada_after_rnn: if not do_ada_lr: ada_p = float(iter_count) / max_iter ada_alpha = 2. / (1. + np.exp(-10. * ada_p)) - 1 if cur_ada >= ada_len: ada_load = DataLoader(orig_ada_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_ada) ada_len = len(ada_load) cur_ada = 0 ada_iter = iter(ada_load) ada_batch = next(ada_iter) cur_ada += 1 ada_imgs = Variable(ada_batch["img"]) ada_img_seq_lens = ada_batch["im_seq_len"] ada_mask = ada_batch['mask'].byte() if cuda: ada_imgs = ada_imgs.cuda() _, ada_cnn, ada_rnn = net(ada_imgs, ada_img_seq_lens, ada_alpha=ada_alpha, mask=ada_mask) if ada_before_rnn: ada_num_features = ada_cnn.size(0) else: ada_num_features = ada_rnn.size(0) domain_label = torch.zeros(ada_num_features) domain_label = domain_label.long() if cuda: domain_label = domain_label.cuda() domain_label = Variable(domain_label) if ada_before_rnn: err_ada_cnn_t = loss_domain(ada_cnn, domain_label) if ada_after_rnn: err_ada_rnn_t = loss_domain(ada_rnn, domain_label) if do_test_vat and do_at: # test part! if cur_vat >= vat_len: vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) test_vat_batch = next(vat_iter) cur_vat += 1 test_vat_mask = test_vat_batch['mask'] test_vat_imgs = Variable(test_vat_batch["img"]) test_vat_img_seq_lens = test_vat_batch["im_seq_len"] if cuda: test_vat_imgs = test_vat_imgs.cuda() test_vat_mask = test_vat_mask.cuda() # train part at_test_vat_loss = LabeledAtAndUnlabeledTestVatLoss( xi=vat_xi, eps=vat_epsilon, ip=vat_ip) at_loss, test_vat_loss = at_test_vat_loss( model=net, train_x=imgs, train_labels_flatten=labels_flatten, train_img_seq_lens=img_seq_lens, train_label_lens=label_lens, batch_size=batch_size, test_x=test_vat_imgs, test_seq_len=test_vat_img_seq_lens, test_mask=test_vat_mask) elif do_test_vat or do_test_vat_rnn or do_test_vat_cnn: if cur_vat >= vat_len: vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) vat_batch = next(vat_iter) cur_vat += 1 vat_mask = vat_batch['mask'] vat_imgs = Variable(vat_batch["img"]) vat_img_seq_lens = vat_batch["im_seq_len"] if cuda: vat_imgs = vat_imgs.cuda() vat_mask = vat_mask.cuda() if do_test_vat: if do_test_vat_rnn or do_test_vat_cnn: raise "can only do one of do_test_vat | (do_test_vat_rnn, do_test_vat_cnn)" if vat_sign == True: test_vat_loss = VATLossSign( do_test_entropy=do_test_entropy, xi=vat_xi, eps=vat_epsilon, ip=vat_ip) else: test_vat_loss = VATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) elif do_test_vat_rnn and do_test_vat_cnn: test_vat_loss = VATonRnnCnnSign(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) elif do_test_vat_rnn: test_vat_loss = VATonRnnSign(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) elif do_test_vat_cnn: test_vat_loss = VATonCnnSign(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) if do_test_vat_cnn and do_test_vat_rnn: test_vat_loss, cnn_lds, rnn_lds = test_vat_loss( net, vat_imgs, vat_img_seq_lens, vat_mask) elif do_test_vat: test_vat_loss = test_vat_loss(net, vat_imgs, vat_img_seq_lens, vat_mask) elif do_vat: vat_loss = VATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) vat_loss = vat_loss(net, imgs, img_seq_lens, mask) elif do_at: at_loss = LabeledATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) at_loss = at_loss(net, imgs, labels_flatten, img_seq_lens, label_lens, batch_size) if ada_after_rnn or ada_before_rnn: preds, ada_cnn, ada_rnn = net(imgs, img_seq_lens, ada_alpha=ada_alpha, mask=mask) if ada_before_rnn: ada_num_features = ada_cnn.size(0) else: ada_num_features = ada_rnn.size(0) domain_label = torch.ones(ada_num_features) domain_label = domain_label.long() if cuda: domain_label = domain_label.cuda() domain_label = Variable(domain_label) if ada_before_rnn: err_ada_cnn_s = loss_domain(ada_cnn, domain_label) if ada_after_rnn: err_ada_rnn_s = loss_domain(ada_rnn, domain_label) else: preds = net(imgs, img_seq_lens) ''' if output_dir is not None: if (show_iter is not None and iter_num != 0 and iter_num % show_iter == 0): print_data_visuals(net, tb_writer, train_data.get_lexicon(), sample["img"], labels_flatten, label_lens, preds, ((epoch_count * len(data_loader)) + iter_num)) ''' loss_ctc = loss_function( preds, labels_flatten, Variable(torch.IntTensor(np.array(img_seq_lens))), label_lens) / batch_size if loss_ctc.data[0] in [float("inf"), -float("inf")]: print("warnning: loss should not be inf.") continue total_loss = loss_ctc if do_vat: #mask = sample['mask'] #if cuda: # mask = mask.cuda() #vat_loss = virtual_adversarial_loss(net, imgs, img_seq_lens, mask, is_training=True, do_entropy=False, epsilon=vat_epsilon, num_power_iterations=1, # xi=1e-6, average_loss=True) total_loss = total_loss + vat_ratio * vat_loss.cpu() if do_test_vat or do_test_vat_rnn or do_test_vat_cnn: total_loss = total_loss + test_vat_ratio * test_vat_loss.cpu() if ada_before_rnn: total_loss = total_loss + ada_ratio * err_ada_cnn_s.cpu( ) + ada_ratio * err_ada_cnn_t.cpu() if ada_after_rnn: total_loss = total_loss + ada_ratio * err_ada_rnn_s.cpu( ) + ada_ratio * err_ada_rnn_t.cpu() total_loss.backward() nn.utils.clip_grad_norm(net.parameters(), 10.0) if -400 < loss_ctc.data[0] < 400: loss_mean_ctc.append(loss_ctc.data[0]) if -1000 < total_loss.data[0] < 1000: loss_mean_total.append(total_loss.data[0]) if len(loss_mean_total) > 100: loss_mean_total = loss_mean_total[-100:] status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format( epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(), np.mean(loss_mean_total), loss_ctc.data[0]) if ada_after_rnn: loss_mean_ada_rnn_s.append(err_ada_rnn_s.data[0]) loss_mean_ada_rnn_t.append(err_ada_rnn_t.data[0]) status += "; ladatrnns: {0:.3f}; ladatrnnt: {1:.3f}".format( err_ada_rnn_s.data[0], err_ada_rnn_t.data[0]) if ada_before_rnn: loss_mean_ada_cnn_s.append(err_ada_cnn_s.data[0]) loss_mean_ada_cnn_t.append(err_ada_cnn_t.data[0]) status += "; ladatcnns: {0:.3f}; ladatcnnt: {1:.3f}".format( err_ada_cnn_s.data[0], err_ada_cnn_t.data[0]) if do_vat: loss_mean_vat.append(vat_loss.data[0]) status += "; lvat: {0:.3f}".format(vat_loss.data[0]) if do_at: loss_mean_at.append(at_loss.data[0]) status += "; lat: {0:.3f}".format(at_loss.data[0]) if do_test_vat: loss_mean_test_vat.append(test_vat_loss.data[0]) status += "; l_tvat: {0:.3f}".format(test_vat_loss.data[0]) if do_test_vat_rnn or do_test_vat_cnn: loss_mean_test_vat.append(test_vat_loss.data[0]) if do_test_vat_rnn and do_test_vat_cnn: status += "; l_tvatc: {}".format(cnn_lds.data[0]) status += "; l_tvatr: {}".format(rnn_lds.data[0]) else: status += "; l_tvat: {}".format(test_vat_loss.data[0]) iterator.set_description(status) optimizer.step() if do_lr_step: lr_scheduler.step() if do_ema: net.udate_ema() iter_count += 1 if output_dir is not None: torch.save(net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_last")) epoch_count += 1 return