Example #1
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        Translation(),
        # Scale(),
        Contrast(),
        Grid_distortion(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    if data_path is not None:
        data = TextDataset(data_path=data_path, mode="pb", transform=transform)
    else:
        data = TestDataset(transform=transform, abc=abc)
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval()
    acc, avg_ed, pred_pb = test_tta(net, data, data.get_abc(), cuda, visualize)
    
    df_submit = pd.DataFrame()
    df_submit['name'] = [x.split('/')[-1] for x in glob.glob('../../input/public_test_data/*')]
    df_submit['label'] = pred_pb
    
    df_submit.to_csv('tmp_rcnn_tta10.csv', index=None)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
Example #2
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    if data_path is not None:
        data = TextDataset(data_path=data_path, mode="test", transform=transform)
    else:
        data = TestDataset(transform=transform, abc=abc)
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval()
    acc, avg_ed = test(net, data, data.get_abc(), cuda, visualize)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
Example #3
0
def main(data_path, base_data_dir, lexicon_path, output_path, seq_proj,
         backend, snapshot, input_height, visualize, do_beam_search,
         dataset_name):
    cuda = True
    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
        print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    transform = Compose([Resize(hight=input_height), AddWidth(), Normalize()])
    data = TextDataset(data_path=data_path,
                       lexicon=lexicon,
                       base_path=base_data_dir,
                       transform=transform,
                       fonts=None)
    dataset_info = SynthDataInfo(None, None, None, dataset_name.lower())

    # data = TextDataset(data_path=data_path, mode="test", transform=transform)
    #else:
    #    data = TestDataset(transform=transform, abc=abc)
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(lexicon=data.get_lexicon(),
                     seq_proj=seq_proj,
                     backend=backend,
                     snapshot=snapshot,
                     cuda=cuda,
                     do_beam_search=do_beam_search).eval()
    acc, avg_ed, avg_no_stop_ed = test(net,
                                       data,
                                       data.get_lexicon(),
                                       cuda,
                                       visualize=visualize,
                                       dataset_info=dataset_info,
                                       batch_size=1,
                                       tb_writer=None,
                                       n_iter=0,
                                       initial_title='val_orig',
                                       loss_function=None,
                                       is_trian=False,
                                       output_path=output_path,
                                       do_beam_search=do_beam_search,
                                       do_results=True)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
    print("Edit distance without stop signs: {}".format(avg_no_stop_ed))
Example #4
0
def main():
    input_size = [int(x) for x in config.input_size.split('x')]
    transform = Compose([
        Rotation(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    # if data_path is not None:
    data = TextDataset(data_path=config.test_path, mode=config.test_mode, transform=transform)
    # else:
    #     data = TestDataset(transform=transform, abc=abc)
    # seq_proj = [int(x) for x in config.seq_proj.split('x')]

    input_size = [int(x) for x in config.input_size.split('x')]
    net = load_model(input_size, data.get_abc(), None, config.backend, config.snapshot).eval()

    assert data.mode == config.test_mode
    acc, avg_ed = test(net, data, data.get_abc(), visualize=True,
                       batch_size=config.batch_size, num_workers=0,
                       output_csv=config.output_csv, output_image=config.output_image)

    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
Example #5
0
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path,
         orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir,
         lexicon_path, seq_proj, backend, snapshot, input_height, base_lr,
         elastic_alpha, elastic_sigma, step_size, max_iter, batch_size,
         output_dir, test_iter, show_iter, test_init, use_gpu,
         use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio,
         vat_epsilon, vat_ip, vat_xi, vat_sign, do_comp, comp_ratio,
         do_remove_augs, aug_to_remove, do_beam_search, dropout_conv,
         dropout_rnn, dropout_output, do_ema, do_gray, do_test_vat,
         do_test_entropy, do_test_vat_cnn, do_test_vat_rnn, do_test_rand,
         ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio, rnn_hidden_size,
         do_test_pseudo, test_pseudo_ratio, test_pseudo_thresh, do_lr_step,
         do_test_ensemble, test_ensemble_ratio, test_ensemble_thresh):
    num_nets = 4

    train_data_path = os.path.join(base_data_dir, train_data_path)
    train_base_dir = os.path.join(base_data_dir, train_base_dir)
    synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path)
    synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir)

    orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path)
    orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir)
    lexicon_path = os.path.join(base_data_dir, lexicon_path)

    all_parameters = locals()
    cuda = use_gpu
    #print(train_base_dir)
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        tb_writer = TbSummary(output_dir)
        output_dir = os.path.join(output_dir, 'model')
        os.makedirs(output_dir, exist_ok=True)

    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
    #print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    with open(os.path.join(output_dir, 'params.txt'), 'w') as f:
        f.writelines(str(all_parameters))
    print(all_parameters)
    print('new vat')

    sin_magnitude = 4
    rotate_max_angle = 2
    train_fonts = [
        'Qomolangma-Betsu', 'Shangshung Sgoba-KhraChen',
        'Shangshung Sgoba-KhraChung', 'Qomolangma-Drutsa'
    ]

    all_args = locals()

    print('doing all transforms :)')
    rand_trans = [
        ElasticAndSine(elastic_alpha=elastic_alpha,
                       elastic_sigma=elastic_sigma,
                       sin_magnitude=sin_magnitude),
        Rotation(angle=rotate_max_angle, fill_value=255),
        ColorGradGausNoise()
    ]
    if do_gray:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(),
            ToGray(),
            Normalize()
        ]
    else:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(), Normalize()
        ]

    transform_random = Compose(rand_trans)
    if do_gray:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(),
             ToGray(),
             Normalize()])
    else:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(), Normalize()])

    if use_no_font_repeat_data:
        print('create dataset')
        train_data = TextDatasetRandomFont(data_path=train_data_path,
                                           lexicon=lexicon,
                                           base_path=train_base_dir,
                                           transform=transform_random,
                                           fonts=train_fonts)
        print('finished creating dataset')
    else:
        print('train data path:\n{}'.format(train_data_path))
        print('train_base_dir:\n{}'.format(train_base_dir))
        train_data = TextDataset(data_path=train_data_path,
                                 lexicon=lexicon,
                                 base_path=train_base_dir,
                                 transform=transform_random,
                                 fonts=train_fonts)
    synth_eval_data = TextDataset(data_path=synth_eval_data_path,
                                  lexicon=lexicon,
                                  base_path=synth_eval_base_dir,
                                  transform=transform_random,
                                  fonts=train_fonts)
    orig_eval_data = TextDataset(data_path=orig_eval_data_path,
                                 lexicon=lexicon,
                                 base_path=orig_eval_base_dir,
                                 transform=transform_simple,
                                 fonts=None)
    if do_test_ensemble:
        orig_vat_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    #else:
    #    train_data = TestDataset(transform=transform, abc=abc).set_mode("train")
    #    synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    #    orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    seq_proj = [int(x) for x in seq_proj.split('x')]
    nets = []
    optimizers = []
    lr_schedulers = []
    for neti in range(num_nets):
        nets.append(
            load_model(lexicon=train_data.get_lexicon(),
                       seq_proj=seq_proj,
                       backend=backend,
                       snapshot=snapshot,
                       cuda=cuda,
                       do_beam_search=do_beam_search,
                       dropout_conv=dropout_conv,
                       dropout_rnn=dropout_rnn,
                       dropout_output=dropout_output,
                       do_ema=do_ema,
                       ada_after_rnn=ada_after_rnn,
                       ada_before_rnn=ada_before_rnn,
                       rnn_hidden_size=rnn_hidden_size,
                       gpu=neti))
        optimizers.append(
            optim.Adam(nets[neti].parameters(),
                       lr=base_lr,
                       weight_decay=0.0001))
        lr_schedulers.append(
            StepLR(optimizers[neti], step_size=step_size, max_iter=max_iter))
    loss_function = CTCLoss()

    synth_avg_ed_best = float("inf")
    orig_avg_ed_best = float("inf")
    epoch_count = 0

    if do_test_ensemble:
        collate_vat = lambda x: text_collate(x, do_mask=True)
        vat_load = DataLoader(orig_vat_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_vat)
        vat_len = len(vat_load)
        cur_vat = 0
        vat_iter = iter(vat_load)

    loss_domain = torch.nn.NLLLoss()

    while True:
        collate = lambda x: text_collate(
            x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn))
        data_loader = DataLoader(train_data,
                                 batch_size=batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 collate_fn=collate)
        if do_comp:
            data_loader_comp = DataLoader(train_data_comp,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_comp)
            iter_comp = iter(data_loader_comp)

        loss_mean_ctc = []
        loss_mean_total = []
        loss_mean_test_ensemble = []
        num_labels_used_total = 0
        iterator = tqdm(data_loader)
        nll_loss = torch.nn.NLLLoss()
        iter_count = 0
        for iter_num, sample in enumerate(iterator):
            total_iter = (epoch_count * len(data_loader)) + iter_num
            if ((total_iter > 1)
                    and total_iter % test_iter == 0) or (test_init
                                                         and total_iter == 0):
                # epoch_count != 0 and

                print("Test phase")
                for net in nets:
                    net = net.eval()
                    if do_ema:
                        net.start_test()

                synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test(
                    nets,
                    synth_eval_data,
                    synth_eval_data.get_lexicon(),
                    cuda,
                    batch_size=batch_size,
                    visualize=False,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='val_synth',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=False)

                orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test(
                    nets,
                    orig_eval_data,
                    orig_eval_data.get_lexicon(),
                    cuda,
                    batch_size=batch_size,
                    visualize=False,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='test_orig',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=do_beam_search)

                for net in nets:
                    net = net.train()
                #save periodic
                if output_dir is not None and total_iter // 30000:
                    periodic_save = os.path.join(output_dir, 'periodic_save')
                    os.makedirs(periodic_save, exist_ok=True)
                    old_save = glob.glob(os.path.join(periodic_save, '*'))
                    for neti, net in enumerate(nets):
                        torch.save(
                            net.state_dict(),
                            os.path.join(
                                output_dir, "crnn_{}_".format(neti) + backend +
                                "_" + str(total_iter)))

                if orig_avg_no_stop_ed < orig_avg_ed_best:
                    orig_avg_ed_best = orig_avg_no_stop_ed
                if output_dir is not None:
                    for neti, net in enumerate(nets):
                        torch.save(
                            net.state_dict(),
                            os.path.join(
                                output_dir, "crnn_{}_".format(neti) + backend +
                                "_iter_{}".format(total_iter)))

                if synth_avg_no_stop_ed < synth_avg_ed_best:
                    synth_avg_ed_best = synth_avg_no_stop_ed
                if do_ema:
                    for net in nets:
                        net.end_test()
                print(
                    "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(synth_avg_ed_best, synth_avg_ed,
                            synth_avg_no_stop_ed, synth_acc))
                print(
                    "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed,
                            orig_acc))
                tb_writer.get_writer().add_scalars(
                    'data/test', {
                        'synth_ed_total': synth_avg_ed,
                        'synth_ed_no_stop': synth_avg_no_stop_ed,
                        'synth_avg_loss': synth_avg_loss,
                        'orig_ed_total': orig_avg_ed,
                        'orig_ed_no_stop': orig_avg_no_stop_ed,
                        'orig_avg_loss': orig_avg_loss
                    }, total_iter)
                if len(loss_mean_ctc) > 0:
                    train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)}
                    train_dict = {
                        **train_dict,
                        **{
                            'mean_test_ensemble_loss':
                            np.mean(loss_mean_test_ensemble)
                        }
                    }
                    train_dict = {
                        **train_dict,
                        **{
                            'num_labels_used': num_labels_used_total
                        }
                    }
                    num_labels_used_total = 0
                    print(train_dict)
                    tb_writer.get_writer().add_scalars('data/train',
                                                       train_dict, total_iter)
            '''
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            '''
            for optimizer in optimizers:
                optimizer.zero_grad()
            imgs = Variable(sample["img"])
            #print("images sizes are:")
            #print(sample["img"].shape)
            if do_vat or ada_after_rnn or ada_before_rnn:
                mask = sample['mask']
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            #print("image sequence length is:")
            #print(sample["im_seq_len"])
            #print("label sequence length is:")
            #print(sample["seq_len"].view(1,-1))
            img_seq_lens = sample["im_seq_len"]

            if do_test_ensemble:
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                vat_batch = next(vat_iter)
                cur_vat += 1
                vat_mask = vat_batch['mask']
                vat_imgs = Variable(vat_batch["img"])
                vat_img_seq_lens = vat_batch["im_seq_len"]
                all_net_classes = []
                all_net_preds = []

                def run_net_get_classes(neti_net_pair, cur_vat_imgs,
                                        cur_vat_mask, cur_vat_img_seq_lens,
                                        cuda):
                    neti, net = neti_net_pair
                    if cuda:
                        cur_vat_imgs = cur_vat_imgs.cuda(neti)
                        cur_vat_mask = cur_vat_mask.cuda(neti)
                    vat_pred = net.vat_forward(cur_vat_imgs,
                                               cur_vat_img_seq_lens)
                    vat_pred = vat_pred * cur_vat_mask
                    vat_pred = F.softmax(vat_pred,
                                         dim=2).view(-1,
                                                     vat_pred.size()[-1])
                    all_net_preds.append(vat_pred)
                    np_vat_preds = vat_pred.cpu().data.numpy()
                    classes_by_index = np.argmax(np_vat_preds, axis=1)
                    return classes_by_index

                for neti, net in enumerate(nets):
                    if cuda:
                        vat_imgs = vat_imgs.cuda(neti)
                        vat_mask = vat_mask.cuda(neti)
                    vat_pred = net.vat_forward(vat_imgs, vat_img_seq_lens)
                    vat_pred = vat_pred * vat_mask
                    vat_pred = F.softmax(vat_pred,
                                         dim=2).view(-1,
                                                     vat_pred.size()[-1])
                    all_net_preds.append(vat_pred)
                    np_vat_preds = vat_pred.cpu().data.numpy()
                    classes_by_index = np.argmax(np_vat_preds, axis=1)
                    all_net_classes.append(classes_by_index)
                all_net_classes = np.stack(all_net_classes)
                all_net_classes, all_nets_count = stats.mode(all_net_classes,
                                                             axis=0)
                all_net_classes = all_net_classes.reshape(-1)
                all_nets_count = all_nets_count.reshape(-1)
                ens_indices = np.argwhere(
                    all_nets_count > test_ensemble_thresh)
                ens_indices = ens_indices.reshape(-1)
                ens_classes = all_net_classes[
                    all_nets_count > test_ensemble_thresh]
                net_ens_losses = []
                num_labels_used = len(ens_indices)
                for neti, net in enumerate(nets):
                    indices = Variable(
                        torch.from_numpy(ens_indices).cuda(neti))
                    labels = Variable(torch.from_numpy(ens_classes).cuda(neti))
                    net_preds_to_ens = all_net_preds[neti][indices]
                    loss = nll_loss(net_preds_to_ens, labels)
                    net_ens_losses.append(loss.cpu())
            nets_total_losses = []
            nets_ctc_losses = []
            loss_is_inf = False
            for neti, net in enumerate(nets):
                if cuda:
                    imgs = imgs.cuda(neti)
                preds = net(imgs, img_seq_lens)
                loss_ctc = loss_function(
                    preds, labels_flatten,
                    Variable(torch.IntTensor(np.array(img_seq_lens))),
                    label_lens) / batch_size

                if loss_ctc.data[0] in [float("inf"), -float("inf")]:
                    print("warnning: loss should not be inf.")
                    loss_is_inf = True
                    break
                total_loss = loss_ctc

                if do_test_ensemble:
                    total_loss = total_loss + test_ensemble_ratio * net_ens_losses[
                        neti]
                    net_ens_losses[neti] = net_ens_losses[neti].data[0]
                total_loss.backward()
                nets_total_losses.append(total_loss.data[0])
                nets_ctc_losses.append(loss_ctc.data[0])
                nn.utils.clip_grad_norm(net.parameters(), 10.0)
            if loss_is_inf:
                continue
            if -400 < loss_ctc.data[0] < 400:
                loss_mean_ctc.append(np.mean(nets_ctc_losses))
            if -400 < total_loss.data[0] < 400:
                loss_mean_total.append(np.mean(nets_total_losses))
            status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format(
                epoch_count, lr_schedulers[0].last_iter,
                lr_schedulers[0].get_lr(), np.mean(nets_total_losses),
                np.mean(nets_ctc_losses))

            if do_test_ensemble:
                ens_loss = np.mean(net_ens_losses)
                if ens_loss != 0:
                    loss_mean_test_ensemble.append(ens_loss)
                    status += "; loss_ens: {0:.3f}".format(ens_loss)
                    status += "; num_ens_used {}".format(num_labels_used)
                else:
                    loss_mean_test_ensemble.append(0)
                    status += "; loss_ens: {}".format(0)
            iterator.set_description(status)
            for optimizer in optimizers:
                optimizer.step()
            if do_lr_step:
                for lr_scheduler in lr_schedulers:
                    lr_scheduler.step()
            iter_count += 1
        if output_dir is not None:
            for neti, net in enumerate(nets):
                torch.save(
                    net.state_dict(),
                    os.path.join(output_dir,
                                 "crnn_{}_".format(neti) + backend + "_last"))
        epoch_count += 1

    return
