예제 #1
0
def train():
    device = 'gpu:0' if tf.test.is_gpu_available() else 'cpu'
    args = parse_train_arguments()

    with tf.device(device):

        losses = {
            'train_loss': Mean(name='train_loss'),
            'train_mse': Mean(name='train_mse'),
            'train_psnr': Mean(name='train_psnr'),
            'train_ssim': Mean(name='train_ssim')
        }

        train_dataset_path = glob(os.path.join(args.train_dataset_base_path, '**/**.png'), recursive=True) + \
                             glob(os.path.join(args.train_dataset_base_path, '**/**.jpg'), recursive=True) + \
                             glob(os.path.join(args.train_dataset_base_path, '**/**.bmp'), recursive=True)

        dataset = load_simulation_data(train_dataset_path, args.batch_size * args.batches, args.patch_size,
                                       args.radious, args.epsilon)

        model, optimizer, initial_epoch, clip_norms = load_model(args.checkpoint_directory, args.restore_model,
                                                                 args.learning_rate)

        for epoch in range(initial_epoch, args.epochs):
            total_clip_norms = [tf.cast(0, dtype=tf.float32), tf.cast(0, dtype=tf.float32)]
            batched_dataset = dataset.batch(args.batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
            progress_bar = tqdm(batched_dataset, total=args.batches)

            for index, data_batch in enumerate(progress_bar):
                dnet_new_norm, snet_new_norm = train_step(model, optimizer, data_batch, losses, clip_norms,
                                                          args.radious)
                on_batch_end(epoch, index, dnet_new_norm, snet_new_norm, total_clip_norms, losses, progress_bar)

            on_epoch_end(model, optimizer, epoch, losses, clip_norms, total_clip_norms, args.checkpoint_directory, args.checkpoint_frequency)
예제 #2
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        Translation(),
        # Scale(),
        Contrast(),
        Grid_distortion(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    if data_path is not None:
        data = TextDataset(data_path=data_path, mode="pb", transform=transform)
    else:
        data = TestDataset(transform=transform, abc=abc)
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval()
    acc, avg_ed, pred_pb = test_tta(net, data, data.get_abc(), cuda, visualize)
    
    df_submit = pd.DataFrame()
    df_submit['name'] = [x.split('/')[-1] for x in glob.glob('../../input/public_test_data/*')]
    df_submit['label'] = pred_pb
    
    df_submit.to_csv('tmp_rcnn_tta10.csv', index=None)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
예제 #3
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu,
         visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(abc, seq_proj, backend, snapshot, cuda).eval()
    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([Resize(size=(input_size[0], input_size[1]))])
    if data_path is not None:
        data = TextDataset(data_path=data_path,
                           mode="test",
                           transform=transform)
    else:
        data = TestDataset(transform=transform, abc=abc)
    acc = test(net, data, abc, cuda, visualize)
    print("Accuracy: {}".format(acc))
예제 #4
0
def main(data_path, base_data_dir, lexicon_path, output_path, seq_proj,
         backend, snapshot, input_height, visualize, do_beam_search,
         dataset_name):
    cuda = True
    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
        print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    transform = Compose([Resize(hight=input_height), AddWidth(), Normalize()])
    data = TextDataset(data_path=data_path,
                       lexicon=lexicon,
                       base_path=base_data_dir,
                       transform=transform,
                       fonts=None)
    dataset_info = SynthDataInfo(None, None, None, dataset_name.lower())

    # data = TextDataset(data_path=data_path, mode="test", transform=transform)
    #else:
    #    data = TestDataset(transform=transform, abc=abc)
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(lexicon=data.get_lexicon(),
                     seq_proj=seq_proj,
                     backend=backend,
                     snapshot=snapshot,
                     cuda=cuda,
                     do_beam_search=do_beam_search).eval()
    acc, avg_ed, avg_no_stop_ed = test(net,
                                       data,
                                       data.get_lexicon(),
                                       cuda,
                                       visualize=visualize,
                                       dataset_info=dataset_info,
                                       batch_size=1,
                                       tb_writer=None,
                                       n_iter=0,
                                       initial_title='val_orig',
                                       loss_function=None,
                                       is_trian=False,
                                       output_path=output_path,
                                       do_beam_search=do_beam_search,
                                       do_results=True)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
    print("Edit distance without stop signs: {}".format(avg_no_stop_ed))
예제 #5
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu,
         visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False
    ## manually set cuda to False
    cuda = False
    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose(
        [Rotation(), Resize(size=(input_size[0], input_size[1]))])
    if data_path is not None:
        data = TextDataset(data_path=data_path,
                           mode="test",
                           transform=transform)
    else:
        data = TestDataset(transform=transform, abc=abc)

    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval()
    acc, avg_ed = test(net, data, data.get_abc(), cuda, visualize)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
예제 #6
0
파일: inference.py 프로젝트: hvloc15/crnn
def main():
    input_size = [int(x) for x in config.input_size.split('x')]
    transform = Compose([
        Rotation(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    # if data_path is not None:
    data = TextDataset(data_path=config.test_path, mode=config.test_mode, transform=transform)
    # else:
    #     data = TestDataset(transform=transform, abc=abc)
    # seq_proj = [int(x) for x in config.seq_proj.split('x')]

    input_size = [int(x) for x in config.input_size.split('x')]
    net = load_model(input_size, data.get_abc(), None, config.backend, config.snapshot).eval()

    assert data.mode == config.test_mode
    acc, avg_ed = test(net, data, data.get_abc(), visualize=True,
                       batch_size=config.batch_size, num_workers=0,
                       output_csv=config.output_csv, output_image=config.output_image)

    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
예제 #7
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu,
         visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    seq_proj = [int(x) for x in seq_proj.split('x')]
    config = json.load(open(os.path.join(data_path, "desc.json")))
    net = load_model(config["abc"], seq_proj, backend, snapshot, cuda).eval()

    DEBUG = True
    if DEBUG:
        img = cv2.imread('2.png')
        img = cv2.resize(img, (input_size[0], input_size[1]))
        out = singleOp(net, img, cuda, visualize)
        print(out[0])
        return

    files = os.listdir('./d/data')
    tongji = 0
    for myfile in files:
        img = cv2.imread('./d/data/' + myfile)
        img = cv2.resize(img, (input_size[0], input_size[1]))
        out = singleOp(net, img, cuda, visualize)
        splits = myfile.split('_')
        id = splits[0]
        if splits[-1] == '0.jpg':
            trueresult = id[:len(id) // 2 - 1]
        else:
            trueresult = id[len(id) // 2:]
        if trueresult == out[0]:
            tongji += 1
        else:
            print(myfile)
            print(out)
            print(trueresult)
            print('--------------------')
    print(tongji * 1.0 / len(files))
예제 #8
0
def forward(img):
    assert img is not None
    sample = {"img": img}
    transform = Compose([Resize(size=(input_size[0], input_size[1]))])
    sample = transform(sample)
    img = torch.from_numpy(sample["img"].transpose((2, 0, 1))).float()

    net = load_model(input_size,
                     config.abc,
                     None,
                     config.backend,
                     config.snapshot,
                     cuda=False)
    # assert not net.is_cuda
    assert not next(net.parameters()).is_cuda
    net = net.eval()
    with torch.no_grad():
        img = img.unsqueeze(0)
        assert img.size()[0] == 1 and img.size(
        )[1] == 3 and img.size()[2] < img.size()[3]
        out = net(img, decode=True)
        result = out[0]
        return result
예제 #9
0
def train(
    train_dataloader,
    query_dataloader,
    retrieval_dataloader,
    arch,
    code_length,
    device,
    lr,
    max_iter,
    mu,
    nu,
    eta,
    topk,
    evaluate_interval,
):
    """
    Training model.

    Args
        train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader.
        arch(str): CNN model name.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter: int
        Maximum iteration
        mu, nu, eta(float): Hyper-parameters.
        topk(int): Compute mAP using top k retrieval result
        evaluate_interval(int): Evaluation interval.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Construct network, optimizer, loss
    model = load_model(arch, code_length).to(device)
    criterion = DSDHLoss(eta)
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7)

    # Initialize
    N = len(train_dataloader.dataset)
    B = torch.randn(code_length, N).sign().to(device)
    U = torch.zeros(code_length, N).to(device)
    train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
    S = (train_targets @ train_targets.t() > 0).float()
    Y = train_targets.t()
    best_map = 0.
    iter_time = time.time()

    for it in range(max_iter):
        model.train()
        # CNN-step
        for data, targets, index in train_dataloader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()

            U_batch = model(data).t()
            U[:, index] = U_batch.data
            loss = criterion(U_batch, U, S[:, index], B[:, index])

            loss.backward()
            optimizer.step()
        scheduler.step()

        # W-step
        W = torch.inverse(B @ B.t() + nu / mu *
                          torch.eye(code_length, device=device)) @ B @ Y.t()

        # B-step
        B = solve_dcc(W, Y, U, B, eta, mu)

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            iter_time = time.time() - iter_time
            epoch_loss = calc_loss(U, S, Y, W, B, mu, nu, eta)

            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length,
                                       device)
            retrieval_code = generate_code(model, retrieval_dataloader,
                                           code_length, device)
            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets(
            )

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )
            logger.info(
                '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                    it + 1, max_iter, epoch_loss, mAP, iter_time))

            # Save checkpoint
            if best_map < mAP:
                best_map = mAP
                checkpoint = {
                    'qB': query_code,
                    'qL': query_targets,
                    'rB': retrieval_code,
                    'rL': retrieval_targets,
                    'model': model.state_dict(),
                    'map': best_map,
                }
            iter_time = time.time()

    return checkpoint
예제 #10
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu, visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    seq_proj = [int(x) for x in seq_proj.split('x')]
    
    print(list(glob.glob('./tmp/fold*_best') + glob.glob('./tmp2/fold*_best')))
    fold_pred_pb_tta = []
    # for snapshot in glob.glob('./tmp/fold*_best')[:]:
    
    for snapshot in list(glob.glob('./tmp/fold*_best') + glob.glob('./tmp2/fold*_best'))[:]:
    
#     for snapshot in ['./tmp/fold12_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.997181964573',
#                     './tmp/fold13_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.995571658615',
#                     './tmp/fold3_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.993961352657',
#                     './tmp/fold5_train_crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_0.994363929147']:
        if np.random.uniform(0.0, 1.0) < 1:
            transform = Compose([
                # Rotation(),
                Translation(),
                # Scale(),
                Contrast(),
                # Grid_distortion(),
                Resize(size=(input_size[0], input_size[1]))
            ])
        else:
            transform = Compose([
                # Rotation(),
                Translation(),
                # Scale(),
                Contrast(),
                # Grid_distortion(),
                Resize(size=(input_size[0], input_size[1]))
            ])
            
        if data_path is not None:
            data = TextDataset(data_path=data_path, mode="pb", transform=transform)
        else:
            data = TestDataset(transform=transform, abc=abc)
        print(snapshot)
        
        net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda).eval()
        acc, avg_ed, pred_pb = test_tta(net, data, data.get_abc(), cuda, visualize)
        fold_pred_pb_tta.append(pred_pb)
    
    with open('../data/desc.json') as up:
        data_json = json.load(up)
    
    fold_pred_pb = []
    if len(fold_pred_pb_tta) > 1:
        for test_idx in range(len(fold_pred_pb_tta[0])):
            test_idx_folds = [fold_pred_pb_tta[i][test_idx] for i in range(len(fold_pred_pb_tta))]

            test_idx_chars = []
            for char_idx in range(10):
                char_tta = [test_idx_folds[i][char_idx] for i in range(len(test_idx_folds)) 
                            if len(test_idx_folds[i]) > char_idx]
#                 if len(char_tta) < len(glob.glob('./tmp/fold*_best'))-2:
#                     print(test_idx, glob.glob('../../input/private_test_data/*')[test_idx])
                
                if len(char_tta) > 0:
                    char_tta = Counter(char_tta).most_common()[0][0]
                else:
                    char_tta = '*'
                    # print(test_idx, glob.glob('../../input/private_test_data/*')[test_idx])

                test_idx_chars += char_tta
            fold_pred_pb.append(''.join(test_idx_chars))
    
        joblib.dump(fold_pred_pb_tta, 'fold_tta.pkl')
        
        df_submit = pd.DataFrame()
        df_submit['name'] = [x['name'] for x in data_json['pb']]
        # print(fold_pred_pb_tta)
        df_submit['label'] = fold_pred_pb
    else:
        df_submit = pd.DataFrame()
        df_submit['name'] = [x['name'] for x in data_json['pb']]
        # print(fold_pred_pb_tta)
        df_submit['label'] = fold_pred_pb_tta[0]
    
    df_submit.to_csv('tmp_rcnn_tta10_pb.csv', index=None)
    print("Accuracy: {}".format(acc))
    print("Edit distance: {}".format(avg_ed))
예제 #11
0
def train(
    train_dataloader,
    query_dataloader,
    retrieval_dataloader,
    arch,
    code_length,
    device,
    eta,
    lr,
    max_iter,
    topk,
    evaluate_interval,
):
    """
    Training model.

    Args
        train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        arch(str): CNN model name.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        eta(float): Hyper-parameter.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        topk(int): Calculate map of top k.
        evaluate_interval(int): Evaluation interval.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Create model, optimizer, criterion, scheduler
    model = load_model(arch, code_length).to(device)
    criterion = DPSHLoss(eta)
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=1e-5,
    )
    scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7)

    # Initialization
    N = len(train_dataloader.dataset)
    U = torch.zeros(N, code_length).to(device)
    train_targets = train_dataloader.dataset.get_onehot_targets().to(device)

    # Training
    best_map = 0.0
    iter_time = time.time()
    for it in range(max_iter):
        model.train()
        running_loss = 0.
        for data, targets, index in train_dataloader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()

            S = (targets @ train_targets.t() > 0).float()
            U_cnn = model(data)
            U[index, :] = U_cnn.data
            loss = criterion(U_cnn, U, S)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            iter_time = time.time() - iter_time

            # Generate hash code and one-hot targets
            query_code = generate_code(model, query_dataloader, code_length,
                                       device)
            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_code = generate_code(model, retrieval_dataloader,
                                           code_length, device)
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets(
            )

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Save checkpoint
            if best_map < mAP:
                best_map = mAP
                checkpoint = {
                    'qB': query_code,
                    'qL': query_targets,
                    'rB': retrieval_code,
                    'rL': retrieval_targets,
                    'model': model.state_dict(),
                    'map': best_map,
                }
            logger.info(
                '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                    it + 1,
                    max_iter,
                    running_loss,
                    mAP,
                    iter_time,
                ))
            iter_time = time.time()

    return checkpoint
예제 #12
0
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path,
         orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir,
         lexicon_path, seq_proj, backend, snapshot, input_height, base_lr,
         elastic_alpha, elastic_sigma, step_size, max_iter, batch_size,
         output_dir, test_iter, show_iter, test_init, use_gpu,
         use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio,
         vat_epsilon, vat_ip, vat_xi, vat_sign, do_comp, comp_ratio,
         do_remove_augs, aug_to_remove, do_beam_search, dropout_conv,
         dropout_rnn, dropout_output, do_ema, do_gray, do_test_vat,
         do_test_entropy, do_test_vat_cnn, do_test_vat_rnn, do_test_rand,
         ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio, rnn_hidden_size,
         do_test_pseudo, test_pseudo_ratio, test_pseudo_thresh, do_lr_step,
         do_test_ensemble, test_ensemble_ratio, test_ensemble_thresh):
    num_nets = 4

    train_data_path = os.path.join(base_data_dir, train_data_path)
    train_base_dir = os.path.join(base_data_dir, train_base_dir)
    synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path)
    synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir)

    orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path)
    orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir)
    lexicon_path = os.path.join(base_data_dir, lexicon_path)

    all_parameters = locals()
    cuda = use_gpu
    #print(train_base_dir)
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        tb_writer = TbSummary(output_dir)
        output_dir = os.path.join(output_dir, 'model')
        os.makedirs(output_dir, exist_ok=True)

    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
    #print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    with open(os.path.join(output_dir, 'params.txt'), 'w') as f:
        f.writelines(str(all_parameters))
    print(all_parameters)
    print('new vat')

    sin_magnitude = 4
    rotate_max_angle = 2
    train_fonts = [
        'Qomolangma-Betsu', 'Shangshung Sgoba-KhraChen',
        'Shangshung Sgoba-KhraChung', 'Qomolangma-Drutsa'
    ]

    all_args = locals()

    print('doing all transforms :)')
    rand_trans = [
        ElasticAndSine(elastic_alpha=elastic_alpha,
                       elastic_sigma=elastic_sigma,
                       sin_magnitude=sin_magnitude),
        Rotation(angle=rotate_max_angle, fill_value=255),
        ColorGradGausNoise()
    ]
    if do_gray:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(),
            ToGray(),
            Normalize()
        ]
    else:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(), Normalize()
        ]

    transform_random = Compose(rand_trans)
    if do_gray:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(),
             ToGray(),
             Normalize()])
    else:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(), Normalize()])

    if use_no_font_repeat_data:
        print('create dataset')
        train_data = TextDatasetRandomFont(data_path=train_data_path,
                                           lexicon=lexicon,
                                           base_path=train_base_dir,
                                           transform=transform_random,
                                           fonts=train_fonts)
        print('finished creating dataset')
    else:
        print('train data path:\n{}'.format(train_data_path))
        print('train_base_dir:\n{}'.format(train_base_dir))
        train_data = TextDataset(data_path=train_data_path,
                                 lexicon=lexicon,
                                 base_path=train_base_dir,
                                 transform=transform_random,
                                 fonts=train_fonts)
    synth_eval_data = TextDataset(data_path=synth_eval_data_path,
                                  lexicon=lexicon,
                                  base_path=synth_eval_base_dir,
                                  transform=transform_random,
                                  fonts=train_fonts)
    orig_eval_data = TextDataset(data_path=orig_eval_data_path,
                                 lexicon=lexicon,
                                 base_path=orig_eval_base_dir,
                                 transform=transform_simple,
                                 fonts=None)
    if do_test_ensemble:
        orig_vat_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    #else:
    #    train_data = TestDataset(transform=transform, abc=abc).set_mode("train")
    #    synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    #    orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    seq_proj = [int(x) for x in seq_proj.split('x')]
    nets = []
    optimizers = []
    lr_schedulers = []
    for neti in range(num_nets):
        nets.append(
            load_model(lexicon=train_data.get_lexicon(),
                       seq_proj=seq_proj,
                       backend=backend,
                       snapshot=snapshot,
                       cuda=cuda,
                       do_beam_search=do_beam_search,
                       dropout_conv=dropout_conv,
                       dropout_rnn=dropout_rnn,
                       dropout_output=dropout_output,
                       do_ema=do_ema,
                       ada_after_rnn=ada_after_rnn,
                       ada_before_rnn=ada_before_rnn,
                       rnn_hidden_size=rnn_hidden_size,
                       gpu=neti))
        optimizers.append(
            optim.Adam(nets[neti].parameters(),
                       lr=base_lr,
                       weight_decay=0.0001))
        lr_schedulers.append(
            StepLR(optimizers[neti], step_size=step_size, max_iter=max_iter))
    loss_function = CTCLoss()

    synth_avg_ed_best = float("inf")
    orig_avg_ed_best = float("inf")
    epoch_count = 0

    if do_test_ensemble:
        collate_vat = lambda x: text_collate(x, do_mask=True)
        vat_load = DataLoader(orig_vat_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_vat)
        vat_len = len(vat_load)
        cur_vat = 0
        vat_iter = iter(vat_load)

    loss_domain = torch.nn.NLLLoss()

    while True:
        collate = lambda x: text_collate(
            x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn))
        data_loader = DataLoader(train_data,
                                 batch_size=batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 collate_fn=collate)
        if do_comp:
            data_loader_comp = DataLoader(train_data_comp,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_comp)
            iter_comp = iter(data_loader_comp)

        loss_mean_ctc = []
        loss_mean_total = []
        loss_mean_test_ensemble = []
        num_labels_used_total = 0
        iterator = tqdm(data_loader)
        nll_loss = torch.nn.NLLLoss()
        iter_count = 0
        for iter_num, sample in enumerate(iterator):
            total_iter = (epoch_count * len(data_loader)) + iter_num
            if ((total_iter > 1)
                    and total_iter % test_iter == 0) or (test_init
                                                         and total_iter == 0):
                # epoch_count != 0 and

                print("Test phase")
                for net in nets:
                    net = net.eval()
                    if do_ema:
                        net.start_test()

                synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test(
                    nets,
                    synth_eval_data,
                    synth_eval_data.get_lexicon(),
                    cuda,
                    batch_size=batch_size,
                    visualize=False,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='val_synth',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=False)

                orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test(
                    nets,
                    orig_eval_data,
                    orig_eval_data.get_lexicon(),
                    cuda,
                    batch_size=batch_size,
                    visualize=False,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='test_orig',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=do_beam_search)

                for net in nets:
                    net = net.train()
                #save periodic
                if output_dir is not None and total_iter // 30000:
                    periodic_save = os.path.join(output_dir, 'periodic_save')
                    os.makedirs(periodic_save, exist_ok=True)
                    old_save = glob.glob(os.path.join(periodic_save, '*'))
                    for neti, net in enumerate(nets):
                        torch.save(
                            net.state_dict(),
                            os.path.join(
                                output_dir, "crnn_{}_".format(neti) + backend +
                                "_" + str(total_iter)))

                if orig_avg_no_stop_ed < orig_avg_ed_best:
                    orig_avg_ed_best = orig_avg_no_stop_ed
                if output_dir is not None:
                    for neti, net in enumerate(nets):
                        torch.save(
                            net.state_dict(),
                            os.path.join(
                                output_dir, "crnn_{}_".format(neti) + backend +
                                "_iter_{}".format(total_iter)))

                if synth_avg_no_stop_ed < synth_avg_ed_best:
                    synth_avg_ed_best = synth_avg_no_stop_ed
                if do_ema:
                    for net in nets:
                        net.end_test()
                print(
                    "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(synth_avg_ed_best, synth_avg_ed,
                            synth_avg_no_stop_ed, synth_acc))
                print(
                    "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed,
                            orig_acc))
                tb_writer.get_writer().add_scalars(
                    'data/test', {
                        'synth_ed_total': synth_avg_ed,
                        'synth_ed_no_stop': synth_avg_no_stop_ed,
                        'synth_avg_loss': synth_avg_loss,
                        'orig_ed_total': orig_avg_ed,
                        'orig_ed_no_stop': orig_avg_no_stop_ed,
                        'orig_avg_loss': orig_avg_loss
                    }, total_iter)
                if len(loss_mean_ctc) > 0:
                    train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)}
                    train_dict = {
                        **train_dict,
                        **{
                            'mean_test_ensemble_loss':
                            np.mean(loss_mean_test_ensemble)
                        }
                    }
                    train_dict = {
                        **train_dict,
                        **{
                            'num_labels_used': num_labels_used_total
                        }
                    }
                    num_labels_used_total = 0
                    print(train_dict)
                    tb_writer.get_writer().add_scalars('data/train',
                                                       train_dict, total_iter)
            '''
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            '''
            for optimizer in optimizers:
                optimizer.zero_grad()
            imgs = Variable(sample["img"])
            #print("images sizes are:")
            #print(sample["img"].shape)
            if do_vat or ada_after_rnn or ada_before_rnn:
                mask = sample['mask']
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            #print("image sequence length is:")
            #print(sample["im_seq_len"])
            #print("label sequence length is:")
            #print(sample["seq_len"].view(1,-1))
            img_seq_lens = sample["im_seq_len"]

            if do_test_ensemble:
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                vat_batch = next(vat_iter)
                cur_vat += 1
                vat_mask = vat_batch['mask']
                vat_imgs = Variable(vat_batch["img"])
                vat_img_seq_lens = vat_batch["im_seq_len"]
                all_net_classes = []
                all_net_preds = []

                def run_net_get_classes(neti_net_pair, cur_vat_imgs,
                                        cur_vat_mask, cur_vat_img_seq_lens,
                                        cuda):
                    neti, net = neti_net_pair
                    if cuda:
                        cur_vat_imgs = cur_vat_imgs.cuda(neti)
                        cur_vat_mask = cur_vat_mask.cuda(neti)
                    vat_pred = net.vat_forward(cur_vat_imgs,
                                               cur_vat_img_seq_lens)
                    vat_pred = vat_pred * cur_vat_mask
                    vat_pred = F.softmax(vat_pred,
                                         dim=2).view(-1,
                                                     vat_pred.size()[-1])
                    all_net_preds.append(vat_pred)
                    np_vat_preds = vat_pred.cpu().data.numpy()
                    classes_by_index = np.argmax(np_vat_preds, axis=1)
                    return classes_by_index

                for neti, net in enumerate(nets):
                    if cuda:
                        vat_imgs = vat_imgs.cuda(neti)
                        vat_mask = vat_mask.cuda(neti)
                    vat_pred = net.vat_forward(vat_imgs, vat_img_seq_lens)
                    vat_pred = vat_pred * vat_mask
                    vat_pred = F.softmax(vat_pred,
                                         dim=2).view(-1,
                                                     vat_pred.size()[-1])
                    all_net_preds.append(vat_pred)
                    np_vat_preds = vat_pred.cpu().data.numpy()
                    classes_by_index = np.argmax(np_vat_preds, axis=1)
                    all_net_classes.append(classes_by_index)
                all_net_classes = np.stack(all_net_classes)
                all_net_classes, all_nets_count = stats.mode(all_net_classes,
                                                             axis=0)
                all_net_classes = all_net_classes.reshape(-1)
                all_nets_count = all_nets_count.reshape(-1)
                ens_indices = np.argwhere(
                    all_nets_count > test_ensemble_thresh)
                ens_indices = ens_indices.reshape(-1)
                ens_classes = all_net_classes[
                    all_nets_count > test_ensemble_thresh]
                net_ens_losses = []
                num_labels_used = len(ens_indices)
                for neti, net in enumerate(nets):
                    indices = Variable(
                        torch.from_numpy(ens_indices).cuda(neti))
                    labels = Variable(torch.from_numpy(ens_classes).cuda(neti))
                    net_preds_to_ens = all_net_preds[neti][indices]
                    loss = nll_loss(net_preds_to_ens, labels)
                    net_ens_losses.append(loss.cpu())
            nets_total_losses = []
            nets_ctc_losses = []
            loss_is_inf = False
            for neti, net in enumerate(nets):
                if cuda:
                    imgs = imgs.cuda(neti)
                preds = net(imgs, img_seq_lens)
                loss_ctc = loss_function(
                    preds, labels_flatten,
                    Variable(torch.IntTensor(np.array(img_seq_lens))),
                    label_lens) / batch_size

                if loss_ctc.data[0] in [float("inf"), -float("inf")]:
                    print("warnning: loss should not be inf.")
                    loss_is_inf = True
                    break
                total_loss = loss_ctc

                if do_test_ensemble:
                    total_loss = total_loss + test_ensemble_ratio * net_ens_losses[
                        neti]
                    net_ens_losses[neti] = net_ens_losses[neti].data[0]
                total_loss.backward()
                nets_total_losses.append(total_loss.data[0])
                nets_ctc_losses.append(loss_ctc.data[0])
                nn.utils.clip_grad_norm(net.parameters(), 10.0)
            if loss_is_inf:
                continue
            if -400 < loss_ctc.data[0] < 400:
                loss_mean_ctc.append(np.mean(nets_ctc_losses))
            if -400 < total_loss.data[0] < 400:
                loss_mean_total.append(np.mean(nets_total_losses))
            status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format(
                epoch_count, lr_schedulers[0].last_iter,
                lr_schedulers[0].get_lr(), np.mean(nets_total_losses),
                np.mean(nets_ctc_losses))

            if do_test_ensemble:
                ens_loss = np.mean(net_ens_losses)
                if ens_loss != 0:
                    loss_mean_test_ensemble.append(ens_loss)
                    status += "; loss_ens: {0:.3f}".format(ens_loss)
                    status += "; num_ens_used {}".format(num_labels_used)
                else:
                    loss_mean_test_ensemble.append(0)
                    status += "; loss_ens: {}".format(0)
            iterator.set_description(status)
            for optimizer in optimizers:
                optimizer.step()
            if do_lr_step:
                for lr_scheduler in lr_schedulers:
                    lr_scheduler.step()
            iter_count += 1
        if output_dir is not None:
            for neti, net in enumerate(nets):
                torch.save(
                    net.state_dict(),
                    os.path.join(output_dir,
                                 "crnn_{}_".format(neti) + backend + "_last"))
        epoch_count += 1

    return
