Exemple #1
0
def train(config):

    training_set_list = load_file_list(pretrain_config['training_set'])
    train_dataset = SolDataset(
        training_set_list,
        rescale_range=pretrain_config['sol']['training_rescale_range'],
        transform=CropTransform(pretrain_config['sol']['crop_params']))

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=pretrain_config['sol']['batch_size'],
        shuffle=True,
        num_workers=0,
        collate_fn=sol.sol_dataset.collate)

    batches_per_epoch = int(pretrain_config['sol']['images_per_epoch'] /
                            pretrain_config['sol']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    if not os.path.exists("snapshots/sol_train"):
        os.makedirs("snapshots/sol_train")

    solf.train()
    sum_loss = 0.0
    steps = 0.0

    for step_i, x in enumerate(train_dataloader):
        img = Variable(x['img'].type(dtype), requires_grad=False)

        sol_gt = None
        if x['sol_gt'] is not None:
            # This is needed because if sol_gt is None it means that there
            # no GT positions in the image. The alignment loss will handle,
            # it correctly as None
            sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False)

        # print((img.shape))
        predictions = solf(img)
        loss = alignment_loss(predictions, sol_gt, x['label_sizes'],
                              alpha_alignment, alpha_backprop)

        org_img = img[0].data.cpu().numpy().transpose([2, 1, 0])
        org_img = ((org_img + 1) * 128).astype(np.uint8)
        org_img = org_img.copy()

        org_img = drawing.draw_sol_torch(predictions, org_img)
        cv2.imwrite("snapshots/sol_train/{}.png".format(step_i), org_img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        sum_loss += loss.data.cpu().numpy()
        steps += 1
        predictions = None
        loss = None

        gc.collect()

    return sum_loss / steps
Exemple #2
0
def training_step(config):

    train_config = config['training']

    allowed_training_time = train_config['sol']['reset_interval']
    init_training_time = time.time()

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = SolDataset(
        training_set_list,
        rescale_range=train_config['sol']['training_rescale_range'],
        transform=CropTransform(train_config['sol']['crop_params']))

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_config['sol']['batch_size'],
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=sol_dataset.collate)

    batches_per_epoch = int(train_config['sol']['images_per_epoch'] /
                            train_config['sol']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = SolDataset(
        test_set_list,
        rescale_range=train_config['sol']['validation_rescale_range'],
        random_subset_size=train_config['sol']['validation_subset_size'],
        transform=None)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=sol_dataset.collate)

    alpha_alignment = train_config['sol']['alpha_alignment']
    alpha_backprop = train_config['sol']['alpha_backprop']

    sol, lf, hw = init_model(config, only_load='sol')

    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    epoch = -1
    while True:  #This ends on a break based on the current itme
        epoch += 1
        print "Train Time:", (
            time.time() -
            init_training_time), "Allowed Time:", allowed_training_time

        sol.eval()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(test_dataloader):
            img = Variable(x['img'].type(dtype), requires_grad=False)

            sol_gt = None
            if x['sol_gt'] is not None:
                sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False)

            predictions = sol(img)
            predictions = transformation_utils.pt_xyrs_2_xyxy(predictions)
            loss = alignment_loss(predictions, sol_gt, x['label_sizes'],
                                  alpha_alignment, alpha_backprop)
            sum_loss += loss.data[0]
            steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation CER:", sum_loss / steps
            lowest_loss = sum_loss / steps

            sol, lf, hw = init_model(config,
                                     sol_dir='current',
                                     only_load='sol')

            optimizer = torch.optim.Adam(
                sol.parameters(), lr=train_config['sol']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "sol_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        elif lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "sol.pt")

            torch.save(sol.state_dict(), save_path)
            lowest_loss_i = epoch

        print "Test Loss", sum_loss / steps, lowest_loss
        print "Time:", time.time() - start_time
        print ""

        print "Epoch", epoch

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time. Saving current state and exiting..."
            dirname = train_config['snapshot']['current']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "sol.pt")
            torch.save(sol.state_dict(), save_path)

            optim_path = os.path.join(dirname, "sol_optim.pt")
            torch.save(optimizer.state_dict(), optim_path)
            break

        sol.train()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(train_dataloader):
            img = Variable(x['img'].type(dtype), requires_grad=False)

            sol_gt = None
            if x['sol_gt'] is not None:
                sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False)

            predictions = sol(img)
            predictions = transformation_utils.pt_xyrs_2_xyxy(predictions)
            loss = alignment_loss(predictions, sol_gt, x['label_sizes'],
                                  alpha_alignment, alpha_backprop)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += loss.data[0]
            steps += 1

        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch
        print "Time:", time.time() - start_time
Exemple #3
0
training_set_list = load_file_list(pretrain_config['training_set'])
train_dataset = SolDataset(
    training_set_list,
    rescale_range=pretrain_config['sol']['training_rescale_range'],
    transform=CropTransform(pretrain_config['sol']['crop_params']))

train_dataloader = DataLoader(train_dataset,
                              batch_size=pretrain_config['sol']['batch_size'],
                              shuffle=True,
                              num_workers=0,
                              collate_fn=sol.sol_dataset.collate)

batches_per_epoch = int(pretrain_config['sol']['images_per_epoch'] /
                        pretrain_config['sol']['batch_size'])
train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

test_set_list = load_file_list(pretrain_config['validation_set'])
test_dataset = SolDataset(
    test_set_list,
    rescale_range=pretrain_config['sol']['validation_rescale_range'],
    transform=None)
test_dataloader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0,
                             collate_fn=sol.sol_dataset.collate)

