Example #1
0
 def call(self, inputs, **kwargs):
     # inputs shape [B, T, V]
     _, max_sequence_length, size = shape_list(inputs)
     if size % 2 != 0:
         raise ValueError(f"Input last dim must be even: {size}")
     pe = positional_encoding(max_sequence_length, size)
     return inputs + tf.cast(pe, dtype=inputs.dtype)
Example #2
0
        def recognize_pb(features, length, training=False):
            b_i = tf.constant(0, dtype=tf.int32)

            B = shape_list(features)[0]

            decoded = tf.constant([], dtype=tf.int32)

            def _cond(b_i, B, features, decoded): return tf.less(b_i, B)

            def _body(b_i, B, features, decoded):
                yseq = self.perform_greedy(tf.expand_dims(features[b_i], axis=0),
                                           streaming=False)

                yseq=tf.concat([yseq,tf.constant([[self.text_featurizer.stop]],tf.int32)],axis=-1)
                decoded = tf.concat([decoded, yseq[0]], axis=0)
                return b_i + 1, B, features, decoded

            _, _, _, decoded = tf.while_loop(
                _cond,
                _body,
                loop_vars=(b_i, B, features, decoded),
                shape_invariants=(
                    tf.TensorShape([]),
                    tf.TensorShape([]),
                    get_shape_invariants(features),
                    tf.TensorShape([None])
                )
            )

            return [decoded]
Example #3
0
    def call(self,inputs,training=False):
        x=self.in_conv(inputs)
        prelu=self.BN_Prelu1(x,training=training)
        avg_pl=self.avg(prelu)
        esp_1=self.espblock1(prelu,training=training)
        esp_1_out=self.block1_projecter(esp_1)
        esp_2_input=tf.concat([avg_pl,esp_1_out],-1)
        esp_2=self.espblock2(esp_2_input,training=training)
        b,w,h,c=shape_list(esp_2)
        esp_2=tf.reshape(esp_2,[b,w,h*c])
        out = self.projecter(esp_2)
        out=self.last_block(out,training=training)
        # print(out.shape)

        return out
Example #4
0
def rnnt_ctc_loss(logits, labels, label_length, logit_length, blank=0):
    logits = tf.reduce_sum(logits, 2)
    _,_,c=shape_list(logits)
    if (c-1)==blank:
        logits=tf.nn.softmax(logits,-1)
        return tf.keras.backend.ctc_batch_cost(labels,logits,tf.expand_dims(logit_length,-1),tf.expand_dims(label_length,-1))
    else:

        return tf.nn.ctc_loss(
        labels=tf.cast(labels, tf.int32),
        logit_length=tf.cast(logit_length, tf.int32),
        logits=tf.cast(logits, tf.float32),
        label_length=tf.cast(label_length, tf.int32),
        logits_time_major=False,
        blank_index=blank
    )
Example #5
0
    def perform_greedy(self, features):
        batch = tf.shape(features)[0]
        new_hyps = Hypotheses(
            tf.zeros([batch], tf.float32),
            self.text_featurizer.start * tf.ones([batch, 1], dtype=tf.int32),
            self.predict_net.get_initial_state(features))
        if self.mel_layer is not None:
            features = self.mel_layer(features)
        enc = self.encoder(features, training=False)  # [B, T, E]
        # enc = tf.squeeze(enc, axis=0)  # [T, E]
        stop_flag = tf.zeros([batch, 1], tf.float32)
        T = tf.cast(shape_list(enc)[1], dtype=tf.int32)

        i = tf.constant(0, dtype=tf.int32)

        def _cond(enc, i, new_hyps, T, stop_flag):
            return tf.less(i, T)

        def _body(enc, i, new_hyps, T, stop_flag):
            hi = enc[:, i:i + 1]  # [B, 1, E]
            y, n_memory_states = self.predict_net(
                inputs=new_hyps[1][:, -1:],  # [1, 1]
                p_memory_states=new_hyps[2],
                training=False)  # [1, 1, P], [1, P], [1, P]
            # [1, 1, E] + [1, 1, P] => [1, 1, 1, V]
            ytu = tf.nn.log_softmax(self.joint_net([hi, y], training=False))
            ytu = tf.squeeze(ytu, axis=None)  # [B, 1, 1, V] => [B,V]
            n_predict = tf.expand_dims(
                tf.argmax(ytu, axis=-1, output_type=tf.int32),
                -1)  # => argmax []

            # print(stop_flag.shape,n_predict.shape)
            new_hyps = Hypotheses(
                new_hyps[0] + 1,
                tf.concat(
                    [new_hyps[1], tf.reshape(n_predict, [-1, 1])], -1),
                n_memory_states)

            stop_flag += tf.cast(
                tf.equal(tf.reshape(n_predict, [-1, 1]),
                         self.text_featurizer.stop), tf.float32)
            n_i = tf.cond(
                tf.reduce_all(tf.cast(stop_flag, tf.bool)),
                true_fn=lambda: T,
                false_fn=lambda: i + 1,
            )

            return enc, n_i, new_hyps, T, stop_flag

        _, _, new_hyps, _, stop_flag = tf.while_loop(
            _cond,
            _body,
            loop_vars=(enc, i, new_hyps, T, stop_flag),
            shape_invariants=(
                tf.TensorShape([None, None, None]),
                tf.TensorShape([]),
                Hypotheses(
                    tf.TensorShape([None]), tf.TensorShape([None, None]),
                    tf.nest.map_structure(get_shape_invariants, new_hyps[-1])),
                tf.TensorShape([]),
                tf.TensorShape([None, 1]),
            ))

        return new_hyps[1]
