Пример #1
0
def _create_topk_unique(inputs, k):
    """Creates the top k values in sorted order with indices."""
    height = inputs.shape[0]
    width = inputs.shape[1]
    neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32)
    ones = tf.ones([height, width], dtype=tf.float32)
    neg_inf_r2 = ones * neg_inf_r0
    inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs)

    tmp = inputs
    topk_r2 = tf.zeros([height, k], dtype=tf.float32)
    for i in range(k):
        kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True)
        k_mask = tf.tile(
            tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0),
            [height, 1])
        topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]),
                           topk_r2)
        ge_r2 = tf.greater_equal(inputs,
                                 tf.tile(kth_order_statistic, [1, width]))
        tmp = tf.where(ge_r2, neg_inf_r2, inputs)

    log2_ceiling = int(math.ceil(math.log(float(int(width)), 2)))
    next_power_of_two = 1 << log2_ceiling
    count_mask = next_power_of_two - 1
    mask_r0 = tf.constant(count_mask)
    mask_r2 = tf.fill([height, k], mask_r0)
    topk_r2_s32 = tf.bitcast(topk_r2, tf.int32)
    topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2)
    return topk_r2, topk_indices_r2
Пример #2
0
    def test_maxpool(self):
        '''test maxpool'''
        inputs = tf.reshape(tf.range(25), shape=[1, 5, 5, 1])  #A 4D tensor
        ksize = [3, 3]
        strides = [1, 1]
        output = cl.max_pool(inputs, ksize, strides)
        output_shape = [1, 3, 3, 1]
        self.assertAllEqual(tf.shape(output), output_shape)

        output_true = tf.constant([[[[12], [13], [14]], [[17], [18], [19]],
                                    [[22], [23], [24]]]])
        self.assertAllEqual(output, output_true)
Пример #3
0
    def test_splice_layer(self):
        '''test splice layer'''
        inputs = tf.reshape(tf.range(15), shape=[1, 5, 3])
        context = [0, 1]
        output = cl.splice_layer(inputs, 'splice', context)
        output_true = tf.constant([[[0, 1, 2, 3, 4, 5], [3, 4, 5, 6, 7, 8],
                                    [6, 7, 8, 9, 10, 11],
                                    [9, 10, 11, 12, 13, 14],
                                    [12, 13, 14, 12, 13, 14]]])
        self.assertAllEqual(output, output_true)

        context = [-1, 0, 1]
        output = cl.splice_layer(inputs, 'splice', context)
        output_true = tf.constant([[[0, 1, 2, 0, 1, 2, 3, 4, 5],
                                    [0, 1, 2, 3, 4, 5, 6, 7, 8],
                                    [3, 4, 5, 6, 7, 8, 9, 10, 11],
                                    [6, 7, 8, 9, 10, 11, 12, 13, 14],
                                    [9, 10, 11, 12, 13, 14, 12, 13, 14]]])
        self.assertAllEqual(output, output_true)

        context = [0, 1, 3]
        output = cl.splice_layer(inputs, 'splice', context)
        output_true = tf.constant([[[0, 1, 2, 3, 4, 5, 9, 10, 11],
                                    [3, 4, 5, 6, 7, 8, 12, 13, 14],
                                    [6, 7, 8, 9, 10, 11, 12, 13, 14],
                                    [9, 10, 11, 12, 13, 14, 12, 13, 14],
                                    [12, 13, 14, 12, 13, 14, 12, 13, 14]]])
        self.assertAllEqual(output, output_true)

        context = [1, 3]
        output = cl.splice_layer(inputs, 'splice', context)
        output_true = tf.constant([[[3, 4, 5, 9, 10, 11],
                                    [6, 7, 8, 12, 13, 14],
                                    [9, 10, 11, 12, 13, 14],
                                    [12, 13, 14, 12, 13, 14],
                                    [12, 13, 14, 12, 13, 14]]])
        self.assertAllEqual(output, output_true)

        context = [1, 2, 3]
        output = cl.splice_layer(inputs, 'splice', context)
        output_true = tf.constant([[[3, 4, 5, 6, 7, 8, 9, 10, 11],
                                    [6, 7, 8, 9, 10, 11, 12, 13, 14],
                                    [9, 10, 11, 12, 13, 14, 12, 13, 14],
                                    [12, 13, 14, 12, 13, 14, 12, 13, 14],
                                    [12, 13, 14, 12, 13, 14, 12, 13, 14]]])
        self.assertAllEqual(output, output_true)
Пример #4
0
  def call(self, inputs: list, **kwargs) -> typing.Any:
    """
        The computation logic of DynamicPoolingLayer.
        :param inputs: two input tensors.
        """
    self._validate_dpool_size()
    x, dpool_index = inputs
    dpool_shape = tf.shape(dpool_index)
    batch_index_one = tf.expand_dims(
        tf.expand_dims(tf.range(dpool_shape[0]), axis=-1), axis=-1)
    batch_index = tf.expand_dims(
        tf.tile(batch_index_one, [1, self._msize1, self._msize2]), axis=-1)
    dpool_index_ex = tf.concat([batch_index, dpool_index], axis=3)
    x_expand = tf.gather_nd(x, dpool_index_ex)
    stride1 = self._msize1 // self._psize1
    stride2 = self._msize2 // self._psize2

    x_pool = tf.nn.max_pool(x_expand, [1, stride1, stride2, 1],
                            [1, stride1, stride2, 1], "VALID")
    return x_pool