예제 #13
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, base_lr, step_size, max_iter, batch_size, output_dir, test_epoch, test_init, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        Translation(),
        # Scale(),
        Contrast(),
        # Grid_distortion(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    seq_proj = [int(x) for x in seq_proj.split('x')]
    
    for fold_idx in range(24):
        train_mode = 'fold{0}_train'.format(fold_idx)
        val_mode = 'fold{0}_test'.format(fold_idx)
        
        if data_path is not None:
            data = TextDataset(data_path=data_path, mode=train_mode, transform=transform)
        else:
            data = TestDataset(transform=transform, abc=abc)
        
        net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda)
        optimizer = optim.Adam(net.parameters(), lr = base_lr, weight_decay=0.0001)
        lr_scheduler = StepLR(optimizer, step_size=step_size)
        # lr_scheduler = StepLR(optimizer, step_size=len(data)/batch_size*2)
        loss_function = CTCLoss()
        
        print(fold_idx)
        # continue
        
        acc_best = 0
        epoch_count = 0
        for epoch_idx in range(15):
            data_loader = DataLoader(data, batch_size=batch_size, num_workers=10, shuffle=True, collate_fn=text_collate)
            loss_mean = []
            iterator = tqdm(data_loader)
            iter_count = 0
            for sample in iterator:
                # for multi-gpu support
                if sample["img"].size(0) % len(gpu.split(',')) != 0:
                    continue
                optimizer.zero_grad()
                imgs = Variable(sample["img"])
                labels = Variable(sample["seq"]).view(-1)
                label_lens = Variable(sample["seq_len"].int())
                if cuda:
                    imgs = imgs.cuda()
                preds = net(imgs).cpu()
                pred_lens = Variable(Tensor([preds.size(0)] * batch_size).int())
                loss = loss_function(preds, labels, pred_lens, label_lens) / batch_size
                loss.backward()
                # nn.utils.clip_grad_norm(net.parameters(), 10.0)
                loss_mean.append(loss.data[0])
                status = "{}/{}; lr: {}; loss_mean: {}; loss: {}".format(epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(), np.mean(loss_mean), loss.data[0])
                iterator.set_description(status)
                optimizer.step()
                lr_scheduler.step()
                iter_count += 1
            
            if True:
                logging.info("Test phase")
                
                net = net.eval()
                
