Beispiel #1
0
def update_alignment(out, gt_lines, alignments, idx_to_char, idx_mapping,
                     sol_positions):

    preds = out.cpu()
    batch_size = preds.size(1)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))

    for i, logits in enumerate(out.data.cpu().numpy()):
        raw_decode, raw_decode_full = string_utils.naive_decode(logits)
        pred_str = string_utils.label2str_single(raw_decode, idx_to_char,
                                                 False)

        for j, gt in enumerate(gt_lines):
            cer = error_rates.cer(gt, pred_str)
            global_i = idx_mapping[i]
            c = sol_positions[i, 0, -1].data[0]

            # alignment_error = cer
            alignment_error = cer + 0.1 * (1.0 - c)

            if alignment_error < alignments[j][0]:
                alignments[j][0] = alignment_error
                alignments[j][1] = global_i
                # alignments[j][2] = out[i][:,None,:]
                alignments[j][2] = None
                alignments[j][3] = pred_str
def getCER(gt, pred, idx_to_char):
    cer = []
    pred_strs = []
    for i, gt_line in enumerate(gt):
        logits = pred[:, i]
        pred_str, raw_pred = string_utils.naive_decode(logits)
        pred_str = string_utils.label2str_single(pred_str, idx_to_char, False)
        cer.append(error_rates.cer(gt_line, pred_str))
        pred_strs.append(pred_str)
    return cer, pred_strs
Beispiel #3
0
def align_to_gt_lines(decoded_hw, gt_lines):
    costs = []
    for i in xrange(len(decoded_hw)):
        costs.append([])
        for j in xrange(len(gt_lines)):
            pred = decoded_hw[i]
            gt = gt_lines[j]
            cer = error_rates.cer(gt, pred)
            costs[i].append(cer)

    costs = np.array(costs)
    min_idx = costs.argmin(axis=0)
    min_val = costs.min(axis=0)

    return min_idx, min_val
 def getCER(self, gt, pred, individual=False):
     cer = 0
     if individual:
         all_cer = []
     pred_strs = []
     for i, gt_line in enumerate(gt):
         logits = pred[:, i]
         pred_str, raw_pred = string_utils.naive_decode(logits)
         pred_str = string_utils.label2str_single(pred_str,
                                                  self.idx_to_char, False)
         this_cer = error_rates.cer(gt_line, pred_str)
         cer += this_cer
         if individual:
             all_cer.append(this_cer)
         pred_strs.append(pred_str)
     cer /= len(gt)
     if individual:
         return cer, pred_strs, all_cer
     return cer, pred_strs
Beispiel #5
0
    def run(self, instance):
        image, targetvalid, targetchars = self._to_tensor(instance)
        gt_chars = instance['gt_char']
        outvalid, outchars = self.model(image)
        batch_size = image.size(0)
        losses = {}
        charLoss = self.loss['char'](outchars.reshape(
            batch_size * outchars.size(1), -1),
                                     targetchars.view(-1),
                                     *self.loss_params['char'],
                                     reduction='none')
        losses['charLoss'] = (
            charLoss.view(batch_size, -1) *
            targetvalid[:, None]).mean()  #only use loss of valid QR imagse
        if 'valid' in self.loss:
            losses['validLoss'] = self.loss['valid'](
                outvalid, targetvalid, *self.loss_params['valid'])

        chars = []
        char_indexes = outchars.argmax(dim=2)
        b_cer = 0
        for b in range(batch_size):
            s = ''
            for p in range(outchars.size(1)):
                if char_indexes[b, p].item() > 0:  #skip the null character
                    s += self.data_loader.dataset.index_to_char[char_indexes[
                        b, p].item()]
                #else:
                #    s+='N'
            chars.append(s)

            if targetvalid[b]:
                b_cer += cer(s, gt_chars[b])
            if outvalid[b] < 0:
                chars[b] = None

        acc = ((outvalid > 0) == (targetvalid > 0)).float().mean().item()
        #print('GT:{} Pred:{}'.format(gt_chars[0],chars[0]))
        #import pdb;pdb.set_trace()
        log = {'cer': b_cer / targetvalid.sum().item(), 'valid_acc': acc}

        return losses, log, chars