Example #6
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, base_lr, step_size, max_iter, batch_size, output_dir, test_epoch, test_init, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        Translation(),
        # Scale(),
        Contrast(),
        # Grid_distortion(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    seq_proj = [int(x) for x in seq_proj.split('x')]
    
    for fold_idx in range(24):
        train_mode = 'fold{0}_train'.format(fold_idx)
        val_mode = 'fold{0}_test'.format(fold_idx)
        
        if data_path is not None:
            data = TextDataset(data_path=data_path, mode=train_mode, transform=transform)
        else:
            data = TestDataset(transform=transform, abc=abc)
        
        net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda)
        optimizer = optim.Adam(net.parameters(), lr = base_lr, weight_decay=0.0001)
        lr_scheduler = StepLR(optimizer, step_size=step_size)
        # lr_scheduler = StepLR(optimizer, step_size=len(data)/batch_size*2)
        loss_function = CTCLoss()
        
        print(fold_idx)
        # continue
        
        acc_best = 0
        epoch_count = 0
        for epoch_idx in range(15):
            data_loader = DataLoader(data, batch_size=batch_size, num_workers=10, shuffle=True, collate_fn=text_collate)
            loss_mean = []
            iterator = tqdm(data_loader)
            iter_count = 0
            for sample in iterator:
                # for multi-gpu support
                if sample["img"].size(0) % len(gpu.split(',')) != 0:
                    continue
                optimizer.zero_grad()
                imgs = Variable(sample["img"])
                labels = Variable(sample["seq"]).view(-1)
                label_lens = Variable(sample["seq_len"].int())
                if cuda:
                    imgs = imgs.cuda()
                preds = net(imgs).cpu()
                pred_lens = Variable(Tensor([preds.size(0)] * batch_size).int())
                loss = loss_function(preds, labels, pred_lens, label_lens) / batch_size
                loss.backward()
                # nn.utils.clip_grad_norm(net.parameters(), 10.0)
                loss_mean.append(loss.data[0])
                status = "{}/{}; lr: {}; loss_mean: {}; loss: {}".format(epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(), np.mean(loss_mean), loss.data[0])
                iterator.set_description(status)
                optimizer.step()
                lr_scheduler.step()
                iter_count += 1
            
            if True:
                logging.info("Test phase")
                
                net = net.eval()
                