Example #6
0
def compute_rnnt_loss_and_grad_helper(logits, labels, label_length, logit_length):
    batch_size, input_max_len, target_max_len, vocab_size = shape_list(logits)
    # tf.print(shape_list(logits))
    # tf.print(shape_list(labels))
    # tf.print(shape_list(label_length))
    # tf.print(shape_list(logit_length))

    one_hot_labels = tf.one_hot(tf.tile(tf.expand_dims(labels, axis=1),
                                        multiples=[1, input_max_len, 1]), depth=vocab_size)

    log_probs = tf.nn.log_softmax(logits)
    blank_probs, truth_probs = transition_probs(one_hot_labels, log_probs)
    bp_diags = extract_diagonals(blank_probs)
    tp_diags = extract_diagonals(truth_probs)

    label_mask = tf.expand_dims(tf.sequence_mask(
        label_length + 1, maxlen=target_max_len, dtype=tf.float32), axis=1)
    small_label_mask = tf.expand_dims(tf.sequence_mask(
        label_length, maxlen=target_max_len, dtype=tf.float32), axis=1)
    input_mask = tf.expand_dims(tf.sequence_mask(
        logit_length, maxlen=input_max_len, dtype=tf.float32), axis=2)
    small_input_mask = tf.expand_dims(tf.sequence_mask(
        logit_length - 1, maxlen=input_max_len, dtype=tf.float32), axis=2)
    mask = label_mask * input_mask
    grad_blank_mask = (label_mask * small_input_mask)[:, :-1, :]
    grad_truth_mask = (small_label_mask * input_mask)[:, :, :-1]

    alpha = forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len) * mask

    indices = tf.stack([logit_length - 1, label_length], axis=1)
    blank_sl = tf.gather_nd(blank_probs, indices, batch_dims=1)

    beta = backward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len, label_length, logit_length,
                       blank_sl) * mask
    beta = tf.where(tf.math.is_nan(beta), tf.zeros_like(beta), beta)
    final_state_probs = beta[:, 0, 0]

    # Compute gradients of loss w.r.t. blank log-probabilities.
    grads_blank = -tf.exp((alpha[:, :-1, :] + beta[:, 1:, :] - tf.reshape(final_state_probs,
                                                                          shape=[batch_size, 1, 1]) + blank_probs[:,
                                                                                                                  :-1,
                                                                                                                  :]) * grad_blank_mask) * grad_blank_mask
    grads_blank = tf.concat([grads_blank, tf.zeros(
        shape=(batch_size, 1, target_max_len))], axis=1)
    last_grads_blank = -1 * tf.scatter_nd(
        tf.concat([tf.reshape(tf.range(batch_size, dtype=tf.int32),
                              shape=[batch_size, 1]), indices], axis=1),
        tf.ones(batch_size, dtype=tf.float32), [batch_size, input_max_len, target_max_len])
    grads_blank = grads_blank + last_grads_blank

    # Compute gradients of loss w.r.t. truth log-probabilities.
    grads_truth = -tf.exp((alpha[:, :, :-1] + beta[:, :, 1:] - tf.reshape(final_state_probs, shape=[batch_size, 1,
                                                                                                    1]) + truth_probs) * grad_truth_mask) * grad_truth_mask

    # Compute gradients of loss w.r.t. activations.
    a = tf.tile(tf.reshape(tf.range(target_max_len - 1, dtype=tf.int32), shape=(1, 1, target_max_len - 1, 1)),
                multiples=[batch_size, 1, 1, 1])
    b = tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1))
    c = tf.concat([a, b], axis=3)
    d = tf.tile(c, multiples=(1, input_max_len, 1, 1))
    e = tf.tile(tf.reshape(tf.range(input_max_len, dtype=tf.int32), shape=(1, input_max_len, 1, 1)),
                multiples=(batch_size, 1, target_max_len - 1, 1))
    f = tf.concat([e, d], axis=3)
    g = tf.tile(tf.reshape(tf.range(batch_size, dtype=tf.int32), shape=(batch_size, 1, 1, 1)),
                multiples=[1, input_max_len, target_max_len - 1, 1])
    scatter_idx = tf.concat([g, f], axis=3)
    # TODO - improve the part of code for scatter_idx computation.
    probs = tf.exp(log_probs)
    grads_truth_scatter = tf.scatter_nd(scatter_idx, grads_truth,
                                        [batch_size, input_max_len, target_max_len, vocab_size-1])
    grads = tf.concat(
        [tf.reshape(grads_blank, shape=(batch_size, input_max_len, target_max_len, -1)), grads_truth_scatter], axis=3)
    grads_logits = grads - probs * (tf.reduce_sum(grads, axis=3, keepdims=True))

    loss = -final_state_probs
    return loss, grads_logits