#                 train_acc, train_avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False)
#                 if acc > 0.95:
#                     error_name = [data.config[data.mode][idx]["name"] for idx in error_idx]
#                     logging.info('Train: '+','.join(error_name))
#                 logging.info("acc: {}\tacc_best: {}; avg_ed: {}\n\n".format(train_acc, train_avg_ed))

                data.set_mode(val_mode)
                acc, avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False)
                
                if acc > 0.95:
                    error_name = [data.config[data.mode][idx]["name"] for idx in error_idx]
                    logging.info('Val: '+','.join(error_name))
                
                
                
                net = net.train()
                data.set_mode(train_mode)
                
                if acc > acc_best:
                    if output_dir is not None:
                        torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_best"))
                    acc_best = acc
                
                if acc > 0.985:
                    if output_dir is not None:
                        torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_"+str(acc)))
                logging.info("train_acc: {}\t; avg_ed: {}\n\n".format(acc, acc_best, avg_ed))
                
                
            epoch_count += 1
def train():
    device = 'gpu:0' if tf.test.is_gpu_available() else 'cpu'

    args = parse_train_arguments()

    with tf.device(device):

        best_losses = {'validation_psnr': 0, 'validation_ssim': 0}

        losses = {
            'train_loss': Mean(name='train_loss'),
            'train_mse': Mean(name='train_mse'),
            'validation_mse': Mean(name='validation_mse'),
            'validation_psnr': Mean(name='validation_psnr'),
            'validation_ssim': Mean(name='validation_ssim')
        }

        train_dataset_path = glob(os.path.join(args.train_dataset_base_path,
                                               '**/*GT*.PNG'),
                                  recursive=True)
        train_noisy_dataset_path = [
            path_image.replace('GT', 'NOISY')
            for path_image in train_dataset_path
        ]

        dataset = load_train_data(
            train_dataset_path=train_dataset_path,
            train_noisy_dataset_path=train_noisy_dataset_path,
            data_length=args.batch_size * args.batches,
            patch_size=args.patch_size,
            radious=args.radious,
            epsilon=args.epsilon)

        validation_dataset = load_validation_dataset(
            validation_clean_dataset_path=args.validation_clean_dataset_path,
            validation_noisy_dataset_path=args.validation_noisy_dataset_path)

        model, optimizer, initial_epoch, clip_norms = load_model(
            checkpoint_directory=args.checkpoint_directory,
            restore_model=args.restore_model,
            learning_rate=args.learning_rate)

        for epoch in range(initial_epoch, args.epochs + 1):

            total_clip_norms = [
                tf.cast(0, dtype=tf.float32),
                tf.cast(0, dtype=tf.float32)
            ]
            batched_dataset = dataset.batch(args.batch_size).prefetch(
                buffer_size=tf.data.experimental.AUTOTUNE)
            progress_bar = tqdm(batched_dataset, total=args.batches)

            for index, data_batch in enumerate(progress_bar):
                dnet_new_norm, snet_new_norm = train_step(
                    model, optimizer, data_batch, losses, clip_norms,
                    args.radious)
                on_batch_end(epoch, index, dnet_new_norm, snet_new_norm,
                             total_clip_norms, losses, progress_bar,
                             args.batches)

            validation_progress_bar = tqdm(
                validation_dataset.batch(args.batch_size))

            for validation_data_batch in validation_progress_bar:
                validation_step(model, validation_data_batch, losses)

            on_epoch_end(model, optimizer, epoch, losses, best_losses,
                         clip_norms, total_clip_norms,
                         args.checkpoint_directory)