Beispiel #6
0
def update_ideal_results(pick, costs, decoded_hw, gt_json):

    most_ideal_pred = []
    improved_idxs = {}

    for i in range(len(gt_json)):
        gt_obj = gt_json[i]

        prev_pred = gt_obj.get('pred', '')
        gt = gt_obj['gt']

        pred = decoded_hw[pick[i]]

        prev_cer = error_rates.cer(gt, prev_pred)
        cer = costs[i]

        if cer > prev_cer or len(pred) == 0:
            most_ideal_pred.append(prev_pred)
            continue

        most_ideal_pred.append(pred)
        improved_idxs[i] = pick[i]

    return most_ideal_pred, improved_idxs
Beispiel #7
0
def accumulate_scores(out, out_positions, xy_positions, gt_state, idx_to_char):

    preds = out.transpose(0, 1).cpu()
    batch_size = preds.size(1)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))

    for i, logits in enumerate(out.data.cpu().numpy()):
        raw_decode, raw_decode_full = string_utils.naive_decode(logits)
        pred_str = string_utils.label2str_single(raw_decode, idx_to_char,
                                                 False)
        pred_str_full = string_utils.label2str_single(raw_decode_full,
                                                      idx_to_char, True)

        sub_out_positions = [
            o[i].data.cpu().numpy().tolist() for o in out_positions
        ]
        sub_xy_positions = [
            o[i].data.cpu().numpy().tolist() for o in xy_positions
        ]

        for gt_obj in gt_state:
            gt_text = gt_obj['gt']
            cer = error_rates.cer(gt_text, pred_str)

            #This is a terrible way to do this...
            gt_obj['errors'] = gt_obj.get('errors', [])
            gt_obj['pred'] = gt_obj.get('pred', [])
            gt_obj['pred_full'] = gt_obj.get('pred_full', [])
            gt_obj['path'] = gt_obj.get('path', [])
            gt_obj['path_xy'] = gt_obj.get('path_xy', [])

            gt_obj['errors'].append(cer)
            gt_obj['pred'].append(pred_str)
            gt_obj['pred_full'].append(pred_str_full)
            gt_obj['path'].append(sub_out_positions)
            gt_obj['path_xy'].append(sub_xy_positions)
def forward_pass(x, e2e, config, thresholds, idx_to_char, update_json=False):

    gt_lines = x['gt_lines']
    gt = "\n".join(gt_lines)

    out_original = e2e(x)
    results = {}
    if out_original is None:
        #TODO: not a good way to handle this, but fine for now
        None

    gt_lines = x['gt_lines']
    gt = "\n".join(gt_lines)

    out_original = E2EModel.results_to_numpy(out_original)
    out_original['idx'] = np.arange(out_original['sol'].shape[0])

    decoded_hw, decoded_raw_hw = E2EModel.decode_handwriting(out_original, idx_to_char)
    pick, costs = E2EModel.align_to_gt_lines(decoded_hw, gt_lines)

    most_ideal_pred_lines, improved_idxs = validation_utils.update_ideal_results(pick, costs, decoded_hw, x['gt_json'])
    # if update_json:
    #     validation_utils.save_improved_idxs(improved_idxs, decoded_hw,
    #                                         decoded_raw_hw, out_original,
    #                                         x, config[dataset_lookup]['json_folder'], config['alignment']['trim_to_sol'])

    sol_thresholds = thresholds[0]
    sol_thresholds_idx = range(len(sol_thresholds))

    lf_nms_ranges =  thresholds[1]
    lf_nms_ranges_idx = range(len(lf_nms_ranges))

    lf_nms_thresholds = thresholds[2]
    lf_nms_thresholds_idx = range(len(lf_nms_thresholds))

    most_ideal_pred_lines = "\n".join(most_ideal_pred_lines)

    ideal_pred_lines = [decoded_hw[i] for i in pick]
    ideal_pred_lines = "\n".join(ideal_pred_lines)

    error = error_rates.cer(gt, ideal_pred_lines)
    ideal_result = error

    error = error_rates.cer(gt, most_ideal_pred_lines)
    most_ideal_result = error

    for key in itertools.product(sol_thresholds_idx, lf_nms_ranges_idx, lf_nms_thresholds_idx):
        i,j,k = key
        sol_threshold = sol_thresholds[i]
        lf_nms_range = lf_nms_ranges[j]
        lf_nms_threshold = lf_nms_thresholds[k]

        out = copy.copy(out_original)

        out = E2EModel.postprocess(out,
            sol_threshold=sol_threshold,
            lf_nms_params={
                "overlap_range": lf_nms_range,
                "overlap_threshold": lf_nms_threshold
        })
        order = E2EModel.read_order(out)
        E2EModel.filter_on_pick(out, order)

        # draw_img = E2EModel.draw_output(out, img)
        # cv2.imwrite("test_b_samples/test_img_{}.png".format(a), draw_img)

        preds = [decoded_hw[i] for i in out['idx']]
        pred = "\n".join(preds)

        error = error_rates.cer(gt, pred)

        results[key] = error

    return results, ideal_result, most_ideal_result
