Ejemplo n.º 1
0
    def __init__(self, config, image_folder):
        super(Segmenter, self).__init__()
        self.config = config
        self.image_folder = image_folder
        self.sol_network_config = config['network']['sol']
        self.pretrain_config = config['pretraining']
        self.outpath = config['evaluation']['output_path']

        self.rows = int(config['evaluation']['rows'])
        self.columns = int(config['evaluation']['columns'])

        self.network = continuous_state.init_model(config)
Ejemplo n.º 2
0
def main():
  with open(sys.argv[1]) as f:
      config = yaml.load(f)

  sol_network_config = config['network']['sol']
  pretrain_config = config['pretraining']
  eval_folder = pretrain_config['validation_set']['img_folder']

  solf = continuous_state.init_model(config)

  if torch.cuda.is_available():
    print("Using GPU")
    solf.cuda()
    dtype = torch.cuda.FloatTensor
  else:
    print("Warning: Not using a GPU, untested")
    dtype = torch.FloatTensor

  writep = config['evaluation']['output_path'].split('_')[0]
  writep = 'data/{}_val'.format(writep)
  if not os.path.exists(writep):
    os.makedirs(writep)

  for fil in os.listdir(eval_folder):
    imgfil = os.path.join(eval_folder, fil)
    org_img = cv2.imread(imgfil, cv2.IMREAD_COLOR)
    if org_img is not None:
      rescale_range = config['pretraining']['sol']['validation_rescale_range']
      target_dim1 = rescale_range[0]

      s = target_dim1 / float(org_img.shape[1])
      target_dim0 = int(org_img.shape[0]/float(org_img.shape[1]) * target_dim1)
      org_img = cv2.resize(org_img,(target_dim1, target_dim0), interpolation=cv2.INTER_CUBIC)

      img = org_img.transpose([2,1,0])[None,...]
      img = img.astype(np.float32)
      img = torch.from_numpy(img)
      img = img / 128.0 - 1.0

      img = Variable(img.type(dtype), requires_grad=False)

      # print((img))
      predictions = solf(img)

      org_img = img[0].data.cpu().numpy().transpose([2,1,0])
      org_img = ((org_img + 1)*128).astype(np.uint8)
      org_img = org_img.copy()
      org_img = drawing.draw_sol_torch(predictions, org_img)
      predictions = None
      cv2.imwrite(os.path.join(writep, fil), org_img)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config_path')
    parser.add_argument('npz_folder')
    parser.add_argument('--in_xml_folder')
    parser.add_argument('--out_xml_folder')
    parser.add_argument('--lm', action='store_true')
    parser.add_argument('--aug', action='store_true')
    parser.add_argument('--roi', action='store_true')
    args = parser.parse_args()

    config_path = args.config_path
    npz_folder = args.npz_folder
    in_xml_folder = args.in_xml_folder
    out_xml_folder = args.out_xml_folder

    in_xml_files = {}
    if in_xml_folder and out_xml_folder:
        for root, folders, files in os.walk(in_xml_folder):
            for f in files:
                if f.endswith(".xml"):
                    basename = os.path.basename(f).replace(".xml", "")
                    in_xml_files[basename] = os.path.join(root, f)

    use_lm = args.lm
    use_aug = args.aug
    use_roi = args.roi

    if use_lm:
        from utils import lm_decoder

    with open(config_path) as f:
        config = yaml.load(f)

    npz_paths = []
    for root, folder, files in os.walk(npz_folder):
        for f in files:
            if f.lower().endswith(".npz"):
                npz_paths.append(os.path.join(root, f))

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    if use_aug:
        model_mode = "pretrain"
        _, _, hw = init_model(config, hw_dir=model_mode, only_load="hw")
        dtype = torch.cuda.FloatTensor
        hw.eval()

    if use_lm:
        lm_params = config['network']['lm']
        print "Loading LM"
        decoder = lm_decoder.LMDecoder(idx_to_char, lm_params)
        print "Done Loading LM"

        print "Accumulating stats for LM"
        for npz_path in sorted(npz_paths):
            out = np.load(npz_path)
            out = dict(out)
            for o in out['hw']:
                o = log_softmax(o)
                decoder.add_stats(o)
        print "Done accumulating stats for LM"
    else:
        print "Skip Loading LM"

    for npz_path in sorted(npz_paths):

        out = np.load(npz_path)
        out = dict(out)

        image_path = str(out['image_path'])
        print image_path
        org_img = cv2.imread(image_path)

        # Postprocessing Steps
        out['idx'] = np.arange(out['sol'].shape[0])
        out = e2e_postprocessing.trim_ends(out)
        e2e_postprocessing.filter_on_pick(
            out, e2e_postprocessing.select_non_empty_string(out))
        out = e2e_postprocessing.postprocess(
            out,
            sol_threshold=config['post_processing']['sol_threshold'],
            lf_nms_params={
                "overlap_range": config['post_processing']['lf_nms_range'],
                "overlap_threshold":
                config['post_processing']['lf_nms_threshold']
            }
            # },
            # lf_nms_2_params={
            #     "overlap_threshold": 0.5
            # }
        )
        order = e2e_postprocessing.read_order(out)
        e2e_postprocessing.filter_on_pick(out, order)

        # Decoding network output
        output_strings = []
        if use_aug:
            number_of_iterations = 20
            for line_img in out['line_imgs']:
                batch = []
                for i in range(number_of_iterations):
                    warped_image = grid_distortion.warp_image(line_img)
                    batch.append(warped_image)

                batch = np.array(batch)
                batch = Variable(torch.from_numpy(batch),
                                 requires_grad=False,
                                 volatile=True).cuda()
                batch = batch / 128.0 - 1.0
                batch = batch.transpose(2, 3)
                batch = batch.transpose(1, 2)
                hw_out = hw(batch)
                hw_out = hw_out.transpose(0, 1)
                hw_out = hw_out.data.cpu().numpy()

                if use_lm:
                    decoded_hw = []
                    for line in hw_out:
                        log_softmax_line = log_softmax(line)
                        lm_output = decoder.decode(log_softmax_line)[0]
                        decoded_hw.append(lm_output)
                else:
                    decoded_hw, decoded_raw_hw = e2e_postprocessing.decode_handwriting(
                        {"hw": hw_out}, idx_to_char)

                cnt_d = defaultdict(list)
                for i, d in enumerate(decoded_hw):
                    cnt_d[d].append(i)

                cnt_d = dict(cnt_d)
                sorted_list = list(
                    sorted(cnt_d.iteritems(), key=lambda x: len(x[1])))

                best_idx = sorted_list[-1][1][0]
                output_strings.append(decoded_hw[best_idx])

        else:
            if use_lm:
                for line in out['hw']:
                    log_softmax_line = log_softmax(line)
                    lm_output = decoder.decode(log_softmax_line)[0]
                    output_strings.append(lm_output)
            else:
                output_strings, decoded_raw_hw = e2e_postprocessing.decode_handwriting(
                    out, idx_to_char)

        draw_img = visualization.draw_output(out, org_img)
        cv2.imwrite(npz_path + ".png", draw_img)

        # Save results
        label_string = "_"
        if use_lm:
            label_string += "lm_"
        if use_aug:
            label_string += "aug_"
        filepath = npz_path + label_string + ".txt"

        with codecs.open(filepath, 'w', encoding='utf-8') as f:
            f.write("\n".join(output_strings))

        key = os.path.basename(image_path)[:-len(".jpg")]
        if in_xml_folder:
            if use_roi:

                key, region_id = key.split("_", 1)
                region_id = region_id.split(".")[0]

                if key in in_xml_files:
                    in_xml_file = in_xml_files[key]
                    out_xml_file = os.path.join(out_xml_folder,
                                                os.path.basename(in_xml_file))
                    PAGE_xml.create_output_xml_roi(in_xml_file, out,
                                                   output_strings,
                                                   out_xml_file, region_id)
                    in_xml_files[
                        key] = out_xml_file  #after first, add to current xml
                else:
                    print "Couldn't find xml file for ", key
            else:
                if key in in_xml_files:
                    in_xml_file = in_xml_files[key]
                    out_xml_file = os.path.join(out_xml_folder,
                                                os.path.basename(in_xml_file))
                    PAGE_xml.create_output_xml(in_xml_file, out,
                                               output_strings, out_xml_file)
                else:
                    print "Couldn't find xml file for ", key
