Ejemplo n.º 1
0
def decode_handwriting(out, idx_to_char):
    hw_out = out['hw']
    list_of_pred = []
    list_of_raw_pred = []
    for i in xrange(hw_out.shape[0]):
        logits = hw_out[i, ...]
        pred, raw_pred = string_utils.naive_decode(logits)
        pred_str = string_utils.label2str_single(pred, idx_to_char, False)
        raw_pred_str = string_utils.label2str_single(raw_pred, idx_to_char,
                                                     True)
        list_of_pred.append(pred_str)
        list_of_raw_pred.append(raw_pred_str)

    return list_of_pred, list_of_raw_pred
Ejemplo n.º 2
0
def update_alignment(out, gt_lines, alignments, idx_to_char, idx_mapping,
                     sol_positions):

    preds = out.cpu()
    batch_size = preds.size(1)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))

    for i, logits in enumerate(out.data.cpu().numpy()):
        raw_decode, raw_decode_full = string_utils.naive_decode(logits)
        pred_str = string_utils.label2str_single(raw_decode, idx_to_char,
                                                 False)

        for j, gt in enumerate(gt_lines):
            cer = error_rates.cer(gt, pred_str)
            global_i = idx_mapping[i]
            c = sol_positions[i, 0, -1].data[0]

            # alignment_error = cer
            alignment_error = cer + 0.1 * (1.0 - c)

            if alignment_error < alignments[j][0]:
                alignments[j][0] = alignment_error
                alignments[j][1] = global_i
                # alignments[j][2] = out[i][:,None,:]
                alignments[j][2] = None
                alignments[j][3] = pred_str
def getCER(gt, pred, idx_to_char):
    cer = []
    pred_strs = []
    for i, gt_line in enumerate(gt):
        logits = pred[:, i]
        pred_str, raw_pred = string_utils.naive_decode(logits)
        pred_str = string_utils.label2str_single(pred_str, idx_to_char, False)
        cer.append(error_rates.cer(gt_line, pred_str))
        pred_strs.append(pred_str)
    return cer, pred_strs
Ejemplo n.º 4
0
def accumulate_scores(out, out_positions, xy_positions, gt_state, idx_to_char):

    preds = out.transpose(0, 1).cpu()
    batch_size = preds.size(1)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))

    for i, logits in enumerate(out.data.cpu().numpy()):
        raw_decode, raw_decode_full = string_utils.naive_decode(logits)
        pred_str = string_utils.label2str_single(raw_decode, idx_to_char,
                                                 False)
        pred_str_full = string_utils.label2str_single(raw_decode_full,
                                                      idx_to_char, True)

        sub_out_positions = [
            o[i].data.cpu().numpy().tolist() for o in out_positions
        ]
        sub_xy_positions = [
            o[i].data.cpu().numpy().tolist() for o in xy_positions
        ]

        for gt_obj in gt_state:
            gt_text = gt_obj['gt']
            cer = error_rates.cer(gt_text, pred_str)

            #This is a terrible way to do this...
            gt_obj['errors'] = gt_obj.get('errors', [])
            gt_obj['pred'] = gt_obj.get('pred', [])
            gt_obj['pred_full'] = gt_obj.get('pred_full', [])
            gt_obj['path'] = gt_obj.get('path', [])
            gt_obj['path_xy'] = gt_obj.get('path_xy', [])

            gt_obj['errors'].append(cer)
            gt_obj['pred'].append(pred_str)
            gt_obj['pred_full'].append(pred_str_full)
            gt_obj['path'].append(sub_out_positions)
            gt_obj['path_xy'].append(sub_xy_positions)
 def getCER(self, gt, pred, individual=False):
     cer = 0
     if individual:
         all_cer = []
     pred_strs = []
     for i, gt_line in enumerate(gt):
         logits = pred[:, i]
         pred_str, raw_pred = string_utils.naive_decode(logits)
         pred_str = string_utils.label2str_single(pred_str,
                                                  self.idx_to_char, False)
         this_cer = error_rates.cer(gt_line, pred_str)
         cer += this_cer
         if individual:
             all_cer.append(this_cer)
         pred_strs.append(pred_str)
     cer /= len(gt)
     if individual:
         return cer, pred_strs, all_cer
     return cer, pred_strs