Beispiel #9
0
def alignment_step(config,
                   dataset_lookup=None,
                   model_mode='best_validation',
                   percent_range=None):

    set_list = load_file_list(config['training'][dataset_lookup])

    if percent_range is not None:
        start = int(len(set_list) * percent_range[0])
        end = int(len(set_list) * percent_range[1])
        set_list = set_list[start:end]

    dataset = AlignmentDataset(set_list, None)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            collate_fn=alignment_dataset.collate)

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    sol, lf, hw = init_model(config,
                             sol_dir=model_mode,
                             lf_dir=model_mode,
                             hw_dir=model_mode)

    e2e = E2EModel(sol, lf, hw)
    dtype = torch.cuda.FloatTensor
    e2e.eval()

    post_processing_config = config['training']['alignment'][
        'validation_post_processing']
    sol_thresholds = post_processing_config['sol_thresholds']
    sol_thresholds_idx = range(len(sol_thresholds))

    lf_nms_ranges = post_processing_config['lf_nms_ranges']
    lf_nms_ranges_idx = range(len(lf_nms_ranges))

    lf_nms_thresholds = post_processing_config['lf_nms_thresholds']
    lf_nms_thresholds_idx = range(len(lf_nms_thresholds))

    results = defaultdict(list)
    aligned_results = []
    best_ever_results = []

    prev_time = time.time()
    cnt = 0
    a = 0
    for x in dataloader:
        sys.stdout.flush()
        a += 1

        if a % 100 == 0:
            print a, np.mean(aligned_results)

        x = x[0]
        if x is None:
            print "Skipping alignment because it returned None"
            continue

        img = x['resized_img'].numpy()[0, ...].transpose([2, 1, 0])
        img = ((img + 1) * 128).astype(np.uint8)

        full_img = x['full_img'].numpy()[0, ...].transpose([2, 1, 0])
        full_img = ((full_img + 1) * 128).astype(np.uint8)

        gt_lines = x['gt_lines']
        gt = "\n".join(gt_lines)

        out_original = e2e(x)
        if out_original is None:
            #TODO: not a good way to handle this, but fine for now
            print "Possible Error: Skipping alignment on image"
            continue

        out_original = e2e_postprocessing.results_to_numpy(out_original)
        out_original['idx'] = np.arange(out_original['sol'].shape[0])
        e2e_postprocessing.trim_ends(out_original)
        decoded_hw, decoded_raw_hw = e2e_postprocessing.decode_handwriting(
            out_original, idx_to_char)
        pick, costs = e2e_postprocessing.align_to_gt_lines(
            decoded_hw, gt_lines)

        best_ever_pred_lines, improved_idxs = validation_utils.update_ideal_results(
            pick, costs, decoded_hw, x['gt_json'])
        validation_utils.save_improved_idxs(
            improved_idxs, decoded_hw, decoded_raw_hw, out_original, x,
            config['training'][dataset_lookup]['json_folder'])

        best_ever_pred_lines = "\n".join(best_ever_pred_lines)
        error = error_rates.cer(gt, best_ever_pred_lines)
        best_ever_results.append(error)

        aligned_pred_lines = [decoded_hw[i] for i in pick]
        aligned_pred_lines = "\n".join(aligned_pred_lines)
        error = error_rates.cer(gt, aligned_pred_lines)
        aligned_results.append(error)

        if dataset_lookup == "validation_set":
            # We only care about the hyperparameter postprocessing seach for the validation set
            for key in itertools.product(sol_thresholds_idx, lf_nms_ranges_idx,
                                         lf_nms_thresholds_idx):
                i, j, k = key
                sol_threshold = sol_thresholds[i]
                lf_nms_range = lf_nms_ranges[j]
                lf_nms_threshold = lf_nms_thresholds[k]

                out = copy.copy(out_original)

                out = e2e_postprocessing.postprocess(
                    out,
                    sol_threshold=sol_threshold,
                    lf_nms_params={
                        "overlap_range": lf_nms_range,
                        "overlap_threshold": lf_nms_threshold
                    })
                order = e2e_postprocessing.read_order(out)
                e2e_postprocessing.filter_on_pick(out, order)

                e2e_postprocessing.trim_ends(out)

                preds = [decoded_hw[i] for i in out['idx']]
                pred = "\n".join(preds)

                error = error_rates.cer(gt, pred)

                results[key].append(error)

    sum_results = None
    if dataset_lookup == "validation_set":
        # Skipping because we didn't do the hyperparameter search
        sum_results = {}
        for k, v in results.iteritems():
            sum_results[k] = np.mean(v)

        sum_results = sorted(sum_results.iteritems(),
                             key=operator.itemgetter(1))
        sum_results = sum_results[0]

    return sum_results, np.mean(aligned_results), np.mean(
        best_ever_results), sol, lf, hw