Пример #5
0
def _create_make_unique(inputs):
    """Replaces the lower bits of each element with iota."""
    if inputs.shape.ndims != 2:
        raise ValueError("Input of top_k_with_unique must be rank-2 "
                         "but got: %s" % inputs.shape)

    height = inputs.shape[0]
    width = inputs.shape[1]
    zeros = tf.zeros([height, width], dtype=tf.int32)

    log2_ceiling = int(math.ceil(math.log(int(width), 2)))
    next_power_of_two = 1 << log2_ceiling
    count_mask = ~(next_power_of_two - 1)
    count_mask_r0 = tf.constant(count_mask)
    count_mask_r2 = tf.fill([height, width], count_mask_r0)

    smallest_normal = 1 << 23
    smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32)
    smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0)

    low_bit_mask = ~(1 << 31)
    low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32)
    low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0)

    iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0),
                   [height, 1])

    input_r2 = tf.bitcast(inputs, tf.int32)
    abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2)
    if_zero_r2 = tf.equal(abs_r2, zeros)
    smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or(
        input_r2, smallest_normal_r2)
    input_no_zeros_r2 = tf.where(if_zero_r2,
                                 smallest_normal_preserving_sign_r2, input_r2)

    and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2)
    or_r2 = tf.bitwise.bitwise_or(and_r2, iota)
    return tf.bitcast(or_r2, tf.float32)
Пример #6
0
def compute_mel_filterbank_features(waveforms,
                                    sample_rate=16000,
                                    dither=1.0 / np.iinfo(np.int16).max,
                                    preemphasis=0.97,
                                    frame_length=25,
                                    frame_step=10,
                                    fft_length=None,
                                    window_fn=functools.partial(
                                        tf.signal.hann_window, periodic=True),
                                    lower_edge_hertz=80.0,
                                    upper_edge_hertz=7600.0,
                                    num_mel_bins=80,
                                    log_noise_floor=1e-3,
                                    apply_mask=True):
    """Implement mel-filterbank extraction using tf ops.
  Args:
    waveforms: float32 tensor with shape [batch_size, max_len]
    sample_rate: sampling rate of the waveform
    dither: stddev of Gaussian noise added to waveform to prevent quantization
      artefacts
    preemphasis: waveform high-pass filtering constant
    frame_length: frame length in ms
    frame_step: frame_Step in ms
    fft_length: number of fft bins
    window_fn: windowing function
    lower_edge_hertz: lowest frequency of the filterbank
    upper_edge_hertz: highest frequency of the filterbank
    num_mel_bins: filterbank size
    log_noise_floor: clip small values to prevent numeric overflow in log
    apply_mask: When working on a batch of samples, set padding frames to zero
  Returns:
    filterbanks: a float32 tensor with shape [batch_size, len, num_bins, 1]
  """
    #  is a complex64 Tensor representing the short-time Fourier
    # Transform of each signal in . Its shape is
    # [batch_size, ?, fft_unique_bins]
    # where fft_unique_bins = fft_length // 2 + 1

    # Find the wave length: the largest index for which the value is !=0
    # note that waveforms samples that are exactly 0.0 are quite common, so
    # simply doing sum(waveforms != 0, axis=-1) will not work correctly.
    wav_lens = tf.reduce_max(
        tf.expand_dims(tf.range(tf.shape(waveforms)[1]), 0) *
        tf.to_int32(tf.not_equal(waveforms, 0.0)),
        axis=-1) + 1
    if dither > 0:
        waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither)
    if preemphasis > 0:
        waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1]
        wav_lens -= 1
    frame_length = int(frame_length * sample_rate / 1e3)
    frame_step = int(frame_step * sample_rate / 1e3)
    if fft_length is None:
        fft_length = int(2**(np.ceil(np.log2(frame_length))))

    stfts = tf.signal.stft(waveforms,
                           frame_length=frame_length,
                           frame_step=frame_step,
                           fft_length=fft_length,
                           window_fn=window_fn,
                           pad_end=True)

    stft_lens = (wav_lens + (frame_step - 1)) // frame_step
    masks = tf.to_float(
        tf.less_equal(tf.expand_dims(tf.range(tf.shape(stfts)[1]), 0),
                      tf.expand_dims(stft_lens, 1)))

    # An energy spectrogram is the magnitude of the complex-valued STFT.
    # A float32 Tensor of shape [batch_size, ?, 257].
    magnitude_spectrograms = tf.abs(stfts)

    # Warp the linear-scale, magnitude spectrograms into the mel-scale.
    num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
    linear_to_mel_weight_matrix = (tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
        upper_edge_hertz))
    mel_spectrograms = tf.tensordot(magnitude_spectrograms,
                                    linear_to_mel_weight_matrix, 1)
    # Note: Shape inference for tensordot does not currently handle this case.
    mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
        linear_to_mel_weight_matrix.shape[-1:]))

    log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms))

    if apply_mask:
        log_mel_sgram *= tf.expand_dims(tf.to_float(masks), -1)

    return tf.expand_dims(log_mel_sgram, -1, name="mel_sgrams")
Пример #7
0
def compute_batch_indices(batch_size, beam_size):
    """Computes the i'th coordinate that contains the batch index for gathers."""
    batch_pos = tf.range(batch_size * beam_size) // beam_size
    batch_pos = tf.reshape(batch_pos, [batch_size, beam_size])
    return batch_pos
Пример #8
0
 def get_pos(inputs):
     """get position id"""
     batch_size, seq_len = tf.shape(inputs)[0], tf.shape(inputs)[1]
     position_ind = tf.tile(tf.expand_dims(tf.range(seq_len), 0),
                            [batch_size, 1])
     return position_ind