def deserve_idx(self, decoded, len_decoded, labels, len_labels): """ if one sent is correct during training, then not to train on it """ decoded_sparse = dense_sequence_to_sparse(seq=decoded, len_seq=len_decoded) label_sparse = dense_sequence_to_sparse(seq=labels, len_seq=len_labels) distance = tf.edit_distance(decoded_sparse, label_sparse, normalize=False) indices = tf.where(distance > 1) return indices
def policy_ctc_loss(logits, len_logits, flabels, len_flabels, batch_reward, args, ctc_merge_repeated=True): """ flabels: not the ground-truth if len_flabels=None, means the `flabels` is sparse """ from tfTools.math_tf import non_linear from tfTools.tfTools import dense_sequence_to_sparse with tf.name_scope("policy_ctc_loss"): if len_flabels is not None: flabels_sparse = dense_sequence_to_sparse( flabels, len_flabels) else: flabels_sparse = flabels ctc_loss_batch = tf.nn.ctc_loss( flabels_sparse, logits, sequence_length=len_logits, ignore_longer_outputs_than_inputs=True, ctc_merge_repeated=ctc_merge_repeated, time_major=False) ctc_loss_batch *= batch_reward ctc_loss_batch = non_linear( ctc_loss_batch, args.model.non_linear, args.model.min_reward) loss = tf.reduce_mean(ctc_loss_batch) # utter-level ctc loss return loss, ctc_loss_batch
def policy_learning(logits, len_logits, decoded_sparse, labels, len_labels, softmax_temperature, dim_output, args): from tfModels.CTCLoss import ctc_sample, ctc_reduce_map from tfTools.tfTools import sparse_shrink, pad_to_same print('using policy learning') with tf.name_scope("policy_learning"): label_sparse = dense_sequence_to_sparse(labels, len_labels) wer_bias = tf.edit_distance(decoded_sparse, label_sparse, normalize=True) wer_bias = tf.stop_gradient(wer_bias) sampled_align = ctc_sample(logits, softmax_temperature) sample_sparse = ctc_reduce_map(sampled_align, id_blank=dim_output-1) wer = tf.edit_distance(sample_sparse, label_sparse, normalize=True) seq_sample, len_sample, _ = sparse_shrink(sample_sparse) # ==0 is not success!! seq_sample, labels = pad_to_same([seq_sample, labels]) seq_sample = tf.where(len_sample<1, labels, seq_sample) len_sample = tf.where(len_sample<1, len_labels, len_sample) reward = wer_bias - wer rl_loss, _ = policy_ctc_loss( logits=logits, len_logits=len_logits, flabels=seq_sample, len_flabels=len_sample, batch_reward=reward, args=args)
def rna_loss(self, logits, len_logits, labels, len_labels, encoded=None, len_encoded=None): with tf.name_scope("ctc_loss"): labels_sparse = dense_sequence_to_sparse( labels, len_labels) loss = tf.nn.ctc_loss( labels_sparse, logits, sequence_length=len_logits, ctc_merge_repeated=False, ignore_longer_outputs_than_inputs=True, time_major=False) if self.args.model.decoder.confidence_penalty: ls_loss = self.args.model.decoder.confidence_penalty * \ confidence_penalty(logits, len_logits) loss += ls_loss return loss
def ctc_loss(self, logits, len_logits, labels, len_labels): """ No valid path found: It is possible that no valid path is found if the activations for the targets are zero. """ with tf.name_scope("ctc_loss"): if self.args.model.use_wrapctc: import warpctc_tensorflow from tfTools.tfTools import get_indices indices = get_indices(len_labels) flat_labels = tf.gather_nd(labels, indices) ctc_loss = warpctc_tensorflow.ctc( activations=tf.transpose(logits, [1, 0, 2]), flat_labels=flat_labels, label_lengths=len_labels, input_lengths=len_logits, blank_label=self.args.dim_output) else: # with tf.get_default_graph()._kernel_label_map({"CTCLoss": "WarpCTC"}): labels_sparse = dense_sequence_to_sparse(labels, len_labels) ctc_loss = tf.nn.ctc_loss( labels_sparse, logits, sequence_length=len_logits, ctc_merge_repeated=self.ctc_merge_repeated, ignore_longer_outputs_than_inputs=True, time_major=False) if self.args.model.policy_learning: from tfModels.regularization import policy_learning softmax_temperature = self.model.decoder.softmax_temperature dim_output = self.dim_output decoded_sparse = self.ctc_decode(logits, len_logits) rl_loss = policy_learning(logits, len_logits, decoded_sparse, labels, len_labels, softmax_temperature, dim_output, self.args) ctc_loss += self.args.model.policy_learning * rl_loss return ctc_loss
def ctc_loss(self, logits, len_logits, labels, len_labels): """ No valid path found: It is possible that no valid path is found if the activations for the targets are zero. return batch shape loss """ with tf.name_scope("ctc_loss"): labels_sparse = dense_sequence_to_sparse(labels, len_labels) loss = tf.nn.ctc_loss( labels_sparse, logits, sequence_length=len_logits, ctc_merge_repeated=self.args.model.avg_repeated, ignore_longer_outputs_than_inputs=True, time_major=False) if self.args.model.decoder.confidence_penalty: ls_loss = self.args.model.decoder.confidence_penalty * \ confidence_penalty(logits, len_logits) loss += ls_loss return loss