#                 train_acc, train_avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False)
#                 if acc > 0.95:
#                     error_name = [data.config[data.mode][idx]["name"] for idx in error_idx]
#                     logging.info('Train: '+','.join(error_name))
#                 logging.info("acc: {}\tacc_best: {}; avg_ed: {}\n\n".format(train_acc, train_avg_ed))

                data.set_mode(val_mode)
                acc, avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False)
                
                if acc > 0.95:
                    error_name = [data.config[data.mode][idx]["name"] for idx in error_idx]
                    logging.info('Val: '+','.join(error_name))
                
                
                
                net = net.train()
                data.set_mode(train_mode)
                
                if acc > acc_best:
                    if output_dir is not None:
                        torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_best"))
                    acc_best = acc
                
                if acc > 0.985:
                    if output_dir is not None:
                        torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_"+str(acc)))
                logging.info("train_acc: {}\t; avg_ed: {}\n\n".format(acc, acc_best, avg_ed))
                
                
            epoch_count += 1
Example #7
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, base_lr,
         step_size, max_iter, batch_size, output_dir, test_epoch, test_init,
         gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    if not gpu == '':
        cuda = True
    else:
        cuda = False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        # Translation(),
        # Scale(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    if data_path is not None:
        data = TextDataset(data_path=data_path,
                           mode="train",
                           transform=transform)
    else:
        data = TestDataset(transform=transform, abc=abc)
    seq_proj = [int(x) for x in seq_proj.split('x')]
    # print(data_path)
    # print(data[0])
    # print(data.get_abc())
    # exit()
    net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda)
    optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001)
    lr_scheduler = StepLR(optimizer, step_size=step_size, max_iter=max_iter)
    loss_function = CTCLoss()

    acc_best = 0
    epoch_count = 0
    while True:
        if (test_epoch is not None and epoch_count != 0
                and epoch_count % test_epoch == 0) or (test_init
                                                       and epoch_count == 0):
            print("Test phase")
            data.set_mode("test")
            net = net.eval()
            acc, avg_ed = test(net,
                               data,
                               data.get_abc(),
                               cuda,
                               visualize=False)
            net = net.train()
            data.set_mode("train")
            if acc > acc_best:
                if output_dir is not None:
                    torch.save(
                        net.state_dict(),
                        os.path.join(output_dir,
                                     "crnn_" + backend + "_" + "_best"))
                acc_best = acc
            print("acc: {}\tacc_best: {}; avg_ed: {}".format(
                acc, acc_best, avg_ed))

        data_loader = DataLoader(data,
                                 batch_size=batch_size,
                                 num_workers=1,
                                 shuffle=True,
                                 collate_fn=text_collate)
        loss_mean = []
        iterator = tqdm(data_loader)
        iter_count = 0
        for sample in iterator:
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            optimizer.zero_grad()
            imgs = Variable(sample["img"])
            labels = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            if cuda:
                imgs = imgs.cuda()
            preds = net(imgs).cpu()
            pred_lens = Variable(Tensor([preds.size(0)] * batch_size).int())
            loss = loss_function(preds, labels, pred_lens,
                                 label_lens) / batch_size
            loss.backward()
            nn.utils.clip_grad_norm(net.parameters(), 10.0)
            loss_mean.append(loss.data[0])
            status = "epoch: {}; iter: {}; lr: {}; loss_mean: {}; loss: {}".format(
                epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(),
                np.mean(loss_mean), loss.data[0])
            iterator.set_description(status)
            optimizer.step()
            lr_scheduler.step()
            iter_count += 1
        if output_dir is not None:
            torch.save(
                net.state_dict(),
                os.path.join(output_dir, "crnn_" + backend + "_" + "_last"))
        epoch_count += 1
        if epoch_count == 50:
            break
    return
