Ejemplo n.º 1
0
def split_one_doc_to_true_len_sens(doc_t, split_token, padding_token,
                                   max_doc_len, max_sen_len):
    """
  Split a document to sentences with true sentence lengths.
  doc_t: [doc_word_len]
  out_t: [max_doc_len, max_sen_len]
  """
    if len(doc_t.get_shape()) == 1:
        split_token_index = tf.squeeze(tf.where(tf.equal(doc_t, split_token)),
                                       axis=1)
        split_token_index.set_shape([None])
        split_len_part_1 = split_token_index[:1] + 1
        split_len_part_2 = split_token_index[1:] - split_token_index[:-1]
        split_lens = tf.concat([split_len_part_1, split_len_part_2], axis=0)
        split_lens = cut_or_padding(split_lens,
                                    max_doc_len,
                                    padding_token=padding_token)
        new_doc_len = tf.reduce_sum(split_lens)
        splited_sentences = tf.split(doc_t[:new_doc_len], split_lens)
        splited_sentences = [
            cut_or_padding(s, max_sen_len) for s in splited_sentences
        ]
        out_t = tf.stack(splited_sentences)
        padding_tokens = tf.multiply(tf.ones_like(out_t, dtype=tf.int32),
                                     padding_token)
        out_t = tf.where(tf.equal(out_t, split_token), padding_tokens, out_t)
        return out_t

    raise ValueError("doc_t should be a tensor with rank 1.")
Ejemplo n.º 2
0
def cross_entropy(logits,
                  labels,
                  input_length=None,
                  label_length=None,
                  smoothing=0.0,
                  reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS):
    '''
  cross entropy function for classfication and seq classfication
  :param, label_length, for seq task, this for target seq length, e.g. a b c </s>, 4
  '''
    del input_length

    onehot_labels = tf.cond(pred=tf.equal(
        tf.rank(logits) - tf.rank(labels), 1),
                            true_fn=lambda: tf.one_hot(
                                labels, tf.shape(logits)[-1], dtype=tf.int32),
                            false_fn=lambda: labels)

    if label_length is not None:
        weights = utils.len_to_mask(label_length)
    else:
        weights = 1.0

    loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels,
                                           logits=logits,
                                           weights=weights,
                                           label_smoothing=smoothing,
                                           reduction=reduction)

    return loss
Ejemplo n.º 3
0
def _create_topk_unique(inputs, k):
    """Creates the top k values in sorted order with indices."""
    height = inputs.shape[0]
    width = inputs.shape[1]
    neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32)
    ones = tf.ones([height, width], dtype=tf.float32)
    neg_inf_r2 = ones * neg_inf_r0
    inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs)

    tmp = inputs
    topk_r2 = tf.zeros([height, k], dtype=tf.float32)
    for i in range(k):
        kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True)
        k_mask = tf.tile(
            tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0),
            [height, 1])
        topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]),
                           topk_r2)
        ge_r2 = tf.greater_equal(inputs,
                                 tf.tile(kth_order_statistic, [1, width]))
        tmp = tf.where(ge_r2, neg_inf_r2, inputs)

    log2_ceiling = int(math.ceil(math.log(float(int(width)), 2)))
    next_power_of_two = 1 << log2_ceiling
    count_mask = next_power_of_two - 1
    mask_r0 = tf.constant(count_mask)
    mask_r2 = tf.fill([height, k], mask_r0)
    topk_r2_s32 = tf.bitcast(topk_r2, tf.int32)
    topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2)
    return topk_r2, topk_indices_r2
Ejemplo n.º 4
0
def masked_softmax(logits, mask, axis):
    """Compute softmax with input mask."""
    e_logits = tf.exp(logits)
    masked_e = tf.multiply(e_logits, mask)
    sum_masked_e = tf.reduce_sum(masked_e, axis, keep_dims=True)
    ones = tf.ones_like(sum_masked_e)
    # pay attention to a situation that if len of mask is zero,
    # denominator should be set to 1
    sum_masked_e_safe = tf.where(tf.equal(sum_masked_e, 0), ones, sum_masked_e)
    return masked_e / sum_masked_e_safe