Beispiel #10
0
        line_imgs = generator(line_imgs)
        #for b in range(line_imgs.size(0)):
        #    draw = ((line_imgs[b,0]+1)*128).cpu().numpy().astype(np.uint8)
        #    cv2.imwrite('test/line{}.png'.format(b),draw)
        #    print('gt[{}]: {}'.format(b,x['gt'][b]))
        #cv2.waitKey()
        preds = hw(line_imgs).cpu()

        output_batch = preds.permute(1, 0, 2)
        out = output_batch.data.cpu().numpy()
        toprint = []
        for b, gt_line in enumerate(x['gt']):
            logits = out[b, ...]
            pred, raw_pred = string_utils.naive_decode(logits)
            pred_str = string_utils.label2str_single(pred, idx_to_char, False)
            cer = error_rates.cer(gt_line, pred_str)
            sum_cer += cer
            steps += 1

            if i % print_freq == 0:
                toprint.append('[cer]:{:.2f} [gt]: {} [pred]: {}'.format(
                    cer, gt_line, pred_str))

        batch_size = preds.size(1)
        preds_size = torch.IntTensor([preds.size(0)] * batch_size)

        # print "before"
        loss = criterion(preds, labels, preds_size, label_lengths)
        # print "after"

        optimizer.zero_grad()