Ejemplo n.º 6
0
        line_imgs = generator(line_imgs)
        #for b in range(line_imgs.size(0)):
        #    draw = ((line_imgs[b,0]+1)*128).cpu().numpy().astype(np.uint8)
        #    cv2.imwrite('test/line{}.png'.format(b),draw)
        #    print('gt[{}]: {}'.format(b,x['gt'][b]))
        #cv2.waitKey()
        preds = hw(line_imgs).cpu()

        output_batch = preds.permute(1, 0, 2)
        out = output_batch.data.cpu().numpy()
        toprint = []
        for b, gt_line in enumerate(x['gt']):
            logits = out[b, ...]
            pred, raw_pred = string_utils.naive_decode(logits)
            pred_str = string_utils.label2str_single(pred, idx_to_char, False)
            cer = error_rates.cer(gt_line, pred_str)
            sum_cer += cer
            steps += 1

            if i % print_freq == 0:
                toprint.append('[cer]:{:.2f} [gt]: {} [pred]: {}'.format(
                    cer, gt_line, pred_str))

        batch_size = preds.size(1)
        preds_size = torch.IntTensor([preds.size(0)] * batch_size)

        # print "before"
        loss = criterion(preds, labels, preds_size, label_lengths)
        # print "after"
Ejemplo n.º 7
0
def training_step(config):

    hw_network_config = config['network']['hw']
    train_config = config['training']

    allowed_training_time = train_config['hw']['reset_interval']
    init_training_time = time.time()

    char_set_path = hw_network_config['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = HwDataset(training_set_list,
                              char_set['char_to_idx'],
                              augmentation=True,
                              img_height=hw_network_config['input_height'])

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_config['hw']['batch_size'],
                                  shuffle=False,
                                  num_workers=0,
                                  collate_fn=hw_dataset.collate)

    batches_per_epoch = int(train_config['hw']['images_per_epoch'] /
                            train_config['hw']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = HwDataset(
        test_set_list,
        char_set['char_to_idx'],
        img_height=hw_network_config['input_height'],
        random_subset_size=train_config['hw']['validation_subset_size'])

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=train_config['hw']['batch_size'],
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=hw_dataset.collate)

    hw = cnn_lstm.create_model(hw_network_config)
    hw_path = os.path.join(train_config['snapshot']['best_validation'],
                           "hw.pt")
    hw_state = safe_load.torch_state(hw_path)
    hw.load_state_dict(hw_state)
    hw.cuda()
    criterion = CTCLoss()
    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000000):
        sum_loss = 0.0
        steps = 0.0
        hw.eval()
        for x in test_dataloader:
            sys.stdout.flush()
            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False,
                                 volatile=True)
            labels = Variable(x['labels'], requires_grad=False, volatile=True)
            label_lengths = Variable(x['label_lengths'],
                                     requires_grad=False,
                                     volatile=True)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation CER:", sum_loss / steps
            lowest_loss = sum_loss / steps

            hw = cnn_lstm.create_model(hw_network_config)
            hw_path = os.path.join(train_config['snapshot']['current'],
                                   "hw.pt")
            hw_state = safe_load.torch_state(hw_path)
            hw.load_state_dict(hw_state)
            hw.cuda()

            optimizer = torch.optim.Adam(
                hw.parameters(), lr=train_config['hw']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "hw_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "hw.pt")

            torch.save(hw.state_dict(), save_path)
            lowest_loss_i = epoch

        print "Test Loss", sum_loss / steps, lowest_loss
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        hw.train()
        for i, x in enumerate(train_dataloader):

            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False)
            labels = Variable(x['labels'], requires_grad=False)
            label_lengths = Variable(x['label_lengths'], requires_grad=False)

            preds = hw(line_imgs).cpu()

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            # if i == 0:
            #     for i in xrange(out.shape[0]):
            #         pred, pred_raw = string_utils.naive_decode(out[i,...])
            #         pred_str = string_utils.label2str_single(pred_raw, idx_to_char, True)
            #         print pred_str

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str_single(
                    pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

            batch_size = preds.size(1)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))

            loss = criterion(preds, labels, preds_size, label_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "hw.pt")
    torch.save(hw.state_dict(), save_path)

    optim_path = os.path.join(dirname, "hw_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)
Ejemplo n.º 8
0
def HWDataset_eval(config,
                   instance,
                   trainer,
                   metrics,
                   outDir=None,
                   startIndex=None,
                   lossFunc=None,
                   toEval=None):
    def __eval_metrics(data, target):
        acc_metrics = np.zeros((output.shape[0], len(metrics)))
        for ind in range(output.shape[0]):
            for i, metric in enumerate(metrics):
                acc_metrics[ind, i] += metric(output[ind:ind + 1],
                                              target[ind:ind + 1])
        return acc_metrics

    if type(trainer) is HWRWithSynthTrainer:
        pred, losses = trainer.run(instance)
        out = {'pred': pred}
    elif type(trainer) is AutoTrainer:
        losses, out = trainer.run_gen(instance, toEval)
    else:  #if type(trainer) is HWWithStyleTrainer:
        if toEval is None:

            pred, recon, losses, style, spaced = trainer.run(instance,
                                                             get_style=True)
            toEval = ['pred', 'recon', 'style', 'spaced']
            out = {}
            if pred is not None:
                out['pred'] = pred
            if recon is not None:
                out['recon'] = recon
            if style is not None:
                out['style'] = style
            if spaced is not None:
                out['spaced'] = spaced
        elif type(toEval) is list:
            losses, out = trainer.run_gen(instance,
                                          trainer.curriculum.getEval(), toEval)
        else:
            if toEval == 'spaced':
                justSpaced(trainer.model, instance,
                           trainer.gpu if trainer.with_cuda else None)
            if toEval == 'spacing':
                justSpacing(trainer.model, instance,
                            trainer.gpu if trainer.with_cuda else None)
            elif toEval == 'mask':
                justMask(trainer.model, instance,
                         trainer.gpu if trainer.with_cuda else None)
            else:
                raise ValueError('unkwon just: {}'.format(toEval))
            return {}, (None, )

    images = instance['image'].numpy()
    gt = instance['gt']
    name = instance['name']
    batchSize = len(gt)
    #style = style.cpu().detach().numpy()
    if 'pred' in out:
        pred = out['pred'].cpu().detach()
        aligned1 = correct_pred(pred, instance['label'])
        aligned = []
        for b in range(batchSize):
            #a, raw_aligned = string_utils.naive_decode(aligned1[:,b])
            a = []
            for i in range(len(aligned1[:, b])):
                if aligned1[i, b] != 0 and not (i > 0 and aligned1[i, b]
                                                == aligned1[i - 1, b]):
                    a.append(aligned1[i, b].item())
            aligned.append(
                string_utils.label2str_single(a, trainer.idx_to_char, False))

        pred = pred.numpy()
        sum_cer, sum_wer, pred_str, cer = trainer.getCER(gt,
                                                         pred,
                                                         individual=True)
        out['cer'] = cer
        out['pred_str'] = pred_str

        for b in range(batchSize):
            print('{} GT:      {}'.format(instance['name'][b], gt[b]))
            print('{} aligned: {}'.format(instance['name'][b], aligned[b]))
            print('{} pred:    {}'.format(instance['name'][b], pred_str[b]))
            print(pred[:, b])
        if 'by_author' in config:
            by_author = defaultdict(list)
            for b, author in enumerate(instance['author']):
                by_author['cer_' + author].append(cer[b])

    if outDir is not None:
        for key_name in ['recon', 'recon_gt_mask', 'recon_pred_space']:
            if key_name in out:  # and 'pred' in out:
                recon = out[key_name].cpu().detach().numpy()
                if 'show_attention' in config:
                    rs = np.random.RandomState(0)
                    colors = (rs.rand(
                        trainer.model.style_extractor.mhAtt1.h *
                        trainer.model.style_extractor.keys1.size(1), 3) *
                              255).astype(np.uint8)
                    attn = trainer.model.style_extractor.mhAtt1.attn
                    assert (attn.size(0) == 1)
                    #OR
                    #attn = attn.view(batchSize*a_batch_size
                    scale = images.shape[3] * images.shape[0] / attn.size(3)
                    batch_len = attn.size(3) / images.shape[0]
                    c_index = 0
                    attn_for = defaultdict(list)
                    for head in range(attn.size(1)):
                        for query in range(attn.size(2)):
                            loc = attn[0, head, query].argmax().item()
                            b = loc // batch_len
                            x_pixel_loc = int((loc % batch_len) * scale)
                            y_pixel_loc = query * images.shape[2] // attn.size(
                                2)  #+ head
                            attn_for[b].append(
                                (y_pixel_loc, x_pixel_loc, colors[c_index]))
                            #print('h:{}, q:{}, b:{}, ({},{})'.format(head,query,b,x_pixel_loc,y_pixel_loc))
                            c_index += 1
                    maxA = attn.max()
                    minA = attn.min()
                    streched_attn = F.interpolate(
                        (attn[0] - minA) / (maxA - minA),
                        size=int(images.shape[3] * batchSize)).cpu()
                for b in range(batchSize):
                    if 'cer_thresh' in config and cer[b] < config['cer_thresh']:
                        continue
                    toColor = False
                    image = (1 - ((1 + np.transpose(images[b][:, :, :],
                                                    (1, 2, 0))) / 2.0)).copy()
                    if recon is not None:
                        reconstructed = (
                            1 - ((1 + np.transpose(recon[b][:, :, :],
                                                   (1, 2, 0))) / 2.0)).copy()
                        #border = np.zeros((image.shape[0],5,image.shape[2]))

                        #bigPic = np.concatenate((image,border,reconstructed),axis=1)

                        padded = None
                        if reconstructed.shape[1] > image.shape[1]:
                            #reconstructed=reconstructed[:,:image.shape[1]]
                            dif = -(image.shape[1] - reconstructed.shape[1])
                            image = np.pad(image,
                                           ((0, 0),
                                            (dif // 2, dif // 2 + dif % 2),
                                            (0, 0)),
                                           mode='constant')
                            padded = 'real'
                        elif image.shape[1] > reconstructed.shape[1]:  #pad
                            dif = image.shape[1] - reconstructed.shape[1]
                            reconstructed = np.pad(
                                reconstructed,
                                ((0, 0), (dif // 2, dif // 2 + dif % 2),
                                 (0, 0)),
                                mode='constant')
                            padded = 'gen'
                        border = np.zeros((2, image.shape[1], image.shape[2]))
                        bigPic = np.concatenate((image, border, reconstructed),
                                                axis=0)

                        #add color border for visibility in paper figures
                        if bigPic.shape[2] == 1 and True:
                            bigPic *= 255
                            bigPic = bigPic.astype(np.uint8)
                            bigPic = cv2.cvtColor(bigPic[:, :, 0],
                                                  cv2.COLOR_GRAY2RGB)
                            toColor = True
                            padReal = dif // 2 if padded == 'real' else 0
                            padGen = dif // 2 if padded == 'gen' else 0
                            #print('COLOR!')
                            if padReal != 0:
                                bigPic[0, padReal:-padReal, 1] = 255
                                bigPic[image.shape[0], padReal:-padReal,
                                       1] = 255
                                bigPic[image.shape[0], padReal:-padReal, 0] = 0
                                bigPic[0, padReal:-padReal, 0] = 0
                                bigPic[0, padReal:-padReal, 2] = 0
                                bigPic[image.shape[0], padReal:-padReal, 2] = 0
                            else:
                                bigPic[0, :, 1] = 255
                                bigPic[image.shape[0], :, 1] = 255
                                bigPic[image.shape[0], :, 0] = 0
                                bigPic[0, :, 0] = 0
                                bigPic[0, :, 2] = 0
                                bigPic[image.shape[0], :, 2] = 0
                            bigPic[:image.shape[0], padReal, 1] = 255
                            bigPic[:image.shape[0], -1 - padReal, 1] = 255
                            bigPic[:image.shape[0], padReal, 0] = 0
                            bigPic[:image.shape[0], -1 - padReal, 0] = 0
                            bigPic[:image.shape[0], padReal, 2] = 0
                            bigPic[:image.shape[0], -1 - padReal, 2] = 0

                            if padGen != 0:
                                bigPic[image.shape[0] + 1, padGen:-padGen,
                                       0] = 255
                                bigPic[-1, padGen:-padGen, 0] = 255
                                bigPic[image.shape[0] + 1, padGen:-padGen,
                                       2] = 0
                                bigPic[-1, padGen:-padGen, 2] = 0
                                bigPic[image.shape[0] + 1, padGen:-padGen,
                                       1] = 0
                                bigPic[-1, padGen:-padGen, 1] = 0
                            else:
                                bigPic[image.shape[0] + 1, :, 0] = 255
                                bigPic[-1, :, 0] = 255
                                bigPic[image.shape[0] + 1, :, 2] = 0
                                bigPic[-1, :, 2] = 0
                                bigPic[image.shape[0] + 1, :, 1] = 0
                                bigPic[-1, :, 1] = 0

                            bigPic[image.shape[0] + 2:, padGen, 0] = 255
                            bigPic[image.shape[0] + 2:, -1 - padGen, 0] = 255
                            bigPic[image.shape[0] + 2:, padGen, 2] = 0
                            bigPic[image.shape[0] + 2:, -1 - padGen, 2] = 0
                            bigPic[image.shape[0] + 2:, padGen, 1] = 0
                            bigPic[image.shape[0] + 2:, -1 - padGen, 1] = 0
                    else:
                        bigPic = image

                    #if image.shape[2]==1:
                    #    image = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
                    if 'pred' in out:
                        border = np.zeros(
                            (50, bigPic.shape[1], bigPic.shape[2]))
                        bigPic = np.concatenate((bigPic, border), axis=0)
                        cv2.putText(
                            bigPic,
                            'CER: {:.3f}, T: {}'.format(cer[b], pred_str[b]),
                            (0, image.shape[0] + 25), cv2.FONT_HERSHEY_SIMPLEX,
                            0.5, (0.9, 0.3, 0), 2, cv2.LINE_AA)
                    if not toColor:
                        bigPic *= 255
                        bigPic = bigPic.astype(np.uint8)
                    if 'show_attention' in config:
                        if bigPic.shape[2] == 1:
                            bigPic = cv2.cvtColor(bigPic, cv2.COLOR_GRAY2RGB)
                        #if 'head' in config['show_attention']:
                        if 'full' in config['show_attention']:
                            attnImage = np.zeros((attn.size(1) * attn.size(2),
                                                  bigPic.shape[1], 3))
                            for head in range(attn.size(1)):
                                for query in range(attn.size(2)):
                                    y_pixel_loc = head + attn.size(
                                        1
                                    ) * query  #query*images.shape[2]//attn.size(2) #+ head
                                    x_start = int(b * image.shape[1])
                                    x_end = int((b + 1) * image.shape[1])
                                    if head < 3:
                                        attnImage[y_pixel_loc,
                                                  0:image.shape[1],
                                                  head] = streched_attn[
                                                      head, query,
                                                      x_start:x_end].numpy()
                                    else:
                                        attnImage[y_pixel_loc,
                                                  0:image.shape[1],
                                                  head % 3] = streched_attn[
                                                      head, query,
                                                      x_start:x_end].numpy()
                                        attnImage[y_pixel_loc,
                                                  0:image.shape[1],
                                                  (head + 1) %
                                                  3] = streched_attn[
                                                      head, query,
                                                      x_start:x_end].numpy()

                            attnImage *= 255
                            attnImage = attnImage.astype(np.uint8)
                            bigPic = np.concatenate((attnImage, bigPic),
                                                    axis=0)

                        else:
                            for y, x, c in attn_for[b]:
                                bigPic[y:y + 2, x:x + 2] = c
                                #print('{}, {}  ({},{})'.format(x,y,image.shape[1],image.shape[0]))

                    saveName = '{}_{}.png'.format(name[b], key_name)
                    if 'cer_thresh' in config:
                        saveName = '{:.3f}_'.format(cer[b]) + saveName
                    cv2.imwrite(os.path.join(outDir, saveName), bigPic)
                    #io.imsave(os.path.join(outDir,saveName),bigPic)
                    print('saved: ' + os.path.join(outDir, saveName))
                    #import pdb;pdb.set_trace()

        if 'gen' in out or 'gen_img' in out:
            if 'gen' in out and out['gen'] is not None:
                gen = out['gen'].cpu().detach().numpy()
            elif 'gen_img' in out and out['gen_img'] is not None:
                gen = out['gen_img'].cpu().detach().numpy()
            else:
                #not sure why this happens
                print('ERROR, None for generated images, {}'.format(name))
                gen = np.ones((batchSize, 1, 5, 5))
            for b in range(batchSize):
                generated = (1 - (
                    (1 + np.transpose(gen[b][:, :, :],
                                      (1, 2, 0))) / 2.0)).copy() * 255
                saveName = 'gen_{}.png'.format(name[b])
                cv2.imwrite(os.path.join(outDir, saveName), generated)
        if 'mask' in out:
            mask = ((1 + out['mask']) * 127.5).cpu().detach().permute(
                0, 2, 3, 1).numpy().astype(np.uint8)
            for b in range(batchSize):
                saveName = '{}_mask.png'.format(name[b])
                cv2.imwrite(os.path.join(outDir, saveName), mask[b])
        if 'gen_mask' in out:
            gen_mask = ((1 + out['gen_mask']) * 127.5).cpu().detach().permute(
                0, 2, 3, 1).numpy().astype(np.uint8)
            for b in range(batchSize):
                saveName = 'gen_{}_mask.png'.format(name[b])
                cv2.imwrite(os.path.join(outDir, saveName), gen_mask[b])
    #return metricsOut
    for name in losses:
        losses[name] = losses[name].item()
    toRet = {
        **losses,
    }
    if 'pred' in out:
        toRet['cer'] = sum_cer
        toRet['wer'] = sum_wer
    if 'by_author' in config:
        toRet.update(by_author)

    #decode spaced
    #d_spaced = []
    #for b in range(spaced.shape[1]):
    #    string=''
    #    for i in range(spaced.shape[0]):#instance['label_lengths'][b]):
    #        index=spaced[i,b].argmax().item()
    #        if index>0:
    #            string+=trainer.idx_to_char[index]
    #        else:
    #            string+='\0'
    #    d_spaced.append(string)
    return (toRet, out)
Ejemplo n.º 9
0
def training_step(config):

    char_set_path = config['network']['hw']['char_set_path']

    with open(char_set_path) as f:
        char_set = json.load(f)

    idx_to_char = {}
    for k, v in char_set['idx_to_char'].iteritems():
        idx_to_char[int(k)] = v

    train_config = config['training']

    allowed_training_time = train_config['lf']['reset_interval']
    init_training_time = time.time()

    training_set_list = load_file_list(train_config['training_set'])
    train_dataset = LfDataset(training_set_list, augmentation=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=1,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=lf_dataset.collate)
    batches_per_epoch = int(train_config['lf']['images_per_epoch'] /
                            train_config['lf']['batch_size'])
    train_dataloader = DatasetWrapper(train_dataloader, batches_per_epoch)

    test_set_list = load_file_list(train_config['validation_set'])
    test_dataset = LfDataset(
        test_set_list,
        random_subset_size=train_config['lf']['validation_subset_size'])
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=lf_dataset.collate)

    _, lf, hw = init_model(config, only_load=['lf', 'hw'])
    hw.eval()

    dtype = torch.cuda.FloatTensor

    lowest_loss = np.inf
    lowest_loss_i = 0
    for epoch in xrange(10000000):
        lf.eval()
        sum_loss = 0.0
        steps = 0.0
        start_time = time.time()
        for step_i, x in enumerate(test_dataloader):
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                print "Skipping"
                continue

            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            skip_grid=False)

            line = torch.nn.functional.grid_sample(img.transpose(2, 3),
                                                   grid_line)
            line = line.transpose(2, 3)
            predictions = hw(line)

            out = predictions.permute(1, 0, 2).data.cpu().numpy()
            gt_line = x['gt']
            pred, raw_pred = string_utils.naive_decode(out[0])
            pred_str = string_utils.label2str_single(pred, idx_to_char, False)
            cer = error_rates.cer(gt_line, pred_str)
            sum_loss += cer
            steps += 1

            # l = line[0].transpose(0,1).transpose(1,2)
            # l = (l + 1)*128
            # l_np = l.data.cpu().numpy()
            #
            # cv2.imwrite("example_line_out.png", l_np)
            # print "Saved!"
            # raw_input()

            # loss = lf_loss.point_loss(xy_output, xy_positions)
            #
            # sum_loss += loss.data[0]
            # steps += 1

        if epoch == 0:
            print "First Validation Step Complete"
            print "Benchmark Validation Loss:", sum_loss / steps
            lowest_loss = sum_loss / steps

            _, lf, _ = init_model(config, lf_dir='current', only_load="lf")

            optimizer = torch.optim.Adam(
                lf.parameters(), lr=train_config['lf']['learning_rate'])
            optim_path = os.path.join(train_config['snapshot']['current'],
                                      "lf_optim.pt")
            if os.path.exists(optim_path):
                print "Loading Optim Settings"
                optimizer.load_state_dict(safe_load.torch_state(optim_path))
            else:
                print "Failed to load Optim Settings"

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print "Saving Best"

            dirname = train_config['snapshot']['best_validation']
            if not len(dirname) != 0 and os.path.exists(dirname):
                os.makedirs(dirname)

            save_path = os.path.join(dirname, "lf.pt")

            torch.save(lf.state_dict(), save_path)
            lowest_loss_i = 0

        test_loss = sum_loss / steps

        print "Test Loss", sum_loss / steps, lowest_loss
        print "Time:", time.time() - start_time
        print ""

        if allowed_training_time < (time.time() - init_training_time):
            print "Out of time: Exiting..."
            break

        print "Epoch", epoch
        sum_loss = 0.0
        steps = 0.0
        lf.train()
        start_time = time.time()
        for x in train_dataloader:
            if x is None:
                continue
            #Only single batch for now
            x = x[0]
            if x is None:
                continue

            positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyrs']
            ]
            xy_positions = [
                Variable(x_i.type(dtype), requires_grad=False)[None, ...]
                for x_i in x['lf_xyxy']
            ]
            img = Variable(x['img'].type(dtype), requires_grad=False)[None,
                                                                      ...]

            #There might be a way to handle this case later,
            #but for now we will skip it
            if len(xy_positions) <= 1:
                continue

            reset_interval = 4
            grid_line, _, _, xy_output = lf(img,
                                            positions[:1],
                                            steps=len(positions),
                                            all_positions=positions,
                                            reset_interval=reset_interval,
                                            randomize=True,
                                            skip_grid=True)

            loss = lf_loss.point_loss(xy_output, xy_positions)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sum_loss += loss.data.item()
            steps += 1

        print "Train Loss", sum_loss / steps
        print "Real Epoch", train_dataloader.epoch
        print "Time:", time.time() - start_time

    ## Save current snapshots for next iteration
    print "Saving Current"
    dirname = train_config['snapshot']['current']
    if not len(dirname) != 0 and os.path.exists(dirname):
        os.makedirs(dirname)

    save_path = os.path.join(dirname, "lf.pt")
    torch.save(lf.state_dict(), save_path)

    optim_path = os.path.join(dirname, "lf_optim.pt")
    torch.save(optimizer.state_dict(), optim_path)