Esempio n. 1
0
  def call(self, inputs, training=None, mask=None):
    input_x = inputs["input_x"]
    if self.use_dense_task:
      dense_input = inputs["input_dense"]

    # [batch_size, max_len, embed_len]
    out = self.embed(input_x)
    if self.use_pretrained_model:
      logging.info("use_pretrained_model: {}, {}".format(
          self.pretrained_model_name, self.pretrained_model_mode))
      if self.pretrained_model_name == 'elmo':
        input_px = self.get_pre_train_graph(input_x)
        input_px = tf.reshape(input_px,
                              [-1, self.max_len, self.pretrained_model_dim])
        out = tf.concat([out, input_px], axis=-1)
        out = tf.reduce_max(out, axis=1)
      if self.pretrained_model_name == 'bert':
        out = self.get_pre_train_graph(input_x)
    else:
      out = tf.reduce_max(out, axis=1)
    out = self.embed_d(out, training=training)
    if self.use_dense_input:
      dense_out = self.dense_input_linear(dense_input)
      if self.only_dense_input:
        out = dense_out
      else:
        out = tf.keras.layers.Concatenate()([out, dense_out])
    # [batch_size, class_num]
    scores = self.final_dense(out)
    return scores
Esempio n. 2
0
    def curvature_range(self):
        # set up the curvature window
        self._curv_win = tf.Variable(np.zeros([
            self._curv_win_width,
        ]),
                                     dtype=tf.float32,
                                     name="curv_win",
                                     trainable=False)
        # we can use log smoothing for curvature range to follow trend faster
        # self._curv_win = tf.scatter_update(
        #   self._curv_win, self._global_step % self._curv_win_width,
        #   tf.log(self._grad_norm_squared + EPS))
        self._curv_win = tf.scatter_update(
            self._curv_win, self._global_step % self._curv_win_width,
            self._grad_norm_squared + EPS)
        # note here the iterations start from iteration 0
        valid_window = tf.slice(
            self._curv_win, tf.constant([
                0,
            ]),
            tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width),
                                      self._global_step + 1),
                           dim=0))

        if self._h_min_log_smooth:
            self._h_min_t = tf.log(tf.reduce_min(valid_window) + EPS)
        else:
            self._h_min_t = tf.reduce_min(valid_window)
        if self._h_max_log_smooth:
            self._h_max_t = tf.log(tf.reduce_max(valid_window) + EPS)
        else:
            self._h_max_t = tf.reduce_max(valid_window)

        curv_range_ops = []
        with tf.control_dependencies([self._h_min_t, self._h_max_t]):
            avg_op = self._moving_averager.apply(
                [self._h_min_t, self._h_max_t])
            with tf.control_dependencies([avg_op]):
                if self._h_min_log_smooth:
                    self._h_min = tf.exp(
                        tf.identity(
                            self._moving_averager.average(self._h_min_t)))
                else:
                    self._h_min = \
                      tf.identity(self._moving_averager.average(self._h_min_t))
                if self._h_max_log_smooth:
                    self._h_max = tf.exp(
                        tf.identity(
                            self._moving_averager.average(self._h_max_t)))
                else:
                    self._h_max = \
                      tf.identity(self._moving_averager.average(self._h_max_t))
            if self._sparsity_debias:
                self._h_min = self._h_min * self._sparsity_avg
                self._h_max = self._h_max * self._sparsity_avg
        curv_range_ops.append(avg_op)
        return curv_range_ops
Esempio n. 3
0
def _create_topk_unique(inputs, k):
    """Creates the top k values in sorted order with indices."""
    height = inputs.shape[0]
    width = inputs.shape[1]
    neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32)
    ones = tf.ones([height, width], dtype=tf.float32)
    neg_inf_r2 = ones * neg_inf_r0
    inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs)

    tmp = inputs
    topk_r2 = tf.zeros([height, k], dtype=tf.float32)
    for i in range(k):
        kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True)
        k_mask = tf.tile(
            tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0),
            [height, 1])
        topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]),
                           topk_r2)
        ge_r2 = tf.greater_equal(inputs,
                                 tf.tile(kth_order_statistic, [1, width]))
        tmp = tf.where(ge_r2, neg_inf_r2, inputs)

    log2_ceiling = int(math.ceil(math.log(float(int(width)), 2)))
    next_power_of_two = 1 << log2_ceiling
    count_mask = next_power_of_two - 1
    mask_r0 = tf.constant(count_mask)
    mask_r2 = tf.fill([height, k], mask_r0)
    topk_r2_s32 = tf.bitcast(topk_r2, tf.int32)
    topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2)
    return topk_r2, topk_indices_r2
