예제 #1
0
def test(model, dataloader, idx_to_char, dtype):
    sum_loss = 0.0
    steps = 0.0
    model.eval()
    for x in dataloader:
        line_imgs = Variable(x['line_imgs'].type(dtype),
                             requires_grad=False,
                             volatile=True)
        labels = Variable(x['labels'], requires_grad=False, volatile=True)
        label_lengths = Variable(x['label_lengths'],
                                 requires_grad=False,
                                 volatile=True)
        online = Variable(x['online'].type(dtype),
                          requires_grad=False).view(1, -1, 1)

        preds = model(line_imgs, online).cpu()

        output_batch = preds.permute(1, 0, 2)
        out = output_batch.data.cpu().numpy()

        for i, gt_line in enumerate(x['gt']):
            logits = out[i, ...]
            pred, raw_pred = string_utils.naive_decode(logits)
            pred_str = string_utils.label2str(pred, idx_to_char, False)
            cer = error_rates.cer(gt_line, pred_str)
            sum_loss += cer
            steps += 1
    test_cer = sum_loss / steps
    return test_cer
예제 #2
0
def main():
    config_path = sys.argv[1]
    image_path = sys.argv[2]

    with open(config_path) as f:
        config = json.load(f)

    idx_to_char, char_to_idx = character_set.load_char_set(
        config['character_set_path'])

    hw = crnn.create_CRNN({
        'cnn_out_size': config['network']['cnn_out_size'],
        'num_of_channels': 3,
        'num_of_outputs': len(idx_to_char) + 1
    })

    hw.load_state_dict(torch.load(config['model_save_path']))
    if torch.cuda.is_available():
        hw.cuda()
        dtype = torch.cuda.FloatTensor
        print("Using GPU")
    else:
        dtype = torch.FloatTensor
        print("No GPU detected")

    hw.eval()

    img = cv2.imread(image_path)
    if img.shape[0] != config['network']['input_height']:
        percent = float(config['network']['input_height']) / img.shape[0]
        img = cv2.resize(img, (0, 0),
                         fx=percent,
                         fy=percent,
                         interpolation=cv2.INTER_CUBIC)

    img = torch.from_numpy(img.transpose(2, 0, 1).astype(np.float32) / 128 - 1)
    img = Variable(img[None, ...].type(dtype),
                   requires_grad=False,
                   volatile=True)

    preds = hw(img)

    output_batch = preds.permute(1, 0, 2)
    out = output_batch.data.cpu().numpy()

    pred, pred_raw = string_utils.naive_decode(out[0])
    pred_str = string_utils.label2str(pred, idx_to_char, False)
    pred_raw_str = string_utils.label2str(pred_raw, idx_to_char, True)
    print(pred_raw_str)
    print(pred_str)
예제 #3
0
 def decode(self, data, as_idx=False):
     #COUNTER += 1
     #print(data.shape
     pred, full_pred = t_data = string_utils.naive_decode(data)
     #print("---"
     res = self._decode(data)
     #global COUNTER
     #if COUNTER % 100 == 0:
     #    print(COUNTER
     #    print("Pr:", string_utils.label2str_single(pred, self.idx_to_char, False).encode("utf-8")
     #    print("Ex:", res[0].encode("utf-8")
     #raw_input()
     if not as_idx:
         return res
     return string_utils.str2label_single(res[0], self.char_to_idx), res[1]
예제 #4
0
    def run(self, img):
        if img.shape[0] != self.config['network']['input_height']:
            percent = float(self.config['network']['input_height']) / img.shape[0]
            img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation=cv2.INTER_CUBIC)

        img = torch.from_numpy(img.transpose(2, 0, 1).astype(np.float32) / 128 - 1)
        img = Variable(img[None, ...].type(self.dtype), requires_grad=False, volatile=True)

        try:
            preds = self.network(img)
        except Exception as e:
            print(e)
            return "UNREADABLE"

        output_batch = preds.permute(1, 0, 2)
        out = output_batch.data.cpu().numpy()

        pred, pred_raw = string_utils.naive_decode(out[0])
        pred_str = string_utils.label2str(pred, self.idx_to_char, False)
        pred_raw_str = string_utils.label2str(pred_raw, self.idx_to_char, True)
        return pred_str
