Example #1
0
def find_perfect_indices(logits, target_transcript, char_encoder):
    """Given the logits and expected transcript labels, find the pairs with
    perfect prediction from the logits.
    """
    perfect_indices = []
    for i, gt_line in enumerate(target_transcript):
        line_encoded = logits[i, ...]

        pred = string_utils.naive_decode(line_encoded)[0]

        pred_str = string_utils.label2str(
            pred,
            char_encoder.encoder.inverse,
            False,
            char_encoder.blank_char,
            char_encoder.blank_idx,
        )

        if error_rates.cer(gt_line, pred_str) <= 0:
            perfect_indices.append(i)

    # TODO find the indicies within the logits that have perf predictions
    #   currently, only find indices of logits w/ correct line predictions,
    #   but must confirm that means that every character pred is correct within
    #   the logits. Note naive decode does not add char from rawPredData to
    #   pred if that char is the same as the prior, this means that if the cer
    #   is zero, then every character within the logits is correct and so all
    #   char indices of the logits, and their respective layer may be used to
    #   train the MEVM.

    return perfect_indices
def eval_transcription(texts, preds):
    """Evaluates the given predicted transcriptions to expected texts.

    Parameters
    ----------
    preds : np.ndarray
        array of shape [samples], where samples is the
        number of sample lines, line_length is the length of the lines.
    texts : np.ndarray(str)
        An iterable of strings that represents the expected lines of texts to
        be predicted.
    """
    total_cer = 0
    total_wer = 0

    for i, pred in enumerate(preds):
        total_cer += error_rates.cer(texts[i], pred)
        total_wer += error_rates.cer(texts[i], pred)

    return TranscriptResults(total_cer / len(texts), total_wer / len(texts))
def eval_transcription_logits(
    texts,
    logits,
    label_encoder,
    decode='naive',
    argmax=True,
):
    """Evaluates the given predicted transcriptions to expected texts where the
    predictions are given as a probability vector per character.

    Parameters
    ----------
    preds : np.ndarray
        array of shape [samples, timesteps, characters], where samples is the
        number of samples, timesteps is the number timesteps of the respective
        RNN's output, and characters is the number of known characters by the
        predictor.
    texts : np.ndarray(str)
        An iterable of strings that represents the expected lines of texts to
        be predicted.
    """
    total_cer = 0
    total_wer = 0

    for i, logit in enumerate(logits):
        pred, raw_pred = string_utils.naive_decode(logit, argmax)

        pred_str = string_utils.label2str(
            pred,
            label_encoder.encoder.inverse,
            False,
            blank_char=label_encoder.blank_char,
            blank=label_encoder.blank_idx,
        )

        total_cer += error_rates.cer(texts[i], pred_str)
        total_wer += error_rates.cer(texts[i], pred_str)

        # TODO log CER and WER for each line?

    return TranscriptResults(total_cer / len(texts), total_wer / len(texts))
