예제 #1
0
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path,
         orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir,
         lexicon_path, seq_proj, backend, snapshot, input_height, base_lr,
         elastic_alpha, elastic_sigma, step_size, max_iter, batch_size,
         output_dir, test_iter, show_iter, test_init, use_gpu,
         use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio,
         vat_epsilon, vat_ip, vat_xi, vat_sign, do_comp, comp_ratio,
         do_remove_augs, aug_to_remove, do_beam_search, dropout_conv,
         dropout_rnn, dropout_output, do_ema, do_gray, do_test_vat,
         do_test_entropy, do_test_vat_cnn, do_test_vat_rnn, do_test_rand,
         ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio, rnn_hidden_size,
         do_test_pseudo, test_pseudo_ratio, test_pseudo_thresh, do_lr_step,
         do_test_ensemble, test_ensemble_ratio, test_ensemble_thresh):
    num_nets = 4

    train_data_path = os.path.join(base_data_dir, train_data_path)
    train_base_dir = os.path.join(base_data_dir, train_base_dir)
    synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path)
    synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir)

    orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path)
    orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir)
    lexicon_path = os.path.join(base_data_dir, lexicon_path)

    all_parameters = locals()
    cuda = use_gpu
    #print(train_base_dir)
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        tb_writer = TbSummary(output_dir)
        output_dir = os.path.join(output_dir, 'model')
        os.makedirs(output_dir, exist_ok=True)

    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
    #print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    with open(os.path.join(output_dir, 'params.txt'), 'w') as f:
        f.writelines(str(all_parameters))
    print(all_parameters)
    print('new vat')

    sin_magnitude = 4
    rotate_max_angle = 2
    train_fonts = [
        'Qomolangma-Betsu', 'Shangshung Sgoba-KhraChen',
        'Shangshung Sgoba-KhraChung', 'Qomolangma-Drutsa'
    ]

    all_args = locals()

    print('doing all transforms :)')
    rand_trans = [
        ElasticAndSine(elastic_alpha=elastic_alpha,
                       elastic_sigma=elastic_sigma,
                       sin_magnitude=sin_magnitude),
        Rotation(angle=rotate_max_angle, fill_value=255),
        ColorGradGausNoise()
    ]
    if do_gray:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(),
            ToGray(),
            Normalize()
        ]
    else:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(), Normalize()
        ]

    transform_random = Compose(rand_trans)
    if do_gray:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(),
             ToGray(),
             Normalize()])
    else:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(), Normalize()])

    if use_no_font_repeat_data:
        print('create dataset')
        train_data = TextDatasetRandomFont(data_path=train_data_path,
                                           lexicon=lexicon,
                                           base_path=train_base_dir,
                                           transform=transform_random,
                                           fonts=train_fonts)
        print('finished creating dataset')
    else:
        print('train data path:\n{}'.format(train_data_path))
        print('train_base_dir:\n{}'.format(train_base_dir))
        train_data = TextDataset(data_path=train_data_path,
                                 lexicon=lexicon,
                                 base_path=train_base_dir,
                                 transform=transform_random,
                                 fonts=train_fonts)
    synth_eval_data = TextDataset(data_path=synth_eval_data_path,
                                  lexicon=lexicon,
                                  base_path=synth_eval_base_dir,
                                  transform=transform_random,
                                  fonts=train_fonts)
    orig_eval_data = TextDataset(data_path=orig_eval_data_path,
                                 lexicon=lexicon,
                                 base_path=orig_eval_base_dir,
                                 transform=transform_simple,
                                 fonts=None)
    if do_test_ensemble:
        orig_vat_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    #else:
    #    train_data = TestDataset(transform=transform, abc=abc).set_mode("train")
    #    synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    #    orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    seq_proj = [int(x) for x in seq_proj.split('x')]
    nets = []
    optimizers = []
    lr_schedulers = []
    for neti in range(num_nets):
        nets.append(
            load_model(lexicon=train_data.get_lexicon(),
                       seq_proj=seq_proj,
                       backend=backend,
                       snapshot=snapshot,
                       cuda=cuda,
                       do_beam_search=do_beam_search,
                       dropout_conv=dropout_conv,
                       dropout_rnn=dropout_rnn,
                       dropout_output=dropout_output,
                       do_ema=do_ema,
                       ada_after_rnn=ada_after_rnn,
                       ada_before_rnn=ada_before_rnn,
                       rnn_hidden_size=rnn_hidden_size,
                       gpu=neti))
        optimizers.append(
            optim.Adam(nets[neti].parameters(),
                       lr=base_lr,
                       weight_decay=0.0001))
        lr_schedulers.append(
            StepLR(optimizers[neti], step_size=step_size, max_iter=max_iter))
    loss_function = CTCLoss()

    synth_avg_ed_best = float("inf")
    orig_avg_ed_best = float("inf")
    epoch_count = 0

    if do_test_ensemble:
        collate_vat = lambda x: text_collate(x, do_mask=True)
        vat_load = DataLoader(orig_vat_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_vat)
        vat_len = len(vat_load)
        cur_vat = 0
        vat_iter = iter(vat_load)

    loss_domain = torch.nn.NLLLoss()

    while True:
        collate = lambda x: text_collate(
            x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn))
        data_loader = DataLoader(train_data,
                                 batch_size=batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 collate_fn=collate)
        if do_comp:
            data_loader_comp = DataLoader(train_data_comp,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_comp)
            iter_comp = iter(data_loader_comp)

        loss_mean_ctc = []
        loss_mean_total = []
        loss_mean_test_ensemble = []
        num_labels_used_total = 0
        iterator = tqdm(data_loader)
        nll_loss = torch.nn.NLLLoss()
        iter_count = 0
        for iter_num, sample in enumerate(iterator):
            total_iter = (epoch_count * len(data_loader)) + iter_num
            if ((total_iter > 1)
                    and total_iter % test_iter == 0) or (test_init
                                                         and total_iter == 0):
                # epoch_count != 0 and

                print("Test phase")
                for net in nets:
                    net = net.eval()
                    if do_ema:
                        net.start_test()

                synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test(
                    nets,
                    synth_eval_data,
                    synth_eval_data.get_lexicon(),
                    cuda,
                    batch_size=batch_size,
                    visualize=False,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='val_synth',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=False)

                orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test(
                    nets,
                    orig_eval_data,
                    orig_eval_data.get_lexicon(),
                    cuda,
                    batch_size=batch_size,
                    visualize=False,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='test_orig',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=do_beam_search)

                for net in nets:
                    net = net.train()
                #save periodic
                if output_dir is not None and total_iter // 30000:
                    periodic_save = os.path.join(output_dir, 'periodic_save')
                    os.makedirs(periodic_save, exist_ok=True)
                    old_save = glob.glob(os.path.join(periodic_save, '*'))
                    for neti, net in enumerate(nets):
                        torch.save(
                            net.state_dict(),
                            os.path.join(
                                output_dir, "crnn_{}_".format(neti) + backend +
                                "_" + str(total_iter)))

                if orig_avg_no_stop_ed < orig_avg_ed_best:
                    orig_avg_ed_best = orig_avg_no_stop_ed
                if output_dir is not None:
                    for neti, net in enumerate(nets):
                        torch.save(
                            net.state_dict(),
                            os.path.join(
                                output_dir, "crnn_{}_".format(neti) + backend +
                                "_iter_{}".format(total_iter)))

                if synth_avg_no_stop_ed < synth_avg_ed_best:
                    synth_avg_ed_best = synth_avg_no_stop_ed
                if do_ema:
                    for net in nets:
                        net.end_test()
                print(
                    "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(synth_avg_ed_best, synth_avg_ed,
                            synth_avg_no_stop_ed, synth_acc))
                print(
                    "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed,
                            orig_acc))
                tb_writer.get_writer().add_scalars(
                    'data/test', {
                        'synth_ed_total': synth_avg_ed,
                        'synth_ed_no_stop': synth_avg_no_stop_ed,
                        'synth_avg_loss': synth_avg_loss,
                        'orig_ed_total': orig_avg_ed,
                        'orig_ed_no_stop': orig_avg_no_stop_ed,
                        'orig_avg_loss': orig_avg_loss
                    }, total_iter)
                if len(loss_mean_ctc) > 0:
                    train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)}
                    train_dict = {
                        **train_dict,
                        **{
                            'mean_test_ensemble_loss':
                            np.mean(loss_mean_test_ensemble)
                        }
                    }
                    train_dict = {
                        **train_dict,
                        **{
                            'num_labels_used': num_labels_used_total
                        }
                    }
                    num_labels_used_total = 0
                    print(train_dict)
                    tb_writer.get_writer().add_scalars('data/train',
                                                       train_dict, total_iter)
            '''
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            '''
            for optimizer in optimizers:
                optimizer.zero_grad()
            imgs = Variable(sample["img"])
            #print("images sizes are:")
            #print(sample["img"].shape)
            if do_vat or ada_after_rnn or ada_before_rnn:
                mask = sample['mask']
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            #print("image sequence length is:")
            #print(sample["im_seq_len"])
            #print("label sequence length is:")
            #print(sample["seq_len"].view(1,-1))
            img_seq_lens = sample["im_seq_len"]

            if do_test_ensemble:
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                vat_batch = next(vat_iter)
                cur_vat += 1
                vat_mask = vat_batch['mask']
                vat_imgs = Variable(vat_batch["img"])
                vat_img_seq_lens = vat_batch["im_seq_len"]
                all_net_classes = []
                all_net_preds = []

                def run_net_get_classes(neti_net_pair, cur_vat_imgs,
                                        cur_vat_mask, cur_vat_img_seq_lens,
                                        cuda):
                    neti, net = neti_net_pair
                    if cuda:
                        cur_vat_imgs = cur_vat_imgs.cuda(neti)
                        cur_vat_mask = cur_vat_mask.cuda(neti)
                    vat_pred = net.vat_forward(cur_vat_imgs,
                                               cur_vat_img_seq_lens)
                    vat_pred = vat_pred * cur_vat_mask
                    vat_pred = F.softmax(vat_pred,
                                         dim=2).view(-1,
                                                     vat_pred.size()[-1])
                    all_net_preds.append(vat_pred)
                    np_vat_preds = vat_pred.cpu().data.numpy()
                    classes_by_index = np.argmax(np_vat_preds, axis=1)
                    return classes_by_index

                for neti, net in enumerate(nets):
                    if cuda:
                        vat_imgs = vat_imgs.cuda(neti)
                        vat_mask = vat_mask.cuda(neti)
                    vat_pred = net.vat_forward(vat_imgs, vat_img_seq_lens)
                    vat_pred = vat_pred * vat_mask
                    vat_pred = F.softmax(vat_pred,
                                         dim=2).view(-1,
                                                     vat_pred.size()[-1])
                    all_net_preds.append(vat_pred)
                    np_vat_preds = vat_pred.cpu().data.numpy()
                    classes_by_index = np.argmax(np_vat_preds, axis=1)
                    all_net_classes.append(classes_by_index)
                all_net_classes = np.stack(all_net_classes)
                all_net_classes, all_nets_count = stats.mode(all_net_classes,
                                                             axis=0)
                all_net_classes = all_net_classes.reshape(-1)
                all_nets_count = all_nets_count.reshape(-1)
                ens_indices = np.argwhere(
                    all_nets_count > test_ensemble_thresh)
                ens_indices = ens_indices.reshape(-1)
                ens_classes = all_net_classes[
                    all_nets_count > test_ensemble_thresh]
                net_ens_losses = []
                num_labels_used = len(ens_indices)
                for neti, net in enumerate(nets):
                    indices = Variable(
                        torch.from_numpy(ens_indices).cuda(neti))
                    labels = Variable(torch.from_numpy(ens_classes).cuda(neti))
                    net_preds_to_ens = all_net_preds[neti][indices]
                    loss = nll_loss(net_preds_to_ens, labels)
                    net_ens_losses.append(loss.cpu())
            nets_total_losses = []
            nets_ctc_losses = []
            loss_is_inf = False
            for neti, net in enumerate(nets):
                if cuda:
                    imgs = imgs.cuda(neti)
                preds = net(imgs, img_seq_lens)
                loss_ctc = loss_function(
                    preds, labels_flatten,
                    Variable(torch.IntTensor(np.array(img_seq_lens))),
                    label_lens) / batch_size

                if loss_ctc.data[0] in [float("inf"), -float("inf")]:
                    print("warnning: loss should not be inf.")
                    loss_is_inf = True
                    break
                total_loss = loss_ctc

                if do_test_ensemble:
                    total_loss = total_loss + test_ensemble_ratio * net_ens_losses[
                        neti]
                    net_ens_losses[neti] = net_ens_losses[neti].data[0]
                total_loss.backward()
                nets_total_losses.append(total_loss.data[0])
                nets_ctc_losses.append(loss_ctc.data[0])
                nn.utils.clip_grad_norm(net.parameters(), 10.0)
            if loss_is_inf:
                continue
            if -400 < loss_ctc.data[0] < 400:
                loss_mean_ctc.append(np.mean(nets_ctc_losses))
            if -400 < total_loss.data[0] < 400:
                loss_mean_total.append(np.mean(nets_total_losses))
            status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format(
                epoch_count, lr_schedulers[0].last_iter,
                lr_schedulers[0].get_lr(), np.mean(nets_total_losses),
                np.mean(nets_ctc_losses))

            if do_test_ensemble:
                ens_loss = np.mean(net_ens_losses)
                if ens_loss != 0:
                    loss_mean_test_ensemble.append(ens_loss)
                    status += "; loss_ens: {0:.3f}".format(ens_loss)
                    status += "; num_ens_used {}".format(num_labels_used)
                else:
                    loss_mean_test_ensemble.append(0)
                    status += "; loss_ens: {}".format(0)
            iterator.set_description(status)
            for optimizer in optimizers:
                optimizer.step()
            if do_lr_step:
                for lr_scheduler in lr_schedulers:
                    lr_scheduler.step()
            iter_count += 1
        if output_dir is not None:
            for neti, net in enumerate(nets):
                torch.save(
                    net.state_dict(),
                    os.path.join(output_dir,
                                 "crnn_{}_".format(neti) + backend + "_last"))
        epoch_count += 1

    return