예제 #5
0
    def batch_run(self, input_batch):
        if input_batch.shape[1] != self.config['network']['input_height']:
            input_batch = batch_resize(input_batch, self.config['network']['input_height'])
        input_batch = input_batch.transpose([0,3,1,2])
        imgs = input_batch.astype(np.float32) / 128 - 1
        imgs = torch.from_numpy(imgs).type(self.dtype)
        line_imgs = Variable(imgs, requires_grad=False, volatile=True)
        try:
            preds = self.network(line_imgs)
        except Exception as e:
            print(e)
            return ['UNREADABLE' for i in range(8)]
        output_batch = preds.permute(1, 0, 2)
        out = output_batch.data.cpu().numpy()

        retval = []
        for i in range(out.shape[0]):
            logits = out[i, ...]
            pred, pred_raw = string_utils.naive_decode(logits)
            pred_str = string_utils.label2str(pred, self.idx_to_char, False)
            retval.append(pred_str)
        return retval
예제 #6
0
def run_epoch(model, dataloader, criterion, optimizer, idx_to_char, dtype):
    sum_loss = 0.0
    steps = 0.0
    model.train()
    for i, x in enumerate(dataloader):
        line_imgs = Variable(x['line_imgs'].type(dtype), requires_grad=False)
        labels = Variable(x['labels'], requires_grad=False)
        label_lengths = Variable(x['label_lengths'], requires_grad=False)
        online = Variable(x['online'].type(dtype),
                          requires_grad=False).view(1, -1, 1)

        preds = model(line_imgs, online).cpu()
        preds_size = Variable(torch.IntTensor([preds.size(0)] * preds.size(1)))

        output_batch = preds.permute(1, 0, 2)
        out = output_batch.data.cpu().numpy()

        loss = criterion(preds, labels, preds_size, label_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # if i == 0:
        #    for i in xrange(out.shape[0]):
        #        pred, pred_raw = string_utils.naive_decode(out[i,...])
        #        pred_str = string_utils.label2str(pred_raw, idx_to_char, True)
        #        print(pred_str)

        for j in range(out.shape[0]):
            logits = out[j, ...]
            pred, raw_pred = string_utils.naive_decode(logits)
            pred_str = string_utils.label2str(pred, idx_to_char, False)
            gt_str = x['gt'][j]
            cer = error_rates.cer(gt_str, pred_str)
            sum_loss += cer
            steps += 1

    training_cer = sum_loss / steps
    return training_cer
예제 #7
0
def main():
    torch.manual_seed(68)
    torch.backends.cudnn.deterministic = True

    print(torch.LongTensor(10).random_(0, 10))

    config_path = sys.argv[1]
    RIMES = (config_path.lower().find('rimes') != -1)
    print(RIMES)
    with open(config_path) as f:
        config = json.load(f)

    with open(config_path) as f:
        paramList = f.readlines()

    baseMessage = ""

    for line in paramList:
        baseMessage = baseMessage + line


    # print(baseMessage)
    # lexicon = aps.ApproxLookupTable(generate_dict.get_lexicon())


    idx_to_char, char_to_idx = character_set.load_char_set(config['character_set_path'])
    # val_dataset = HwDataset(config['validation_set_path'], char_to_idx, img_height=config['network']['input_height'], root_path=config['image_root_directory'])
    # val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=0, collate_fn=hw_dataset.collate)

    train_dataset = HwDataset(config['training_set_path'], char_to_idx, img_height=config['network']['input_height'],
                              root_path=config['image_root_directory'], augmentation=False)
    try:
        val_dataset = HwDataset(config['validation_set_path'], char_to_idx, img_height=config['network']['input_height'], root_path=config['image_root_directory'], remove_errors=True)

    except KeyError as e:
        print("No validation set found, generating one")
        master = train_dataset
        print("Total of " +str(len(master)) +" Training Examples")
        n = len(master)  # how many total elements you have
        n_test = int(n * .1)
        n_train = n - n_test
        idx = list(range(n))  # indices to all elements
        train_idx = idx[:n_train]
        test_idx = idx[n_train:]
        val_dataset = data.Subset(master, test_idx)
        train_dataset = data.Subset(master, train_idx)
    val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=1,
                                 collate_fn=hw_dataset.collate)




    if(not RIMES):
        val2_dataset = HwDataset(config['validation2_set_path'], char_to_idx, img_height=config['network']['input_height'],
                                root_path=config['image_root_directory'], remove_errors=True)
        val2_dataloader = DataLoader(val2_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=0,
                                    collate_fn=hw_dataset.collate)

    test_dataset = HwDataset(config['test_set_path'], char_to_idx, img_height=config['network']['input_height'], root_path=config['image_root_directory'], remove_errors=True)
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=0, collate_fn=hw_dataset.collate)
    if config['model'] == "crnn":
        print("Using CRNN")
        hw = crnn.create_model({
            'input_height': config['network']['input_height'],
            'cnn_out_size': config['network']['cnn_out_size'],
            'num_of_channels': 3,
            'num_of_outputs': len(idx_to_char) + 1
        })
    #elif config['model'] == "urnn":
    #    print("Using URNN")
    #    hw = urnn.create_model({
    #        'input_height': config['network']['input_height'],
    #        'cnn_out_size': config['network']['cnn_out_size'],
    #        'num_of_channels': 3,
    #        'num_of_outputs': len(idx_to_char)+1,
    #        'bridge_width': config['network']['bridge_width']
    #    })
    # elif config['model'] == "urnn2":
    #     print("Using URNN with Curtis's recurrence")
    #     hw = urnn2.create_model({
    #         'input_height': config['network']['input_height'],
    #         'cnn_out_size': config['network']['cnn_out_size'],
    #         'num_of_channels': 3,
    #         'num_of_outputs': len(idx_to_char) + 1,
    #         'bridge_width': config['network']['bridge_width']
    #     })
    # elif config['model'] == "crnn2":
    #     print("Using original CRNN")
    #     hw = crnn2.create_model({
    #         'cnn_out_size': config['network']['cnn_out_size'],
    #         'num_of_channels': 3,
    #         'num_of_outputs': len(idx_to_char) + 1
    #     })
    # elif config['model'] == "urnn3":
    #     print("Using windowed URNN with Curtis's recurrence")
    #     hw = urnn_window.create_model({
    #         'input_height': config['network']['input_height'],
    #         'cnn_out_size': config['network']['cnn_out_size'],
    #         'num_of_channels': 3,
    #         'num_of_outputs': len(idx_to_char) + 1,
    #         'bridge_width': config['network']['bridge_width']
    #     })
    hw.load_state_dict(torch.load(config['model_load_path']))

    if torch.cuda.is_available():
        hw.cuda()
        dtype = torch.cuda.FloatTensor
        print("Using GPU")
    else:
        dtype = torch.FloatTensor
        print("No GPU detected")
    message = ""

    print(char_to_idx)
    voc = " "
    for x in range(1, len(idx_to_char) + 1):
        voc = voc + idx_to_char[x]
    print(voc)


    tot_ce = 0.0
    tot_we = 0.0
    sum_loss = 0.0
    sum_wer = 0.0
    sum_beam_loss = 0.0
    sum_beam_wer = 0.0
    steps = 0.0
    hw.eval()
    # idx_to_char[0] = ''
    print(idx_to_char)
    print("Validation Set Size = " + str(len(val_dataloader)))
    # , model_path="common_crawl_00.prune01111.trie.klm"
    #decoder = ctcdecode.CTCBeamDecoder(voc, beam_width=100, beta = 0, model_path="iam5.klm",  blank_id=0, log_probs_input=True)
    # for x in val_dataloader:
    #     if x is None:
    #         continue
    #     with torch.no_grad():
    #         line_imgs = Variable(x['line_imgs'].type(dtype), requires_grad=False)
    #         # labels =  Variable(x['labels'], requires_grad=False, volatile=True)
    #         # label_lengths = Variable(x['label_lengths'], requires_grad=False, volatile=True)
    #         preds = hw(line_imgs)
    #         output_batch = preds.permute(1, 0, 2)
    #         beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(output_batch)
    #         beam_result = beam_result[:,0,]
    #         beam_results = []
    #         o = 0
    #         for i in beam_result:
    #             beam_results.append(i[:out_seq_len[o,0].data.cpu().numpy()].data.cpu().numpy())
    #             o+=1
    #         beam_strings = []
    #         for i in beam_results:
    #             beam_strings.append(string_utils.label2str(i, idx_to_char, False))
    #         out = output_batch.data.cpu().numpy()
    #         for i, gt_line in enumerate(x['gt']):
    #             logits = out[i, ...]
    #             pred, raw_pred = string_utils.naive_decode(logits)
    #             pred_str = string_utils.label2str(pred, idx_to_char, False)
    #             # print(gt_line)
    #             # lex_pred = string_utils.lexicon_decode(preds[:,i:],pred_str,raw_pred,lexicon,char_to_idx = char_to_idx)
    #             # print("-----------------")
    #             # print(gt_line)
    #             # print(pred_str)
    #             # print(beam_strings[i])
    #             # print("-----------------")
    #             wer = error_rates.wer(gt_line, pred_str)
    #             beam_wer = error_rates.wer(gt_line, beam_strings[i])
    #             sum_wer += wer
    #             sum_beam_wer += beam_wer
    #             cer = error_rates.cer(gt_line, pred_str)
    #             beam_cer = error_rates.cer(gt_line, beam_strings[i])
    #             tot_we += wer * len(gt_line.split())
    #             tot_ce += cer * len(u' '.join(gt_line.split()))
    #             sum_loss += cer
    #             sum_beam_loss += beam_cer
    #             steps += 1
    #
    # message = message + "\nTest CER: " + str(sum_loss / steps)
    # message = message + "\nTest WER: " + str(sum_wer / steps)
    # message = message + "\nBeam CER: " + str(sum_beam_loss / steps)
    # message = message + "\nBeam WER: " + str(sum_beam_wer / steps)
    # print("Validation CER", sum_loss / steps)
    # print("Validation WER", sum_wer / steps)
    # print("Beam CER", sum_beam_loss / steps)
    # print("Beam wER", sum_beam_wer / steps)
    # print("Total character Errors:", tot_ce)
    # print("Total word errors", tot_we)
    # tot_ce = 0.0
    # tot_we = 0.0
    # sum_loss = 0.0
    # sum_wer = 0.0
    # sum_beam_loss = 0.0
    # sum_beam_wer = 0.0
    # steps = 0.0
    # hw.eval()
    #
    # if not RIMES:
    #     print("Validation 2 Set Size = " + str(len(val2_dataloader)))
    #     for x in val2_dataloader:
    #         if x is None:
    #             continue
    #         with torch.no_grad():
    #             line_imgs = Variable(x['line_imgs'].type(dtype), requires_grad=False)
    #             # labels =  Variable(x['labels'], requires_grad=False, volatile=True)
    #             # label_lengths = Variable(x['label_lengths'], requires_grad=False, volatile=True)
    #             preds = hw(line_imgs)
    #             preds = preds.permute(1, 0, 2)
    #             beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(preds)
    #             beam_result = beam_result[:, 0, ]
    #             beam_results = []
    #             o = 0
    #             for i in beam_result:
    #                 beam_results.append(i[:out_seq_len[o, 0].data.cpu().numpy()].data.cpu().numpy())
    #                 o += 1
    #             beam_strings = []
    #             for i in beam_results:
    #                 beam_strings.append(string_utils.label2str(i, idx_to_char, False))
    #             output_batch = preds
    #             out = output_batch.data.cpu().numpy()
    #             for i, gt_line in enumerate(x['gt']):
    #                 logits = out[i, ...]
    #                 pred, raw_pred = string_utils.naive_decode(logits)
    #                 pred_str = string_utils.label2str(pred, idx_to_char, False)
    #                 # print("-----------------")
    #                 # print(gt_line)
    #                 # print(pred_str)
    #                 # print(beam_strings[i])
    #                 # print("-----------------")
    #                 wer = error_rates.wer(gt_line, pred_str)
    #                 beam_wer = error_rates.wer(gt_line, beam_strings[i])
    #                 sum_wer += wer
    #                 sum_beam_wer += beam_wer
    #                 cer = error_rates.cer(gt_line, pred_str)
    #                 beam_cer = error_rates.cer(gt_line, beam_strings[i])
    #                 tot_we += wer * len(gt_line.split())
    #                 tot_ce += cer * len(u' '.join(gt_line.split()))
    #                 sum_loss += cer
    #                 sum_beam_loss += beam_cer
    #                 steps += 1
    #
    #     message = message + "\nTest CER: " + str(sum_loss / steps)
    #     message = message + "\nTest WER: " + str(sum_wer / steps)
    #     message = message + "\nBeam CER: " + str(sum_beam_loss / steps)
    #     message = message + "\nBeam WER: " + str(sum_beam_wer / steps)
    #     print("Validation CER", sum_loss / steps)
    #     print("Validation WER", sum_wer / steps)
    #     print("Beam CER", sum_beam_loss / steps)
    #     print("Beam wER", sum_beam_wer / steps)
    #     print("Total character Errors:", tot_ce)
    #     print("Total word errors", tot_we)
    #     tot_ce = 0.0
    #     tot_we = 0.0
    #     sum_loss = 0.0
    #     sum_wer = 0.0
    #     sum_beam_loss = 0.0
    #     sum_beam_wer = 0.0
    #     steps = 0.0
    #     hw.eval()
    # print("Test Set Size = " + str(len(test_dataloader)))


    for x in test_dataloader:
        if x is None:
            continue
        with torch.no_grad():
            line_imgs = Variable(x['line_imgs'].type(dtype), requires_grad=False)
            # labels =  Variable(x['labels'], requires_grad=False, volatile=True)
            # label_lengths = Variable(x['label_lengths'], requires_grad=False, volatile=True)
            preds = hw(line_imgs)
            preds = preds.permute(1, 0, 2)
            #beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(preds)
            #beam_result = beam_result[:, 0, ]
            #beam_results = []
            #o = 0
            #for i in beam_result:
            #    beam_results.append(i[:out_seq_len[o, 0].data.cpu().numpy()].data.cpu().numpy())
            #    o += 1
            #beam_strings = []
            #for i in beam_results:
            #    beam_strings.append(string_utils.label2str(i, idx_to_char, False))
            output_batch = preds
            out = output_batch.data.cpu().numpy()
            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str(pred, idx_to_char, False)
                # print("-----------------")
                # print(gt_line)
                # print(pred_str)
                # print(beam_strings[i])
                # print("-----------------")
                wer = error_rates.wer(gt_line, pred_str)
                #beam_wer = error_rates.wer(gt_line, beam_strings[i])
                sum_wer += wer
                #sum_beam_wer += beam_wer
                cer = error_rates.cer(gt_line, pred_str)
                #beam_cer = error_rates.cer(gt_line, beam_strings[i])
                tot_we += wer * len(gt_line.split())
                tot_ce += cer * len(u' '.join(gt_line.split()))
                sum_loss += cer
                #sum_beam_loss += beam_cer
                steps += 1

    message = message + "\nTest CER: " + str(sum_loss / steps)
    message = message + "\nTest WER: " + str(sum_wer / steps)
    #message = message + "\nBeam CER: " + str(sum_beam_loss / steps)
    #message = message + "\nBeam WER: " + str(sum_beam_wer / steps)
    print("Validation CER", sum_loss / steps)
    print("Validation WER", sum_wer / steps)
    #print("Beam CER", sum_beam_loss / steps)
    #print("Beam wER", sum_beam_wer / steps)
    print("Total character Errors:", tot_ce)
    print("Total word errors", tot_we)
    tot_ce = 0.0
    tot_we = 0.0
    sum_loss = 0.0
    sum_wer = 0.0
    #sum_beam_loss = 0.0
    #sum_beam_wer = 0.0
    steps = 0.0