Example #4
0
            # Check pred and gt pairing to ensure they match
            assert len(gt) == len(pred)
            assert gt.keys() == pred.keys()

            # Initialize the saved values for the split of this fold
            cer_sum = 0
            wer_sum = 0

            unique_chars = set()

            pred_nd = []
            actual_nd = []

            for key, actual in gt.items():
                # Calculate the transcription results
                cer = error_rates.cer(actual, pred[key])
                wer = error_rates.wer(actual, pred[key])

                cer_sum += cer
                wer_sum += wer

                actual_char_set = set(actual)
                unique_chars |= actual_char_set

                # Novelty predicted when novel/unknown char '#' in pred
                pred_nd.append(
                    'unknown' if unknown_char in pred[key] else 'known'
                )
                # novelty exists when novel char in gt
                if any([
                    v == unknown_char or v not in known_chars
Example #5
0
def eval_crnn(
    hw_crnn,
    dataloader,
    char_encoder,
    dtype,
    output_crnn_eval=True,
    layer=None,
    return_logits=True,
    return_slice=False,
    deterministic=True,
    random_seed=None,
    return_col_chars=False,
    skip_none_labels=True,
):
    """Evaluates CRNN and returns the CRNN output. Optionally, this is also
    used to obtain certain layer's outputs such as the penultimate RNN or CNN
    layers of the CRNN.

    Parameters
    ----------
    hw_crnn :
        The
    dataloader :
        Pytorch dataloader for input to CRNN
    dtype : type
        Pytorch type of a the handwritten line image data. Used to handle CPU
        or GPU use.
    output_crnn_eval : bool, optional
        Outputs the CRNNs performance without the EVM.
    layer : str, optional
        If 'rnn', uses the CRNN's final RNN output as input to the MultipleEVM.
        If 'conv', uses the final convolutional layer's output. If 'concat',
        then returns both concatenated together (Concat is to be implemented).
        Defaults to None, and thus only evaluates the CRNN itself.
    return_logits : bool, optional
    return_slice : bool, optional
        Returns the indices of the perfect slices

    Returns
    -------
    list(np.ndarray)
        Returns a list of the selected layer's output for each input sample.
        `layer` determines which layer of the CRNN is used. The shape of each
        np.ndarray is [glyph_window, classes]. This assumes batch size is
        always 1.
    """
    if deterministic:
        if random_seed is None:
            logging.warning(' '.join([
                'Model is being evaluated and deterministic is true, but no',
                'seed was provided. The default seed of `4816` is being used.',
            ]))
            random_seed = 4816
        logging.info('random seed = %d', random_seed)

        torch.manual_seed(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        logging.warning(' '.join([
            'Model is being evaluated and deterministic is False! The results',
            'may not be reproducible now due to GPU hardware, even if there',
            'is no shuffling or updating of the model.',
        ]))

    # Initialize metrics
    if output_crnn_eval or return_slice or save_col_chars:
        tot_ce = 0.0
        tot_we = 0.0
        sum_loss = 0.0
        sum_wer = 0.0
        steps = 0.0
        total_chars = 0.0
        total_words = 0.0

        count_skips = 0

    hw_crnn.eval()

    layer_outs = []

    if return_logits:
        logits_list = []

    if return_slice:
        perfect_indices = []

    if return_col_chars:
        col_chars = []

    # For batch in dataloader
    for x in dataloader:
        if x is None:
            continue
        with torch.no_grad():
            line_imgs = Variable(
                x['line_imgs'].type(dtype),
                requires_grad=False,
            )

            if layer is None:
                preds = hw_crnn(line_imgs)
            elif layer.lower() == 'rnn':
                preds, layer_out = hw_crnn(line_imgs, return_rnn=True)

                # Shape is then [timesteps, hidden layer width]
                layer_outs.append(layer_out.data.cpu().numpy())

            elif layer.lower() in {'conv', 'cnn'}:
                # Last Convolution Layer
                preds, layer_out = hw_crnn(line_imgs, return_conv=True)

                # Shape is then [timesteps, conv layer flat: height * width]
                layer_outs.append(layer_out.data.cpu().numpy())
                layer_outs.append(
                    np.squeeze(layer_out.permute(1, 0, 2).data.cpu().numpy()))
            else:
                raise NotImplementedError('Concat/both RNN and Conv of CRNN.')

            # Swap 0 and 1 indices to have:
            #   batch sample, "character window", classes
            # Except, since batch sample is always 1 here, that dim is removed:
            #   "character windows", classes

            if layer is None or output_crnn_eval:
                output_batch = preds.permute(1, 0, 2)
                out = output_batch.data.cpu().numpy()

                # Consider MEVM input here after enough obtained to do batch
                # training Or save the layer_outs to be used in training the
                # MEVM

                # Loop through the batch
                for i, gt_line in enumerate(x['gt']):
                    if skip_none_labels and (gt_line is None or gt_line == ''):
                        count_skips += 1
                        logging.debug(
                            'No ground truth label. Count: %d; `%s`',
                            count_skips,
                            x['gt'],
                        )
                        continue
                    logits = out[i, ...]

                    pred, raw_pred = string_utils.naive_decode(logits)

                    pred_str = string_utils.label2str(
                        pred,
                        char_encoder.encoder.inverse,
                        False,
                        char_encoder.blank_char,
                        char_encoder.blank_idx,
                    )

                    wer = error_rates.wer(gt_line, pred_str)
                    sum_wer += wer

                    cer = error_rates.cer(gt_line, pred_str)

                    tot_we += wer * len(gt_line.split())
                    tot_ce += cer * len(u' '.join(gt_line.split()))

                    total_words += len(gt_line.split())
                    total_chars += len(u' '.join(gt_line.split()))

                    sum_loss += cer

                    if return_slice and cer <= 0:
                        perfect_indices.append(steps)

                    if return_logits:
                        logits_list.append(logits)

                    if return_col_chars:
                        # NOTE this is easily obtainable w/ the DataLoader
                        # alone, so consider removing this here. Only useful
                        # when shuffle is on in the dataloader.
                        col_chars.append(x['col_chars'][i])

                    steps += 1

    if layer is None or output_crnn_eval:
        logging.info('CRNN results:')
        logging.info("Eval CER %f", sum_loss / steps)
        logging.info("Eval WER %f", sum_wer / steps)

        logging.info("Total character Errors: %d", tot_ce)
        logging.info("Total characters: %d", total_chars)
        logging.info("Total character errors rate: %f", tot_ce / total_chars)

        logging.info("Total word errors %d", tot_we)
        logging.info("Total words: %d", total_words)
        logging.info("Total word error rate: %f", tot_we / total_words)

    # NOTE that the way this is setup, it always expects to return the layers
    if not (return_logits or isinstance(layer, str) or return_slice
            or return_col_chars):
        return None

    return_list = []
    if return_logits:
        logging.debug(
            'logits_list shapes:\n%s',
            [logit.shape for logit in logits_list],
        )
        return_list.append(logits_list)

    if isinstance(layer, str):
        logging.debug('layer shapes:\n%s',
                      [layer.shape for layer in layer_outs])
        return_list.append(layer_outs)

    if return_slice:
        logging.debug('perfect_indices len: %d', len(perfect_indices))
        logging.debug('perfect_indices:\n%s', perfect_indices)

        return_list.append(perfect_indices)

    if return_col_chars:
        return_list.append(col_chars)

    #return tuple(return_list)
    return return_list
Example #6
0
def train_crnn(
    hw_crnn,
    optimizer,
    criterion,
    char_encoder,
    train_dataloader,
    dtype,
    model_save_path=None,
    test_dataloader=None,
    epochs=1000,
    metric='CER',
    base_message='',
    thresh=None,
    max_epochs_no_improvement=800,
    skip_none_labels=True,
):
    """Streamline the training of the CRNN."""
    # Variables for training loop
    lowest_loss = float('inf')
    best_distance = 0

    # Training Epoch Loop
    for epoch in range(epochs):
        torch.enable_grad()

        startTime = time.time()
        message = base_message

        sum_loss = 0.0
        sum_wer_loss = 0.0
        steps = 0.0

        hw_crnn.train()

        disp_ctc_loss = 0.0
        disp_loss = 0.0
        gt = ""
        ot = ""
        loss = 0.0

        count_skips = 0
        count_skips_train = 0

        logging.info("Train Set Size = %d", len(train_dataloader))

        # Training Batch Loop
        prog_bar = tqdm(
            enumerate(train_dataloader),
            total=len(train_dataloader),
        )
        for i, x in prog_bar:
            prog_bar.set_description(' '.join([
                f'CER: {disp_loss} CTC: {loss} Ground Truth: |{gt}| Network',
                f'Output: |{ot}|',
            ]))

            for ground_truth in x['gt']:
                if ground_truth is None or ground_truth == '':
                    raise ValueError('Ground Truth is None or empty string!')

            line_imgs = x['line_imgs']
            #"""
            rem = line_imgs.shape[3] % 32
            if rem != 0:
                imgshape = line_imgs.shape
                temp = torch.zeros(
                    imgshape[0],
                    imgshape[1],
                    imgshape[2],
                    imgshape[3] + (32 - rem),
                )
                temp[:, :, :, :imgshape[3]] = line_imgs
                line_imgs = temp
                del temp
            #"""
            line_imgs = Variable(line_imgs.type(dtype), requires_grad=False)

            labels = Variable(x['labels'], requires_grad=False)
            label_lengths = Variable(x['label_lengths'], requires_grad=False)

            preds = hw_crnn(line_imgs).cpu()
            preds_size = Variable(
                torch.IntTensor([preds.size(0)] * preds.size(1)))

            output_batch = preds.permute(1, 0, 2)
            out = output_batch.data.cpu().numpy()
            loss = criterion(preds, labels, preds_size, label_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # TODO output the Loss of the network!
            # TODO and change sum_loss name to sum_cer

            # Training Eval loop on training data
            for j in range(out.shape[0]):
                if skip_none_labels and (gt_line is None or gt_line == ''):
                    count_skips_train += 1
                    logging.debug(
                        'No ground truth label for train sample. Count: %d; `%s`',
                        count_skips_train,
                        x['gt'],
                    )
                    raise ValueError('Ground Truth is empty string!')

                logits = out[j, ...]

                pred, raw_pred = string_utils.naive_decode(logits)

                pred_str = string_utils.label2str(
                    pred,
                    char_encoder.encoder.inverse,
                    False,
                    char_encoder.blank_char,
                    char_encoder.blank_idx,
                )

                gt_str = x['gt'][j]
                cer = error_rates.cer(gt_str, pred_str)
                wer = error_rates.wer(gt_str, pred_str)
                gt = gt_str
                ot = pred_str
                sum_loss += cer
                sum_wer_loss += wer
                steps += 1
            disp_loss = sum_loss / steps
        eTime = time.time() - startTime

        message = (message + "\n" + "Epoch: " + str(epoch) +
                   " Training CER: " + str(sum_loss / steps) +
                   " Training WER: " + str(sum_wer_loss / steps) + "\n" +
                   "Time: " + str(eTime) + " Seconds")

        logging.info("Epoch: %d: Training CER %f", epoch, sum_loss / steps)
        logging.info("Training WER: %f", sum_wer_loss / steps)
        logging.info("Time: %f Seconds.", eTime)

        sum_loss = 0.0
        sum_wer_loss = 0.0
        steps = 0.0
        hw_crnn.eval()

        # Validation loop per epoch
        if test_dataloader is not None:
            logging.info("Validation Set Size = %d", len(test_dataloader))

            for x in tqdm(test_dataloader):

                torch.no_grad()
                line_imgs = Variable(
                    x['line_imgs'].type(dtype),
                    requires_grad=False,
                )

                preds = hw_crnn(line_imgs).cpu()
                output_batch = preds.permute(1, 0, 2)
                out = output_batch.data.cpu().numpy()
                for i, gt_line in enumerate(x['gt']):
                    if skip_none_labels and (gt_line is None or gt_line == ''):
                        count_skips += 1
                        logging.debug(
                            'No ground truth label. Count: %d; `%s`',
                            count_skips,
                            x['gt'],
                        )
                        continue
                    logits = out[i, ...]
                    pred, raw_pred = string_utils.naive_decode(logits)

                    pred_str = string_utils.label2str(
                        pred,
                        char_encoder.encoder.inverse,
                        False,
                        char_encoder.blank_char,
                        char_encoder.blank_idx,
                    )

                    cer = error_rates.cer(gt_line, pred_str)
                    wer = error_rates.wer(gt_line, pred_str)
                    sum_wer_loss += wer
                    sum_loss += cer
                    steps += 1

            message = message + "\nTest CER: " + str(sum_loss / steps)
            message = message + "\nTest WER: " + str(sum_wer_loss / steps)
            logging.info("Test CER %f", sum_loss / steps)
            logging.info("Test WER %f", sum_wer_loss / steps)
            best_distance += 1

            # Repeatedly saves the best performing model so-far based on Val.
            if metric == "CER":
                if lowest_loss > sum_loss / steps:
                    if thresh and lowest_loss - sum_loss / steps > thresh:
                        lowest_loss = sum_loss / steps
                        logging.info("Saving Best")
                        message = message + "\nBest Result :)"
                        torch.save(
                            hw_crnn.state_dict(),
                            os.path.join(
                                model_save_path,
                                f'crnn_ep{str(epoch)}.pt',
                            ),
                        )
                        best_distance = 0
                    elif thresh is None:
                        # Save the weights for this epoch if the ANN has the
                        # lowest CER yet. NOTE that this is not the Loss of the
                        # network, but the CER.
                        # TODO include saving network w/ best ANN Loss on Val
                        lowest_loss = sum_loss / steps
                        logging.info("Saving Best")
                        message = message + "\nBest Result :)"
                        torch.save(
                            hw_crnn.state_dict(),
                            os.path.join(
                                model_save_path,
                                f'crnn_ep{str(epoch)}.pt',
                            ),
                        )
                        best_distance = 0
                if best_distance > max_epochs_no_improvement:
                    break
            elif metric == "WER":
                if lowest_loss > sum_wer_loss / steps:
                    lowest_loss = sum_wer_loss / steps
                    logging.info("Saving Best")
                    message = message + "\nBest Result :)"
                    torch.save(
                        hw_crnn.state_dict(),
                        os.path.join(
                            model_save_path,
                            f'crnn_ep{str(epoch)}.pt',
                        ),
                    )
                    best_distance = 0
                if best_distance > max_epochs_no_improvement:
                    break
            else:
                raise ValueError("This is actually very bad")
    return