Example #8
0
def main():
    input_size = [int(x) for x in config.input_size.split('x')]
    # TODO: 1) Sử dụng elastic transform 2) Random erasor một phần của bức ảnh. de data augmentation
    transform = Compose([
        # Rotation(),
        # Resize(size=(input_size[0], input_size[1]), data_augmen=True)
        Resize(size=(input_size[0], input_size[1]))
    ])
    data = TextDataset(data_path=config.data_path,
                       mode="train",
                       transform=transform)
    print("Len of train =", len(data))
    data.set_mode("dev")
    print("Len of dev =", len(data))
    data.set_mode("test")
    print("Len of test =", len(data))
    data.set_mode("test_annotated")
    print("Len of test_annotated =", len(data))
    data.set_mode("train")

    net = load_model(input_size, data.get_abc(), None, config.backend,
                     config.snapshot)
    total_params = sum(p.numel() for p in net.parameters())
    train_total_params = sum(p.numel() for p in net.parameters()
                             if p.requires_grad)
    print("# of parameters =", total_params)
    print("# of non-training parameters =", total_params - train_total_params)
    print("")
    if config.output_image:
        input_img_path = os.path.join(config.output_dir, "input_images")
        file_list = glob.glob(input_img_path + "/*")
        print("Remove the old", input_img_path)
        for file in file_list:
            if os.path.isfile(file):
                os.remove(file)

    optimizer = optim.Adam(net.parameters(), lr=config.base_lr)
    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     factor=0.5,
                                     patience=5,
                                     verbose=True)
    loss_function = CTCLoss(blank=0)
    loss_label = nn.NLLLoss()

    dev_avg_ed_best = float("inf")
    anno_avg_ed_best = 0.1544685954462857
    epoch_count = 0
    print("Start running ...")

    while True:
        # test dev phrase
        # if epoch_count == 0:
        #     print("dev phase")
        #     data.set_mode("dev")
        #     acc, dev_avg_ed = test(net, data, data.get_abc(), visualize=True,
        #                            batch_size=config.batch_size, num_workers=config.num_worker)
        #     print("DEV: acc: {}; avg_ed: {}; avg_ed_best: {}".format(acc, dev_avg_ed, dev_avg_ed_best))
        #
        #     data.set_mode("test_annotated")
        #     annotated_acc, annotated_avg_ed = test(net, data, data.get_abc(), visualize=True,
        #                                            batch_size=config.batch_size, num_workers=config.num_worker)
        #     print("ANNOTATED: acc: {}; avg_ed: {}".format(annotated_acc, annotated_avg_ed))

        net = net.train()
        data.set_mode("train")
        data_loader = DataLoader(data,
                                 batch_size=config.batch_size,
                                 num_workers=config.num_worker,
                                 shuffle=True,
                                 collate_fn=text_collate)
        loss_mean = []
        iterator = tqdm(data_loader)
        for sample in iterator:
            optimizer.zero_grad()
            imgs = Variable(sample["img"])
            labels_ocr = Variable(sample["seq"]).view(-1)
            labels_ocr_len = Variable(sample["seq_len"].int())
            labels = Variable(sample["label"].long())
            imgs = imgs.cuda()

            preds, label_logsoftmax = net(imgs)
            preds = preds.cpu()
            label_logsoftmax = label_logsoftmax.cpu()
            pred_lens = Variable(
                Tensor([preds.size(0)] * len(labels_ocr_len)).int())

            # ctc loss len > label_len
            assert preds.size()[0] > max(labels_ocr_len).item()
            loss = loss_function(preds, labels_ocr, pred_lens,
                                 labels_ocr_len) + loss_label(
                                     label_logsoftmax, labels)

            # unit test
            assert not torch.isnan(loss).any()
            assert not torch.isinf(loss).any()
            assert loss.item() != 0
            loss.backward()
            for name, para in net.named_parameters():
                if (para.grad is None or para.grad.equal(
                        torch.zeros_like(para.grad))) and para.requires_grad:
                    print("WARNING: There is no grad at", name)

            nn.utils.clip_grad_norm_(net.parameters(), 10.0)
            loss_mean.append(loss.item())
            optimizer.step()

        print("dev phase")
        data.set_mode("dev")
        acc, dev_avg_ed = test(net,
                               data,
                               data.get_abc(),
                               visualize=True,
                               batch_size=config.batch_size,
                               num_workers=0)

        if dev_avg_ed < dev_avg_ed_best:
            assert config.output_dir is not None
            torch.save(
                net.state_dict(),
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best"))
            print(
                "Saving best model to",
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best"))
            dev_avg_ed_best = dev_avg_ed

        # TODO: print avg_ed & acc in train epoch
        print("train: epoch: {}; loss_mean: {}".format(epoch_count,
                                                       np.mean(loss_mean)))
        print("dev: acc: {}; avg_ed: {}; avg_ed_best: {}".format(
            acc, dev_avg_ed, dev_avg_ed_best))

        data.set_mode("test_annotated")
        annotated_acc, annotated_avg_ed = test(net,
                                               data,
                                               data.get_abc(),
                                               visualize=True,
                                               batch_size=config.batch_size,
                                               num_workers=config.num_worker)
        if annotated_avg_ed < anno_avg_ed_best:
            assert config.output_dir is not None
            torch.save(
                net.state_dict(),
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best_anno"))
            print(
                "Saving best model to",
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best_anno"))
            anno_avg_ed_best = annotated_avg_ed
        print("ANNOTATED: acc: {}; avg_ed: {}, best: {}".format(
            annotated_acc, annotated_avg_ed, anno_avg_ed_best))

        # TODO: add tensorboard to visualize loss_mean & avg_ed & acc
        lr_scheduler.step(dev_avg_ed)
        epoch_count += 1
Example #9
0
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path,
         orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir,
         lexicon_path, seq_proj, backend, snapshot, input_height, base_lr,
         elastic_alpha, elastic_sigma, step_size, max_iter, batch_size,
         output_dir, test_iter, show_iter, test_init, use_gpu,
         use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio,
         vat_epsilon, vat_ip, vat_xi, vat_sign, do_remove_augs, aug_to_remove,
         do_beam_search, dropout_conv, dropout_rnn, dropout_output, do_ema,
         do_gray, do_test_vat, do_test_entropy, do_test_vat_cnn,
         do_test_vat_rnn, ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio,
         rnn_hidden_size, do_lr_step, dataset_name):
    if not do_lr_step and not do_ada_lr:
        raise NotImplementedError(
            'learning rate should be either step or ada.')
    train_data_path = os.path.join(base_data_dir, train_data_path)
    train_base_dir = os.path.join(base_data_dir, train_base_dir)
    synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path)
    synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir)
    orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path)
    orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir)
    lexicon_path = os.path.join(base_data_dir, lexicon_path)

    all_parameters = locals()
    cuda = use_gpu
    #print(train_base_dir)
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        tb_writer = TbSummary(output_dir)
        output_dir = os.path.join(output_dir, 'model')
        os.makedirs(output_dir, exist_ok=True)

    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
    #print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    with open(os.path.join(output_dir, 'params.txt'), 'w') as f:
        f.writelines(str(all_parameters))
    print(all_parameters)
    print('new vat')

    sin_magnitude = 4
    rotate_max_angle = 2
    dataset_info = SynthDataInfo(None, None, None, dataset_name.lower())
    train_fonts = dataset_info.font_names

    all_args = locals()

    allowed_removals = [
        'elastic', 'sine', 'sine_rotate', 'rotation', 'color_aug',
        'color_gaus', 'color_sine'
    ]
    if do_remove_augs and aug_to_remove not in allowed_removals:
        raise Exception('augmentation removal value is not allowed.')

    if do_remove_augs:
        rand_trans = []
        if aug_to_remove == 'elastic':
            print('doing sine transform :)')
            rand_trans.append(OnlySine(sin_magnitude=sin_magnitude))
        elif aug_to_remove in ['sine', 'sine_rotate']:
            print('doing elastic transform :)')
            rand_trans.append(
                OnlyElastic(elastic_alpha=elastic_alpha,
                            elastic_sigma=elastic_sigma))
        if aug_to_remove not in ['elastic', 'sine', 'sine_rotate']:
            print('doing elastic transform :)')
            print('doing sine transform :)')
            rand_trans.append(
                ElasticAndSine(elastic_alpha=elastic_alpha,
                               elastic_sigma=elastic_sigma,
                               sin_magnitude=sin_magnitude))
        if aug_to_remove not in ['rotation', 'sine_rotate']:
            print('doing rotation transform :)')
            rand_trans.append(Rotation(angle=rotate_max_angle, fill_value=255))
        if aug_to_remove not in ['color_aug', 'color_gaus', 'color_sine']:
            print('doing color_aug transform :)')
            rand_trans.append(ColorGradGausNoise())
        elif aug_to_remove == 'color_gaus':
            print('doing color_sine transform :)')
            rand_trans.append(ColorGrad())
        elif aug_to_remove == 'color_sine':
            print('doing color_gaus transform :)')
            rand_trans.append(ColorGausNoise())
    else:
        print('doing all transforms :)')
        rand_trans = [
            ElasticAndSine(elastic_alpha=elastic_alpha,
                           elastic_sigma=elastic_sigma,
                           sin_magnitude=sin_magnitude),
            Rotation(angle=rotate_max_angle, fill_value=255),
            ColorGradGausNoise()
        ]
    if do_gray:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(),
            ToGray(),
            Normalize()
        ]
    else:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(), Normalize()
        ]

    transform_random = Compose(rand_trans)
    if do_gray:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(),
             ToGray(),
             Normalize()])
    else:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(), Normalize()])

    if use_no_font_repeat_data:
        print('creating dataset')
        train_data = TextDatasetRandomFont(data_path=train_data_path,
                                           lexicon=lexicon,
                                           base_path=train_base_dir,
                                           transform=transform_random,
                                           fonts=train_fonts)
        print('finished creating dataset')
    else:
        print('train data path:\n{}'.format(train_data_path))
        print('train_base_dir:\n{}'.format(train_base_dir))
        train_data = TextDataset(data_path=train_data_path,
                                 lexicon=lexicon,
                                 base_path=train_base_dir,
                                 transform=transform_random,
                                 fonts=train_fonts)
    synth_eval_data = TextDataset(data_path=synth_eval_data_path,
                                  lexicon=lexicon,
                                  base_path=synth_eval_base_dir,
                                  transform=transform_random,
                                  fonts=train_fonts)
    orig_eval_data = TextDataset(data_path=orig_eval_data_path,
                                 lexicon=lexicon,
                                 base_path=orig_eval_base_dir,
                                 transform=transform_simple,
                                 fonts=None)
    if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
        orig_vat_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    if ada_after_rnn or ada_before_rnn:
        orig_ada_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    #else:
    #    train_data = TestDataset(transform=transform, abc=abc).set_mode("train")
    #    synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    #    orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(lexicon=train_data.get_lexicon(),
                     seq_proj=seq_proj,
                     backend=backend,
                     snapshot=snapshot,
                     cuda=cuda,
                     do_beam_search=do_beam_search,
                     dropout_conv=dropout_conv,
                     dropout_rnn=dropout_rnn,
                     dropout_output=dropout_output,
                     do_ema=do_ema,
                     ada_after_rnn=ada_after_rnn,
                     ada_before_rnn=ada_before_rnn,
                     rnn_hidden_size=rnn_hidden_size)
    optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001)
    if do_ada_lr:
        print('using ada lr')
        lr_scheduler = DannLR(optimizer, max_iter=max_iter)
    elif do_lr_step:
        print('using step lr')
        lr_scheduler = StepLR(optimizer,
                              step_size=step_size,
                              max_iter=max_iter)
    loss_function = CTCLoss()

    synth_avg_ed_best = float("inf")
    orig_avg_ed_best = float("inf")
    epoch_count = 0

    if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
        collate_vat = lambda x: text_collate(x, do_mask=True)
        vat_load = DataLoader(orig_vat_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_vat)
        vat_len = len(vat_load)
        cur_vat = 0
        vat_iter = iter(vat_load)
    if ada_after_rnn or ada_before_rnn:
        collate_ada = lambda x: text_collate(x, do_mask=True)
        ada_load = DataLoader(orig_ada_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_ada)
        ada_len = len(ada_load)
        cur_ada = 0
        ada_iter = iter(ada_load)

    loss_domain = torch.nn.NLLLoss()

    while True:
        collate = lambda x: text_collate(
            x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn))
        data_loader = DataLoader(train_data,
                                 batch_size=batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 collate_fn=collate)

        loss_mean_ctc = []
        loss_mean_vat = []
        loss_mean_at = []
        loss_mean_comp = []
        loss_mean_total = []
        loss_mean_test_vat = []
        loss_mean_test_pseudo = []
        loss_mean_test_rand = []
        loss_mean_ada_rnn_s = []
        loss_mean_ada_rnn_t = []
        loss_mean_ada_cnn_s = []
        loss_mean_ada_cnn_t = []
        iterator = tqdm(data_loader)
        iter_count = 0
        for iter_num, sample in enumerate(iterator):
            total_iter = (epoch_count * len(data_loader)) + iter_num
            if ((total_iter > 1)
                    and total_iter % test_iter == 0) or (test_init
                                                         and total_iter == 0):
                # epoch_count != 0 and

                print("Test phase")
                net = net.eval()
                if do_ema:
                    net.start_test()

                synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test(
                    net,
                    synth_eval_data,
                    synth_eval_data.get_lexicon(),
                    cuda,
                    visualize=False,
                    dataset_info=dataset_info,
                    batch_size=batch_size,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='val_synth',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=False)

                orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test(
                    net,
                    orig_eval_data,
                    orig_eval_data.get_lexicon(),
                    cuda,
                    visualize=False,
                    dataset_info=dataset_info,
                    batch_size=batch_size,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='test_orig',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=do_beam_search)

                net = net.train()
                #save periodic
                if output_dir is not None and total_iter // 30000:
                    periodic_save = os.path.join(output_dir, 'periodic_save')
                    os.makedirs(periodic_save, exist_ok=True)
                    old_save = glob.glob(os.path.join(periodic_save, '*'))

                    torch.save(
                        net.state_dict(),
                        os.path.join(output_dir, "crnn_" + backend + "_" +
                                     str(total_iter)))

                if orig_avg_no_stop_ed < orig_avg_ed_best:
                    orig_avg_ed_best = orig_avg_no_stop_ed
                    if output_dir is not None:
                        torch.save(
                            net.state_dict(),
                            os.path.join(output_dir,
                                         "crnn_" + backend + "_best"))

                if synth_avg_no_stop_ed < synth_avg_ed_best:
                    synth_avg_ed_best = synth_avg_no_stop_ed
                if do_ema:
                    net.end_test()
                print(
                    "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(synth_avg_ed_best, synth_avg_ed,
                            synth_avg_no_stop_ed, synth_acc))
                print(
                    "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed,
                            orig_acc))
                tb_writer.get_writer().add_scalars(
                    'data/test', {
                        'synth_ed_total': synth_avg_ed,
                        'synth_ed_no_stop': synth_avg_no_stop_ed,
                        'synth_avg_loss': synth_avg_loss,
                        'orig_ed_total': orig_avg_ed,
                        'orig_ed_no_stop': orig_avg_no_stop_ed,
                        'orig_avg_loss': orig_avg_loss
                    }, total_iter)
                if len(loss_mean_ctc) > 0:
                    train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)}
                    if do_vat:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_vat_loss': np.mean(loss_mean_vat)
                            }
                        }
                    if do_at:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_at_loss': np.mean(loss_mean_at)
                            }
                        }
                    if do_test_vat:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_loss': np.mean(loss_mean_test_vat)
                            }
                        }
                    if do_test_vat_rnn and do_test_vat_cnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_crnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    elif do_test_vat_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_rnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    elif do_test_vat_cnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_cnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    if ada_after_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_ada_rnn_s_loss': np.mean(loss_mean_ada_rnn_s),
                                'mean_ada_rnn_t_loss': np.mean(loss_mean_ada_rnn_t)
                            }
                        }
                    if ada_before_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_ada_cnn_s_loss': np.mean(loss_mean_ada_cnn_s),
                                'mean_ada_cnn_t_loss': np.mean(loss_mean_ada_cnn_t)
                            }
                        }
                    print(train_dict)
                    tb_writer.get_writer().add_scalars('data/train',
                                                       train_dict, total_iter)
            '''
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            '''
            optimizer.zero_grad()
            imgs = Variable(sample["img"])
            #print("images sizes are:")
            #print(sample["img"].shape)
            if do_vat or ada_after_rnn or ada_before_rnn:
                mask = sample['mask']
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            #print("image sequence length is:")
            #print(sample["im_seq_len"])
            #print("label sequence length is:")
            #print(sample["seq_len"].view(1,-1))
            img_seq_lens = sample["im_seq_len"]
            if cuda:
                imgs = imgs.cuda()
                if do_vat or ada_after_rnn or ada_before_rnn:
                    mask = mask.cuda()

            if do_ada_lr:
                ada_p = float(iter_count) / max_iter
                lr_scheduler.update(ada_p)

            if ada_before_rnn or ada_after_rnn:
                if not do_ada_lr:
                    ada_p = float(iter_count) / max_iter
                ada_alpha = 2. / (1. + np.exp(-10. * ada_p)) - 1

                if cur_ada >= ada_len:
                    ada_load = DataLoader(orig_ada_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_ada)
                    ada_len = len(ada_load)
                    cur_ada = 0
                    ada_iter = iter(ada_load)
                ada_batch = next(ada_iter)
                cur_ada += 1
                ada_imgs = Variable(ada_batch["img"])
                ada_img_seq_lens = ada_batch["im_seq_len"]
                ada_mask = ada_batch['mask'].byte()
                if cuda:
                    ada_imgs = ada_imgs.cuda()

                _, ada_cnn, ada_rnn = net(ada_imgs,
                                          ada_img_seq_lens,
                                          ada_alpha=ada_alpha,
                                          mask=ada_mask)
                if ada_before_rnn:
                    ada_num_features = ada_cnn.size(0)
                else:
                    ada_num_features = ada_rnn.size(0)
                domain_label = torch.zeros(ada_num_features)
                domain_label = domain_label.long()
                if cuda:
                    domain_label = domain_label.cuda()
                domain_label = Variable(domain_label)

                if ada_before_rnn:
                    err_ada_cnn_t = loss_domain(ada_cnn, domain_label)
                if ada_after_rnn:
                    err_ada_rnn_t = loss_domain(ada_rnn, domain_label)

            if do_test_vat and do_at:
                # test part!
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                test_vat_batch = next(vat_iter)
                cur_vat += 1
                test_vat_mask = test_vat_batch['mask']
                test_vat_imgs = Variable(test_vat_batch["img"])
                test_vat_img_seq_lens = test_vat_batch["im_seq_len"]
                if cuda:
                    test_vat_imgs = test_vat_imgs.cuda()
                    test_vat_mask = test_vat_mask.cuda()
                # train part
                at_test_vat_loss = LabeledAtAndUnlabeledTestVatLoss(
                    xi=vat_xi, eps=vat_epsilon, ip=vat_ip)

                at_loss, test_vat_loss = at_test_vat_loss(
                    model=net,
                    train_x=imgs,
                    train_labels_flatten=labels_flatten,
                    train_img_seq_lens=img_seq_lens,
                    train_label_lens=label_lens,
                    batch_size=batch_size,
                    test_x=test_vat_imgs,
                    test_seq_len=test_vat_img_seq_lens,
                    test_mask=test_vat_mask)
            elif do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                vat_batch = next(vat_iter)
                cur_vat += 1
                vat_mask = vat_batch['mask']
                vat_imgs = Variable(vat_batch["img"])
                vat_img_seq_lens = vat_batch["im_seq_len"]
                if cuda:
                    vat_imgs = vat_imgs.cuda()
                    vat_mask = vat_mask.cuda()
                if do_test_vat:
                    if do_test_vat_rnn or do_test_vat_cnn:
                        raise "can only do one of do_test_vat | (do_test_vat_rnn, do_test_vat_cnn)"
                    if vat_sign == True:
                        test_vat_loss = VATLossSign(
                            do_test_entropy=do_test_entropy,
                            xi=vat_xi,
                            eps=vat_epsilon,
                            ip=vat_ip)
                    else:
                        test_vat_loss = VATLoss(xi=vat_xi,
                                                eps=vat_epsilon,
                                                ip=vat_ip)
                elif do_test_vat_rnn and do_test_vat_cnn:
                    test_vat_loss = VATonRnnCnnSign(xi=vat_xi,
                                                    eps=vat_epsilon,
                                                    ip=vat_ip)
                elif do_test_vat_rnn:
                    test_vat_loss = VATonRnnSign(xi=vat_xi,
                                                 eps=vat_epsilon,
                                                 ip=vat_ip)
                elif do_test_vat_cnn:
                    test_vat_loss = VATonCnnSign(xi=vat_xi,
                                                 eps=vat_epsilon,
                                                 ip=vat_ip)
                if do_test_vat_cnn and do_test_vat_rnn:
                    test_vat_loss, cnn_lds, rnn_lds = test_vat_loss(
                        net, vat_imgs, vat_img_seq_lens, vat_mask)
                elif do_test_vat:
                    test_vat_loss = test_vat_loss(net, vat_imgs,
                                                  vat_img_seq_lens, vat_mask)
            elif do_vat:
                vat_loss = VATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip)
                vat_loss = vat_loss(net, imgs, img_seq_lens, mask)
            elif do_at:
                at_loss = LabeledATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip)
                at_loss = at_loss(net, imgs, labels_flatten, img_seq_lens,
                                  label_lens, batch_size)

            if ada_after_rnn or ada_before_rnn:
                preds, ada_cnn, ada_rnn = net(imgs,
                                              img_seq_lens,
                                              ada_alpha=ada_alpha,
                                              mask=mask)

                if ada_before_rnn:
                    ada_num_features = ada_cnn.size(0)
                else:
                    ada_num_features = ada_rnn.size(0)

                domain_label = torch.ones(ada_num_features)
                domain_label = domain_label.long()
                if cuda:
                    domain_label = domain_label.cuda()
                domain_label = Variable(domain_label)

                if ada_before_rnn:
                    err_ada_cnn_s = loss_domain(ada_cnn, domain_label)
                if ada_after_rnn:
                    err_ada_rnn_s = loss_domain(ada_rnn, domain_label)

            else:
                preds = net(imgs, img_seq_lens)
            '''
            if output_dir is not None:
                if (show_iter is not None and iter_num != 0 and iter_num % show_iter == 0):
                    print_data_visuals(net, tb_writer, train_data.get_lexicon(), sample["img"], labels_flatten, label_lens,
                                       preds, ((epoch_count * len(data_loader)) + iter_num))
            '''
            loss_ctc = loss_function(
                preds, labels_flatten,
                Variable(torch.IntTensor(np.array(img_seq_lens))),
                label_lens) / batch_size

            if loss_ctc.data[0] in [float("inf"), -float("inf")]:
                print("warnning: loss should not be inf.")
                continue
            total_loss = loss_ctc

            if do_vat:
                #mask = sample['mask']
                #if cuda:
                #    mask = mask.cuda()
                #vat_loss = virtual_adversarial_loss(net, imgs, img_seq_lens, mask, is_training=True, do_entropy=False, epsilon=vat_epsilon, num_power_iterations=1,
                #             xi=1e-6, average_loss=True)
                total_loss = total_loss + vat_ratio * vat_loss.cpu()
            if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
                total_loss = total_loss + test_vat_ratio * test_vat_loss.cpu()

            if ada_before_rnn:
                total_loss = total_loss + ada_ratio * err_ada_cnn_s.cpu(
                ) + ada_ratio * err_ada_cnn_t.cpu()
            if ada_after_rnn:
                total_loss = total_loss + ada_ratio * err_ada_rnn_s.cpu(
                ) + ada_ratio * err_ada_rnn_t.cpu()

            total_loss.backward()
            nn.utils.clip_grad_norm(net.parameters(), 10.0)
            if -400 < loss_ctc.data[0] < 400:
                loss_mean_ctc.append(loss_ctc.data[0])
            if -1000 < total_loss.data[0] < 1000:
                loss_mean_total.append(total_loss.data[0])
            if len(loss_mean_total) > 100:
                loss_mean_total = loss_mean_total[-100:]
            status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format(
                epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(),
                np.mean(loss_mean_total), loss_ctc.data[0])
            if ada_after_rnn:
                loss_mean_ada_rnn_s.append(err_ada_rnn_s.data[0])
                loss_mean_ada_rnn_t.append(err_ada_rnn_t.data[0])
                status += "; ladatrnns: {0:.3f}; ladatrnnt: {1:.3f}".format(
                    err_ada_rnn_s.data[0], err_ada_rnn_t.data[0])
            if ada_before_rnn:
                loss_mean_ada_cnn_s.append(err_ada_cnn_s.data[0])
                loss_mean_ada_cnn_t.append(err_ada_cnn_t.data[0])
                status += "; ladatcnns: {0:.3f}; ladatcnnt: {1:.3f}".format(
                    err_ada_cnn_s.data[0], err_ada_cnn_t.data[0])
            if do_vat:
                loss_mean_vat.append(vat_loss.data[0])
                status += "; lvat: {0:.3f}".format(vat_loss.data[0])
            if do_at:
                loss_mean_at.append(at_loss.data[0])
                status += "; lat: {0:.3f}".format(at_loss.data[0])
            if do_test_vat:
                loss_mean_test_vat.append(test_vat_loss.data[0])
                status += "; l_tvat: {0:.3f}".format(test_vat_loss.data[0])
            if do_test_vat_rnn or do_test_vat_cnn:
                loss_mean_test_vat.append(test_vat_loss.data[0])
                if do_test_vat_rnn and do_test_vat_cnn:
                    status += "; l_tvatc: {}".format(cnn_lds.data[0])
                    status += "; l_tvatr: {}".format(rnn_lds.data[0])
                else:
                    status += "; l_tvat: {}".format(test_vat_loss.data[0])

            iterator.set_description(status)
            optimizer.step()
            if do_lr_step:
                lr_scheduler.step()
            if do_ema:
                net.udate_ema()
            iter_count += 1
        if output_dir is not None:
            torch.save(net.state_dict(),
                       os.path.join(output_dir, "crnn_" + backend + "_last"))
        epoch_count += 1

    return
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    seq_proj = [int(x) for x in seq_proj.split('x')]
    
    print(list(glob.glob('./tmp/fold*_best') + glob.glob('./tmp2/fold*_best')))
    fold_pred_pb_tta = []
    # for snapshot in glob.glob('./tmp/fold*_best')[:]:
    
    for snapshot in list(glob.glob('./tmp/fold*_best') + glob.glob('./tmp2/fold*_best'))[:]:
    