예제 #15
0
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr,
          max_iter, topk, evaluate_interval, anchor_num, proportion
          ):
    #print("using device")
    #print(torch.cuda.current_device())
    #print(torch.cuda.get_device_name(torch.cuda.current_device()))
    # Load model
    model = load_model(arch, code_length).to(device)
    model_mo = load_model_mo(arch).to(device)

    # Create criterion, optimizer, scheduler
    criterion = PrototypicalLoss()
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    for it in range(max_iter):
        # timer
        tic = time.time()

        # harvest prototypes/anchors#some times killed, try another way
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)
            for data, _, _ in train_dataloader:
                data = data.to(device)
                output_mo_temp = model_mo(data)
                output_mo = torch.cat((output_mo, output_mo_temp), 0)
                torch.cuda.empty_cache()

            anchor = get_anchor(output_mo, anchor_num, device)  # compute anchor

        # self-supervised deep learning
        model.train()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(device), index.to(device)
            optimizer.zero_grad()

            # output
            output_B = model(data)
            output_mo_batch = model_mo(data)

            # prototypes/anchors based similarity

            #sample_anchor_distance = torch.sqrt(torch.sum((output_mo_batch[:, None, :] - anchor) ** 2, dim=2)).to(device)
            #sample_anchor_dist_normalize = F.normalize(sample_anchor_distance, p=2, dim=1).to(device)
            #S = sample_anchor_dist_normalize @ sample_anchor_dist_normalize.t()

            # loss
            #loss = criterion(output_B, S)
            #running_loss = running_loss + loss.item()
            #loss.backward(retain_graph=True)
            with torch.no_grad():
                dist = torch.sum((output_mo_batch[:, None, :] - anchor.to(device)) ** 2, dim=2)
                k = dist.size(1)
                dist = torch.exp(-1 * dist / torch.max(dist)).to(device)
                Z_su = torch.ones(k, 1).to(device)
                Z_sum = torch.sqrt(dist.mm(Z_su)) + 1e-12
                Z_simi = torch.div(dist, Z_sum).to(device)
                S = (Z_simi.mm(Z_simi.t()))
                S=(2/(torch.max(S)-torch.min(S)))*S-1


            loss = criterion(output_B, S)

            running_loss += loss.item()
            loss.backward()

            optimizer.step()
        with torch.no_grad():
            # momentum update:
            for param_q, param_k in zip(model.parameters(), model_mo.parameters()):
                param_k.data = param_k.data * proportion + param_q.data * (1. - proportion)  # proportion = 0.999 for update

        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length, device)
            retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                it + 1,
                max_iter,
                running_loss / evaluate_interval,
                mAP,
                training_time,
            ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
예제 #16
0
    if opt.frac_save > 0.0:
        img_dir = os.path.join(save_dir, 'output')
        os.makedirs(img_dir)

    if opt.save_gt:
        gt_dir = os.path.join(opt.dataroot, 'gt')
        os.makedirs(gt_dir)

    if opt.save_noisy:
        noisy_dir = os.path.join(
            opt.dataroot, '{}_{}_{}'.format(opt.noise_type, opt.noise_stat1,
                                            opt.noise_stat2))
        os.makedirs(noisy_dir)

model = load_model(opt)
test_dataset = load_dataset(opt, 'test')
test_loader = DataLoader(test_dataset,
                         batch_size=opt.test_bsz,
                         shuffle=False,
                         num_workers=int(opt.nThreads))

num_test_bs = len(test_loader)

load_dir = os.path.join(opt.model_save, 'train')
if opt.model_type == 'cyclegan':
    model.netG_A.load_state_dict(
        torch.load(os.path.join(load_dir, 'G_A_net.pth')))
    model.netG_A.eval()
else:
    model.netG.load_state_dict(torch.load(os.path.join(load_dir, 'G_net.pth')))
