def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize): os.environ["CUDA_VISIBLE_DEVICES"] = gpu cuda = True if gpu is not '' else False input_size = [int(x) for x in input_size.split('x')] transform = Compose([ Rotation(), Translation(), # Scale(), Contrast(), Grid_distortion(), Resize(size=(input_size[0], input_size[1])) ]) if data_path is not None: data = TextDataset(data_path=data_path, mode="pb", transform=transform) else: data = TestDataset(transform=transform, abc=abc) seq_proj = [int(x) for x in seq_proj.split('x')] net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval() acc, avg_ed, pred_pb = test_tta(net, data, data.get_abc(), cuda, visualize) df_submit = pd.DataFrame() df_submit['name'] = [x.split('/')[-1] for x in glob.glob('../../input/public_test_data/*')] df_submit['label'] = pred_pb df_submit.to_csv('tmp_rcnn_tta10.csv', index=None) print("Accuracy: {}".format(acc)) print("Edit distance: {}".format(avg_ed))
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize): os.environ["CUDA_VISIBLE_DEVICES"] = gpu cuda = True if gpu is not '' else False abc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" #print(abc) os.environ['CUDA_VISIBLE_DEVICES'] = gpu input_size = [int(x) for x in input_size.split('x')] transform = Compose( [Rotation(), Resize(size=(input_size[0], input_size[1]))]) if data_path is not None: data = LoadDataset(data_path=data_path, mode="test", transform=transform) seq_proj = [int(x) for x in seq_proj.split('x')] #net = load_model(abc, seq_proj, backend, snapshot, cuda) net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend) #net = nn.DataParallel(net) if snapshot is not None: load_weights(net, torch.load(snapshot)) if cuda: net = net.cuda() #import pdb;pdb.set_trace() net = net.eval() detect(net, data, cuda, visualize)
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize): os.environ["CUDA_VISIBLE_DEVICES"] = gpu cuda = True if gpu is not '' else False input_size = [int(x) for x in input_size.split('x')] transform = Compose([ Rotation(), Resize(size=(input_size[0], input_size[1])) ]) if data_path is not None: data = TextDataset(data_path=data_path, mode="test", transform=transform) else: data = TestDataset(transform=transform, abc=abc) seq_proj = [int(x) for x in seq_proj.split('x')] net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval() acc, avg_ed = test(net, data, data.get_abc(), cuda, visualize) print("Accuracy: {}".format(acc)) print("Edit distance: {}".format(avg_ed))
def main(): input_size = [int(x) for x in config.input_size.split('x')] transform = Compose([ Rotation(), Resize(size=(input_size[0], input_size[1])) ]) # if data_path is not None: data = TextDataset(data_path=config.test_path, mode=config.test_mode, transform=transform) # else: # data = TestDataset(transform=transform, abc=abc) # seq_proj = [int(x) for x in config.seq_proj.split('x')] input_size = [int(x) for x in config.input_size.split('x')] net = load_model(input_size, data.get_abc(), None, config.backend, config.snapshot).eval() assert data.mode == config.test_mode acc, avg_ed = test(net, data, data.get_abc(), visualize=True, batch_size=config.batch_size, num_workers=0, output_csv=config.output_csv, output_image=config.output_image) print("Accuracy: {}".format(acc)) print("Edit distance: {}".format(avg_ed))
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path, orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir, lexicon_path, seq_proj, backend, snapshot, input_height, base_lr, elastic_alpha, elastic_sigma, step_size, max_iter, batch_size, output_dir, test_iter, show_iter, test_init, use_gpu, use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio, vat_epsilon, vat_ip, vat_xi, vat_sign, do_comp, comp_ratio, do_remove_augs, aug_to_remove, do_beam_search, dropout_conv, dropout_rnn, dropout_output, do_ema, do_gray, do_test_vat, do_test_entropy, do_test_vat_cnn, do_test_vat_rnn, do_test_rand, ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio, rnn_hidden_size, do_test_pseudo, test_pseudo_ratio, test_pseudo_thresh, do_lr_step, do_test_ensemble, test_ensemble_ratio, test_ensemble_thresh): num_nets = 4 train_data_path = os.path.join(base_data_dir, train_data_path) train_base_dir = os.path.join(base_data_dir, train_base_dir) synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path) synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir) orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path) orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir) lexicon_path = os.path.join(base_data_dir, lexicon_path) all_parameters = locals() cuda = use_gpu #print(train_base_dir) if output_dir is not None: os.makedirs(output_dir, exist_ok=True) tb_writer = TbSummary(output_dir) output_dir = os.path.join(output_dir, 'model') os.makedirs(output_dir, exist_ok=True) with open(lexicon_path, 'rb') as f: lexicon = pkl.load(f) #print(sorted(lexicon.items(), key=operator.itemgetter(1))) with open(os.path.join(output_dir, 'params.txt'), 'w') as f: f.writelines(str(all_parameters)) print(all_parameters) print('new vat') sin_magnitude = 4 rotate_max_angle = 2 train_fonts = [ 'Qomolangma-Betsu', 'Shangshung Sgoba-KhraChen', 'Shangshung Sgoba-KhraChung', 'Qomolangma-Drutsa' ] all_args = locals() print('doing all transforms :)') rand_trans = [ ElasticAndSine(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma, sin_magnitude=sin_magnitude), Rotation(angle=rotate_max_angle, fill_value=255), ColorGradGausNoise() ] if do_gray: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), ToGray(), Normalize() ] else: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), Normalize() ] transform_random = Compose(rand_trans) if do_gray: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), ToGray(), Normalize()]) else: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), Normalize()]) if use_no_font_repeat_data: print('create dataset') train_data = TextDatasetRandomFont(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) print('finished creating dataset') else: print('train data path:\n{}'.format(train_data_path)) print('train_base_dir:\n{}'.format(train_base_dir)) train_data = TextDataset(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) synth_eval_data = TextDataset(data_path=synth_eval_data_path, lexicon=lexicon, base_path=synth_eval_base_dir, transform=transform_random, fonts=train_fonts) orig_eval_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) if do_test_ensemble: orig_vat_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) #else: # train_data = TestDataset(transform=transform, abc=abc).set_mode("train") # synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") # orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") seq_proj = [int(x) for x in seq_proj.split('x')] nets = [] optimizers = [] lr_schedulers = [] for neti in range(num_nets): nets.append( load_model(lexicon=train_data.get_lexicon(), seq_proj=seq_proj, backend=backend, snapshot=snapshot, cuda=cuda, do_beam_search=do_beam_search, dropout_conv=dropout_conv, dropout_rnn=dropout_rnn, dropout_output=dropout_output, do_ema=do_ema, ada_after_rnn=ada_after_rnn, ada_before_rnn=ada_before_rnn, rnn_hidden_size=rnn_hidden_size, gpu=neti)) optimizers.append( optim.Adam(nets[neti].parameters(), lr=base_lr, weight_decay=0.0001)) lr_schedulers.append( StepLR(optimizers[neti], step_size=step_size, max_iter=max_iter)) loss_function = CTCLoss() synth_avg_ed_best = float("inf") orig_avg_ed_best = float("inf") epoch_count = 0 if do_test_ensemble: collate_vat = lambda x: text_collate(x, do_mask=True) vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) loss_domain = torch.nn.NLLLoss() while True: collate = lambda x: text_collate( x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn)) data_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate) if do_comp: data_loader_comp = DataLoader(train_data_comp, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_comp) iter_comp = iter(data_loader_comp) loss_mean_ctc = [] loss_mean_total = [] loss_mean_test_ensemble = [] num_labels_used_total = 0 iterator = tqdm(data_loader) nll_loss = torch.nn.NLLLoss() iter_count = 0 for iter_num, sample in enumerate(iterator): total_iter = (epoch_count * len(data_loader)) + iter_num if ((total_iter > 1) and total_iter % test_iter == 0) or (test_init and total_iter == 0): # epoch_count != 0 and print("Test phase") for net in nets: net = net.eval() if do_ema: net.start_test() synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test( nets, synth_eval_data, synth_eval_data.get_lexicon(), cuda, batch_size=batch_size, visualize=False, tb_writer=tb_writer, n_iter=total_iter, initial_title='val_synth', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=False) orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test( nets, orig_eval_data, orig_eval_data.get_lexicon(), cuda, batch_size=batch_size, visualize=False, tb_writer=tb_writer, n_iter=total_iter, initial_title='test_orig', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=do_beam_search) for net in nets: net = net.train() #save periodic if output_dir is not None and total_iter // 30000: periodic_save = os.path.join(output_dir, 'periodic_save') os.makedirs(periodic_save, exist_ok=True) old_save = glob.glob(os.path.join(periodic_save, '*')) for neti, net in enumerate(nets): torch.save( net.state_dict(), os.path.join( output_dir, "crnn_{}_".format(neti) + backend + "_" + str(total_iter))) if orig_avg_no_stop_ed < orig_avg_ed_best: orig_avg_ed_best = orig_avg_no_stop_ed if output_dir is not None: for neti, net in enumerate(nets): torch.save( net.state_dict(), os.path.join( output_dir, "crnn_{}_".format(neti) + backend + "_iter_{}".format(total_iter))) if synth_avg_no_stop_ed < synth_avg_ed_best: synth_avg_ed_best = synth_avg_no_stop_ed if do_ema: for net in nets: net.end_test() print( "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(synth_avg_ed_best, synth_avg_ed, synth_avg_no_stop_ed, synth_acc)) print( "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed, orig_acc)) tb_writer.get_writer().add_scalars( 'data/test', { 'synth_ed_total': synth_avg_ed, 'synth_ed_no_stop': synth_avg_no_stop_ed, 'synth_avg_loss': synth_avg_loss, 'orig_ed_total': orig_avg_ed, 'orig_ed_no_stop': orig_avg_no_stop_ed, 'orig_avg_loss': orig_avg_loss }, total_iter) if len(loss_mean_ctc) > 0: train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)} train_dict = { **train_dict, **{ 'mean_test_ensemble_loss': np.mean(loss_mean_test_ensemble) } } train_dict = { **train_dict, **{ 'num_labels_used': num_labels_used_total } } num_labels_used_total = 0 print(train_dict) tb_writer.get_writer().add_scalars('data/train', train_dict, total_iter) ''' # for multi-gpu support if sample["img"].size(0) % len(gpu.split(',')) != 0: continue ''' for optimizer in optimizers: optimizer.zero_grad() imgs = Variable(sample["img"]) #print("images sizes are:") #print(sample["img"].shape) if do_vat or ada_after_rnn or ada_before_rnn: mask = sample['mask'] labels_flatten = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) #print("image sequence length is:") #print(sample["im_seq_len"]) #print("label sequence length is:") #print(sample["seq_len"].view(1,-1)) img_seq_lens = sample["im_seq_len"] if do_test_ensemble: if cur_vat >= vat_len: vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) vat_batch = next(vat_iter) cur_vat += 1 vat_mask = vat_batch['mask'] vat_imgs = Variable(vat_batch["img"]) vat_img_seq_lens = vat_batch["im_seq_len"] all_net_classes = [] all_net_preds = [] def run_net_get_classes(neti_net_pair, cur_vat_imgs, cur_vat_mask, cur_vat_img_seq_lens, cuda): neti, net = neti_net_pair if cuda: cur_vat_imgs = cur_vat_imgs.cuda(neti) cur_vat_mask = cur_vat_mask.cuda(neti) vat_pred = net.vat_forward(cur_vat_imgs, cur_vat_img_seq_lens) vat_pred = vat_pred * cur_vat_mask vat_pred = F.softmax(vat_pred, dim=2).view(-1, vat_pred.size()[-1]) all_net_preds.append(vat_pred) np_vat_preds = vat_pred.cpu().data.numpy() classes_by_index = np.argmax(np_vat_preds, axis=1) return classes_by_index for neti, net in enumerate(nets): if cuda: vat_imgs = vat_imgs.cuda(neti) vat_mask = vat_mask.cuda(neti) vat_pred = net.vat_forward(vat_imgs, vat_img_seq_lens) vat_pred = vat_pred * vat_mask vat_pred = F.softmax(vat_pred, dim=2).view(-1, vat_pred.size()[-1]) all_net_preds.append(vat_pred) np_vat_preds = vat_pred.cpu().data.numpy() classes_by_index = np.argmax(np_vat_preds, axis=1) all_net_classes.append(classes_by_index) all_net_classes = np.stack(all_net_classes) all_net_classes, all_nets_count = stats.mode(all_net_classes, axis=0) all_net_classes = all_net_classes.reshape(-1) all_nets_count = all_nets_count.reshape(-1) ens_indices = np.argwhere( all_nets_count > test_ensemble_thresh) ens_indices = ens_indices.reshape(-1) ens_classes = all_net_classes[ all_nets_count > test_ensemble_thresh] net_ens_losses = [] num_labels_used = len(ens_indices) for neti, net in enumerate(nets): indices = Variable( torch.from_numpy(ens_indices).cuda(neti)) labels = Variable(torch.from_numpy(ens_classes).cuda(neti)) net_preds_to_ens = all_net_preds[neti][indices] loss = nll_loss(net_preds_to_ens, labels) net_ens_losses.append(loss.cpu()) nets_total_losses = [] nets_ctc_losses = [] loss_is_inf = False for neti, net in enumerate(nets): if cuda: imgs = imgs.cuda(neti) preds = net(imgs, img_seq_lens) loss_ctc = loss_function( preds, labels_flatten, Variable(torch.IntTensor(np.array(img_seq_lens))), label_lens) / batch_size if loss_ctc.data[0] in [float("inf"), -float("inf")]: print("warnning: loss should not be inf.") loss_is_inf = True break total_loss = loss_ctc if do_test_ensemble: total_loss = total_loss + test_ensemble_ratio * net_ens_losses[ neti] net_ens_losses[neti] = net_ens_losses[neti].data[0] total_loss.backward() nets_total_losses.append(total_loss.data[0]) nets_ctc_losses.append(loss_ctc.data[0]) nn.utils.clip_grad_norm(net.parameters(), 10.0) if loss_is_inf: continue if -400 < loss_ctc.data[0] < 400: loss_mean_ctc.append(np.mean(nets_ctc_losses)) if -400 < total_loss.data[0] < 400: loss_mean_total.append(np.mean(nets_total_losses)) status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format( epoch_count, lr_schedulers[0].last_iter, lr_schedulers[0].get_lr(), np.mean(nets_total_losses), np.mean(nets_ctc_losses)) if do_test_ensemble: ens_loss = np.mean(net_ens_losses) if ens_loss != 0: loss_mean_test_ensemble.append(ens_loss) status += "; loss_ens: {0:.3f}".format(ens_loss) status += "; num_ens_used {}".format(num_labels_used) else: loss_mean_test_ensemble.append(0) status += "; loss_ens: {}".format(0) iterator.set_description(status) for optimizer in optimizers: optimizer.step() if do_lr_step: for lr_scheduler in lr_schedulers: lr_scheduler.step() iter_count += 1 if output_dir is not None: for neti, net in enumerate(nets): torch.save( net.state_dict(), os.path.join(output_dir, "crnn_{}_".format(neti) + backend + "_last")) epoch_count += 1 return
def main(data_path, abc, seq_proj, backend, snapshot, input_size, base_lr, step_size, max_iter, batch_size, output_dir, test_epoch, test_init, gpu): os.environ["CUDA_VISIBLE_DEVICES"] = gpu cuda = True if gpu is not '' else False input_size = [int(x) for x in input_size.split('x')] transform = Compose([ Rotation(), Translation(), # Scale(), Contrast(), # Grid_distortion(), Resize(size=(input_size[0], input_size[1])) ]) seq_proj = [int(x) for x in seq_proj.split('x')] for fold_idx in range(24): train_mode = 'fold{0}_train'.format(fold_idx) val_mode = 'fold{0}_test'.format(fold_idx) if data_path is not None: data = TextDataset(data_path=data_path, mode=train_mode, transform=transform) else: data = TestDataset(transform=transform, abc=abc) net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda) optimizer = optim.Adam(net.parameters(), lr = base_lr, weight_decay=0.0001) lr_scheduler = StepLR(optimizer, step_size=step_size) # lr_scheduler = StepLR(optimizer, step_size=len(data)/batch_size*2) loss_function = CTCLoss() print(fold_idx) # continue acc_best = 0 epoch_count = 0 for epoch_idx in range(15): data_loader = DataLoader(data, batch_size=batch_size, num_workers=10, shuffle=True, collate_fn=text_collate) loss_mean = [] iterator = tqdm(data_loader) iter_count = 0 for sample in iterator: # for multi-gpu support if sample["img"].size(0) % len(gpu.split(',')) != 0: continue optimizer.zero_grad() imgs = Variable(sample["img"]) labels = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) if cuda: imgs = imgs.cuda() preds = net(imgs).cpu() pred_lens = Variable(Tensor([preds.size(0)] * batch_size).int()) loss = loss_function(preds, labels, pred_lens, label_lens) / batch_size loss.backward() # nn.utils.clip_grad_norm(net.parameters(), 10.0) loss_mean.append(loss.data[0]) status = "{}/{}; lr: {}; loss_mean: {}; loss: {}".format(epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(), np.mean(loss_mean), loss.data[0]) iterator.set_description(status) optimizer.step() lr_scheduler.step() iter_count += 1 if True: logging.info("Test phase") net = net.eval() # train_acc, train_avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False) # if acc > 0.95: # error_name = [data.config[data.mode][idx]["name"] for idx in error_idx] # logging.info('Train: '+','.join(error_name)) # logging.info("acc: {}\tacc_best: {}; avg_ed: {}\n\n".format(train_acc, train_avg_ed)) data.set_mode(val_mode) acc, avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False) if acc > 0.95: error_name = [data.config[data.mode][idx]["name"] for idx in error_idx] logging.info('Val: '+','.join(error_name)) net = net.train() data.set_mode(train_mode) if acc > acc_best: if output_dir is not None: torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_best")) acc_best = acc if acc > 0.985: if output_dir is not None: torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_"+str(acc))) logging.info("train_acc: {}\t; avg_ed: {}\n\n".format(acc, acc_best, avg_ed)) epoch_count += 1
def main(data_path, abc, seq_proj, backend, snapshot, input_size, base_lr, step_size, max_iter, batch_size, output_dir, test_epoch, test_init, gpu): os.environ["CUDA_VISIBLE_DEVICES"] = gpu if not gpu == '': cuda = True else: cuda = False input_size = [int(x) for x in input_size.split('x')] transform = Compose([ Rotation(), # Translation(), # Scale(), Resize(size=(input_size[0], input_size[1])) ]) if data_path is not None: data = TextDataset(data_path=data_path, mode="train", transform=transform) else: data = TestDataset(transform=transform, abc=abc) seq_proj = [int(x) for x in seq_proj.split('x')] # print(data_path) # print(data[0]) # print(data.get_abc()) # exit() net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda) optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001) lr_scheduler = StepLR(optimizer, step_size=step_size, max_iter=max_iter) loss_function = CTCLoss() acc_best = 0 epoch_count = 0 while True: if (test_epoch is not None and epoch_count != 0 and epoch_count % test_epoch == 0) or (test_init and epoch_count == 0): print("Test phase") data.set_mode("test") net = net.eval() acc, avg_ed = test(net, data, data.get_abc(), cuda, visualize=False) net = net.train() data.set_mode("train") if acc > acc_best: if output_dir is not None: torch.save( net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_" + "_best")) acc_best = acc print("acc: {}\tacc_best: {}; avg_ed: {}".format( acc, acc_best, avg_ed)) data_loader = DataLoader(data, batch_size=batch_size, num_workers=1, shuffle=True, collate_fn=text_collate) loss_mean = [] iterator = tqdm(data_loader) iter_count = 0 for sample in iterator: # for multi-gpu support if sample["img"].size(0) % len(gpu.split(',')) != 0: continue optimizer.zero_grad() imgs = Variable(sample["img"]) labels = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) if cuda: imgs = imgs.cuda() preds = net(imgs).cpu() pred_lens = Variable(Tensor([preds.size(0)] * batch_size).int()) loss = loss_function(preds, labels, pred_lens, label_lens) / batch_size loss.backward() nn.utils.clip_grad_norm(net.parameters(), 10.0) loss_mean.append(loss.data[0]) status = "epoch: {}; iter: {}; lr: {}; loss_mean: {}; loss: {}".format( epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(), np.mean(loss_mean), loss.data[0]) iterator.set_description(status) optimizer.step() lr_scheduler.step() iter_count += 1 if output_dir is not None: torch.save( net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_" + "_last")) epoch_count += 1 if epoch_count == 50: break return
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path, orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir, lexicon_path, seq_proj, backend, snapshot, input_height, base_lr, elastic_alpha, elastic_sigma, step_size, max_iter, batch_size, output_dir, test_iter, show_iter, test_init, use_gpu, use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio, vat_epsilon, vat_ip, vat_xi, vat_sign, do_remove_augs, aug_to_remove, do_beam_search, dropout_conv, dropout_rnn, dropout_output, do_ema, do_gray, do_test_vat, do_test_entropy, do_test_vat_cnn, do_test_vat_rnn, ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio, rnn_hidden_size, do_lr_step, dataset_name): if not do_lr_step and not do_ada_lr: raise NotImplementedError( 'learning rate should be either step or ada.') train_data_path = os.path.join(base_data_dir, train_data_path) train_base_dir = os.path.join(base_data_dir, train_base_dir) synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path) synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir) orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path) orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir) lexicon_path = os.path.join(base_data_dir, lexicon_path) all_parameters = locals() cuda = use_gpu #print(train_base_dir) if output_dir is not None: os.makedirs(output_dir, exist_ok=True) tb_writer = TbSummary(output_dir) output_dir = os.path.join(output_dir, 'model') os.makedirs(output_dir, exist_ok=True) with open(lexicon_path, 'rb') as f: lexicon = pkl.load(f) #print(sorted(lexicon.items(), key=operator.itemgetter(1))) with open(os.path.join(output_dir, 'params.txt'), 'w') as f: f.writelines(str(all_parameters)) print(all_parameters) print('new vat') sin_magnitude = 4 rotate_max_angle = 2 dataset_info = SynthDataInfo(None, None, None, dataset_name.lower()) train_fonts = dataset_info.font_names all_args = locals() allowed_removals = [ 'elastic', 'sine', 'sine_rotate', 'rotation', 'color_aug', 'color_gaus', 'color_sine' ] if do_remove_augs and aug_to_remove not in allowed_removals: raise Exception('augmentation removal value is not allowed.') if do_remove_augs: rand_trans = [] if aug_to_remove == 'elastic': print('doing sine transform :)') rand_trans.append(OnlySine(sin_magnitude=sin_magnitude)) elif aug_to_remove in ['sine', 'sine_rotate']: print('doing elastic transform :)') rand_trans.append( OnlyElastic(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma)) if aug_to_remove not in ['elastic', 'sine', 'sine_rotate']: print('doing elastic transform :)') print('doing sine transform :)') rand_trans.append( ElasticAndSine(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma, sin_magnitude=sin_magnitude)) if aug_to_remove not in ['rotation', 'sine_rotate']: print('doing rotation transform :)') rand_trans.append(Rotation(angle=rotate_max_angle, fill_value=255)) if aug_to_remove not in ['color_aug', 'color_gaus', 'color_sine']: print('doing color_aug transform :)') rand_trans.append(ColorGradGausNoise()) elif aug_to_remove == 'color_gaus': print('doing color_sine transform :)') rand_trans.append(ColorGrad()) elif aug_to_remove == 'color_sine': print('doing color_gaus transform :)') rand_trans.append(ColorGausNoise()) else: print('doing all transforms :)') rand_trans = [ ElasticAndSine(elastic_alpha=elastic_alpha, elastic_sigma=elastic_sigma, sin_magnitude=sin_magnitude), Rotation(angle=rotate_max_angle, fill_value=255), ColorGradGausNoise() ] if do_gray: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), ToGray(), Normalize() ] else: rand_trans = rand_trans + [ Resize(hight=input_height), AddWidth(), Normalize() ] transform_random = Compose(rand_trans) if do_gray: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), ToGray(), Normalize()]) else: transform_simple = Compose( [Resize(hight=input_height), AddWidth(), Normalize()]) if use_no_font_repeat_data: print('creating dataset') train_data = TextDatasetRandomFont(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) print('finished creating dataset') else: print('train data path:\n{}'.format(train_data_path)) print('train_base_dir:\n{}'.format(train_base_dir)) train_data = TextDataset(data_path=train_data_path, lexicon=lexicon, base_path=train_base_dir, transform=transform_random, fonts=train_fonts) synth_eval_data = TextDataset(data_path=synth_eval_data_path, lexicon=lexicon, base_path=synth_eval_base_dir, transform=transform_random, fonts=train_fonts) orig_eval_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) if do_test_vat or do_test_vat_rnn or do_test_vat_cnn: orig_vat_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) if ada_after_rnn or ada_before_rnn: orig_ada_data = TextDataset(data_path=orig_eval_data_path, lexicon=lexicon, base_path=orig_eval_base_dir, transform=transform_simple, fonts=None) #else: # train_data = TestDataset(transform=transform, abc=abc).set_mode("train") # synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") # orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test") seq_proj = [int(x) for x in seq_proj.split('x')] net = load_model(lexicon=train_data.get_lexicon(), seq_proj=seq_proj, backend=backend, snapshot=snapshot, cuda=cuda, do_beam_search=do_beam_search, dropout_conv=dropout_conv, dropout_rnn=dropout_rnn, dropout_output=dropout_output, do_ema=do_ema, ada_after_rnn=ada_after_rnn, ada_before_rnn=ada_before_rnn, rnn_hidden_size=rnn_hidden_size) optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001) if do_ada_lr: print('using ada lr') lr_scheduler = DannLR(optimizer, max_iter=max_iter) elif do_lr_step: print('using step lr') lr_scheduler = StepLR(optimizer, step_size=step_size, max_iter=max_iter) loss_function = CTCLoss() synth_avg_ed_best = float("inf") orig_avg_ed_best = float("inf") epoch_count = 0 if do_test_vat or do_test_vat_rnn or do_test_vat_cnn: collate_vat = lambda x: text_collate(x, do_mask=True) vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) if ada_after_rnn or ada_before_rnn: collate_ada = lambda x: text_collate(x, do_mask=True) ada_load = DataLoader(orig_ada_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_ada) ada_len = len(ada_load) cur_ada = 0 ada_iter = iter(ada_load) loss_domain = torch.nn.NLLLoss() while True: collate = lambda x: text_collate( x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn)) data_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate) loss_mean_ctc = [] loss_mean_vat = [] loss_mean_at = [] loss_mean_comp = [] loss_mean_total = [] loss_mean_test_vat = [] loss_mean_test_pseudo = [] loss_mean_test_rand = [] loss_mean_ada_rnn_s = [] loss_mean_ada_rnn_t = [] loss_mean_ada_cnn_s = [] loss_mean_ada_cnn_t = [] iterator = tqdm(data_loader) iter_count = 0 for iter_num, sample in enumerate(iterator): total_iter = (epoch_count * len(data_loader)) + iter_num if ((total_iter > 1) and total_iter % test_iter == 0) or (test_init and total_iter == 0): # epoch_count != 0 and print("Test phase") net = net.eval() if do_ema: net.start_test() synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test( net, synth_eval_data, synth_eval_data.get_lexicon(), cuda, visualize=False, dataset_info=dataset_info, batch_size=batch_size, tb_writer=tb_writer, n_iter=total_iter, initial_title='val_synth', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=False) orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test( net, orig_eval_data, orig_eval_data.get_lexicon(), cuda, visualize=False, dataset_info=dataset_info, batch_size=batch_size, tb_writer=tb_writer, n_iter=total_iter, initial_title='test_orig', loss_function=loss_function, output_path=os.path.join(output_dir, 'results'), do_beam_search=do_beam_search) net = net.train() #save periodic if output_dir is not None and total_iter // 30000: periodic_save = os.path.join(output_dir, 'periodic_save') os.makedirs(periodic_save, exist_ok=True) old_save = glob.glob(os.path.join(periodic_save, '*')) torch.save( net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_" + str(total_iter))) if orig_avg_no_stop_ed < orig_avg_ed_best: orig_avg_ed_best = orig_avg_no_stop_ed if output_dir is not None: torch.save( net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_best")) if synth_avg_no_stop_ed < synth_avg_ed_best: synth_avg_ed_best = synth_avg_no_stop_ed if do_ema: net.end_test() print( "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(synth_avg_ed_best, synth_avg_ed, synth_avg_no_stop_ed, synth_acc)) print( "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}" .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed, orig_acc)) tb_writer.get_writer().add_scalars( 'data/test', { 'synth_ed_total': synth_avg_ed, 'synth_ed_no_stop': synth_avg_no_stop_ed, 'synth_avg_loss': synth_avg_loss, 'orig_ed_total': orig_avg_ed, 'orig_ed_no_stop': orig_avg_no_stop_ed, 'orig_avg_loss': orig_avg_loss }, total_iter) if len(loss_mean_ctc) > 0: train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)} if do_vat: train_dict = { **train_dict, **{ 'mean_vat_loss': np.mean(loss_mean_vat) } } if do_at: train_dict = { **train_dict, **{ 'mean_at_loss': np.mean(loss_mean_at) } } if do_test_vat: train_dict = { **train_dict, **{ 'mean_test_vat_loss': np.mean(loss_mean_test_vat) } } if do_test_vat_rnn and do_test_vat_cnn: train_dict = { **train_dict, **{ 'mean_test_vat_crnn_loss': np.mean(loss_mean_test_vat) } } elif do_test_vat_rnn: train_dict = { **train_dict, **{ 'mean_test_vat_rnn_loss': np.mean(loss_mean_test_vat) } } elif do_test_vat_cnn: train_dict = { **train_dict, **{ 'mean_test_vat_cnn_loss': np.mean(loss_mean_test_vat) } } if ada_after_rnn: train_dict = { **train_dict, **{ 'mean_ada_rnn_s_loss': np.mean(loss_mean_ada_rnn_s), 'mean_ada_rnn_t_loss': np.mean(loss_mean_ada_rnn_t) } } if ada_before_rnn: train_dict = { **train_dict, **{ 'mean_ada_cnn_s_loss': np.mean(loss_mean_ada_cnn_s), 'mean_ada_cnn_t_loss': np.mean(loss_mean_ada_cnn_t) } } print(train_dict) tb_writer.get_writer().add_scalars('data/train', train_dict, total_iter) ''' # for multi-gpu support if sample["img"].size(0) % len(gpu.split(',')) != 0: continue ''' optimizer.zero_grad() imgs = Variable(sample["img"]) #print("images sizes are:") #print(sample["img"].shape) if do_vat or ada_after_rnn or ada_before_rnn: mask = sample['mask'] labels_flatten = Variable(sample["seq"]).view(-1) label_lens = Variable(sample["seq_len"].int()) #print("image sequence length is:") #print(sample["im_seq_len"]) #print("label sequence length is:") #print(sample["seq_len"].view(1,-1)) img_seq_lens = sample["im_seq_len"] if cuda: imgs = imgs.cuda() if do_vat or ada_after_rnn or ada_before_rnn: mask = mask.cuda() if do_ada_lr: ada_p = float(iter_count) / max_iter lr_scheduler.update(ada_p) if ada_before_rnn or ada_after_rnn: if not do_ada_lr: ada_p = float(iter_count) / max_iter ada_alpha = 2. / (1. + np.exp(-10. * ada_p)) - 1 if cur_ada >= ada_len: ada_load = DataLoader(orig_ada_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_ada) ada_len = len(ada_load) cur_ada = 0 ada_iter = iter(ada_load) ada_batch = next(ada_iter) cur_ada += 1 ada_imgs = Variable(ada_batch["img"]) ada_img_seq_lens = ada_batch["im_seq_len"] ada_mask = ada_batch['mask'].byte() if cuda: ada_imgs = ada_imgs.cuda() _, ada_cnn, ada_rnn = net(ada_imgs, ada_img_seq_lens, ada_alpha=ada_alpha, mask=ada_mask) if ada_before_rnn: ada_num_features = ada_cnn.size(0) else: ada_num_features = ada_rnn.size(0) domain_label = torch.zeros(ada_num_features) domain_label = domain_label.long() if cuda: domain_label = domain_label.cuda() domain_label = Variable(domain_label) if ada_before_rnn: err_ada_cnn_t = loss_domain(ada_cnn, domain_label) if ada_after_rnn: err_ada_rnn_t = loss_domain(ada_rnn, domain_label) if do_test_vat and do_at: # test part! if cur_vat >= vat_len: vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) test_vat_batch = next(vat_iter) cur_vat += 1 test_vat_mask = test_vat_batch['mask'] test_vat_imgs = Variable(test_vat_batch["img"]) test_vat_img_seq_lens = test_vat_batch["im_seq_len"] if cuda: test_vat_imgs = test_vat_imgs.cuda() test_vat_mask = test_vat_mask.cuda() # train part at_test_vat_loss = LabeledAtAndUnlabeledTestVatLoss( xi=vat_xi, eps=vat_epsilon, ip=vat_ip) at_loss, test_vat_loss = at_test_vat_loss( model=net, train_x=imgs, train_labels_flatten=labels_flatten, train_img_seq_lens=img_seq_lens, train_label_lens=label_lens, batch_size=batch_size, test_x=test_vat_imgs, test_seq_len=test_vat_img_seq_lens, test_mask=test_vat_mask) elif do_test_vat or do_test_vat_rnn or do_test_vat_cnn: if cur_vat >= vat_len: vat_load = DataLoader(orig_vat_data, batch_size=batch_size, num_workers=4, shuffle=True, collate_fn=collate_vat) vat_len = len(vat_load) cur_vat = 0 vat_iter = iter(vat_load) vat_batch = next(vat_iter) cur_vat += 1 vat_mask = vat_batch['mask'] vat_imgs = Variable(vat_batch["img"]) vat_img_seq_lens = vat_batch["im_seq_len"] if cuda: vat_imgs = vat_imgs.cuda() vat_mask = vat_mask.cuda() if do_test_vat: if do_test_vat_rnn or do_test_vat_cnn: raise "can only do one of do_test_vat | (do_test_vat_rnn, do_test_vat_cnn)" if vat_sign == True: test_vat_loss = VATLossSign( do_test_entropy=do_test_entropy, xi=vat_xi, eps=vat_epsilon, ip=vat_ip) else: test_vat_loss = VATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) elif do_test_vat_rnn and do_test_vat_cnn: test_vat_loss = VATonRnnCnnSign(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) elif do_test_vat_rnn: test_vat_loss = VATonRnnSign(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) elif do_test_vat_cnn: test_vat_loss = VATonCnnSign(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) if do_test_vat_cnn and do_test_vat_rnn: test_vat_loss, cnn_lds, rnn_lds = test_vat_loss( net, vat_imgs, vat_img_seq_lens, vat_mask) elif do_test_vat: test_vat_loss = test_vat_loss(net, vat_imgs, vat_img_seq_lens, vat_mask) elif do_vat: vat_loss = VATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) vat_loss = vat_loss(net, imgs, img_seq_lens, mask) elif do_at: at_loss = LabeledATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip) at_loss = at_loss(net, imgs, labels_flatten, img_seq_lens, label_lens, batch_size) if ada_after_rnn or ada_before_rnn: preds, ada_cnn, ada_rnn = net(imgs, img_seq_lens, ada_alpha=ada_alpha, mask=mask) if ada_before_rnn: ada_num_features = ada_cnn.size(0) else: ada_num_features = ada_rnn.size(0) domain_label = torch.ones(ada_num_features) domain_label = domain_label.long() if cuda: domain_label = domain_label.cuda() domain_label = Variable(domain_label) if ada_before_rnn: err_ada_cnn_s = loss_domain(ada_cnn, domain_label) if ada_after_rnn: err_ada_rnn_s = loss_domain(ada_rnn, domain_label) else: preds = net(imgs, img_seq_lens) ''' if output_dir is not None: if (show_iter is not None and iter_num != 0 and iter_num % show_iter == 0): print_data_visuals(net, tb_writer, train_data.get_lexicon(), sample["img"], labels_flatten, label_lens, preds, ((epoch_count * len(data_loader)) + iter_num)) ''' loss_ctc = loss_function( preds, labels_flatten, Variable(torch.IntTensor(np.array(img_seq_lens))), label_lens) / batch_size if loss_ctc.data[0] in [float("inf"), -float("inf")]: print("warnning: loss should not be inf.") continue total_loss = loss_ctc if do_vat: #mask = sample['mask'] #if cuda: # mask = mask.cuda() #vat_loss = virtual_adversarial_loss(net, imgs, img_seq_lens, mask, is_training=True, do_entropy=False, epsilon=vat_epsilon, num_power_iterations=1, # xi=1e-6, average_loss=True) total_loss = total_loss + vat_ratio * vat_loss.cpu() if do_test_vat or do_test_vat_rnn or do_test_vat_cnn: total_loss = total_loss + test_vat_ratio * test_vat_loss.cpu() if ada_before_rnn: total_loss = total_loss + ada_ratio * err_ada_cnn_s.cpu( ) + ada_ratio * err_ada_cnn_t.cpu() if ada_after_rnn: total_loss = total_loss + ada_ratio * err_ada_rnn_s.cpu( ) + ada_ratio * err_ada_rnn_t.cpu() total_loss.backward() nn.utils.clip_grad_norm(net.parameters(), 10.0) if -400 < loss_ctc.data[0] < 400: loss_mean_ctc.append(loss_ctc.data[0]) if -1000 < total_loss.data[0] < 1000: loss_mean_total.append(total_loss.data[0]) if len(loss_mean_total) > 100: loss_mean_total = loss_mean_total[-100:] status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format( epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(), np.mean(loss_mean_total), loss_ctc.data[0]) if ada_after_rnn: loss_mean_ada_rnn_s.append(err_ada_rnn_s.data[0]) loss_mean_ada_rnn_t.append(err_ada_rnn_t.data[0]) status += "; ladatrnns: {0:.3f}; ladatrnnt: {1:.3f}".format( err_ada_rnn_s.data[0], err_ada_rnn_t.data[0]) if ada_before_rnn: loss_mean_ada_cnn_s.append(err_ada_cnn_s.data[0]) loss_mean_ada_cnn_t.append(err_ada_cnn_t.data[0]) status += "; ladatcnns: {0:.3f}; ladatcnnt: {1:.3f}".format( err_ada_cnn_s.data[0], err_ada_cnn_t.data[0]) if do_vat: loss_mean_vat.append(vat_loss.data[0]) status += "; lvat: {0:.3f}".format(vat_loss.data[0]) if do_at: loss_mean_at.append(at_loss.data[0]) status += "; lat: {0:.3f}".format(at_loss.data[0]) if do_test_vat: loss_mean_test_vat.append(test_vat_loss.data[0]) status += "; l_tvat: {0:.3f}".format(test_vat_loss.data[0]) if do_test_vat_rnn or do_test_vat_cnn: loss_mean_test_vat.append(test_vat_loss.data[0]) if do_test_vat_rnn and do_test_vat_cnn: status += "; l_tvatc: {}".format(cnn_lds.data[0]) status += "; l_tvatr: {}".format(rnn_lds.data[0]) else: status += "; l_tvat: {}".format(test_vat_loss.data[0]) iterator.set_description(status) optimizer.step() if do_lr_step: lr_scheduler.step() if do_ema: net.udate_ema() iter_count += 1 if output_dir is not None: torch.save(net.state_dict(), os.path.join(output_dir, "crnn_" + backend + "_last")) epoch_count += 1 return