Ejemplo n.º 4
0
def training_step(config):

    train_config = config['training']

    allowed_training_time = train_config['sol']['reset_interval']
    init_training_time = time.time()

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = SolDataset(
        training_set_list,
        rescale_range=train_config['sol']['training_rescale_range'],
        transform=CropTransform(train_config['sol']['crop_params']))

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_config['sol']['batch_size'],
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=sol_dataset.collate)

    batches_per_epoch = int(train_config['sol']['images_per_epoch'] /
                            train_config['sol']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = SolDataset(
        test_set_list,
        rescale_range=train_config['sol']['validation_rescale_range'],
        random_subset_size=train_config['sol']['validation_subset_size'],
        transform=None)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=sol_dataset.collate)

    alpha_alignment = train_config['sol']['alpha_alignment']
    alpha_backprop = train_config['sol']['alpha_backprop']

    sol, lf, hw = init_model(config, only_load='sol')

    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    epoch = -1
    while True:  #This ends on a break based on the current itme
        epoch += 1
        print "Train Time:", (
            time.time() -
            init_training_time), "Allowed Time:", allowed_training_time

        sol.eval()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(test_dataloader):
            img = Variable(x['img'].type(dtype), requires_grad=False)

            sol_gt = None
            if x['sol_gt'] is not None:
                sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False)

            predictions = sol(img)
            predictions = transformation_utils.pt_xyrs_2_xyxy(predictions)
            loss = alignment_loss(predictions, sol_gt, x['label_sizes'],
                                  alpha_alignment, alpha_backprop)
            sum_loss += loss.data[0]
            steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation CER:", sum_loss / steps
            lowest_loss = sum_loss / steps

            sol, lf, hw = init_model(config,
                                     sol_dir='current',
                                     only_load='sol')

            optimizer = torch.optim.Adam(
                sol.parameters(), lr=train_config['sol']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "sol_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        elif lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "sol.pt")

            torch.save(sol.state_dict(), save_path)
            lowest_loss_i = epoch

        print "Test Loss", sum_loss / steps, lowest_loss
        print "Time:", time.time() - start_time
        print ""

        print "Epoch", epoch

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time. Saving current state and exiting..."
            dirname = train_config['snapshot']['current']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "sol.pt")
            torch.save(sol.state_dict(), save_path)

            optim_path = os.path.join(dirname, "sol_optim.pt")
            torch.save(optimizer.state_dict(), optim_path)
            break

        sol.train()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(train_dataloader):
            img = Variable(x['img'].type(dtype), requires_grad=False)

            sol_gt = None
            if x['sol_gt'] is not None:
                sol_gt = Variable(x['sol_gt'].type(dtype), requires_grad=False)

            predictions = sol(img)
            predictions = transformation_utils.pt_xyrs_2_xyxy(predictions)
            loss = alignment_loss(predictions, sol_gt, x['label_sizes'],
                                  alpha_alignment, alpha_backprop)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += loss.data[0]
            steps += 1

        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch
        print "Time:", time.time() - start_time