예제 #17
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, base_lr,
         step_size, max_iter, batch_size, output_dir, test_epoch, test_init,
         gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    if not gpu == '':
        cuda = True
    else:
        cuda = False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        # Translation(),
        # Scale(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    if data_path is not None:
        data = TextDataset(data_path=data_path,
                           mode="train",
                           transform=transform)
    else:
        data = TestDataset(transform=transform, abc=abc)
    seq_proj = [int(x) for x in seq_proj.split('x')]
    # print(data_path)
    # print(data[0])
    # print(data.get_abc())
    # exit()
    net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda)
    optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001)
    lr_scheduler = StepLR(optimizer, step_size=step_size, max_iter=max_iter)
    loss_function = CTCLoss()

    acc_best = 0
    epoch_count = 0
    while True:
        if (test_epoch is not None and epoch_count != 0
                and epoch_count % test_epoch == 0) or (test_init
                                                       and epoch_count == 0):
            print("Test phase")
            data.set_mode("test")
            net = net.eval()
            acc, avg_ed = test(net,
                               data,
                               data.get_abc(),
                               cuda,
                               visualize=False)
            net = net.train()
            data.set_mode("train")
            if acc > acc_best:
                if output_dir is not None:
                    torch.save(
                        net.state_dict(),
                        os.path.join(output_dir,
                                     "crnn_" + backend + "_" + "_best"))
                acc_best = acc
            print("acc: {}\tacc_best: {}; avg_ed: {}".format(
                acc, acc_best, avg_ed))

        data_loader = DataLoader(data,
                                 batch_size=batch_size,
                                 num_workers=1,
                                 shuffle=True,
                                 collate_fn=text_collate)
        loss_mean = []
        iterator = tqdm(data_loader)
        iter_count = 0
        for sample in iterator:
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            optimizer.zero_grad()
            imgs = Variable(sample["img"])
            labels = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            if cuda:
                imgs = imgs.cuda()
            preds = net(imgs).cpu()
            pred_lens = Variable(Tensor([preds.size(0)] * batch_size).int())
            loss = loss_function(preds, labels, pred_lens,
                                 label_lens) / batch_size
            loss.backward()
            nn.utils.clip_grad_norm(net.parameters(), 10.0)
            loss_mean.append(loss.data[0])
            status = "epoch: {}; iter: {}; lr: {}; loss_mean: {}; loss: {}".format(
                epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(),
                np.mean(loss_mean), loss.data[0])
            iterator.set_description(status)
            optimizer.step()
            lr_scheduler.step()
            iter_count += 1
        if output_dir is not None:
            torch.save(
                net.state_dict(),
                os.path.join(output_dir, "crnn_" + backend + "_" + "_last"))
        epoch_count += 1
        if epoch_count == 50:
            break
    return
