Пример #1
0
    def deserve_idx(self, decoded, len_decoded, labels, len_labels):
        """
        if one sent is correct during training, then not to train on it
        """
        decoded_sparse = dense_sequence_to_sparse(seq=decoded,
                                                  len_seq=len_decoded)
        label_sparse = dense_sequence_to_sparse(seq=labels, len_seq=len_labels)

        distance = tf.edit_distance(decoded_sparse,
                                    label_sparse,
                                    normalize=False)
        indices = tf.where(distance > 1)

        return indices
Пример #2
0
def policy_ctc_loss(logits, len_logits, flabels, len_flabels, batch_reward, args, ctc_merge_repeated=True):
    """
    flabels: not the ground-truth
    if len_flabels=None, means the `flabels` is sparse
    """
    from tfTools.math_tf import non_linear
    from tfTools.tfTools import dense_sequence_to_sparse

    with tf.name_scope("policy_ctc_loss"):
        if len_flabels is not None:
            flabels_sparse = dense_sequence_to_sparse(
                flabels,
                len_flabels)
        else:
            flabels_sparse = flabels

        ctc_loss_batch = tf.nn.ctc_loss(
            flabels_sparse,
            logits,
            sequence_length=len_logits,
            ignore_longer_outputs_than_inputs=True,
            ctc_merge_repeated=ctc_merge_repeated,
            time_major=False)
        ctc_loss_batch *= batch_reward
        ctc_loss_batch = non_linear(
            ctc_loss_batch,
            args.model.non_linear,
            args.model.min_reward)
        loss = tf.reduce_mean(ctc_loss_batch) # utter-level ctc loss

    return loss, ctc_loss_batch
Пример #3
0
def policy_learning(logits, len_logits, decoded_sparse, labels, len_labels, softmax_temperature, dim_output, args):
    from tfModels.CTCLoss import ctc_sample, ctc_reduce_map
    from tfTools.tfTools import sparse_shrink, pad_to_same
    print('using policy learning')
    with tf.name_scope("policy_learning"):
        label_sparse = dense_sequence_to_sparse(labels, len_labels)
        wer_bias = tf.edit_distance(decoded_sparse, label_sparse, normalize=True)
        wer_bias = tf.stop_gradient(wer_bias)

        sampled_align = ctc_sample(logits, softmax_temperature)
        sample_sparse = ctc_reduce_map(sampled_align, id_blank=dim_output-1)
        wer = tf.edit_distance(sample_sparse, label_sparse, normalize=True)
        seq_sample, len_sample, _ = sparse_shrink(sample_sparse)

        # ==0 is not success!!
        seq_sample, labels = pad_to_same([seq_sample, labels])
        seq_sample = tf.where(len_sample<1, labels, seq_sample)
        len_sample = tf.where(len_sample<1, len_labels, len_sample)

        reward = wer_bias - wer

        rl_loss, _ = policy_ctc_loss(
            logits=logits,
            len_logits=len_logits,
            flabels=seq_sample,
            len_flabels=len_sample,
            batch_reward=reward,
            args=args)
Пример #4
0
    def rna_loss(self, logits, len_logits, labels, len_labels, encoded=None, len_encoded=None):
        with tf.name_scope("ctc_loss"):
            labels_sparse = dense_sequence_to_sparse(
                labels,
                len_labels)
            loss = tf.nn.ctc_loss(
                labels_sparse,
                logits,
                sequence_length=len_logits,
                ctc_merge_repeated=False,
                ignore_longer_outputs_than_inputs=True,
                time_major=False)

        if self.args.model.decoder.confidence_penalty:
            ls_loss = self.args.model.decoder.confidence_penalty * \
                        confidence_penalty(logits, len_logits)
            loss += ls_loss

        return loss
Пример #5
0
    def ctc_loss(self, logits, len_logits, labels, len_labels):
        """
        No valid path found: It is possible that no valid path is found if the
        activations for the targets are zero.
        """
        with tf.name_scope("ctc_loss"):
            if self.args.model.use_wrapctc:
                import warpctc_tensorflow
                from tfTools.tfTools import get_indices

                indices = get_indices(len_labels)
                flat_labels = tf.gather_nd(labels, indices)
                ctc_loss = warpctc_tensorflow.ctc(
                    activations=tf.transpose(logits, [1, 0, 2]),
                    flat_labels=flat_labels,
                    label_lengths=len_labels,
                    input_lengths=len_logits,
                    blank_label=self.args.dim_output)
            else:
                # with tf.get_default_graph()._kernel_label_map({"CTCLoss": "WarpCTC"}):
                labels_sparse = dense_sequence_to_sparse(labels, len_labels)
                ctc_loss = tf.nn.ctc_loss(
                    labels_sparse,
                    logits,
                    sequence_length=len_logits,
                    ctc_merge_repeated=self.ctc_merge_repeated,
                    ignore_longer_outputs_than_inputs=True,
                    time_major=False)

        if self.args.model.policy_learning:
            from tfModels.regularization import policy_learning

            softmax_temperature = self.model.decoder.softmax_temperature
            dim_output = self.dim_output
            decoded_sparse = self.ctc_decode(logits, len_logits)
            rl_loss = policy_learning(logits, len_logits, decoded_sparse,
                                      labels, len_labels, softmax_temperature,
                                      dim_output, self.args)
            ctc_loss += self.args.model.policy_learning * rl_loss

        return ctc_loss
Пример #6
0
    def ctc_loss(self, logits, len_logits, labels, len_labels):
        """
        No valid path found: It is possible that no valid path is found if the
        activations for the targets are zero.
        return batch shape loss
        """
        with tf.name_scope("ctc_loss"):
            labels_sparse = dense_sequence_to_sparse(labels, len_labels)
            loss = tf.nn.ctc_loss(
                labels_sparse,
                logits,
                sequence_length=len_logits,
                ctc_merge_repeated=self.args.model.avg_repeated,
                ignore_longer_outputs_than_inputs=True,
                time_major=False)

        if self.args.model.decoder.confidence_penalty:
            ls_loss = self.args.model.decoder.confidence_penalty * \
                        confidence_penalty(logits, len_logits)
            loss += ls_loss

        return loss