#     for snapshot in ['./tmp/fold12_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.997181964573',
#                     './tmp/fold13_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.995571658615',
#                     './tmp/fold3_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.993961352657',
#                     './tmp/fold5_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.994363929147']:
        if np.random.uniform(0.0, 1.0) < 1:
            transform = Compose([
                # Rotation(),
                Translation(),
                # Scale(),
                Contrast(),
                # Grid_distortion(),
                Resize(size=(input_size[0], input_size[1]))
            ])
        else:
            transform = Compose([
                # Rotation(),
                Translation(),
                # Scale(),
                Contrast(),
                # Grid_distortion(),
                Resize(size=(input_size[0], input_size[1]))
            ])
            
        if data_path is not None:
            data = TextDataset(data_path=data_path, mode="pb", transform=transform)
        else:
            data = TestDataset(transform=transform, abc=abc)
        print(snapshot)
        
        net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval()
        acc, avg_ed, pred_pb = test_tta(net, data, data.get_abc(), cuda, visualize)
        fold_pred_pb_tta.append(pred_pb)
    
    with open('../data/desc.json') as up:
        data_json = json.load(up)
    
    fold_pred_pb = []
    if len(fold_pred_pb_tta) > 1:
        for test_idx in range(len(fold_pred_pb_tta[0])):
            test_idx_folds = [fold_pred_pb_tta[i][test_idx] for i in range(len(fold_pred_pb_tta))]

            test_idx_chars = []
            for char_idx in range(10):
                char_tta = [test_idx_folds[i][char_idx] for i in range(len(test_idx_folds)) 
                            if len(test_idx_folds[i]) > char_idx]
#                 if len(char_tta) < len(glob.glob('./tmp/fold*_best'))-2:
#                     print(test_idx, glob.glob('../../input/private_test_data/*')[test_idx])
                
                if len(char_tta) > 0:
                    char_tta = Counter(char_tta).most_common()[0][0]
                else:
                    char_tta = '*'
                    # print(test_idx, glob.glob('../../input/private_test_data/*')[test_idx])

                test_idx_chars += char_tta
            fold_pred_pb.append(''.join(test_idx_chars))
    
        joblib.dump(fold_pred_pb_tta, 'fold_tta.pkl')
        
        df_submit = pd.DataFrame()
        df_submit['name'] = [x['name'] for x in data_json['pb']]
        # print(fold_pred_pb_tta)
        df_submit['label'] = fold_pred_pb
    else:
        df_submit = pd.DataFrame()
        df_submit['name'] = [x['name'] for x in data_json['pb']]
        # print(fold_pred_pb_tta)
        df_submit['label'] = fold_pred_pb_tta[0]
    
    df_submit.to_csv('tmp_rcnn_tta10_pb.csv', index=None)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))