Ejemplo n.º 5
0
def ctc_lambda_loss(logits, labels, input_length, label_length, blank_index=0):
  '''
  ctc loss function
  psram: logits, (B, T, D)
  psram: input_length,  (B, 1), input length of encoder
  psram: labels, (B, T)
  psram: label_length,  (B, 1), label length for convert dense label to sparse
  returns: loss, scalar
  '''
  ilen = tf.cond(
      pred=tf.equal(tf.rank(input_length), 1),
      true_fn=lambda: input_length,
      false_fn=lambda: tf.squeeze(input_length),
  )
  ilen = tf.cast(ilen, tf.int32)

  olen = tf.cond(
      pred=tf.equal(tf.rank(label_length), 1),
      true_fn=lambda: label_length,
      false_fn=lambda: tf.squeeze(label_length))
  olen = tf.cast(olen, tf.int32)

  deps = [
      tf.assert_rank(labels, 2, name='label_rank_check'),
      tf.assert_rank(logits, 3, name='logits_rank_check'),
      tf.assert_rank(ilen, 1, name='src_len_rank_check'),  # input_length
      tf.assert_rank(olen, 1, name='tgt_len_rank_check'),  # output_length
  ]

  labels, logits = ctc_data_transform(labels, logits, blank_index)

  with tf.control_dependencies(deps):
    # (B, 1)
    # blank index is consistent with Espnet, zero
    batch_loss = tf.nn.ctc_loss(
        labels=labels,
        inputs=logits,
        sequence_length=ilen,
        time_major=False,
        preprocess_collapse_repeated=False,
        ctc_merge_repeated=True,
        ignore_longer_outputs_than_inputs=False)
  return batch_loss
Ejemplo n.º 6
0
def load_textline_dataset(paths, column_num):
    """Load raw data for text task."""
    ds = tf.data.TextLineDataset(paths)
    ds = ds.map(
        lambda x: tf.strings.split(x, sep="\t", result_type="RaggedTensor"))
    ds = ds.filter(lambda line: tf.equal(tf.size(line), column_num))
    ds_list = []
    for i in range(column_num):
        ds_list.append(ds.map(lambda x: x[i]))

    return tuple(ds_list)
Ejemplo n.º 7
0
def compute_sen_lens(inputs, padding_token=0):
    """
  Count how many words in a sentence.
  inputs: [..., time_steps]
  sen_lens: [...]
  """
    x_binary = tf.cast(tf.not_equal(inputs, padding_token), tf.int32)
    sen_lens = tf.reduce_sum(x_binary, axis=-1)
    ones = tf.ones_like(sen_lens)
    sen_lens = tf.where(tf.equal(sen_lens, utils.PAD_IDX), x=ones, y=sen_lens)
    return sen_lens
Ejemplo n.º 8
0
  def compute_lens(inputs, max_len):
    """count sequence length.
    input: [batch_size, max_len]
    lens: [batch_size]
    """

    x_binary = tf.cast(tf.cast(tf.reverse(inputs, axis=[1]), tf.bool), tf.int32)
    lens = max_len - tf.argmax(x_binary, axis=1, output_type=tf.int32)

    zeros = tf.zeros_like(lens, dtype=tf.int32)
    x_sum = tf.reduce_sum(inputs, axis=1)
    sen_lens = tf.where(tf.equal(x_sum, 0), zeros, lens)
    return sen_lens