Beispiel #11
0
def training_step(config):

    hw_network_config = config['network']['hw']
    train_config = config['training']

    allowed_training_time = train_config['hw']['reset_interval']
    init_training_time = time.time()

    char_set_path = hw_network_config['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = HwDataset(training_set_list,
                              char_set['char_to_idx'],
                              augmentation=True,
                              img_height=hw_network_config['input_height'])

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_config['hw']['batch_size'],
                                  shuffle=False,
                                  num_workers=0,
                                  collate_fn=hw_dataset.collate)

    batches_per_epoch = int(train_config['hw']['images_per_epoch'] /
                            train_config['hw']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = HwDataset(
        test_set_list,
        char_set['char_to_idx'],
        img_height=hw_network_config['input_height'],
        random_subset_size=train_config['hw']['validation_subset_size'])

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=train_config['hw']['batch_size'],
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=hw_dataset.collate)

    hw = cnn_lstm.create_model(hw_network_config)
    hw_path = os.path.join(train_config['snapshot']['best_validation'],
                           "hw.pt")
    hw_state = safe_load.torch_state(hw_path)
    hw.load_state_dict(hw_state)
    hw.cuda()
    criterion = CTCLoss()
    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000000):
        sum_loss = 0.0
        steps = 0.0
        hw.eval()
        for x in test_dataloader:
            sys.stdout.flush()
            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False,
                                 volatile=True)
            labels = Variable(x['labels'], requires_grad=False, volatile=True)
            label_lengths = Variable(x['label_lengths'],
                                     requires_grad=False,
                                     volatile=True)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation CER:", sum_loss / steps
            lowest_loss = sum_loss / steps

            hw = cnn_lstm.create_model(hw_network_config)
            hw_path = os.path.join(train_config['snapshot']['current'],
                                   "hw.pt")
            hw_state = safe_load.torch_state(hw_path)
            hw.load_state_dict(hw_state)
            hw.cuda()

            optimizer = torch.optim.Adam(
                hw.parameters(), lr=train_config['hw']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "hw_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "hw.pt")

            torch.save(hw.state_dict(), save_path)
            lowest_loss_i = epoch

        print "Test Loss", sum_loss / steps, lowest_loss
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        hw.train()
        for i, x in enumerate(train_dataloader):

            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False)
            labels = Variable(x['labels'], requires_grad=False)
            label_lengths = Variable(x['label_lengths'], requires_grad=False)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            # if i == 0:
            #     for i in xrange(out.shape[0]):
            #         pred, pred_raw = string_utils.naive_decode(out[i,...])
            #         pred_str = string_utils.label2str_single(pred_raw, idx_to_char, True)
            #         print pred_str

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

            batch_size = preds.size(1)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))

            loss = criterion(preds, labels, preds_size, label_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "hw.pt")
    torch.save(hw.state_dict(), save_path)

    optim_path = os.path.join(dirname, "hw_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)
Beispiel #12
0
def training_step(config):

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    train_config = config['training']

    allowed_training_time = train_config['lf']['reset_interval']
    init_training_time = time.time()

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = LfDataset(training_set_list, augmentation=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=1,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=lf_dataset.collate)
    batches_per_epoch = int(train_config['lf']['images_per_epoch'] /
                            train_config['lf']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = LfDataset(
        test_set_list,
        random_subset_size=train_config['lf']['validation_subset_size'])
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=lf_dataset.collate)

    _, lf, hw = init_model(config, only_load=['lf', 'hw'])
    hw.eval()

    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000):
        lf.eval()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(test_dataloader):
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                print "Skipping"
                continue

            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            skip_grid=False)

            line = torch.nn.functional.grid_sample(img.transpose(2, 3),
                                                   grid_line)
            line = line.transpose(2, 3)
            predictions = hw(line)

            out = predictions.permute(1, 0, 2).data.cpu().numpy()
            gt_line = x['gt']
            pred, raw_pred = string_utils.naive_decode(out[0])
            pred_str = string_utils.label2str_single(pred, idx_to_char, False)
            cer = error_rates.cer(gt_line, pred_str)
            sum_loss += cer
            steps += 1

            # l = line[0].transpose(0,1).transpose(1,2)
            # l = (l + 1)*128
            # l_np = l.data.cpu().numpy()
            #
            # cv2.imwrite("example_line_out.png", l_np)
            # print "Saved!"
            # raw_input()

            # loss = lf_loss.point_loss(xy_output, xy_positions)
            #
            # sum_loss += loss.data[0]
            # steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation Loss:", sum_loss / steps
            lowest_loss = sum_loss / steps

            _, lf, _ = init_model(config, lf_dir='current', only_load="lf")

            optimizer = torch.optim.Adam(
                lf.parameters(), lr=train_config['lf']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "lf_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "lf.pt")

            torch.save(lf.state_dict(), save_path)
            lowest_loss_i = 0

        test_loss = sum_loss / steps

        print "Test Loss", sum_loss / steps, lowest_loss
        print "Time:", time.time() - start_time
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        lf.train()
        start_time = time.time()
        for x in train_dataloader:
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                continue

            reset_interval = 4
            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            all_positions=positions,
                                            reset_interval=reset_interval,
                                            randomize=True,
                                            skip_grid=True)

            loss = lf_loss.point_loss(xy_output, xy_positions)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += loss.data.item()
            steps += 1

        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch
        print "Time:", time.time() - start_time

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "lf.pt")
    torch.save(lf.state_dict(), save_path)

    optim_path = os.path.join(dirname, "lf_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)