def split_one_doc_to_true_len_sens(doc_t, split_token, padding_token, max_doc_len, max_sen_len): """ Split a document to sentences with true sentence lengths. doc_t: [doc_word_len] out_t: [max_doc_len, max_sen_len] """ if len(doc_t.get_shape()) == 1: split_token_index = tf.squeeze(tf.where(tf.equal(doc_t, split_token)), axis=1) split_token_index.set_shape([None]) split_len_part_1 = split_token_index[:1] + 1 split_len_part_2 = split_token_index[1:] - split_token_index[:-1] split_lens = tf.concat([split_len_part_1, split_len_part_2], axis=0) split_lens = cut_or_padding(split_lens, max_doc_len, padding_token=padding_token) new_doc_len = tf.reduce_sum(split_lens) splited_sentences = tf.split(doc_t[:new_doc_len], split_lens) splited_sentences = [ cut_or_padding(s, max_sen_len) for s in splited_sentences ] out_t = tf.stack(splited_sentences) padding_tokens = tf.multiply(tf.ones_like(out_t, dtype=tf.int32), padding_token) out_t = tf.where(tf.equal(out_t, split_token), padding_tokens, out_t) return out_t raise ValueError("doc_t should be a tensor with rank 1.")
def cross_entropy(logits, labels, input_length=None, label_length=None, smoothing=0.0, reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS): ''' cross entropy function for classfication and seq classfication :param, label_length, for seq task, this for target seq length, e.g. a b c </s>, 4 ''' del input_length onehot_labels = tf.cond(pred=tf.equal( tf.rank(logits) - tf.rank(labels), 1), true_fn=lambda: tf.one_hot( labels, tf.shape(logits)[-1], dtype=tf.int32), false_fn=lambda: labels) if label_length is not None: weights = utils.len_to_mask(label_length) else: weights = 1.0 loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits, weights=weights, label_smoothing=smoothing, reduction=reduction) return loss
def _create_topk_unique(inputs, k): """Creates the top k values in sorted order with indices.""" height = inputs.shape[0] width = inputs.shape[1] neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32) ones = tf.ones([height, width], dtype=tf.float32) neg_inf_r2 = ones * neg_inf_r0 inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs) tmp = inputs topk_r2 = tf.zeros([height, k], dtype=tf.float32) for i in range(k): kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True) k_mask = tf.tile( tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0), [height, 1]) topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]), topk_r2) ge_r2 = tf.greater_equal(inputs, tf.tile(kth_order_statistic, [1, width])) tmp = tf.where(ge_r2, neg_inf_r2, inputs) log2_ceiling = int(math.ceil(math.log(float(int(width)), 2))) next_power_of_two = 1 << log2_ceiling count_mask = next_power_of_two - 1 mask_r0 = tf.constant(count_mask) mask_r2 = tf.fill([height, k], mask_r0) topk_r2_s32 = tf.bitcast(topk_r2, tf.int32) topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2) return topk_r2, topk_indices_r2
def masked_softmax(logits, mask, axis): """Compute softmax with input mask.""" e_logits = tf.exp(logits) masked_e = tf.multiply(e_logits, mask) sum_masked_e = tf.reduce_sum(masked_e, axis, keep_dims=True) ones = tf.ones_like(sum_masked_e) # pay attention to a situation that if len of mask is zero, # denominator should be set to 1 sum_masked_e_safe = tf.where(tf.equal(sum_masked_e, 0), ones, sum_masked_e) return masked_e / sum_masked_e_safe
def ctc_lambda_loss(logits, labels, input_length, label_length, blank_index=0): ''' ctc loss function psram: logits, (B, T, D) psram: input_length, (B, 1), input length of encoder psram: labels, (B, T) psram: label_length, (B, 1), label length for convert dense label to sparse returns: loss, scalar ''' ilen = tf.cond( pred=tf.equal(tf.rank(input_length), 1), true_fn=lambda: input_length, false_fn=lambda: tf.squeeze(input_length), ) ilen = tf.cast(ilen, tf.int32) olen = tf.cond( pred=tf.equal(tf.rank(label_length), 1), true_fn=lambda: label_length, false_fn=lambda: tf.squeeze(label_length)) olen = tf.cast(olen, tf.int32) deps = [ tf.assert_rank(labels, 2, name='label_rank_check'), tf.assert_rank(logits, 3, name='logits_rank_check'), tf.assert_rank(ilen, 1, name='src_len_rank_check'), # input_length tf.assert_rank(olen, 1, name='tgt_len_rank_check'), # output_length ] labels, logits = ctc_data_transform(labels, logits, blank_index) with tf.control_dependencies(deps): # (B, 1) # blank index is consistent with Espnet, zero batch_loss = tf.nn.ctc_loss( labels=labels, inputs=logits, sequence_length=ilen, time_major=False, preprocess_collapse_repeated=False, ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=False) return batch_loss
def load_textline_dataset(paths, column_num): """Load raw data for text task.""" ds = tf.data.TextLineDataset(paths) ds = ds.map( lambda x: tf.strings.split(x, sep="\t", result_type="RaggedTensor")) ds = ds.filter(lambda line: tf.equal(tf.size(line), column_num)) ds_list = [] for i in range(column_num): ds_list.append(ds.map(lambda x: x[i])) return tuple(ds_list)
def compute_sen_lens(inputs, padding_token=0): """ Count how many words in a sentence. inputs: [..., time_steps] sen_lens: [...] """ x_binary = tf.cast(tf.not_equal(inputs, padding_token), tf.int32) sen_lens = tf.reduce_sum(x_binary, axis=-1) ones = tf.ones_like(sen_lens) sen_lens = tf.where(tf.equal(sen_lens, utils.PAD_IDX), x=ones, y=sen_lens) return sen_lens
def compute_lens(inputs, max_len): """count sequence length. input: [batch_size, max_len] lens: [batch_size] """ x_binary = tf.cast(tf.cast(tf.reverse(inputs, axis=[1]), tf.bool), tf.int32) lens = max_len - tf.argmax(x_binary, axis=1, output_type=tf.int32) zeros = tf.zeros_like(lens, dtype=tf.int32) x_sum = tf.reduce_sum(inputs, axis=1) sen_lens = tf.where(tf.equal(x_sum, 0), zeros, lens) return sen_lens
def compute_mel_filterbank_features(waveforms, sample_rate=16000, preemphasis=0.97, frame_length=0.025, frame_step=0.010, fft_length=None, lower_edge_hertz=80.0, upper_edge_hertz=7600.0, num_mel_bins=80, log_noise_floor=1e-3, apply_mask=True): """Implement mel-filterbank extraction using tf ops. Args: waveforms: float32 tensor with shape [max_len, nchannels] sample_rate: sampling rate of the waveform preemphasis: waveform high-pass filtering constant frame_length: frame length in ms frame_step: frame_Step in ms fft_length: number of fft bins lower_edge_hertz: lowest frequency of the filterbank upper_edge_hertz: highest frequency of the filterbank num_mel_bins: filterbank size log_noise_floor: clip small values to prevent numeric overflow in log apply_mask: When working on a batch of samples, set padding frames to zero Returns: filterbanks: a float32 tensor with shape [nchannles, max_len, num_bins] """ del log_noise_floor, apply_mask spectrogram = powspec_feat(waveforms, sr=sample_rate, nfft=512 if not fft_length else fft_length, winlen=frame_length, winstep=frame_step, lowfreq=lower_edge_hertz, highfreq=upper_edge_hertz, preemph=preemphasis) # [channels, time, feat_dim] fbank = fbank_feat(spectrogram, sr=sample_rate, feature_size=num_mel_bins, nfft=512 if not fft_length else fft_length, lowfreq=lower_edge_hertz, highfreq=upper_edge_hertz) # [time, feat_dim] fbank = tf.cond(tf.equal(tf.rank(fbank), 3), true_fn=lambda: fbank[0, :, :], false_fn=lambda: fbank) return fbank
def grow_topk(i, alive_seq, alive_log_probs, states): """Inner beam search loop.""" flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1]) # (batch_size * beam_size, decoded_length) if states: flat_states = nest.map_structure(_merge_beam_dim, states) flat_logits, flat_states = symbols_to_logits_fn( flat_ids, i, flat_states) states = nest.map_structure( lambda t: _unmerge_beam_dim(t, batch_size, beam_size), flat_states) else: flat_logits = symbols_to_logits_fn(flat_ids) logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) candidate_log_probs = log_prob_from_logits(logits) log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha) curr_scores = log_probs / length_penalty flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids // vocab_size topk_ids %= vocab_size # Unflatten the ids batch_pos = compute_batch_indices(batch_size, beam_size * 2) topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2) topk_seq = tf.gather_nd(alive_seq, topk_coordinates) if states: states = nest.map_structure( lambda state: tf.gather_nd(state, topk_coordinates), states) topk_seq = tf.concat( [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_finished = tf.equal(topk_ids, eos_id) return topk_seq, topk_log_probs, topk_scores, topk_finished, states
def accuracy(logits, labels): ''' accuracy candies params: logits: [B, ..., D] labels: [B, ...] return: accuracy tensor ''' with tf.name_scope('accuracy'): assert_rank = tf.assert_equal(tf.rank(logits), tf.rank(labels) + 1) assert_shape = tf.assert_equal(tf.shape(logits)[:-1], tf.shape(labels)) with tf.control_dependencies([assert_rank, assert_shape]): predictions = tf.argmax(logits, axis=-1, output_type=tf.int64) labels = tf.cast(labels, tf.int64) return tf.reduce_mean( tf.cast(tf.equal(predictions, labels), dtype=tf.float32))
def delta_delta(feat, order=2): ''' params: feat: a tensor of shape [nframe, nfbank] or [nframe, nfbank, 1] return: [nframe, nfbank, 3] ''' feat = tf.cond(tf.equal(tf.rank(feat), 3), true_fn=lambda: feat[:, :, 0], false_fn=lambda: feat) shape = tf.shape(feat) # [nframe nfbank*3] nframe = shape[0] nfbank = shape[1] delta = py_x_ops.delta_delta(feat, order=order) feat_with_delta_delta = tf.reshape(delta, (nframe, nfbank, (order + 1))) return feat_with_delta_delta
def labels_blankid_to_last(labels, blank_index, num_class=None): ''' Change the value of blank_label elements from blank_index to num_class - 1''' assert num_class is not None, 'The num_class should not be None!' labels = transform_preprocess( labels=labels, blank_index=blank_index, num_class=num_class) labels_values = labels.values labels_num_class = tf.zeros_like(labels_values, dtype=tf.int32) + num_class labels_values_change_blank = tf.where( tf.equal(labels_values, blank_index), labels_num_class, labels_values) labels_values = tf.where(labels_values_change_blank < blank_index, labels_values_change_blank, labels_values_change_blank - 1) labels = tf.SparseTensor( indices=labels.indices, values=labels_values, dense_shape=labels.dense_shape) return labels
def _create_make_unique(inputs): """Replaces the lower bits of each element with iota.""" if inputs.shape.ndims != 2: raise ValueError("Input of top_k_with_unique must be rank-2 " "but got: %s" % inputs.shape) height = inputs.shape[0] width = inputs.shape[1] zeros = tf.zeros([height, width], dtype=tf.int32) log2_ceiling = int(math.ceil(math.log(int(width), 2))) next_power_of_two = 1 << log2_ceiling count_mask = ~(next_power_of_two - 1) count_mask_r0 = tf.constant(count_mask) count_mask_r2 = tf.fill([height, width], count_mask_r0) smallest_normal = 1 << 23 smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32) smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0) low_bit_mask = ~(1 << 31) low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32) low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0) iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0), [height, 1]) input_r2 = tf.bitcast(inputs, tf.int32) abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2) if_zero_r2 = tf.equal(abs_r2, zeros) smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or( input_r2, smallest_normal_r2) input_no_zeros_r2 = tf.where(if_zero_r2, smallest_normal_preserving_sign_r2, input_r2) and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2) or_r2 = tf.bitwise.bitwise_or(and_r2, iota) return tf.bitcast(or_r2, tf.float32)
def fbank_feat(powspec, sr=8000, feature_size=40, nfft=512, lowfreq=0, highfreq=None): ''' powspec: [audio_channels, spectrogram_length, spectrogram_feat_dim] return : [auido_chnnels, nframe, nfbank] ''' del nfft true_fn = lambda: tf.expand_dims(powspec, 0) false_fn = lambda: powspec powspec = tf.cond(tf.equal(tf.rank(powspec), 2), true_fn, false_fn) feat = py_x_ops.fbank( powspec, sr, filterbank_channel_count=feature_size, lower_frequency_limit=lowfreq, upper_frequency_limit=highfreq, ) return feat
def mask_outputs(origin_outputs): """mask position embedding""" inputs, outputs = origin_outputs outputs = tf.where(tf.equal(inputs, 0), inputs, outputs) return outputs