Ejemplo n.º 5
0
def alignment_step(config,
                   dataset_lookup=None,
                   model_mode='best_validation',
                   percent_range=None):

    set_list = load_file_list(config['training'][dataset_lookup])

    if percent_range is not None:
        start = int(len(set_list) * percent_range[0])
        end = int(len(set_list) * percent_range[1])
        set_list = set_list[start:end]

    dataset = AlignmentDataset(set_list, None)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            collate_fn=alignment_dataset.collate)

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    sol, lf, hw = init_model(config,
                             sol_dir=model_mode,
                             lf_dir=model_mode,
                             hw_dir=model_mode)

    e2e = E2EModel(sol, lf, hw)
    dtype = torch.cuda.FloatTensor
    e2e.eval()

    post_processing_config = config['training']['alignment'][
        'validation_post_processing']
    sol_thresholds = post_processing_config['sol_thresholds']
    sol_thresholds_idx = range(len(sol_thresholds))

    lf_nms_ranges = post_processing_config['lf_nms_ranges']
    lf_nms_ranges_idx = range(len(lf_nms_ranges))

    lf_nms_thresholds = post_processing_config['lf_nms_thresholds']
    lf_nms_thresholds_idx = range(len(lf_nms_thresholds))

    results = defaultdict(list)
    aligned_results = []
    best_ever_results = []

    prev_time = time.time()
    cnt = 0
    a = 0
    for x in dataloader:
        sys.stdout.flush()
        a += 1

        if a % 100 == 0:
            print a, np.mean(aligned_results)

        x = x[0]
        if x is None:
            print "Skipping alignment because it returned None"
            continue

        img = x['resized_img'].numpy()[0, ...].transpose([2, 1, 0])
        img = ((img + 1) * 128).astype(np.uint8)

        full_img = x['full_img'].numpy()[0, ...].transpose([2, 1, 0])
        full_img = ((full_img + 1) * 128).astype(np.uint8)

        gt_lines = x['gt_lines']
        gt = "\n".join(gt_lines)

        out_original = e2e(x)
        if out_original is None:
            #TODO: not a good way to handle this, but fine for now
            print "Possible Error: Skipping alignment on image"
            continue

        out_original = e2e_postprocessing.results_to_numpy(out_original)
        out_original['idx'] = np.arange(out_original['sol'].shape[0])
        e2e_postprocessing.trim_ends(out_original)
        decoded_hw, decoded_raw_hw = e2e_postprocessing.decode_handwriting(
            out_original, idx_to_char)
        pick, costs = e2e_postprocessing.align_to_gt_lines(
            decoded_hw, gt_lines)

        best_ever_pred_lines, improved_idxs = validation_utils.update_ideal_results(
            pick, costs, decoded_hw, x['gt_json'])
        validation_utils.save_improved_idxs(
            improved_idxs, decoded_hw, decoded_raw_hw, out_original, x,
            config['training'][dataset_lookup]['json_folder'])

        best_ever_pred_lines = "\n".join(best_ever_pred_lines)
        error = error_rates.cer(gt, best_ever_pred_lines)
        best_ever_results.append(error)

        aligned_pred_lines = [decoded_hw[i] for i in pick]
        aligned_pred_lines = "\n".join(aligned_pred_lines)
        error = error_rates.cer(gt, aligned_pred_lines)
        aligned_results.append(error)

        if dataset_lookup == "validation_set":
            # We only care about the hyperparameter postprocessing seach for the validation set
            for key in itertools.product(sol_thresholds_idx, lf_nms_ranges_idx,
                                         lf_nms_thresholds_idx):
                i, j, k = key
                sol_threshold = sol_thresholds[i]
                lf_nms_range = lf_nms_ranges[j]
                lf_nms_threshold = lf_nms_thresholds[k]

                out = copy.copy(out_original)

                out = e2e_postprocessing.postprocess(
                    out,
                    sol_threshold=sol_threshold,
                    lf_nms_params={
                        "overlap_range": lf_nms_range,
                        "overlap_threshold": lf_nms_threshold
                    })
                order = e2e_postprocessing.read_order(out)
                e2e_postprocessing.filter_on_pick(out, order)

                e2e_postprocessing.trim_ends(out)

                preds = [decoded_hw[i] for i in out['idx']]
                pred = "\n".join(preds)

                error = error_rates.cer(gt, pred)

                results[key].append(error)

    sum_results = None
    if dataset_lookup == "validation_set":
        # Skipping because we didn't do the hyperparameter search
        sum_results = {}
        for k, v in results.iteritems():
            sum_results[k] = np.mean(v)

        sum_results = sorted(sum_results.iteritems(),
                             key=operator.itemgetter(1))
        sum_results = sum_results[0]

    return sum_results, np.mean(aligned_results), np.mean(
        best_ever_results), sol, lf, hw