base0 = sol_network_config['base0']
base1 = sol_network_config['base1']
sol = StartOfLineFinder(base0, base1)
Exemple #4
0
def training_step(config):

    hw_network_config = config['network']['hw']
    train_config = config['training']

    allowed_training_time = train_config['hw']['reset_interval']
    init_training_time = time.time()

    char_set_path = hw_network_config['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = HwDataset(training_set_list,
                              char_set['char_to_idx'],
                              augmentation=True,
                              img_height=hw_network_config['input_height'])

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_config['hw']['batch_size'],
                                  shuffle=False,
                                  num_workers=0,
                                  collate_fn=hw_dataset.collate)

    batches_per_epoch = int(train_config['hw']['images_per_epoch'] /
                            train_config['hw']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = HwDataset(
        test_set_list,
        char_set['char_to_idx'],
        img_height=hw_network_config['input_height'],
        random_subset_size=train_config['hw']['validation_subset_size'])

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=train_config['hw']['batch_size'],
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=hw_dataset.collate)

    hw = cnn_lstm.create_model(hw_network_config)
    hw_path = os.path.join(train_config['snapshot']['best_validation'],
                           "hw.pt")
    hw_state = safe_load.torch_state(hw_path)
    hw.load_state_dict(hw_state)
    hw.cuda()
    criterion = CTCLoss()
    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000000):
        sum_loss = 0.0
        steps = 0.0
        hw.eval()
        for x in test_dataloader:
            sys.stdout.flush()
            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False,
                                 volatile=True)
            labels = Variable(x['labels'], requires_grad=False, volatile=True)
            label_lengths = Variable(x['label_lengths'],
                                     requires_grad=False,
                                     volatile=True)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation CER:", sum_loss / steps
            lowest_loss = sum_loss / steps

            hw = cnn_lstm.create_model(hw_network_config)
            hw_path = os.path.join(train_config['snapshot']['current'],
                                   "hw.pt")
            hw_state = safe_load.torch_state(hw_path)
            hw.load_state_dict(hw_state)
            hw.cuda()

            optimizer = torch.optim.Adam(
                hw.parameters(), lr=train_config['hw']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "hw_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "hw.pt")

            torch.save(hw.state_dict(), save_path)
            lowest_loss_i = epoch

        print "Test Loss", sum_loss / steps, lowest_loss
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        hw.train()
        for i, x in enumerate(train_dataloader):

            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False)
            labels = Variable(x['labels'], requires_grad=False)
            label_lengths = Variable(x['label_lengths'], requires_grad=False)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            # if i == 0:
            #     for i in xrange(out.shape[0]):
            #         pred, pred_raw = string_utils.naive_decode(out[i,...])
            #         pred_str = string_utils.label2str_single(pred_raw, idx_to_char, True)
            #         print pred_str

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

            batch_size = preds.size(1)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))

            loss = criterion(preds, labels, preds_size, label_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "hw.pt")
    torch.save(hw.state_dict(), save_path)

    optim_path = os.path.join(dirname, "hw_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)
Exemple #5
0
def training_step(config):

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    train_config = config['training']

    allowed_training_time = train_config['lf']['reset_interval']
    init_training_time = time.time()

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = LfDataset(training_set_list, augmentation=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=1,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=lf_dataset.collate)
    batches_per_epoch = int(train_config['lf']['images_per_epoch'] /
                            train_config['lf']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = LfDataset(
        test_set_list,
        random_subset_size=train_config['lf']['validation_subset_size'])
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=lf_dataset.collate)

    _, lf, hw = init_model(config, only_load=['lf', 'hw'])
    hw.eval()

    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000):
        lf.eval()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(test_dataloader):
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                print "Skipping"
                continue

            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            skip_grid=False)

            line = torch.nn.functional.grid_sample(img.transpose(2, 3),
                                                   grid_line)
            line = line.transpose(2, 3)
            predictions = hw(line)

            out = predictions.permute(1, 0, 2).data.cpu().numpy()
            gt_line = x['gt']
            pred, raw_pred = string_utils.naive_decode(out[0])
            pred_str = string_utils.label2str_single(pred, idx_to_char, False)
            cer = error_rates.cer(gt_line, pred_str)
            sum_loss += cer
            steps += 1

            # l = line[0].transpose(0,1).transpose(1,2)
            # l = (l + 1)*128
            # l_np = l.data.cpu().numpy()
            #
            # cv2.imwrite("example_line_out.png", l_np)
            # print "Saved!"
            # raw_input()

            # loss = lf_loss.point_loss(xy_output, xy_positions)
            #
            # sum_loss += loss.data[0]
            # steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation Loss:", sum_loss / steps
            lowest_loss = sum_loss / steps

            _, lf, _ = init_model(config, lf_dir='current', only_load="lf")

            optimizer = torch.optim.Adam(
                lf.parameters(), lr=train_config['lf']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "lf_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "lf.pt")

            torch.save(lf.state_dict(), save_path)
            lowest_loss_i = 0

        test_loss = sum_loss / steps

        print "Test Loss", sum_loss / steps, lowest_loss
        print "Time:", time.time() - start_time
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        lf.train()
        start_time = time.time()
        for x in train_dataloader:
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                continue

            reset_interval = 4
            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            all_positions=positions,
                                            reset_interval=reset_interval,
                                            randomize=True,
                                            skip_grid=True)

            loss = lf_loss.point_loss(xy_output, xy_positions)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += loss.data.item()
            steps += 1

        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch
        print "Time:", time.time() - start_time

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "lf.pt")
    torch.save(lf.state_dict(), save_path)

    optim_path = os.path.join(dirname, "lf_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)