Example #7
0
    def perform_greedy(self,
                       features,
                       streaming: bool = False) -> tf.Tensor:
        if self.mel_layer is not None:
            features=self.mel_layer(features)
        new_hyps = Hypotheses(
            tf.constant(0.0, dtype=tf.float32),
            self.text_featurizer.start* tf.ones([1], dtype=tf.int32),
            self.predict_net.get_initial_state(features)
        )

        if self.kept_hyps is not None:
            new_hyps = self.kept_hyps

        enc = self.encoder(features, training=False)  # [1, T, E]
        enc = tf.squeeze(enc, axis=0)  # [T, E]

        T = tf.cast(shape_list(enc)[0], dtype=tf.int32)

        i = tf.constant(0, dtype=tf.int32)

        def _cond(enc, i, new_hyps, T):
            return tf.less(i, T)

        def _body(enc, i, new_hyps, T):
            hi = tf.reshape(enc[i], [1, 1, -1])  # [1, 1, E]
            y, n_memory_states = self.predict_net(
                inputs=tf.reshape(new_hyps[1][-1], [1, 1]),  # [1, 1]
                p_memory_states=new_hyps[2],
                training=False
            )  # [1, 1, P], [1, P], [1, P]
            # [1, 1, E] + [1, 1, P] => [1, 1, 1, V]
            ytu = tf.nn.log_softmax(self.joint_net([hi, y], training=False))
            ytu = tf.squeeze(ytu, axis=None)  # [1, 1, 1, V] => [V]
            n_predict = tf.argmax(ytu, axis=-1, output_type=tf.int32)  # => argmax []

            def return_no_blank():
                return Hypotheses(
                    new_hyps[0] + ytu[n_predict],
                    tf.concat([new_hyps[1], [n_predict]], axis=0),
                    n_memory_states,
                )

            hyps = tf.cond(
                n_predict != self.text_featurizer.blank and n_predict!=0,
                true_fn=return_no_blank,
                false_fn=lambda: new_hyps
            )

            return enc, i + 1, hyps, T

        _, _, new_hyps, _ = tf.while_loop(
            _cond,
            _body,
            loop_vars=(enc, i, new_hyps, T),
            shape_invariants=(
                tf.TensorShape([None, None]),
                tf.TensorShape([]),
                Hypotheses(
                    tf.TensorShape([]),
                    tf.TensorShape([None]),
                    tf.nest.map_structure(get_shape_invariants, new_hyps[-1])
                ),
                tf.TensorShape([])
            )
        )

        if streaming: self.kept_hyps = new_hyps

        return tf.expand_dims(new_hyps[1], axis=0)