예제 #2
0
def test_attn(net,
              data,
              abc,
              cuda,
              visualize,
              batch_size=1,
              tb_writer=None,
              n_iter=0,
              initial_title="",
              is_trian=True,
              output_path=None):
    collate = lambda x: text_collate(x, do_mask=True)
    net.eval()
    data_loader = DataLoader(data,
                             batch_size=1,
                             num_workers=2,
                             shuffle=False,
                             collate_fn=collate)
    stop_characters = ['-', '.', '༎', '༑', '།', '་']
    count = 0
    tp = 0
    avg_ed = 0
    avg_no_stop_ed = 0
    avg_loss = 0
    min_ed = 1000
    iterator = tqdm(data_loader)
    all_pred_text = all_label_text = all_im_pathes = []
    test_letter_statistics = Statistics()
    with torch.no_grad():
        for i, sample in enumerate(iterator):
            if is_trian and (i > 1000):
                break
            imgs = Variable(sample["img"])
            mask = sample["mask"]
            padded_labels = sample["padded_seq"]
            if cuda:
                imgs = imgs.cuda()
                mask = mask.cuda()
                padded_labels = padded_labels.cuda()

            img_seq_lens = sample["im_seq_len"]

            # Forward propagation
            decoder_outputs, decoder_hidden, other = net(
                imgs, img_seq_lens, mask, None, teacher_forcing_ratio=0)

            # Get loss
            loss = NLLLoss()
            loss.reset()
            zero_labels = torch.zeros_like(padded_labels[:, 1])
            max_label_size = padded_labels.size(1)
            for step, step_output in enumerate(decoder_outputs):
                batch_size = padded_labels.size(0)
                if (step + 1) < max_label_size:
                    loss.eval_batch(
                        step_output.contiguous().view(batch_size, -1),
                        padded_labels[:, step + 1])
                else:
                    loss.eval_batch(
                        step_output.contiguous().view(batch_size, -1),
                        zero_labels)
            # Backward propagation
            total_loss = loss.get_loss().data[0]
            avg_loss += total_loss
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            preds_text = net.predict(other)
            padded_labels = (sample["padded_seq"].numpy()).tolist()
            lens = sample["seq_len"].numpy().tolist()
            label_text = net.padded_seq_to_txt(padded_labels, lens)
            if output_path is not None:

                all_pred_text = all_pred_text + [
                    pd + '\n' for pd in preds_text
                ]
                all_label_text = all_label_text + [
                    lb + '\n' for lb in label_text
                ]
                all_im_pathes.append(
                    sample["im_path"] +
                    '\n')  #[imp +'\n' for imp in sample["im_path"]]

            if i == 0:
                if tb_writer is not None:
                    tb_writer.show_images(
                        sample["img"],
                        label_text=[lb + '\n' for lb in label_text],
                        pred_text=[pd + '\n' for pd in preds_text],
                        n_iter=n_iter,
                        initial_title=initial_title)

            pos = 0
            key = ''
            for i in range(len(label_text)):
                cur_out_no_stops = ''.join(c for c in label_text[i]
                                           if not c in stop_characters)
                cur_gts_no_stops = ''.join(c for c in preds_text[i]
                                           if not c in stop_characters)
                cur_ed = editdistance.eval(preds_text[i], label_text[i]) / max(
                    len(preds_text[i]), len(label_text[i]))
                errors, matches, bp = my_edit_distance_backpointer(
                    cur_out_no_stops, cur_gts_no_stops)
                test_letter_statistics.add_data(bp)
                my_no_stop_ed = errors / max(len(cur_out_no_stops),
                                             len(cur_gts_no_stops))
                cur_no_stop_ed = editdistance.eval(
                    cur_out_no_stops, cur_gts_no_stops) / max(
                        len(cur_out_no_stops), len(cur_gts_no_stops))

                if my_no_stop_ed != cur_no_stop_ed:
                    print('old ed: {} , vs. new ed: {}\n'.format(
                        my_no_stop_ed, cur_no_stop_ed))
                avg_no_stop_ed += cur_no_stop_ed
                avg_ed += cur_ed
                if cur_ed < min_ed: min_ed = cur_ed

                count += 1
                if visualize:
                    status = "pred: {}; gt: {}".format(preds_text[i],
                                                       label_text[i])
                    iterator.set_description(status)
                    img = imgs[i].permute(1, 2, 0).cpu().data.numpy().astype(
                        np.uint8)
                    cv2.imshow("img", img)
                    key = chr(cv2.waitKey() & 255)
                    if key == 'q':
                        break
            if key == 'q':
                break
            if not visualize:
                iterator.set_description(
                    "acc: {0:.4f}; avg_ed: {0:.4f}".format(
                        float(tp) / float(count),
                        float(avg_ed) / float(count)))
    with open(
            output_path +
            '_{}_{}_statistics.pkl'.format(initial_title, n_iter), 'wb') as sf:

        pkl.dump(test_letter_statistics.total_actions_hists, sf)

    if output_path is not None:
        os.makedirs(output_path, exist_ok=True)
        print('writing output')
        with open(
                output_path + '_{}_{}_pred.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines(all_pred_text)
        with open(
                output_path + '_{}_{}_label.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines(all_label_text)
        with open(output_path + '_{}_{}_im.txt'.format(initial_title, n_iter),
                  'w') as fp:
            fp.writelines(all_im_pathes)
        stop_characters = ['-', '.', '༎', '༑', '།', '་']

        all_pred_text = [
            ''.join(c for c in line if not c in stop_characters)
            for line in all_pred_text
        ]
        with open(
                output_path +
                '_{}_{}_pred_no_stopchars.txt'.format(initial_title, n_iter),
                'w') as rf:
            rf.writelines(all_pred_text)
        all_label_text = [
            ''.join(c for c in line if not c in stop_characters)
            for line in all_label_text
        ]
        with open(
                output_path +
                '_{}_{}_label_no_stopchars.txt'.format(initial_title, n_iter),
                'w') as rf:
            rf.writelines(all_label_text)
    acc = float(tp) / float(count)
    avg_ed = float(avg_ed) / float(count)
    avg_no_stop_ed = float(avg_no_stop_ed) / float(count)
    avg_loss = float(avg_loss) / float(count)
    return acc, avg_ed, avg_no_stop_ed, avg_loss
예제 #3
0
def test(net,
         data,
         abc,
         cuda,
         visualize,
         batch_size=1,
         tb_writer=None,
         n_iter=0,
         initial_title="",
         loss_function=None,
         is_trian=True,
         output_path=None,
         do_beam_search=False,
         do_results=False,
         word_lexicon=None):
    collate = lambda x: text_collate(x, do_mask=False)
    data_loader = DataLoader(data,
                             batch_size=1,
                             num_workers=2,
                             shuffle=False,
                             collate_fn=collate)
    stop_characters = ['-', '.', '༎', '༑', '།', '་']
    garbage = '-'
    count = 0
    tp = 0
    avg_ed = 0
    avg_no_stop_ed = 0
    avg_accuracy = 0
    avg_loss = 0
    min_ed = 1000
    iterator = tqdm(data_loader)
    all_pred_text = all_label_text = all_im_pathes = []
    test_letter_statistics = Statistics()
    im_by_error = {}

    for i, sample in enumerate(iterator):
        if is_trian and (i > 500):
            break
        imgs = Variable(sample["img"])
        if cuda:
            imgs = imgs.cuda()
        img_seq_lens = sample["im_seq_len"]
        out, orig_seq = net(imgs,
                            img_seq_lens,
                            decode=True,
                            do_beam_search=do_beam_search)
        if loss_function is not None:
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            loss = loss_function(
                orig_seq, labels_flatten,
                Variable(torch.IntTensor(np.array(img_seq_lens))),
                label_lens) / batch_size
            avg_loss += loss.data[0]
        gt = (sample["seq"].numpy()).tolist()
        lens = sample["seq_len"].numpy().tolist()
        labels_flatten = Variable(sample["seq"]).view(-1)
        label_lens = Variable(sample["seq_len"].int())
        if output_path is not None:
            preds_text = net.decode(orig_seq, data.get_lexicon())
            all_pred_text = all_pred_text + [
                ''.join(c for c in pd if c != garbage) + '\n'
                for pd in preds_text
            ]

            label_text = net.decode_flatten(labels_flatten, label_lens,
                                            data.get_lexicon())
            all_label_text = all_label_text + [lb + '\n' for lb in label_text]

            all_im_pathes.append(
                sample["im_path"] +
                '\n')  #[imp +'\n' for imp in sample["im_path"]]

        if i == 0:
            if tb_writer is not None:
                print_data_visuals(net, tb_writer, data.get_lexicon(),
                                   sample["img"], labels_flatten, label_lens,
                                   orig_seq, n_iter, initial_title)

        pos = 0
        key = ''
        for i in range(len(out)):
            gts = ''.join(abc[c] for c in gt[pos:pos + lens[i]])

            pos += lens[i]
            if gts == out[i]:
                tp += 1
            else:
                cur_out = ''.join(c for c in out[i] if c != garbage)
                cur_gts = ''.join(c for c in gts if c != garbage)
                cur_out_no_stops = ''.join(c for c in out[i]
                                           if not c in stop_characters)
                cur_gts_no_stops = ''.join(c for c in gts
                                           if not c in stop_characters)
                cur_ed = editdistance.eval(cur_out, cur_gts) / len(cur_gts)
                if word_lexicon is not None:
                    closest_word = get_close_matches(cur_out,
                                                     word_lexicon,
                                                     n=1,
                                                     cutoff=0.2)
                else:
                    closest_word = cur_out

                if len(closest_word) > 0 and closest_word[0] == cur_gts:
                    avg_accuracy += 1

                errors, matches, bp = my_edit_distance_backpointer(
                    cur_out_no_stops, cur_gts_no_stops)
                test_letter_statistics.add_data(bp)
                #my_no_stop_ed = errors / max(len(cur_out_no_stops), len(cur_gts_no_stops))
                #cur_no_stop_ed = editdistance.eval(cur_out_no_stops, cur_gts_no_stops) / max(len(cur_out_no_stops), len(cur_gts_no_stops))
                if do_results:
                    im_by_error[sample["im_path"]] = cur_ed
                my_no_stop_ed = errors / len(cur_gts_no_stops)
                cur_no_stop_ed = editdistance.eval(
                    cur_out_no_stops, cur_gts_no_stops) / len(cur_gts_no_stops)

                if my_no_stop_ed != cur_no_stop_ed:
                    print('old ed: {} , vs. new ed: {}\n'.format(
                        my_no_stop_ed, cur_no_stop_ed))
                avg_no_stop_ed += cur_no_stop_ed
                avg_ed += cur_ed
                if cur_ed < min_ed: min_ed = cur_ed
            count += 1
            if visualize:
                status = "pred: {}; gt: {}".format(out[i], gts)
                iterator.set_description(status)
                img = imgs[i].permute(1, 2,
                                      0).cpu().data.numpy().astype(np.uint8)
                cv2.imshow("img", img)
                key = chr(cv2.waitKey() & 255)
                if key == 'q':
                    break

        #if not visualize:
        #    iterator.set_description("acc: {0:.4f}; avg_ed: {0:.4f}".format(
        #        float(tp) / float(count), float(avg_ed) / float(count)))
    #with open(output_path + '_{}_{}_statistics.pkl'.format(initial_title,n_iter), 'wb') as sf:

    #    pkl.dump(test_letter_statistics.total_actions_hists, sf)

    if do_results and output_path is not None:
        print('printing results! :)')
        sorted_im_by_error = sorted(im_by_error.items(),
                                    key=operator.itemgetter(1))
        sorted_im = [key for (key, value) in sorted_im_by_error]
        all_im_pathes_no_new_line = [
            im.replace('\n', '') for im in all_im_pathes
        ]
        printed_res_best = ""
        printed_res_worst = ""
        for im in sorted_im[:20]:
            im_id = all_im_pathes_no_new_line.index(im)
            pred = all_pred_text[im_id]
            label = all_label_text[im_id]
            printed_res_best += im + '\n' + label + pred

        for im in list(reversed(sorted_im))[:20]:
            im_id = all_im_pathes_no_new_line.index(im)
            pred = all_pred_text[im_id]
            label = all_label_text[im_id]
            printed_res_worst += im + '\n' + label + pred

        with open(
                output_path + '_{}_{}_sorted_images_by_errors.txt'.format(
                    initial_title, n_iter), 'w') as fp:
            fp.writelines([
                key + ',' + str(value) + '\n'
                for (key, value) in sorted_im_by_error
            ])

        with open(
                output_path +
                '_{}_{}_res_on_best.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines([printed_res_best])
            with open(
                    output_path +
                    '_{}_{}_res_on_worst.txt'.format(initial_title, n_iter),
                    'w') as fp:
                fp.writelines([printed_res_worst])
        os.makedirs(output_path, exist_ok=True)
        with open(
                output_path + '_{}_{}_pred.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines(all_pred_text)
        with open(
                output_path + '_{}_{}_label.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines(all_label_text)
        with open(output_path + '_{}_{}_im.txt'.format(initial_title, n_iter),
                  'w') as fp:
            fp.writelines(all_im_pathes)
        stop_characters = ['-', '.', '༎', '༑', '།', '་']

        all_pred_text = [
            ''.join(c for c in line if not c in stop_characters)
            for line in all_pred_text
        ]
        with open(
                output_path +
                '_{}_{}_pred_no_stopchars.txt'.format(initial_title, n_iter),
                'w') as rf:
            rf.writelines(all_pred_text)
        all_label_text = [
            ''.join(c for c in line if not c in stop_characters)
            for line in all_label_text
        ]
        with open(
                output_path +
                '_{}_{}_label_no_stopchars.txt'.format(initial_title, n_iter),
                'w') as rf:
            rf.writelines(all_label_text)

    acc = float(avg_accuracy) / float(count)
    avg_ed = float(avg_ed) / float(count)
    avg_no_stop_ed = float(avg_no_stop_ed) / float(count)
    if loss_function is not None:
        avg_loss = float(avg_loss) / float(count)
        return acc, avg_ed, avg_no_stop_ed, avg_loss
    return acc, avg_ed, avg_no_stop_ed
예제 #4
0
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path,
         orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir,
         lexicon_path, seq_proj, backend, snapshot, input_height, base_lr,
         elastic_alpha, elastic_sigma, step_size, max_iter, batch_size,
         output_dir, test_iter, show_iter, test_init, use_gpu,
         use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio,
         vat_epsilon, vat_ip, vat_xi, vat_sign, do_remove_augs, aug_to_remove,
         do_beam_search, dropout_conv, dropout_rnn, dropout_output, do_ema,
         do_gray, do_test_vat, do_test_entropy, do_test_vat_cnn,
         do_test_vat_rnn, ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio,
         rnn_hidden_size, do_lr_step, dataset_name):
    if not do_lr_step and not do_ada_lr:
        raise NotImplementedError(
            'learning rate should be either step or ada.')
    train_data_path = os.path.join(base_data_dir, train_data_path)
    train_base_dir = os.path.join(base_data_dir, train_base_dir)
    synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path)
    synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir)
    orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path)
    orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir)
    lexicon_path = os.path.join(base_data_dir, lexicon_path)

    all_parameters = locals()
    cuda = use_gpu
    #print(train_base_dir)
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        tb_writer = TbSummary(output_dir)
        output_dir = os.path.join(output_dir, 'model')
        os.makedirs(output_dir, exist_ok=True)

    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
    #print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    with open(os.path.join(output_dir, 'params.txt'), 'w') as f:
        f.writelines(str(all_parameters))
    print(all_parameters)
    print('new vat')

    sin_magnitude = 4
    rotate_max_angle = 2
    dataset_info = SynthDataInfo(None, None, None, dataset_name.lower())
    train_fonts = dataset_info.font_names

    all_args = locals()

    allowed_removals = [
        'elastic', 'sine', 'sine_rotate', 'rotation', 'color_aug',
        'color_gaus', 'color_sine'
    ]
    if do_remove_augs and aug_to_remove not in allowed_removals:
        raise Exception('augmentation removal value is not allowed.')

    if do_remove_augs:
        rand_trans = []
        if aug_to_remove == 'elastic':
            print('doing sine transform :)')
            rand_trans.append(OnlySine(sin_magnitude=sin_magnitude))
        elif aug_to_remove in ['sine', 'sine_rotate']:
            print('doing elastic transform :)')
            rand_trans.append(
                OnlyElastic(elastic_alpha=elastic_alpha,
                            elastic_sigma=elastic_sigma))
        if aug_to_remove not in ['elastic', 'sine', 'sine_rotate']:
            print('doing elastic transform :)')
            print('doing sine transform :)')
            rand_trans.append(
                ElasticAndSine(elastic_alpha=elastic_alpha,
                               elastic_sigma=elastic_sigma,
                               sin_magnitude=sin_magnitude))
        if aug_to_remove not in ['rotation', 'sine_rotate']:
            print('doing rotation transform :)')
            rand_trans.append(Rotation(angle=rotate_max_angle, fill_value=255))
        if aug_to_remove not in ['color_aug', 'color_gaus', 'color_sine']:
            print('doing color_aug transform :)')
            rand_trans.append(ColorGradGausNoise())
        elif aug_to_remove == 'color_gaus':
            print('doing color_sine transform :)')
            rand_trans.append(ColorGrad())
        elif aug_to_remove == 'color_sine':
            print('doing color_gaus transform :)')
            rand_trans.append(ColorGausNoise())
    else:
        print('doing all transforms :)')
        rand_trans = [
            ElasticAndSine(elastic_alpha=elastic_alpha,
                           elastic_sigma=elastic_sigma,
                           sin_magnitude=sin_magnitude),
            Rotation(angle=rotate_max_angle, fill_value=255),
            ColorGradGausNoise()
        ]
    if do_gray:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(),
            ToGray(),
            Normalize()
        ]
    else:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(), Normalize()
        ]

    transform_random = Compose(rand_trans)
    if do_gray:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(),
             ToGray(),
             Normalize()])
    else:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(), Normalize()])

    if use_no_font_repeat_data:
        print('creating dataset')
        train_data = TextDatasetRandomFont(data_path=train_data_path,
                                           lexicon=lexicon,
                                           base_path=train_base_dir,
                                           transform=transform_random,
                                           fonts=train_fonts)
        print('finished creating dataset')
    else:
        print('train data path:\n{}'.format(train_data_path))
        print('train_base_dir:\n{}'.format(train_base_dir))
        train_data = TextDataset(data_path=train_data_path,
                                 lexicon=lexicon,
                                 base_path=train_base_dir,
                                 transform=transform_random,
                                 fonts=train_fonts)
    synth_eval_data = TextDataset(data_path=synth_eval_data_path,
                                  lexicon=lexicon,
                                  base_path=synth_eval_base_dir,
                                  transform=transform_random,
                                  fonts=train_fonts)
    orig_eval_data = TextDataset(data_path=orig_eval_data_path,
                                 lexicon=lexicon,
                                 base_path=orig_eval_base_dir,
                                 transform=transform_simple,
                                 fonts=None)
    if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
        orig_vat_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    if ada_after_rnn or ada_before_rnn:
        orig_ada_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    #else:
    #    train_data = TestDataset(transform=transform, abc=abc).set_mode("train")
    #    synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    #    orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(lexicon=train_data.get_lexicon(),
                     seq_proj=seq_proj,
                     backend=backend,
                     snapshot=snapshot,
                     cuda=cuda,
                     do_beam_search=do_beam_search,
                     dropout_conv=dropout_conv,
                     dropout_rnn=dropout_rnn,
                     dropout_output=dropout_output,
                     do_ema=do_ema,
                     ada_after_rnn=ada_after_rnn,
                     ada_before_rnn=ada_before_rnn,
                     rnn_hidden_size=rnn_hidden_size)
    optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001)
    if do_ada_lr:
        print('using ada lr')
        lr_scheduler = DannLR(optimizer, max_iter=max_iter)
    elif do_lr_step:
        print('using step lr')
        lr_scheduler = StepLR(optimizer,
                              step_size=step_size,
                              max_iter=max_iter)
    loss_function = CTCLoss()

    synth_avg_ed_best = float("inf")
    orig_avg_ed_best = float("inf")
    epoch_count = 0

    if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
        collate_vat = lambda x: text_collate(x, do_mask=True)
        vat_load = DataLoader(orig_vat_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_vat)
        vat_len = len(vat_load)
        cur_vat = 0
        vat_iter = iter(vat_load)
    if ada_after_rnn or ada_before_rnn:
        collate_ada = lambda x: text_collate(x, do_mask=True)
        ada_load = DataLoader(orig_ada_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_ada)
        ada_len = len(ada_load)
        cur_ada = 0
        ada_iter = iter(ada_load)

    loss_domain = torch.nn.NLLLoss()

    while True:
        collate = lambda x: text_collate(
            x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn))
        data_loader = DataLoader(train_data,
                                 batch_size=batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 collate_fn=collate)

        loss_mean_ctc = []
        loss_mean_vat = []
        loss_mean_at = []
        loss_mean_comp = []
        loss_mean_total = []
        loss_mean_test_vat = []
        loss_mean_test_pseudo = []
        loss_mean_test_rand = []
        loss_mean_ada_rnn_s = []
        loss_mean_ada_rnn_t = []
        loss_mean_ada_cnn_s = []
        loss_mean_ada_cnn_t = []
        iterator = tqdm(data_loader)
        iter_count = 0
        for iter_num, sample in enumerate(iterator):
            total_iter = (epoch_count * len(data_loader)) + iter_num
            if ((total_iter > 1)
                    and total_iter % test_iter == 0) or (test_init
                                                         and total_iter == 0):
                # epoch_count != 0 and

                print("Test phase")
                net = net.eval()
                if do_ema:
                    net.start_test()

                synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test(
                    net,
                    synth_eval_data,
                    synth_eval_data.get_lexicon(),
                    cuda,
                    visualize=False,
                    dataset_info=dataset_info,
                    batch_size=batch_size,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='val_synth',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=False)

                orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test(
                    net,
                    orig_eval_data,
                    orig_eval_data.get_lexicon(),
                    cuda,
                    visualize=False,
                    dataset_info=dataset_info,
                    batch_size=batch_size,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='test_orig',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=do_beam_search)

                net = net.train()
                #save periodic
                if output_dir is not None and total_iter // 30000:
                    periodic_save = os.path.join(output_dir, 'periodic_save')
                    os.makedirs(periodic_save, exist_ok=True)
                    old_save = glob.glob(os.path.join(periodic_save, '*'))

                    torch.save(
                        net.state_dict(),
                        os.path.join(output_dir, "crnn_" + backend + "_" +
                                     str(total_iter)))

                if orig_avg_no_stop_ed < orig_avg_ed_best:
                    orig_avg_ed_best = orig_avg_no_stop_ed
                    if output_dir is not None:
                        torch.save(
                            net.state_dict(),
                            os.path.join(output_dir,
                                         "crnn_" + backend + "_best"))

                if synth_avg_no_stop_ed < synth_avg_ed_best:
                    synth_avg_ed_best = synth_avg_no_stop_ed
                if do_ema:
                    net.end_test()
                print(
                    "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(synth_avg_ed_best, synth_avg_ed,
                            synth_avg_no_stop_ed, synth_acc))
                print(
                    "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed,
                            orig_acc))
                tb_writer.get_writer().add_scalars(
                    'data/test', {
                        'synth_ed_total': synth_avg_ed,
                        'synth_ed_no_stop': synth_avg_no_stop_ed,
                        'synth_avg_loss': synth_avg_loss,
                        'orig_ed_total': orig_avg_ed,
                        'orig_ed_no_stop': orig_avg_no_stop_ed,
                        'orig_avg_loss': orig_avg_loss
                    }, total_iter)
                if len(loss_mean_ctc) > 0:
                    train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)}
                    if do_vat:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_vat_loss': np.mean(loss_mean_vat)
                            }
                        }
                    if do_at:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_at_loss': np.mean(loss_mean_at)
                            }
                        }
                    if do_test_vat:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_loss': np.mean(loss_mean_test_vat)
                            }
                        }
                    if do_test_vat_rnn and do_test_vat_cnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_crnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    elif do_test_vat_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_rnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    elif do_test_vat_cnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_cnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    if ada_after_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_ada_rnn_s_loss': np.mean(loss_mean_ada_rnn_s),
                                'mean_ada_rnn_t_loss': np.mean(loss_mean_ada_rnn_t)
                            }
                        }
                    if ada_before_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_ada_cnn_s_loss': np.mean(loss_mean_ada_cnn_s),
                                'mean_ada_cnn_t_loss': np.mean(loss_mean_ada_cnn_t)
                            }
                        }
                    print(train_dict)
                    tb_writer.get_writer().add_scalars('data/train',
                                                       train_dict, total_iter)
            '''
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            '''
            optimizer.zero_grad()
            imgs = Variable(sample["img"])
            #print("images sizes are:")
            #print(sample["img"].shape)
            if do_vat or ada_after_rnn or ada_before_rnn:
                mask = sample['mask']
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            #print("image sequence length is:")
            #print(sample["im_seq_len"])
            #print("label sequence length is:")
            #print(sample["seq_len"].view(1,-1))
            img_seq_lens = sample["im_seq_len"]
            if cuda:
                imgs = imgs.cuda()
                if do_vat or ada_after_rnn or ada_before_rnn:
                    mask = mask.cuda()

            if do_ada_lr:
                ada_p = float(iter_count) / max_iter
                lr_scheduler.update(ada_p)

            if ada_before_rnn or ada_after_rnn:
                if not do_ada_lr:
                    ada_p = float(iter_count) / max_iter
                ada_alpha = 2. / (1. + np.exp(-10. * ada_p)) - 1

                if cur_ada >= ada_len:
                    ada_load = DataLoader(orig_ada_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_ada)
                    ada_len = len(ada_load)
                    cur_ada = 0
                    ada_iter = iter(ada_load)
                ada_batch = next(ada_iter)
                cur_ada += 1
                ada_imgs = Variable(ada_batch["img"])
                ada_img_seq_lens = ada_batch["im_seq_len"]
                ada_mask = ada_batch['mask'].byte()
                if cuda:
                    ada_imgs = ada_imgs.cuda()

                _, ada_cnn, ada_rnn = net(ada_imgs,
                                          ada_img_seq_lens,
                                          ada_alpha=ada_alpha,
                                          mask=ada_mask)
                if ada_before_rnn:
                    ada_num_features = ada_cnn.size(0)
                else:
                    ada_num_features = ada_rnn.size(0)
                domain_label = torch.zeros(ada_num_features)
                domain_label = domain_label.long()
                if cuda:
                    domain_label = domain_label.cuda()
                domain_label = Variable(domain_label)

                if ada_before_rnn:
                    err_ada_cnn_t = loss_domain(ada_cnn, domain_label)
                if ada_after_rnn:
                    err_ada_rnn_t = loss_domain(ada_rnn, domain_label)

            if do_test_vat and do_at:
                # test part!
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                test_vat_batch = next(vat_iter)
                cur_vat += 1
                test_vat_mask = test_vat_batch['mask']
                test_vat_imgs = Variable(test_vat_batch["img"])
                test_vat_img_seq_lens = test_vat_batch["im_seq_len"]
                if cuda:
                    test_vat_imgs = test_vat_imgs.cuda()
                    test_vat_mask = test_vat_mask.cuda()
                # train part
                at_test_vat_loss = LabeledAtAndUnlabeledTestVatLoss(
                    xi=vat_xi, eps=vat_epsilon, ip=vat_ip)

                at_loss, test_vat_loss = at_test_vat_loss(
                    model=net,
                    train_x=imgs,
                    train_labels_flatten=labels_flatten,
                    train_img_seq_lens=img_seq_lens,
                    train_label_lens=label_lens,
                    batch_size=batch_size,
                    test_x=test_vat_imgs,
                    test_seq_len=test_vat_img_seq_lens,
                    test_mask=test_vat_mask)
            elif do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                vat_batch = next(vat_iter)
                cur_vat += 1
                vat_mask = vat_batch['mask']
                vat_imgs = Variable(vat_batch["img"])
                vat_img_seq_lens = vat_batch["im_seq_len"]
                if cuda:
                    vat_imgs = vat_imgs.cuda()
                    vat_mask = vat_mask.cuda()
                if do_test_vat:
                    if do_test_vat_rnn or do_test_vat_cnn:
                        raise "can only do one of do_test_vat | (do_test_vat_rnn, do_test_vat_cnn)"
                    if vat_sign == True:
                        test_vat_loss = VATLossSign(
                            do_test_entropy=do_test_entropy,
                            xi=vat_xi,
                            eps=vat_epsilon,
                            ip=vat_ip)
                    else:
                        test_vat_loss = VATLoss(xi=vat_xi,
                                                eps=vat_epsilon,
                                                ip=vat_ip)
                elif do_test_vat_rnn and do_test_vat_cnn:
                    test_vat_loss = VATonRnnCnnSign(xi=vat_xi,
                                                    eps=vat_epsilon,
                                                    ip=vat_ip)
                elif do_test_vat_rnn:
                    test_vat_loss = VATonRnnSign(xi=vat_xi,
                                                 eps=vat_epsilon,
                                                 ip=vat_ip)
                elif do_test_vat_cnn:
                    test_vat_loss = VATonCnnSign(xi=vat_xi,
                                                 eps=vat_epsilon,
                                                 ip=vat_ip)
                if do_test_vat_cnn and do_test_vat_rnn:
                    test_vat_loss, cnn_lds, rnn_lds = test_vat_loss(
                        net, vat_imgs, vat_img_seq_lens, vat_mask)
                elif do_test_vat:
                    test_vat_loss = test_vat_loss(net, vat_imgs,
                                                  vat_img_seq_lens, vat_mask)
            elif do_vat:
                vat_loss = VATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip)
                vat_loss = vat_loss(net, imgs, img_seq_lens, mask)
            elif do_at:
                at_loss = LabeledATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip)
                at_loss = at_loss(net, imgs, labels_flatten, img_seq_lens,
                                  label_lens, batch_size)

            if ada_after_rnn or ada_before_rnn:
                preds, ada_cnn, ada_rnn = net(imgs,
                                              img_seq_lens,
                                              ada_alpha=ada_alpha,
                                              mask=mask)

                if ada_before_rnn:
                    ada_num_features = ada_cnn.size(0)
                else:
                    ada_num_features = ada_rnn.size(0)

                domain_label = torch.ones(ada_num_features)
                domain_label = domain_label.long()
                if cuda:
                    domain_label = domain_label.cuda()
                domain_label = Variable(domain_label)

                if ada_before_rnn:
                    err_ada_cnn_s = loss_domain(ada_cnn, domain_label)
                if ada_after_rnn:
                    err_ada_rnn_s = loss_domain(ada_rnn, domain_label)

            else:
                preds = net(imgs, img_seq_lens)
            '''
            if output_dir is not None:
                if (show_iter is not None and iter_num != 0 and iter_num % show_iter == 0):
                    print_data_visuals(net, tb_writer, train_data.get_lexicon(), sample["img"], labels_flatten, label_lens,
                                       preds, ((epoch_count * len(data_loader)) + iter_num))
            '''
            loss_ctc = loss_function(
                preds, labels_flatten,
                Variable(torch.IntTensor(np.array(img_seq_lens))),
                label_lens) / batch_size

            if loss_ctc.data[0] in [float("inf"), -float("inf")]:
                print("warnning: loss should not be inf.")
                continue
            total_loss = loss_ctc

            if do_vat:
                #mask = sample['mask']
                #if cuda:
                #    mask = mask.cuda()
                #vat_loss = virtual_adversarial_loss(net, imgs, img_seq_lens, mask, is_training=True, do_entropy=False, epsilon=vat_epsilon, num_power_iterations=1,
                #             xi=1e-6, average_loss=True)
                total_loss = total_loss + vat_ratio * vat_loss.cpu()
            if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
                total_loss = total_loss + test_vat_ratio * test_vat_loss.cpu()

            if ada_before_rnn:
                total_loss = total_loss + ada_ratio * err_ada_cnn_s.cpu(
                ) + ada_ratio * err_ada_cnn_t.cpu()
            if ada_after_rnn:
                total_loss = total_loss + ada_ratio * err_ada_rnn_s.cpu(
                ) + ada_ratio * err_ada_rnn_t.cpu()

            total_loss.backward()
            nn.utils.clip_grad_norm(net.parameters(), 10.0)
            if -400 < loss_ctc.data[0] < 400:
                loss_mean_ctc.append(loss_ctc.data[0])
            if -1000 < total_loss.data[0] < 1000:
                loss_mean_total.append(total_loss.data[0])
            if len(loss_mean_total) > 100:
                loss_mean_total = loss_mean_total[-100:]
            status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format(
                epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(),
                np.mean(loss_mean_total), loss_ctc.data[0])
            if ada_after_rnn:
                loss_mean_ada_rnn_s.append(err_ada_rnn_s.data[0])
                loss_mean_ada_rnn_t.append(err_ada_rnn_t.data[0])
                status += "; ladatrnns: {0:.3f}; ladatrnnt: {1:.3f}".format(
                    err_ada_rnn_s.data[0], err_ada_rnn_t.data[0])
            if ada_before_rnn:
                loss_mean_ada_cnn_s.append(err_ada_cnn_s.data[0])
                loss_mean_ada_cnn_t.append(err_ada_cnn_t.data[0])
                status += "; ladatcnns: {0:.3f}; ladatcnnt: {1:.3f}".format(
                    err_ada_cnn_s.data[0], err_ada_cnn_t.data[0])
            if do_vat:
                loss_mean_vat.append(vat_loss.data[0])
                status += "; lvat: {0:.3f}".format(vat_loss.data[0])
            if do_at:
                loss_mean_at.append(at_loss.data[0])
                status += "; lat: {0:.3f}".format(at_loss.data[0])
            if do_test_vat:
                loss_mean_test_vat.append(test_vat_loss.data[0])
                status += "; l_tvat: {0:.3f}".format(test_vat_loss.data[0])
            if do_test_vat_rnn or do_test_vat_cnn:
                loss_mean_test_vat.append(test_vat_loss.data[0])
                if do_test_vat_rnn and do_test_vat_cnn:
                    status += "; l_tvatc: {}".format(cnn_lds.data[0])
                    status += "; l_tvatr: {}".format(rnn_lds.data[0])
                else:
                    status += "; l_tvat: {}".format(test_vat_loss.data[0])

            iterator.set_description(status)
            optimizer.step()
            if do_lr_step:
                lr_scheduler.step()
            if do_ema:
                net.udate_ema()
            iter_count += 1
        if output_dir is not None:
            torch.save(net.state_dict(),
                       os.path.join(output_dir, "crnn_" + backend + "_last"))
        epoch_count += 1

    return