예제 #18
0
def train(
    train_dataloader,
    query_dataloader,
    retrieval_dataloader,
    arch,
    code_length,
    device,
    lr,
    max_iter,
    alpha,
    topk,
    evaluate_interval,
):
    """
    Training model.

    Args
        train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader.
        arch(str): CNN model name.
        code_length(int): Hash code length.
        device(torch.device): GPU or CPU.
        lr(float): Learning rate.
        max_iter(int): Number of iterations.
        alpha(float): Hyper-parameters.
        topk(int): Compute top k map.
        evaluate_interval(int): Interval of evaluation.

    Returns
        checkpoint(dict): Checkpoint.
    """
    # Load model
    model = load_model(arch, code_length).to(device)

    # Create criterion, optimizer, scheduler
    criterion = HashNetLoss(alpha)
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    # In this implementation, I do not use "scaled tanh".
    # It is useless and hard to tune parameters, sometimes it may decrease performance.
    # Refer to https://github.com/thuml/HashNet/issues/29
    for it in range(max_iter):
        tic = time.time()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(
                device), index.to(device)
            optimizer.zero_grad()

            # Create similarity matrix
            S = (targets @ targets.t() > 0).float()
            outputs = model(data)
            loss = criterion(outputs, S)

            running_loss += loss.item()
            loss.backward()
            optimizer.step()
        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length,
                                       device)
            retrieval_code = generate_code(model, retrieval_dataloader,
                                           code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets(
            )

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info(
                '[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                    it + 1,
                    max_iter,
                    running_loss / evaluate_interval,
                    mAP,
                    training_time,
                ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
예제 #19
0
def main(base_data_dir, train_data_path, train_base_dir, orig_eval_data_path,
         orig_eval_base_dir, synth_eval_data_path, synth_eval_base_dir,
         lexicon_path, seq_proj, backend, snapshot, input_height, base_lr,
         elastic_alpha, elastic_sigma, step_size, max_iter, batch_size,
         output_dir, test_iter, show_iter, test_init, use_gpu,
         use_no_font_repeat_data, do_vat, do_at, vat_ratio, test_vat_ratio,
         vat_epsilon, vat_ip, vat_xi, vat_sign, do_remove_augs, aug_to_remove,
         do_beam_search, dropout_conv, dropout_rnn, dropout_output, do_ema,
         do_gray, do_test_vat, do_test_entropy, do_test_vat_cnn,
         do_test_vat_rnn, ada_after_rnn, ada_before_rnn, do_ada_lr, ada_ratio,
         rnn_hidden_size, do_lr_step, dataset_name):
    if not do_lr_step and not do_ada_lr:
        raise NotImplementedError(
            'learning rate should be either step or ada.')
    train_data_path = os.path.join(base_data_dir, train_data_path)
    train_base_dir = os.path.join(base_data_dir, train_base_dir)
    synth_eval_data_path = os.path.join(base_data_dir, synth_eval_data_path)
    synth_eval_base_dir = os.path.join(base_data_dir, synth_eval_base_dir)
    orig_eval_data_path = os.path.join(base_data_dir, orig_eval_data_path)
    orig_eval_base_dir = os.path.join(base_data_dir, orig_eval_base_dir)
    lexicon_path = os.path.join(base_data_dir, lexicon_path)

    all_parameters = locals()
    cuda = use_gpu
    #print(train_base_dir)
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        tb_writer = TbSummary(output_dir)
        output_dir = os.path.join(output_dir, 'model')
        os.makedirs(output_dir, exist_ok=True)

    with open(lexicon_path, 'rb') as f:
        lexicon = pkl.load(f)
    #print(sorted(lexicon.items(), key=operator.itemgetter(1)))

    with open(os.path.join(output_dir, 'params.txt'), 'w') as f:
        f.writelines(str(all_parameters))
    print(all_parameters)
    print('new vat')

    sin_magnitude = 4
    rotate_max_angle = 2
    dataset_info = SynthDataInfo(None, None, None, dataset_name.lower())
    train_fonts = dataset_info.font_names

    all_args = locals()

    allowed_removals = [
        'elastic', 'sine', 'sine_rotate', 'rotation', 'color_aug',
        'color_gaus', 'color_sine'
    ]
    if do_remove_augs and aug_to_remove not in allowed_removals:
        raise Exception('augmentation removal value is not allowed.')

    if do_remove_augs:
        rand_trans = []
        if aug_to_remove == 'elastic':
            print('doing sine transform :)')
            rand_trans.append(OnlySine(sin_magnitude=sin_magnitude))
        elif aug_to_remove in ['sine', 'sine_rotate']:
            print('doing elastic transform :)')
            rand_trans.append(
                OnlyElastic(elastic_alpha=elastic_alpha,
                            elastic_sigma=elastic_sigma))
        if aug_to_remove not in ['elastic', 'sine', 'sine_rotate']:
            print('doing elastic transform :)')
            print('doing sine transform :)')
            rand_trans.append(
                ElasticAndSine(elastic_alpha=elastic_alpha,
                               elastic_sigma=elastic_sigma,
                               sin_magnitude=sin_magnitude))
        if aug_to_remove not in ['rotation', 'sine_rotate']:
            print('doing rotation transform :)')
            rand_trans.append(Rotation(angle=rotate_max_angle, fill_value=255))
        if aug_to_remove not in ['color_aug', 'color_gaus', 'color_sine']:
            print('doing color_aug transform :)')
            rand_trans.append(ColorGradGausNoise())
        elif aug_to_remove == 'color_gaus':
            print('doing color_sine transform :)')
            rand_trans.append(ColorGrad())
        elif aug_to_remove == 'color_sine':
            print('doing color_gaus transform :)')
            rand_trans.append(ColorGausNoise())
    else:
        print('doing all transforms :)')
        rand_trans = [
            ElasticAndSine(elastic_alpha=elastic_alpha,
                           elastic_sigma=elastic_sigma,
                           sin_magnitude=sin_magnitude),
            Rotation(angle=rotate_max_angle, fill_value=255),
            ColorGradGausNoise()
        ]
    if do_gray:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(),
            ToGray(),
            Normalize()
        ]
    else:
        rand_trans = rand_trans + [
            Resize(hight=input_height),
            AddWidth(), Normalize()
        ]

    transform_random = Compose(rand_trans)
    if do_gray:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(),
             ToGray(),
             Normalize()])
    else:
        transform_simple = Compose(
            [Resize(hight=input_height),
             AddWidth(), Normalize()])

    if use_no_font_repeat_data:
        print('creating dataset')
        train_data = TextDatasetRandomFont(data_path=train_data_path,
                                           lexicon=lexicon,
                                           base_path=train_base_dir,
                                           transform=transform_random,
                                           fonts=train_fonts)
        print('finished creating dataset')
    else:
        print('train data path:\n{}'.format(train_data_path))
        print('train_base_dir:\n{}'.format(train_base_dir))
        train_data = TextDataset(data_path=train_data_path,
                                 lexicon=lexicon,
                                 base_path=train_base_dir,
                                 transform=transform_random,
                                 fonts=train_fonts)
    synth_eval_data = TextDataset(data_path=synth_eval_data_path,
                                  lexicon=lexicon,
                                  base_path=synth_eval_base_dir,
                                  transform=transform_random,
                                  fonts=train_fonts)
    orig_eval_data = TextDataset(data_path=orig_eval_data_path,
                                 lexicon=lexicon,
                                 base_path=orig_eval_base_dir,
                                 transform=transform_simple,
                                 fonts=None)
    if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
        orig_vat_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    if ada_after_rnn or ada_before_rnn:
        orig_ada_data = TextDataset(data_path=orig_eval_data_path,
                                    lexicon=lexicon,
                                    base_path=orig_eval_base_dir,
                                    transform=transform_simple,
                                    fonts=None)

    #else:
    #    train_data = TestDataset(transform=transform, abc=abc).set_mode("train")
    #    synth_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    #    orig_eval_data = TestDataset(transform=transform, abc=abc).set_mode("test")
    seq_proj = [int(x) for x in seq_proj.split('x')]
    net = load_model(lexicon=train_data.get_lexicon(),
                     seq_proj=seq_proj,
                     backend=backend,
                     snapshot=snapshot,
                     cuda=cuda,
                     do_beam_search=do_beam_search,
                     dropout_conv=dropout_conv,
                     dropout_rnn=dropout_rnn,
                     dropout_output=dropout_output,
                     do_ema=do_ema,
                     ada_after_rnn=ada_after_rnn,
                     ada_before_rnn=ada_before_rnn,
                     rnn_hidden_size=rnn_hidden_size)
    optimizer = optim.Adam(net.parameters(), lr=base_lr, weight_decay=0.0001)
    if do_ada_lr:
        print('using ada lr')
        lr_scheduler = DannLR(optimizer, max_iter=max_iter)
    elif do_lr_step:
        print('using step lr')
        lr_scheduler = StepLR(optimizer,
                              step_size=step_size,
                              max_iter=max_iter)
    loss_function = CTCLoss()

    synth_avg_ed_best = float("inf")
    orig_avg_ed_best = float("inf")
    epoch_count = 0

    if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
        collate_vat = lambda x: text_collate(x, do_mask=True)
        vat_load = DataLoader(orig_vat_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_vat)
        vat_len = len(vat_load)
        cur_vat = 0
        vat_iter = iter(vat_load)
    if ada_after_rnn or ada_before_rnn:
        collate_ada = lambda x: text_collate(x, do_mask=True)
        ada_load = DataLoader(orig_ada_data,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_ada)
        ada_len = len(ada_load)
        cur_ada = 0
        ada_iter = iter(ada_load)

    loss_domain = torch.nn.NLLLoss()

    while True:
        collate = lambda x: text_collate(
            x, do_mask=(do_vat or ada_before_rnn or ada_after_rnn))
        data_loader = DataLoader(train_data,
                                 batch_size=batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 collate_fn=collate)

        loss_mean_ctc = []
        loss_mean_vat = []
        loss_mean_at = []
        loss_mean_comp = []
        loss_mean_total = []
        loss_mean_test_vat = []
        loss_mean_test_pseudo = []
        loss_mean_test_rand = []
        loss_mean_ada_rnn_s = []
        loss_mean_ada_rnn_t = []
        loss_mean_ada_cnn_s = []
        loss_mean_ada_cnn_t = []
        iterator = tqdm(data_loader)
        iter_count = 0
        for iter_num, sample in enumerate(iterator):
            total_iter = (epoch_count * len(data_loader)) + iter_num
            if ((total_iter > 1)
                    and total_iter % test_iter == 0) or (test_init
                                                         and total_iter == 0):
                # epoch_count != 0 and

                print("Test phase")
                net = net.eval()
                if do_ema:
                    net.start_test()

                synth_acc, synth_avg_ed, synth_avg_no_stop_ed, synth_avg_loss = test(
                    net,
                    synth_eval_data,
                    synth_eval_data.get_lexicon(),
                    cuda,
                    visualize=False,
                    dataset_info=dataset_info,
                    batch_size=batch_size,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='val_synth',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=False)

                orig_acc, orig_avg_ed, orig_avg_no_stop_ed, orig_avg_loss = test(
                    net,
                    orig_eval_data,
                    orig_eval_data.get_lexicon(),
                    cuda,
                    visualize=False,
                    dataset_info=dataset_info,
                    batch_size=batch_size,
                    tb_writer=tb_writer,
                    n_iter=total_iter,
                    initial_title='test_orig',
                    loss_function=loss_function,
                    output_path=os.path.join(output_dir, 'results'),
                    do_beam_search=do_beam_search)

                net = net.train()
                #save periodic
                if output_dir is not None and total_iter // 30000:
                    periodic_save = os.path.join(output_dir, 'periodic_save')
                    os.makedirs(periodic_save, exist_ok=True)
                    old_save = glob.glob(os.path.join(periodic_save, '*'))

                    torch.save(
                        net.state_dict(),
                        os.path.join(output_dir, "crnn_" + backend + "_" +
                                     str(total_iter)))

                if orig_avg_no_stop_ed < orig_avg_ed_best:
                    orig_avg_ed_best = orig_avg_no_stop_ed
                    if output_dir is not None:
                        torch.save(
                            net.state_dict(),
                            os.path.join(output_dir,
                                         "crnn_" + backend + "_best"))

                if synth_avg_no_stop_ed < synth_avg_ed_best:
                    synth_avg_ed_best = synth_avg_no_stop_ed
                if do_ema:
                    net.end_test()
                print(
                    "synth: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(synth_avg_ed_best, synth_avg_ed,
                            synth_avg_no_stop_ed, synth_acc))
                print(
                    "orig: avg_ed_best: {}\t avg_ed: {}; avg_nostop_ed: {}; acc: {}"
                    .format(orig_avg_ed_best, orig_avg_ed, orig_avg_no_stop_ed,
                            orig_acc))
                tb_writer.get_writer().add_scalars(
                    'data/test', {
                        'synth_ed_total': synth_avg_ed,
                        'synth_ed_no_stop': synth_avg_no_stop_ed,
                        'synth_avg_loss': synth_avg_loss,
                        'orig_ed_total': orig_avg_ed,
                        'orig_ed_no_stop': orig_avg_no_stop_ed,
                        'orig_avg_loss': orig_avg_loss
                    }, total_iter)
                if len(loss_mean_ctc) > 0:
                    train_dict = {'mean_ctc_loss': np.mean(loss_mean_ctc)}
                    if do_vat:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_vat_loss': np.mean(loss_mean_vat)
                            }
                        }
                    if do_at:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_at_loss': np.mean(loss_mean_at)
                            }
                        }
                    if do_test_vat:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_loss': np.mean(loss_mean_test_vat)
                            }
                        }
                    if do_test_vat_rnn and do_test_vat_cnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_crnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    elif do_test_vat_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_rnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    elif do_test_vat_cnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_test_vat_cnn_loss':
                                np.mean(loss_mean_test_vat)
                            }
                        }
                    if ada_after_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_ada_rnn_s_loss': np.mean(loss_mean_ada_rnn_s),
                                'mean_ada_rnn_t_loss': np.mean(loss_mean_ada_rnn_t)
                            }
                        }
                    if ada_before_rnn:
                        train_dict = {
                            **train_dict,
                            **{
                                'mean_ada_cnn_s_loss': np.mean(loss_mean_ada_cnn_s),
                                'mean_ada_cnn_t_loss': np.mean(loss_mean_ada_cnn_t)
                            }
                        }
                    print(train_dict)
                    tb_writer.get_writer().add_scalars('data/train',
                                                       train_dict, total_iter)
            '''
            # for multi-gpu support
            if sample["img"].size(0) % len(gpu.split(',')) != 0:
                continue
            '''
            optimizer.zero_grad()
            imgs = Variable(sample["img"])
            #print("images sizes are:")
            #print(sample["img"].shape)
            if do_vat or ada_after_rnn or ada_before_rnn:
                mask = sample['mask']
            labels_flatten = Variable(sample["seq"]).view(-1)
            label_lens = Variable(sample["seq_len"].int())
            #print("image sequence length is:")
            #print(sample["im_seq_len"])
            #print("label sequence length is:")
            #print(sample["seq_len"].view(1,-1))
            img_seq_lens = sample["im_seq_len"]
            if cuda:
                imgs = imgs.cuda()
                if do_vat or ada_after_rnn or ada_before_rnn:
                    mask = mask.cuda()

            if do_ada_lr:
                ada_p = float(iter_count) / max_iter
                lr_scheduler.update(ada_p)

            if ada_before_rnn or ada_after_rnn:
                if not do_ada_lr:
                    ada_p = float(iter_count) / max_iter
                ada_alpha = 2. / (1. + np.exp(-10. * ada_p)) - 1

                if cur_ada >= ada_len:
                    ada_load = DataLoader(orig_ada_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_ada)
                    ada_len = len(ada_load)
                    cur_ada = 0
                    ada_iter = iter(ada_load)
                ada_batch = next(ada_iter)
                cur_ada += 1
                ada_imgs = Variable(ada_batch["img"])
                ada_img_seq_lens = ada_batch["im_seq_len"]
                ada_mask = ada_batch['mask'].byte()
                if cuda:
                    ada_imgs = ada_imgs.cuda()

                _, ada_cnn, ada_rnn = net(ada_imgs,
                                          ada_img_seq_lens,
                                          ada_alpha=ada_alpha,
                                          mask=ada_mask)
                if ada_before_rnn:
                    ada_num_features = ada_cnn.size(0)
                else:
                    ada_num_features = ada_rnn.size(0)
                domain_label = torch.zeros(ada_num_features)
                domain_label = domain_label.long()
                if cuda:
                    domain_label = domain_label.cuda()
                domain_label = Variable(domain_label)

                if ada_before_rnn:
                    err_ada_cnn_t = loss_domain(ada_cnn, domain_label)
                if ada_after_rnn:
                    err_ada_rnn_t = loss_domain(ada_rnn, domain_label)

            if do_test_vat and do_at:
                # test part!
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                test_vat_batch = next(vat_iter)
                cur_vat += 1
                test_vat_mask = test_vat_batch['mask']
                test_vat_imgs = Variable(test_vat_batch["img"])
                test_vat_img_seq_lens = test_vat_batch["im_seq_len"]
                if cuda:
                    test_vat_imgs = test_vat_imgs.cuda()
                    test_vat_mask = test_vat_mask.cuda()
                # train part
                at_test_vat_loss = LabeledAtAndUnlabeledTestVatLoss(
                    xi=vat_xi, eps=vat_epsilon, ip=vat_ip)

                at_loss, test_vat_loss = at_test_vat_loss(
                    model=net,
                    train_x=imgs,
                    train_labels_flatten=labels_flatten,
                    train_img_seq_lens=img_seq_lens,
                    train_label_lens=label_lens,
                    batch_size=batch_size,
                    test_x=test_vat_imgs,
                    test_seq_len=test_vat_img_seq_lens,
                    test_mask=test_vat_mask)
            elif do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
                if cur_vat >= vat_len:
                    vat_load = DataLoader(orig_vat_data,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          shuffle=True,
                                          collate_fn=collate_vat)
                    vat_len = len(vat_load)
                    cur_vat = 0
                    vat_iter = iter(vat_load)
                vat_batch = next(vat_iter)
                cur_vat += 1
                vat_mask = vat_batch['mask']
                vat_imgs = Variable(vat_batch["img"])
                vat_img_seq_lens = vat_batch["im_seq_len"]
                if cuda:
                    vat_imgs = vat_imgs.cuda()
                    vat_mask = vat_mask.cuda()
                if do_test_vat:
                    if do_test_vat_rnn or do_test_vat_cnn:
                        raise "can only do one of do_test_vat | (do_test_vat_rnn, do_test_vat_cnn)"
                    if vat_sign == True:
                        test_vat_loss = VATLossSign(
                            do_test_entropy=do_test_entropy,
                            xi=vat_xi,
                            eps=vat_epsilon,
                            ip=vat_ip)
                    else:
                        test_vat_loss = VATLoss(xi=vat_xi,
                                                eps=vat_epsilon,
                                                ip=vat_ip)
                elif do_test_vat_rnn and do_test_vat_cnn:
                    test_vat_loss = VATonRnnCnnSign(xi=vat_xi,
                                                    eps=vat_epsilon,
                                                    ip=vat_ip)
                elif do_test_vat_rnn:
                    test_vat_loss = VATonRnnSign(xi=vat_xi,
                                                 eps=vat_epsilon,
                                                 ip=vat_ip)
                elif do_test_vat_cnn:
                    test_vat_loss = VATonCnnSign(xi=vat_xi,
                                                 eps=vat_epsilon,
                                                 ip=vat_ip)
                if do_test_vat_cnn and do_test_vat_rnn:
                    test_vat_loss, cnn_lds, rnn_lds = test_vat_loss(
                        net, vat_imgs, vat_img_seq_lens, vat_mask)
                elif do_test_vat:
                    test_vat_loss = test_vat_loss(net, vat_imgs,
                                                  vat_img_seq_lens, vat_mask)
            elif do_vat:
                vat_loss = VATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip)
                vat_loss = vat_loss(net, imgs, img_seq_lens, mask)
            elif do_at:
                at_loss = LabeledATLoss(xi=vat_xi, eps=vat_epsilon, ip=vat_ip)
                at_loss = at_loss(net, imgs, labels_flatten, img_seq_lens,
                                  label_lens, batch_size)

            if ada_after_rnn or ada_before_rnn:
                preds, ada_cnn, ada_rnn = net(imgs,
                                              img_seq_lens,
                                              ada_alpha=ada_alpha,
                                              mask=mask)

                if ada_before_rnn:
                    ada_num_features = ada_cnn.size(0)
                else:
                    ada_num_features = ada_rnn.size(0)

                domain_label = torch.ones(ada_num_features)
                domain_label = domain_label.long()
                if cuda:
                    domain_label = domain_label.cuda()
                domain_label = Variable(domain_label)

                if ada_before_rnn:
                    err_ada_cnn_s = loss_domain(ada_cnn, domain_label)
                if ada_after_rnn:
                    err_ada_rnn_s = loss_domain(ada_rnn, domain_label)

            else:
                preds = net(imgs, img_seq_lens)
            '''
            if output_dir is not None:
                if (show_iter is not None and iter_num != 0 and iter_num % show_iter == 0):
                    print_data_visuals(net, tb_writer, train_data.get_lexicon(), sample["img"], labels_flatten, label_lens,
                                       preds, ((epoch_count * len(data_loader)) + iter_num))
            '''
            loss_ctc = loss_function(
                preds, labels_flatten,
                Variable(torch.IntTensor(np.array(img_seq_lens))),
                label_lens) / batch_size

            if loss_ctc.data[0] in [float("inf"), -float("inf")]:
                print("warnning: loss should not be inf.")
                continue
            total_loss = loss_ctc

            if do_vat:
                #mask = sample['mask']
                #if cuda:
                #    mask = mask.cuda()
                #vat_loss = virtual_adversarial_loss(net, imgs, img_seq_lens, mask, is_training=True, do_entropy=False, epsilon=vat_epsilon, num_power_iterations=1,
                #             xi=1e-6, average_loss=True)
                total_loss = total_loss + vat_ratio * vat_loss.cpu()
            if do_test_vat or do_test_vat_rnn or do_test_vat_cnn:
                total_loss = total_loss + test_vat_ratio * test_vat_loss.cpu()

            if ada_before_rnn:
                total_loss = total_loss + ada_ratio * err_ada_cnn_s.cpu(
                ) + ada_ratio * err_ada_cnn_t.cpu()
            if ada_after_rnn:
                total_loss = total_loss + ada_ratio * err_ada_rnn_s.cpu(
                ) + ada_ratio * err_ada_rnn_t.cpu()

            total_loss.backward()
            nn.utils.clip_grad_norm(net.parameters(), 10.0)
            if -400 < loss_ctc.data[0] < 400:
                loss_mean_ctc.append(loss_ctc.data[0])
            if -1000 < total_loss.data[0] < 1000:
                loss_mean_total.append(total_loss.data[0])
            if len(loss_mean_total) > 100:
                loss_mean_total = loss_mean_total[-100:]
            status = "epoch: {0:5d}; iter_num: {1:5d}; lr: {2:.2E}; loss_mean: {3:.3f}; loss: {4:.3f}".format(
                epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(),
                np.mean(loss_mean_total), loss_ctc.data[0])
            if ada_after_rnn:
                loss_mean_ada_rnn_s.append(err_ada_rnn_s.data[0])
                loss_mean_ada_rnn_t.append(err_ada_rnn_t.data[0])
                status += "; ladatrnns: {0:.3f}; ladatrnnt: {1:.3f}".format(
                    err_ada_rnn_s.data[0], err_ada_rnn_t.data[0])
            if ada_before_rnn:
                loss_mean_ada_cnn_s.append(err_ada_cnn_s.data[0])
                loss_mean_ada_cnn_t.append(err_ada_cnn_t.data[0])
                status += "; ladatcnns: {0:.3f}; ladatcnnt: {1:.3f}".format(
                    err_ada_cnn_s.data[0], err_ada_cnn_t.data[0])
            if do_vat:
                loss_mean_vat.append(vat_loss.data[0])
                status += "; lvat: {0:.3f}".format(vat_loss.data[0])
            if do_at:
                loss_mean_at.append(at_loss.data[0])
                status += "; lat: {0:.3f}".format(at_loss.data[0])
            if do_test_vat:
                loss_mean_test_vat.append(test_vat_loss.data[0])
                status += "; l_tvat: {0:.3f}".format(test_vat_loss.data[0])
            if do_test_vat_rnn or do_test_vat_cnn:
                loss_mean_test_vat.append(test_vat_loss.data[0])
                if do_test_vat_rnn and do_test_vat_cnn:
                    status += "; l_tvatc: {}".format(cnn_lds.data[0])
                    status += "; l_tvatr: {}".format(rnn_lds.data[0])
                else:
                    status += "; l_tvat: {}".format(test_vat_loss.data[0])

            iterator.set_description(status)
            optimizer.step()
            if do_lr_step:
                lr_scheduler.step()
            if do_ema:
                net.udate_ema()
            iter_count += 1
        if output_dir is not None:
            torch.save(net.state_dict(),
                       os.path.join(output_dir, "crnn_" + backend + "_last"))
        epoch_count += 1

    return