Esempio n. 4
0
    def call(self, tensors):
        """Attention layer."""
        left, right = tensors

        len_left = left.shape[1]
        len_right = right.shape[1]
        tensor_left = tf.expand_dims(left, axis=2)
        tensor_right = tf.expand_dims(right, axis=1)
        tensor_left = tf.tile(tensor_left, [1, 1, len_right, 1])
        tensor_right = tf.tile(tensor_right, [1, len_left, 1, 1])
        tensor_merged = tf.concat([tensor_left, tensor_right], axis=-1)
        middle_output = self.middle_layer(tensor_merged)
        attn_scores = self.attn(middle_output)
        attn_scores = tf.squeeze(attn_scores, axis=3)
        exp_attn_scores = tf.exp(
            attn_scores - tf.reduce_max(attn_scores, axis=-1, keepdims=True))
        exp_sum = tf.reduce_sum(exp_attn_scores, axis=-1, keepdims=True)
        attention_weights = exp_attn_scores / exp_sum
        return tf.matmul(attention_weights, right)
Esempio n. 5
0
        def _is_finished(i, unused_alive_seq, alive_log_probs,
                         unused_finished_seq, finished_scores,
                         unused_finished_in_finished, unused_states):
            """Checking termination condition.
      """
            max_length_penalty = tf.pow(
                ((5. + tf.to_float(decode_length)) / 6.), alpha)
            lower_bound_alive_scores = alive_log_probs[:,
                                                       0] / max_length_penalty

            if not stop_early:
                lowest_score_of_finished_in_finished = tf.reduce_min(
                    finished_scores)
            else:
                lowest_score_of_finished_in_finished = tf.reduce_max(
                    finished_scores, axis=1)

            bound_is_met = tf.reduce_all(
                tf.greater(lowest_score_of_finished_in_finished,
                           lower_bound_alive_scores))

            return tf.logical_and(tf.less(i, decode_length),
                                  tf.logical_not(bound_is_met))
Esempio n. 6
0
def compute_mel_filterbank_features(waveforms,
                                    sample_rate=16000,
                                    dither=1.0 / np.iinfo(np.int16).max,
                                    preemphasis=0.97,
                                    frame_length=25,
                                    frame_step=10,
                                    fft_length=None,
                                    window_fn=functools.partial(
                                        tf.signal.hann_window, periodic=True),
                                    lower_edge_hertz=80.0,
                                    upper_edge_hertz=7600.0,
                                    num_mel_bins=80,
                                    log_noise_floor=1e-3,
                                    apply_mask=True):
    """Implement mel-filterbank extraction using tf ops.
  Args:
    waveforms: float32 tensor with shape [batch_size, max_len]
    sample_rate: sampling rate of the waveform
    dither: stddev of Gaussian noise added to waveform to prevent quantization
      artefacts
    preemphasis: waveform high-pass filtering constant
    frame_length: frame length in ms
    frame_step: frame_Step in ms
    fft_length: number of fft bins
    window_fn: windowing function
    lower_edge_hertz: lowest frequency of the filterbank
    upper_edge_hertz: highest frequency of the filterbank
    num_mel_bins: filterbank size
    log_noise_floor: clip small values to prevent numeric overflow in log
    apply_mask: When working on a batch of samples, set padding frames to zero
  Returns:
    filterbanks: a float32 tensor with shape [batch_size, len, num_bins, 1]
  """
    #  is a complex64 Tensor representing the short-time Fourier
    # Transform of each signal in . Its shape is
    # [batch_size, ?, fft_unique_bins]
    # where fft_unique_bins = fft_length // 2 + 1

    # Find the wave length: the largest index for which the value is !=0
    # note that waveforms samples that are exactly 0.0 are quite common, so
    # simply doing sum(waveforms != 0, axis=-1) will not work correctly.
    wav_lens = tf.reduce_max(
        tf.expand_dims(tf.range(tf.shape(waveforms)[1]), 0) *
        tf.to_int32(tf.not_equal(waveforms, 0.0)),
        axis=-1) + 1
    if dither > 0:
        waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither)
    if preemphasis > 0:
        waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1]
        wav_lens -= 1
    frame_length = int(frame_length * sample_rate / 1e3)
    frame_step = int(frame_step * sample_rate / 1e3)
    if fft_length is None:
        fft_length = int(2**(np.ceil(np.log2(frame_length))))

    stfts = tf.signal.stft(waveforms,
                           frame_length=frame_length,
                           frame_step=frame_step,
                           fft_length=fft_length,
                           window_fn=window_fn,
                           pad_end=True)

    stft_lens = (wav_lens + (frame_step - 1)) // frame_step
    masks = tf.to_float(
        tf.less_equal(tf.expand_dims(tf.range(tf.shape(stfts)[1]), 0),
                      tf.expand_dims(stft_lens, 1)))

    # An energy spectrogram is the magnitude of the complex-valued STFT.
    # A float32 Tensor of shape [batch_size, ?, 257].
    magnitude_spectrograms = tf.abs(stfts)

    # Warp the linear-scale, magnitude spectrograms into the mel-scale.
    num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
    linear_to_mel_weight_matrix = (tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
        upper_edge_hertz))
    mel_spectrograms = tf.tensordot(magnitude_spectrograms,
                                    linear_to_mel_weight_matrix, 1)
    # Note: Shape inference for tensordot does not currently handle this case.
    mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
        linear_to_mel_weight_matrix.shape[-1:]))

    log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms))

    if apply_mask:
        log_mel_sgram *= tf.expand_dims(tf.to_float(masks), -1)

    return tf.expand_dims(log_mel_sgram, -1, name="mel_sgrams")
