def init_model(config,
               sol_dir='best_validation',
               lf_dir='best_validation',
               hw_dir='best_validation',
               only_load=None):
    base_0 = config['network']['sol']['base0']
    base_1 = config['network']['sol']['base1']

    sol = None
    lf = None
    hw = None

    if only_load is None or only_load == 'sol' or 'sol' in only_load:
        sol = StartOfLineFinder(base_0, base_1)
        sol_state = safe_load.torch_state(
            os.path.join(config['training']['snapshot'][sol_dir], "sol.pt"))
        sol.load_state_dict(sol_state)
        sol.cuda()

    if only_load is None or only_load == 'lf' or 'lf' in only_load:
        lf = LineFollower(config['network']['hw']['input_height'])
        lf_state = safe_load.torch_state(
            os.path.join(config['training']['snapshot'][lf_dir], "lf.pt"))

        # special case for backward support of
        # previous way to save the LF weights
        if 'cnn' in lf_state:
            new_state = {}
            for k, v in lf_state.iteritems():
                if k == 'cnn':
                    for k2, v2 in v.iteritems():
                        new_state[k + "." + k2] = v2
                if k == 'position_linear':
                    for k2, v2 in v.state_dict().iteritems():
                        new_state[k + "." + k2] = v2
                # if k == 'learned_window':
                #     new_state[k]=nn.Parameter(v.data)
            lf_state = new_state

        lf.load_state_dict(lf_state)
        lf.cuda()

    if only_load is None or only_load == 'hw' or 'hw' in only_load:
        hw = cnn_lstm.create_model(config['network']['hw'])
        hw_state = safe_load.torch_state(
            os.path.join(config['training']['snapshot'][hw_dir], "hw.pt"))
        hw.load_state_dict(hw_state)
        hw.cuda()

    return sol, lf, hw
Пример #2
0
    test_dataset = SyntheticDataset(
        char_set['char_to_idx'],
        img_height=hw_network_config['input_height'],
        param_file=pretrain_config['synth_params'],
        generator_aux_path=generator_aux_path,
        dataset_size=2400)

test_dataloader = DataLoader(test_dataset,
                             batch_size=pretrain_config['hw']['batch_size'],
                             shuffle=False,
                             num_workers=threads,
                             collate_fn=hw_dataset.collate)

criterion = CTCLoss()

hw = cnn_lstm.create_model(hw_network_config)
load_path = os.path.join(
    pretrain_config['snapshot_path'],
    '{}_latest.pt'.format(pretrain_config['snapshot_prefix']))
if os.path.isfile(load_path):
    loaded = torch.load(load_path)

    hw.load_state_dict(loaded['state_dict'])
    lowest_cer = loaded['lowest_cer']
    log = loaded['log']
    start_epoch = loaded['epoch'] + 1
    print('Loaded at epoch {}'.format(start_epoch))
else:
    lowest_cer = np.inf
    log = []
    start_epoch = 0
Пример #3
0
def training_step(config):

    hw_network_config = config['network']['hw']
    train_config = config['training']

    allowed_training_time = train_config['hw']['reset_interval']
    init_training_time = time.time()

    char_set_path = hw_network_config['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = HwDataset(training_set_list,
                              char_set['char_to_idx'],
                              augmentation=True,
                              img_height=hw_network_config['input_height'])

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_config['hw']['batch_size'],
                                  shuffle=False,
                                  num_workers=0,
                                  collate_fn=hw_dataset.collate)

    batches_per_epoch = int(train_config['hw']['images_per_epoch'] /
                            train_config['hw']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = HwDataset(
        test_set_list,
        char_set['char_to_idx'],
        img_height=hw_network_config['input_height'],
        random_subset_size=train_config['hw']['validation_subset_size'])

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=train_config['hw']['batch_size'],
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=hw_dataset.collate)

    hw = cnn_lstm.create_model(hw_network_config)
    hw_path = os.path.join(train_config['snapshot']['best_validation'],
                           "hw.pt")
    hw_state = safe_load.torch_state(hw_path)
    hw.load_state_dict(hw_state)
    hw.cuda()
    criterion = CTCLoss()
    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000000):
        sum_loss = 0.0
        steps = 0.0
        hw.eval()
        for x in test_dataloader:
            sys.stdout.flush()
            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False,
                                 volatile=True)
            labels = Variable(x['labels'], requires_grad=False, volatile=True)
            label_lengths = Variable(x['label_lengths'],
                                     requires_grad=False,
                                     volatile=True)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation CER:", sum_loss / steps
            lowest_loss = sum_loss / steps

            hw = cnn_lstm.create_model(hw_network_config)
            hw_path = os.path.join(train_config['snapshot']['current'],
                                   "hw.pt")
            hw_state = safe_load.torch_state(hw_path)
            hw.load_state_dict(hw_state)
            hw.cuda()

            optimizer = torch.optim.Adam(
                hw.parameters(), lr=train_config['hw']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "hw_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "hw.pt")

            torch.save(hw.state_dict(), save_path)
            lowest_loss_i = epoch

        print "Test Loss", sum_loss / steps, lowest_loss
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        hw.train()
        for i, x in enumerate(train_dataloader):

            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False)
            labels = Variable(x['labels'], requires_grad=False)
            label_lengths = Variable(x['label_lengths'], requires_grad=False)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            # if i == 0:
            #     for i in xrange(out.shape[0]):
            #         pred, pred_raw = string_utils.naive_decode(out[i,...])
            #         pred_str = string_utils.label2str_single(pred_raw, idx_to_char, True)
            #         print pred_str

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

            batch_size = preds.size(1)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))

            loss = criterion(preds, labels, preds_size, label_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "hw.pt")
    torch.save(hw.state_dict(), save_path)

    optim_path = os.path.join(dirname, "hw_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)