Ejemplo n.º 9
0
def compute_mel_filterbank_features(waveforms,
                                    sample_rate=16000,
                                    preemphasis=0.97,
                                    frame_length=0.025,
                                    frame_step=0.010,
                                    fft_length=None,
                                    lower_edge_hertz=80.0,
                                    upper_edge_hertz=7600.0,
                                    num_mel_bins=80,
                                    log_noise_floor=1e-3,
                                    apply_mask=True):
    """Implement mel-filterbank extraction using tf ops.
  Args:
    waveforms: float32 tensor with shape [max_len, nchannels]
    sample_rate: sampling rate of the waveform
    preemphasis: waveform high-pass filtering constant
    frame_length: frame length in ms
    frame_step: frame_Step in ms
    fft_length: number of fft bins
    lower_edge_hertz: lowest frequency of the filterbank
    upper_edge_hertz: highest frequency of the filterbank
    num_mel_bins: filterbank size
    log_noise_floor: clip small values to prevent numeric overflow in log
    apply_mask: When working on a batch of samples, set padding frames to zero
  Returns:
    filterbanks: a float32 tensor with shape [nchannles, max_len, num_bins]
  """
    del log_noise_floor, apply_mask
    spectrogram = powspec_feat(waveforms,
                               sr=sample_rate,
                               nfft=512 if not fft_length else fft_length,
                               winlen=frame_length,
                               winstep=frame_step,
                               lowfreq=lower_edge_hertz,
                               highfreq=upper_edge_hertz,
                               preemph=preemphasis)

    # [channels, time, feat_dim]
    fbank = fbank_feat(spectrogram,
                       sr=sample_rate,
                       feature_size=num_mel_bins,
                       nfft=512 if not fft_length else fft_length,
                       lowfreq=lower_edge_hertz,
                       highfreq=upper_edge_hertz)

    # [time, feat_dim]
    fbank = tf.cond(tf.equal(tf.rank(fbank), 3),
                    true_fn=lambda: fbank[0, :, :],
                    false_fn=lambda: fbank)
    return fbank