예제 #20
0
def train(train_dataloader, query_dataloader, retrieval_dataloader, arch, code_length, device, lr,
          max_iter, topk, evaluate_interval, anchor_num, proportion
          ):
    rho1 = 1e-2
    #ρ1
    rho2 = 1e-2
    #ρ2
    rho3 = 1e-3
    #µ1
    rho4 = 1e-3
    #µ2
    gamma = 1e-3
    #γ
    with torch.no_grad():
        data_mo = torch.tensor([]).to(device)
        for data, _, _ in train_dataloader:
            data = data.to(device)
            data_mo = torch.cat((data_mo, data), 0)
            torch.cuda.empty_cache()
        n = data_mo.size(1)
        Y1 = torch.rand(n, code_length).to(device)
        Y2 = torch.rand(n, code_length).to(device)
        B=torch.rand(n,code_length).to(device)
    # Load model
    model = load_model(arch, code_length).to(device)
    # Create criterion, optimizer, scheduler
    criterion = PrototypicalLoss()
    optimizer = optim.RMSprop(
        model.parameters(),
        lr=lr,
        weight_decay=5e-4,
    )
    scheduler = CosineAnnealingLR(
        optimizer,
        max_iter,
        lr / 100,
    )

    # Initialization
    running_loss = 0.
    best_map = 0.
    training_time = 0.

    # Training
    for it in range(max_iter):
        # timer
        tic = time.time()

        # ADMM use anchors in first step but drop them later
        '''
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)
            for data, _, _ in train_dataloader:
                data = data.to(device)
                output_mo_temp = model_mo(data)
                output_mo = torch.cat((output_mo, output_mo_temp), 0)
                torch.cuda.empty_cache()
            anchor = get_anchor(output_mo, anchor_num, device)  # compute anchor
        '''
        with torch.no_grad():
            output_mo = torch.tensor([]).to(device)

            for data, _, _ in train_dataloader:
                output_B, output_A = model(data)
                output_mo = torch.cat((output_mo, output_A), 0)
                torch.cuda.empty_cache()

            dist = euclidean_dist(output_mo, output_mo)
            dist = torch.exp(-1 * dist / torch.max(dist)).to(device)
            A = (2 / (torch.max(dist) - torch.min(dist))) * dist - 1
            global_A=A.numpy()
            Z1 = B + 1 / rho1 * Y1
            Z1[Z1 > 1] = 1
            Z1[Z1 > -1] = -1
            Z2 = B + 1 / rho2 * Y2
            norm_B = torch.norm(Z2)
            Z2 = torch.sqrt(n * code_length) * Z2 / norm_B
            Y1 = Y1 + gamma * rho1 * (B - Z1)
            Y2 = Y2 + gamma * rho2 * (B - Z2)
            global_Z1=Z1.numpy()
            global_Z2=Z2.numpy()
            global_Y1 = Y1.numpy()
            global_Y2 = Y2.numpy()
            B0 = B.numpy()
            B= torch.from_numpy(scipy.optimize.fmin_l_bfgs_b(Baim_func, B0)).to(device)
        # self-supervised deep learning
        model.train()
        for data, targets, index in train_dataloader:
            data, targets, index = data.to(device), targets.to(device), index.to(device)
            optimizer.zero_grad()

            # output_B for hash code .output_A for result without hash layer
            output_B, output_A= model(data)

            loss = criterion(output_B, B)

            running_loss += loss.item()
            loss.backward()

            optimizer.step()

        scheduler.step()
        training_time += time.time() - tic

        # Evaluate
        if it % evaluate_interval == evaluate_interval - 1:
            # Generate hash code
            query_code = generate_code(model, query_dataloader, code_length, device)
            retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)

            query_targets = query_dataloader.dataset.get_onehot_targets()
            retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()

            # Compute map
            mAP = mean_average_precision(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
                topk,
            )

            # Compute pr curve
            P, R = pr_curve(
                query_code.to(device),
                retrieval_code.to(device),
                query_targets.to(device),
                retrieval_targets.to(device),
                device,
            )

            # Log
            logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(
                it + 1,
                max_iter,
                running_loss / evaluate_interval,
                mAP,
                training_time,
            ))
            running_loss = 0.

            # Checkpoint
            if best_map < mAP:
                best_map = mAP

                checkpoint = {
                    'model': model.state_dict(),
                    'qB': query_code.cpu(),
                    'rB': retrieval_code.cpu(),
                    'qL': query_targets.cpu(),
                    'rL': retrieval_targets.cpu(),
                    'P': P,
                    'R': R,
                    'map': best_map,
                }

    return checkpoint