Ejemplo n.º 6
0
    output_directory = sys.argv[3]

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    char_to_idx = char_set['char_to_idx']

    model_mode = "best_overall"
    sol, lf, hw = init_model(config,
                             sol_dir=model_mode,
                             lf_dir=model_mode,
                             hw_dir=model_mode)

    e2e = E2EModel(sol, lf, hw)
    dtype = torch.cuda.FloatTensor
    e2e.eval()

    for image_path in sorted(image_paths):
        print image_path

        org_img = cv2.imread(image_path)

        target_dim1 = 512
        s = target_dim1 / float(org_img.shape[1])

        pad_amount = 128
Ejemplo n.º 7
0
def training_step(config):

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    train_config = config['training']

    allowed_training_time = train_config['lf']['reset_interval']
    init_training_time = time.time()

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = LfDataset(training_set_list, augmentation=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=1,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=lf_dataset.collate)
    batches_per_epoch = int(train_config['lf']['images_per_epoch'] /
                            train_config['lf']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = LfDataset(
        test_set_list,
        random_subset_size=train_config['lf']['validation_subset_size'])
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=lf_dataset.collate)

    _, lf, hw = init_model(config, only_load=['lf', 'hw'])
    hw.eval()

    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000):
        lf.eval()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(test_dataloader):
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                print "Skipping"
                continue

            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            skip_grid=False)

            line = torch.nn.functional.grid_sample(img.transpose(2, 3),
                                                   grid_line)
            line = line.transpose(2, 3)
            predictions = hw(line)

            out = predictions.permute(1, 0, 2).data.cpu().numpy()
            gt_line = x['gt']
            pred, raw_pred = string_utils.naive_decode(out[0])
            pred_str = string_utils.label2str_single(pred, idx_to_char, False)
            cer = error_rates.cer(gt_line, pred_str)
            sum_loss += cer
            steps += 1

            # l = line[0].transpose(0,1).transpose(1,2)
            # l = (l + 1)*128
            # l_np = l.data.cpu().numpy()
            #
            # cv2.imwrite("example_line_out.png", l_np)
            # print "Saved!"
            # raw_input()

            # loss = lf_loss.point_loss(xy_output, xy_positions)
            #
            # sum_loss += loss.data[0]
            # steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation Loss:", sum_loss / steps
            lowest_loss = sum_loss / steps

            _, lf, _ = init_model(config, lf_dir='current', only_load="lf")

            optimizer = torch.optim.Adam(
                lf.parameters(), lr=train_config['lf']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "lf_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "lf.pt")

            torch.save(lf.state_dict(), save_path)
            lowest_loss_i = 0

        test_loss = sum_loss / steps

        print "Test Loss", sum_loss / steps, lowest_loss
        print "Time:", time.time() - start_time
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        lf.train()
        start_time = time.time()
        for x in train_dataloader:
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                continue

            reset_interval = 4
            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            all_positions=positions,
                                            reset_interval=reset_interval,
                                            randomize=True,
                                            skip_grid=True)

            loss = lf_loss.point_loss(xy_output, xy_positions)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += loss.data.item()
            steps += 1

        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch
        print "Time:", time.time() - start_time

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "lf.pt")
    torch.save(lf.state_dict(), save_path)

    optim_path = os.path.join(dirname, "lf_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)