Ejemplo n.º 10
0
        def grow_topk(i, alive_seq, alive_log_probs, states):
            """Inner beam search loop."""

            flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

            # (batch_size * beam_size, decoded_length)
            if states:
                flat_states = nest.map_structure(_merge_beam_dim, states)
                flat_logits, flat_states = symbols_to_logits_fn(
                    flat_ids, i, flat_states)
                states = nest.map_structure(
                    lambda t: _unmerge_beam_dim(t, batch_size, beam_size),
                    flat_states)
            else:
                flat_logits = symbols_to_logits_fn(flat_ids)

            logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])

            candidate_log_probs = log_prob_from_logits(logits)

            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs,
                                                             axis=2)

            length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha)

            curr_scores = log_probs / length_penalty
            flat_curr_scores = tf.reshape(curr_scores,
                                          [-1, beam_size * vocab_size])

            topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores,
                                                k=beam_size * 2)

            topk_log_probs = topk_scores * length_penalty

            topk_beam_index = topk_ids // vocab_size
            topk_ids %= vocab_size  # Unflatten the ids
            batch_pos = compute_batch_indices(batch_size, beam_size * 2)
            topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

            topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
            if states:
                states = nest.map_structure(
                    lambda state: tf.gather_nd(state, topk_coordinates),
                    states)
            topk_seq = tf.concat(
                [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

            topk_finished = tf.equal(topk_ids, eos_id)

            return topk_seq, topk_log_probs, topk_scores, topk_finished, states
Ejemplo n.º 11
0
def accuracy(logits, labels):
    ''' accuracy candies
  params:
    logits: [B, ..., D]
    labels: [B, ...]
  return:
    accuracy tensor
  '''
    with tf.name_scope('accuracy'):
        assert_rank = tf.assert_equal(tf.rank(logits), tf.rank(labels) + 1)
        assert_shape = tf.assert_equal(tf.shape(logits)[:-1], tf.shape(labels))
        with tf.control_dependencies([assert_rank, assert_shape]):
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int64)
            labels = tf.cast(labels, tf.int64)
            return tf.reduce_mean(
                tf.cast(tf.equal(predictions, labels), dtype=tf.float32))
Ejemplo n.º 12
0
def delta_delta(feat, order=2):
    '''
  params:
    feat: a tensor of shape [nframe, nfbank] or [nframe, nfbank, 1]
  return: [nframe, nfbank, 3]
  '''
    feat = tf.cond(tf.equal(tf.rank(feat), 3),
                   true_fn=lambda: feat[:, :, 0],
                   false_fn=lambda: feat)

    shape = tf.shape(feat)
    # [nframe nfbank*3]
    nframe = shape[0]
    nfbank = shape[1]
    delta = py_x_ops.delta_delta(feat, order=order)
    feat_with_delta_delta = tf.reshape(delta, (nframe, nfbank, (order + 1)))
    return feat_with_delta_delta
Ejemplo n.º 13
0
def labels_blankid_to_last(labels, blank_index, num_class=None):
  ''' Change the value of blank_label elements from blank_index to num_class - 1'''
  assert num_class is not None, 'The num_class should not be None!'

  labels = transform_preprocess(
      labels=labels, blank_index=blank_index, num_class=num_class)

  labels_values = labels.values
  labels_num_class = tf.zeros_like(labels_values, dtype=tf.int32) + num_class
  labels_values_change_blank = tf.where(
      tf.equal(labels_values, blank_index), labels_num_class, labels_values)
  labels_values = tf.where(labels_values_change_blank < blank_index,
                           labels_values_change_blank,
                           labels_values_change_blank - 1)

  labels = tf.SparseTensor(
      indices=labels.indices,
      values=labels_values,
      dense_shape=labels.dense_shape)
  return labels
Ejemplo n.º 14
0
def _create_make_unique(inputs):
    """Replaces the lower bits of each element with iota."""
    if inputs.shape.ndims != 2:
        raise ValueError("Input of top_k_with_unique must be rank-2 "
                         "but got: %s" % inputs.shape)

    height = inputs.shape[0]
    width = inputs.shape[1]
    zeros = tf.zeros([height, width], dtype=tf.int32)

    log2_ceiling = int(math.ceil(math.log(int(width), 2)))
    next_power_of_two = 1 << log2_ceiling
    count_mask = ~(next_power_of_two - 1)
    count_mask_r0 = tf.constant(count_mask)
    count_mask_r2 = tf.fill([height, width], count_mask_r0)

    smallest_normal = 1 << 23
    smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32)
    smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0)

    low_bit_mask = ~(1 << 31)
    low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32)
    low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0)

    iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0),
                   [height, 1])

    input_r2 = tf.bitcast(inputs, tf.int32)
    abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2)
    if_zero_r2 = tf.equal(abs_r2, zeros)
    smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or(
        input_r2, smallest_normal_r2)
    input_no_zeros_r2 = tf.where(if_zero_r2,
                                 smallest_normal_preserving_sign_r2, input_r2)

    and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2)
    or_r2 = tf.bitwise.bitwise_or(and_r2, iota)
    return tf.bitcast(or_r2, tf.float32)
Ejemplo n.º 15
0
def fbank_feat(powspec,
               sr=8000,
               feature_size=40,
               nfft=512,
               lowfreq=0,
               highfreq=None):
    ''' powspec: [audio_channels, spectrogram_length, spectrogram_feat_dim]
      return : [auido_chnnels, nframe, nfbank]
  '''
    del nfft

    true_fn = lambda: tf.expand_dims(powspec, 0)
    false_fn = lambda: powspec
    powspec = tf.cond(tf.equal(tf.rank(powspec), 2), true_fn, false_fn)

    feat = py_x_ops.fbank(
        powspec,
        sr,
        filterbank_channel_count=feature_size,
        lower_frequency_limit=lowfreq,
        upper_frequency_limit=highfreq,
    )
    return feat
Ejemplo n.º 16
0
 def mask_outputs(origin_outputs):
     """mask position embedding"""
     inputs, outputs = origin_outputs
     outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)
     return outputs