예제 #21
0
def main():
    input_size = [int(x) for x in config.input_size.split('x')]
    # TODO: 1) Sử dụng elastic transform 2) Random erasor một phần của bức ảnh. de data augmentation
    transform = Compose([
        # Rotation(),
        # Resize(size=(input_size[0], input_size[1]), data_augmen=True)
        Resize(size=(input_size[0], input_size[1]))
    ])
    data = TextDataset(data_path=config.data_path,
                       mode="train",
                       transform=transform)
    print("Len of train =", len(data))
    data.set_mode("dev")
    print("Len of dev =", len(data))
    data.set_mode("test")
    print("Len of test =", len(data))
    data.set_mode("test_annotated")
    print("Len of test_annotated =", len(data))
    data.set_mode("train")

    net = load_model(input_size, data.get_abc(), None, config.backend,
                     config.snapshot)
    total_params = sum(p.numel() for p in net.parameters())
    train_total_params = sum(p.numel() for p in net.parameters()
                             if p.requires_grad)
    print("# of parameters =", total_params)
    print("# of non-training parameters =", total_params - train_total_params)
    print("")
    if config.output_image:
        input_img_path = os.path.join(config.output_dir, "input_images")
        file_list = glob.glob(input_img_path + "/*")
        print("Remove the old", input_img_path)
        for file in file_list:
            if os.path.isfile(file):
                os.remove(file)

    optimizer = optim.Adam(net.parameters(), lr=config.base_lr)
    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     factor=0.5,
                                     patience=5,
                                     verbose=True)
    loss_function = CTCLoss(blank=0)
    loss_label = nn.NLLLoss()

    dev_avg_ed_best = float("inf")
    anno_avg_ed_best = 0.1544685954462857
    epoch_count = 0
    print("Start running ...")

    while True:
        # test dev phrase
        # if epoch_count == 0:
        #     print("dev phase")
        #     data.set_mode("dev")
        #     acc, dev_avg_ed = test(net, data, data.get_abc(), visualize=True,
        #                            batch_size=config.batch_size, num_workers=config.num_worker)
        #     print("DEV: acc: {}; avg_ed: {}; avg_ed_best: {}".format(acc, dev_avg_ed, dev_avg_ed_best))
        #
        #     data.set_mode("test_annotated")
        #     annotated_acc, annotated_avg_ed = test(net, data, data.get_abc(), visualize=True,
        #                                            batch_size=config.batch_size, num_workers=config.num_worker)
        #     print("ANNOTATED: acc: {}; avg_ed: {}".format(annotated_acc, annotated_avg_ed))

        net = net.train()
        data.set_mode("train")
        data_loader = DataLoader(data,
                                 batch_size=config.batch_size,
                                 num_workers=config.num_worker,
                                 shuffle=True,
                                 collate_fn=text_collate)
        loss_mean = []
        iterator = tqdm(data_loader)
        for sample in iterator:
            optimizer.zero_grad()
            imgs = Variable(sample["img"])
            labels_ocr = Variable(sample["seq"]).view(-1)
            labels_ocr_len = Variable(sample["seq_len"].int())
            labels = Variable(sample["label"].long())
            imgs = imgs.cuda()

            preds, label_logsoftmax = net(imgs)
            preds = preds.cpu()
            label_logsoftmax = label_logsoftmax.cpu()
            pred_lens = Variable(
                Tensor([preds.size(0)] * len(labels_ocr_len)).int())

            # ctc loss len > label_len
            assert preds.size()[0] > max(labels_ocr_len).item()
            loss = loss_function(preds, labels_ocr, pred_lens,
                                 labels_ocr_len) + loss_label(
                                     label_logsoftmax, labels)

            # unit test
            assert not torch.isnan(loss).any()
            assert not torch.isinf(loss).any()
            assert loss.item() != 0
            loss.backward()
            for name, para in net.named_parameters():
                if (para.grad is None or para.grad.equal(
                        torch.zeros_like(para.grad))) and para.requires_grad:
                    print("WARNING: There is no grad at", name)

            nn.utils.clip_grad_norm_(net.parameters(), 10.0)
            loss_mean.append(loss.item())
            optimizer.step()

        print("dev phase")
        data.set_mode("dev")
        acc, dev_avg_ed = test(net,
                               data,
                               data.get_abc(),
                               visualize=True,
                               batch_size=config.batch_size,
                               num_workers=0)

        if dev_avg_ed < dev_avg_ed_best:
            assert config.output_dir is not None
            torch.save(
                net.state_dict(),
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best"))
            print(
                "Saving best model to",
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best"))
            dev_avg_ed_best = dev_avg_ed

        # TODO: print avg_ed & acc in train epoch
        print("train: epoch: {}; loss_mean: {}".format(epoch_count,
                                                       np.mean(loss_mean)))
        print("dev: acc: {}; avg_ed: {}; avg_ed_best: {}".format(
            acc, dev_avg_ed, dev_avg_ed_best))

        data.set_mode("test_annotated")
        annotated_acc, annotated_avg_ed = test(net,
                                               data,
                                               data.get_abc(),
                                               visualize=True,
                                               batch_size=config.batch_size,
                                               num_workers=config.num_worker)
        if annotated_avg_ed < anno_avg_ed_best:
            assert config.output_dir is not None
            torch.save(
                net.state_dict(),
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best_anno"))
            print(
                "Saving best model to",
                os.path.join(config.output_dir,
                             "crnn_" + config.backend + "_best_anno"))
            anno_avg_ed_best = annotated_avg_ed
        print("ANNOTATED: acc: {}; avg_ed: {}, best: {}".format(
            annotated_acc, annotated_avg_ed, anno_avg_ed_best))

        # TODO: add tensorboard to visualize loss_mean & avg_ed & acc
        lr_scheduler.step(dev_avg_ed)
        epoch_count += 1