예제 #8
0
def main():
    config_path = sys.argv[1]
    try:
        jobID = sys.argv[2]
    except:
        jobID = ""
    print(jobID)

    with open(config_path) as f:
        config = json.load(f)

    try:
        model_save_path = sys.argv[3]
        if model_save_path[-1] != os.path.sep:
            model_save_path = model_save_path + os.path.sep
    except:
        model_save_path = config['model_save_path']
    dirname = os.path.dirname(model_save_path)
    print(dirname)
    if len(dirname) > 0 and not os.path.exists(dirname):
        os.makedirs(dirname)

    with open(config_path) as f:
        paramList = f.readlines()

    for x in paramList:
        print(x[:-1])

    baseMessage = ""

    for line in paramList:
        baseMessage = baseMessage + line

    # print(baseMessage)

    idx_to_char, char_to_idx = character_set.load_char_set(
        config['character_set_path'])

    train_dataset = HwDataset(
        config['training_set_path'],
        char_to_idx,
        img_height=config['network']['input_height'],
        root_path=config['image_root_directory'],
        augmentation=config['augmentation'],
    )

    try:
        test_dataset = HwDataset(config['validation_set_path'],
                                 char_to_idx,
                                 img_height=config['network']['input_height'],
                                 root_path=config['image_root_directory'])
    except KeyError as e:
        print("No validation set found, generating one")
        master = train_dataset
        print("Total of " + str(len(master)) + " Training Examples")
        n = len(master)  # how many total elements you have
        n_test = int(n * .1)
        n_train = n - n_test
        idx = list(range(n))  # indices to all elements
        train_idx = idx[:n_train]
        test_idx = idx[n_train:]
        test_dataset = data.Subset(master, test_idx)
        train_dataset = data.Subset(master, train_idx)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config['batch_size'],
                                  shuffle=False,
                                  num_workers=1,
                                  collate_fn=hw_dataset.collate)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=config['batch_size'],
                                 shuffle=False,
                                 num_workers=1,
                                 collate_fn=hw_dataset.collate)
    print("Train Dataset Length: " + str(len(train_dataset)))
    print("Test Dataset Length: " + str(len(test_dataset)))

    hw = model.create_model(len(idx_to_char))
    # hw = model.create_model({
    #     'input_height': config['network']['input_height'],
    #     'cnn_out_size': config['network']['cnn_out_size'],
    #     'num_of_channels': 3,
    #     'num_of_outputs': len(idx_to_char) + 1,
    #     'bridge_width': config['network']['bridge_width']
    # })

    if torch.cuda.is_available():
        hw.cuda()
        dtype = torch.cuda.FloatTensor
        print("Using GPU")
    else:
        dtype = torch.FloatTensor
        print("No GPU detected")

    optimizer = torch.optim.Adadelta(hw.parameters(),
                                     lr=config['network']['learning_rate'])
    criterion = CTCLoss(reduction='sum', zero_infinity=True)
    lowest_loss = float('inf')
    best_distance = 0
    for epoch in range(1000):
        torch.enable_grad()
        startTime = time.time()
        message = baseMessage
        sum_loss = 0.0
        sum_wer_loss = 0.0
        steps = 0.0
        hw.train()
        disp_ctc_loss = 0.0
        disp_loss = 0.0
        gt = ""
        ot = ""
        loss = 0.0
        print("Train Set Size = " + str(len(train_dataloader)))
        prog_bar = tqdm(enumerate(train_dataloader),
                        total=len(train_dataloader))
        for i, x in prog_bar:
            # message = str("CER: " + str(disp_loss) +"\nGT: " +gt +"\nex: "+out+"\nProgress")
            prog_bar.set_description(
                f'CER: {disp_loss} CTC: {loss} Ground Truth: |{gt}| Network Output: |{ot}|'
            )
            line_imgs = x['line_imgs']
            rem = line_imgs.shape[3] % 32
            if rem != 0:
                imgshape = line_imgs.shape
                temp = torch.zeros(imgshape[0], imgshape[1], imgshape[2],
                                   imgshape[3] + (32 - rem))
                temp[:, :, :, :imgshape[3]] = line_imgs
                line_imgs = temp
                del temp
            line_imgs = Variable(line_imgs.type(dtype), requires_grad=False)

            labels = Variable(x['labels'], requires_grad=False)
            label_lengths = Variable(x['label_lengths'], requires_grad=False)

            preds = hw(line_imgs).cpu()
            preds_size = Variable(
                torch.IntTensor([preds.size(0)] * preds.size(1)))

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()
            loss = criterion(preds, labels, preds_size, label_lengths)
            # print(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # if i == 0:
            #    for i in xrange(out.shape[0]):
            #        pred, pred_raw = string_utils.naive_decode(out[i,...])
            #        pred_str = string_utils.label2str(pred_raw, idx_to_char, True)
            #        print(pred_str)

            for j in range(out.shape[0]):
                logits = out[j, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str(pred, idx_to_char, False)
                gt_str = x['gt'][j]
                cer = error_rates.cer(gt_str, pred_str)
                wer = error_rates.wer(gt_str, pred_str)
                gt = gt_str
                ot = pred_str
                sum_loss += cer
                sum_wer_loss += wer
                steps += 1
            disp_loss = sum_loss / steps
        eTime = time.time() - startTime
        message = message + "\n" + "Epoch: " + str(
            epoch) + " Training CER: " + str(
                sum_loss / steps) + " Training WER: " + str(
                    sum_wer_loss /
                    steps) + "\n" + "Time: " + str(eTime) + " Seconds"
        print("Epoch: " + str(epoch) + " Training CER", sum_loss / steps)
        print("Training WER: " + str(sum_wer_loss / steps))
        print("Time: " + str(eTime) + " Seconds")
        sum_loss = 0.0
        sum_wer_loss = 0.0
        steps = 0.0
        hw.eval()
        print("Validation Set Size = " + str(len(test_dataloader)))
        for x in tqdm(test_dataloader):
            torch.no_grad()
            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False)
            # labels =  Variable(x['labels'], requires_grad=False, volatile=True)
            # label_lengths = Variable(x['label_lengths'], requires_grad=False, volatile=True)
            preds = hw(line_imgs).cpu()
            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()
            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str(pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                wer = error_rates.wer(gt_line, pred_str)
                sum_wer_loss += wer
                sum_loss += cer
                steps += 1

        message = message + "\nTest CER: " + str(sum_loss / steps)
        message = message + "\nTest WER: " + str(sum_wer_loss / steps)
        print("Test CER", sum_loss / steps)
        print("Test WER", sum_wer_loss / steps)
        best_distance += 1
        metric = "CER"
        if (metric == "CER"):
            if lowest_loss > sum_loss / steps:
                lowest_loss = sum_loss / steps
                print("Saving Best")
                message = message + "\nBest Result :)"
                torch.save(hw.state_dict(),
                           os.path.join(model_save_path + str(epoch) + ".pt"))
                email_update(message, jobID)
                best_distance = 0
            if best_distance > 800:
                break
        elif (metric == "WER"):
            if lowest_loss > sum_wer_loss / steps:
                lowest_loss = sum_wer_loss / steps
                print("Saving Best")
                message = message + "\nBest Result :)"
                torch.save(hw.state_dict(),
                           os.path.join(model_save_path + str(epoch) + ".pt"))
                email_update(message, jobID)
                best_distance = 0
            if best_distance > 80:
                break
        else:
            print("This is actually very bad")
예제 #9
0
def main():
    config_path = sys.argv[1]

    with open(config_path) as f:
        config = json.load(f)

    idx_to_char, char_to_idx = character_set.load_char_set(
        config['character_set_path'])

    train_dataset = HwDataset(config['training_set_path'],
                              char_to_idx,
                              img_height=config['network']['input_height'],
                              root_path=config['image_root_directory'],
                              augmentation=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=8,
                                  shuffle=False,
                                  num_workers=0,
                                  collate_fn=hw_dataset.collate)

    test_dataset = HwDataset(config['validation_set_path'],
                             char_to_idx,
                             img_height=config['network']['input_height'],
                             root_path=config['image_root_directory'])
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=8,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=hw_dataset.collate)

    hw = crnn.create_model({
        'cnn_out_size': config['network']['cnn_out_size'],
        'num_of_channels': 3,
        'num_of_outputs': len(idx_to_char) + 1
    })

    if torch.cuda.is_available():
        hw.cuda()
        dtype = torch.cuda.FloatTensor
        print("Using GPU")
    else:
        dtype = torch.FloatTensor
        print("No GPU detected")

    optimizer = torch.optim.Adam(hw.parameters(),
                                 lr=config['network']['learning_rate'])
    criterion = CTCLoss()
    lowest_loss = float('inf')
    for epoch in range(1000):
        print('epoch', epoch)
        sum_loss = 0.0
        steps = 0.0
        hw.train()
        for i, x in enumerate(train_dataloader):
            print(i, '/', len(train_dataloader))
            if x['line_imgs'].shape[3] > 500:
                continue
            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False)
            labels = Variable(x['labels'], requires_grad=False)
            label_lengths = Variable(x['label_lengths'], requires_grad=False)

            try:
                preds = hw(line_imgs).cpu()
            except Exception as e:
                print(e)
                continue
            preds_size = Variable(
                torch.IntTensor([preds.size(0)] * preds.size(1)))

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            loss = criterion(preds, labels, preds_size, label_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #if i == 0:
            #    for i in xrange(out.shape[0]):
            #        pred, pred_raw = string_utils.naive_decode(out[i,...])
            #        pred_str = string_utils.label2str(pred_raw, idx_to_char, True)
            #        print(pred_str)

            for j in range(out.shape[0]):
                logits = out[j, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str(pred, idx_to_char, False)
                gt_str = x['gt'][j]
                cer = error_rates.cer(gt_str, pred_str)
                sum_loss += cer
                steps += 1

        if steps == 0.0 or steps == 0:
            cer = "Error"
        else:
            cer = sum_loss / steps
        print("Training CER:", cer)

        sum_loss = 0.0
        steps = 0.0
        hw.eval()
        for x in test_dataloader:
            line_imgs = Variable(x['line_imgs'].type(dtype),
                                 requires_grad=False,
                                 volatile=True)
            labels = Variable(x['labels'], requires_grad=False, volatile=True)
            label_lengths = Variable(x['label_lengths'],
                                     requires_grad=False,
                                     volatile=True)

            try:
                preds = hw(line_imgs).cpu()
            except Exception as e:
                print(e)
                continue

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()

            for i, gt_line in enumerate(x['gt']):
                logits = out[i, ...]
                pred, raw_pred = string_utils.naive_decode(logits)
                pred_str = string_utils.label2str(pred, idx_to_char, False)
                cer = error_rates.cer(gt_line, pred_str)
                sum_loss += cer
                steps += 1

        if steps == 0.0 or steps == 0:
            cer = "Error"
        else:
            cer = sum_loss / steps
        print("Test CER", cer)

        if lowest_loss > sum_loss / steps:
            lowest_loss = sum_loss / steps
            print("Saving Best")
            dirname = os.path.dirname(config['model_save_path'])
            if len(dirname) > 0 and not os.path.exists(dirname):
                os.makedirs(dirname)

            torch.save(hw.state_dict(),
                       os.path.join(config['model_save_path']))