Esempio n. 7
0
  def call(self, inputs, training=None, mask=None):
    dec_emb_fn = lambda ids: self.embed(ids)
    if self.is_infer:
      enc_outputs, enc_state, enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      helper = seq2seq.GreedyEmbeddingHelper(
          embedding=dec_emb_fn,
          start_tokens=tf.fill([batch_size], self.dec_start_id),
          end_token=self.dec_end_id)
    else:
      dec_inputs, dec_seq_len, enc_outputs, enc_state, \
      enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      dec_inputs = self.embed(dec_inputs)
      helper = seq2seq.TrainingHelper(
          inputs=dec_inputs, sequence_length=dec_seq_len)

    if self.is_infer and self.beam_size > 1:
      tiled_enc_outputs = seq2seq.tile_batch(
          enc_outputs, multiplier=self.beam_size)
      tiled_seq_len = seq2seq.tile_batch(enc_seq_len, multiplier=self.beam_size)
      attn_mech = self._build_attention(
          enc_outputs=tiled_enc_outputs, enc_seq_len=tiled_seq_len)
      dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech)
      tiled_enc_last_state = seq2seq.tile_batch(
          enc_state, multiplier=self.beam_size)
      tiled_dec_init_state = dec_cell.zero_state(
          batch_size=batch_size * self.beam_size, dtype=tf.float32)
      if self.initial_decode_state:
        tiled_dec_init_state = tiled_dec_init_state.clone(
            cell_state=tiled_enc_last_state)

      dec = seq2seq.BeamSearchDecoder(
          cell=dec_cell,
          embedding=dec_emb_fn,
          start_tokens=tf.tile([self.dec_start_id], [batch_size]),
          end_token=self.dec_end_id,
          initial_state=tiled_dec_init_state,
          beam_width=self.beam_size,
          output_layer=tf.layers.Dense(self.vocab_size),
          length_penalty_weight=self.length_penalty)
    else:
      attn_mech = self._build_attention(
          enc_outputs=enc_outputs, enc_seq_len=enc_seq_len)
      dec_cell = seq2seq.AttentionWrapper(
          cell=self.cell, attention_mechanism=attn_mech)
      dec_init_state = dec_cell.zero_state(
          batch_size=batch_size, dtype=tf.float32)
      if self.initial_decode_state:
        dec_init_state = dec_init_state.clone(cell_state=enc_state)
      dec = seq2seq.BasicDecoder(
          cell=dec_cell,
          helper=helper,
          initial_state=dec_init_state,
          output_layer=tf.layers.Dense(self.vocab_size))
    if self.is_infer:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=self.max_dec_len,
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
      return dec_outputs.predicted_ids[:, :, 0]
    else:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=tf.reduce_max(dec_seq_len),
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
    return dec_outputs.rnn_output