예제 #1
0
    def call(self,
             inputs,
             training=None,
             mask=None,
             cache=None,
             decode_loop_step=None):
        qx, sx = inputs
        if mask is not None:
            qmask, smask = mask
        else:
            qmask, smask = (None, None)
        q, q_mask = cm(self.q_layer, qx, training=training, mask=qmask)

        # In cross-attention we may compute KV once and perform attention on it in subsequent steps

        def get_kv():
            def do_get_kv():
                k, k_mask = cm(self.k_layer, sx, training=training, mask=smask)
                v, v_mask = cm(self.v_layer, sx, training=training, mask=smask)
                return k, k_mask, v, v_mask

            if self.cache_kv and cache_context.cache_context is not None:
                with cache_context.SubContext(self.name):
                    if 'cached_kv' in cache_context.cache:
                        k, k_mask, v, v_mask = cache_context.cache['cached_kv']
                    else:
                        k, k_mask, v, v_mask = do_get_kv()
                        cache_context.cache['cached_kv'] = k, k_mask, v, v_mask
            else:
                k, k_mask, v, v_mask = do_get_kv()
            return k, k_mask, v, v_mask

        if cache is not None:
            if 'k' not in cache:
                k, k_mask, v, v_mask = get_kv()

                # Update cache
                cache["k"] = k
                cache["v"] = v
                if mask is not None:
                    cache["k_mask"] = k_mask
                    cache["v_mask"] = v_mask
            else:
                k = cache["k"]
                v = cache["v"]
                if mask is not None:
                    k_mask = cache["k_mask"]
                    v_mask = cache["v_mask"]
        else:
            k, k_mask, v, v_mask = get_kv()
        if mask is not None:
            mask = [q_mask, tf.logical_and(k_mask, v_mask)]
        x, weights = self.attention_layer([q, k, v],
                                          mask=mask,
                                          training=training)
        if not self.skip_out:
            x = self.out_layer(x, mask=mask, training=training)
        if self.return_attn_weights:
            return x, weights
        return x
예제 #2
0
 def _call(self, inputs, training=None, cache=None, mask=None, **kwargs):
     max_req_len = max(self.first_kernel_size * self.dilation_rate,
                       self.second_kernel_size * self.dilation_rate)
     if cache is not None:
         inputs = tf.concat([cache['input_queue'], inputs], axis=1)
         if mask is not None:
             mask = tf.concat([cache['mask_queue'], mask], axis=1)
     x, x_mask = cm(self.dff_layer, inputs, training=training, mask=mask)
     if self.inner_gate:
         x, x_mask = cm(self.gating_layer, x, training=training, mask=mask)
     x, x_mask = cm(self.out_layer, x, training=training, mask=x_mask)
     if cache is not None:
         cache['input_queue'] = inputs[:, -max_req_len:, :]
         if mask is not None:
             cache['mask_queue'] = mask[:, -max_req_len:]
             mask = mask[:, -1:]
         return x[:, -1:, :], mask
     return x, mask
예제 #3
0
 def _call(self, inputs, mask=None, cache=None, training=None, **kwargs):
     if 'decode_loop_step' in kwargs:
         _ = kwargs.pop('decode_loop_step')
     x = inputs
     for i, block in enumerate(self.blocks):
         if cache is not None:
             block_cache = cache[i]
         else:
             block_cache = None
         x, mask = cm(block, x, training=training, mask=mask, cache=block_cache, **kwargs)
     return x
예제 #4
0
    def _call(self, inputs, training=None, cache=None, mask=None, **kwargs):

        if cache is not None:
            inputs = tf.concat([cache['input_queue'], inputs], axis=1)
            if mask is not None:
                mask = tf.concat([cache['mask_queue'], mask], axis=1)

        x, _ = cm(self.processing_block, inputs=inputs, mask=mask, training=training, **kwargs)

        if cache is not None:
            cache['input_queue'] = inputs[:, -(self.kernel_size * self.dilation_rate):, :]
            if mask is not None:
                cache['mask_queue'] = mask[:, -(self.kernel_size * self.dilation_rate):]
                mask = mask[:, -1:]
            return x[:, -1:, :], mask
        return x, mask
예제 #5
0
 def call_masked(self,
                 inputs,
                 training=None,
                 mask=None,
                 cache=None,
                 decode_loop_step=None,
                 **kwargs):
     if cache is None:
         cache = {}
     x = inputs
     for i, block in enumerate(self.blocks):
         if cache is not None:
             block_cache = cache.get(i, {})
         else:
             block_cache = None
         x, mask = cm(block,
                      x,
                      training=training,
                      mask=mask,
                      cache=block_cache,
                      decode_loop_step=decode_loop_step,
                      **kwargs)
     return x, mask
예제 #6
0
    def _call(self,
              inputs,
              training=None,
              mask=None,
              cache=None,
              decode_loop_step=None,
              pad_q_to_kv=False):
        if cache is None:
            cache = {}
        x, enc_in = inputs
        ldf = tf.cast(
            tf.random.uniform([], maxval=1.,
                              dtype=inputs[0].dtype) > self.layerdrop_rate,
            inputs[0].dtype) if training else tf.convert_to_tensor(
                1., dtype=inputs[0].dtype)
        _x = x
        if mask is not None:
            x_mask, enc_mask = mask
        else:
            assert False
            x_mask = None
            enc_mask = None
        _x_mask = x_mask
        ## Self-attention
        if self.prenorm:
            x, x_mask = cm(self.norm0, x, training=training, mask=x_mask)
        x, x_mask = cm(self.sa_layer,
                       x,
                       training=training,
                       mask=x_mask,
                       cache=cache.get('sa', None),
                       decode_loop_step=decode_loop_step,
                       pad_q_to_kv=pad_q_to_kv)
        attn_weights_sa = None
        if type(x) in [list, tuple]:
            x, attn_weights_sa = x
        res, res_mask = cm(self.mha_skip_adapt,
                           _x,
                           training=training,
                           mask=_x_mask)
        if x_mask is None:
            n_mask = res_mask
        elif res_mask is None:
            n_mask = x_mask
        else:
            n_mask = tf.math.logical_and(x_mask, res_mask)
        x = tfkl.Dropout(self.dropout_rate)(x, training=training)
        if self.prenorm:
            x = (ldf * x) + res
            x_mask = n_mask
        else:
            x, x_mask = cm(self.norm0, (ldf * x) + res,
                           training=training,
                           mask=n_mask)
        _x = x
        ## Attend to encoding
        if mask is not None:
            ca_mask = x_mask, enc_mask
        else:
            assert False
            ca_mask = None

        res, res_mask = cm(self.mha_skip_adapt,
                           _x,
                           training=training,
                           mask=x_mask)
        if self.prenorm:
            x, x_mask = cm(self.norm1, x, training=training, mask=x_mask)

        x, x_mask = cm(self.ca_layer, (x, enc_in),
                       training=training,
                       mask=ca_mask,
                       cache=cache.get('ca', None))
        attn_weights_ca = None
        if type(x) in [list, tuple]:
            x, attn_weights_ca = x

        if x_mask is None:
            n_mask = res_mask
        elif res_mask is None:
            n_mask = x_mask
        else:
            n_mask = tf.math.logical_and(x_mask, res_mask)
        x = tfkl.Dropout(self.dropout_rate)(x, training=training)
        if self.prenorm:
            x = (ldf * x) + res
            x_mask = n_mask
        else:
            x, x_mask = cm(self.norm1, (ldf * x) + res,
                           training=training,
                           mask=n_mask)
        _x = x
        res, res_mask = cm(self.ffn_skip_adapt,
                           _x,
                           training=training,
                           mask=x_mask)
        ## FF-net
        if self.prenorm:
            x, x_mask = cm(self.norm2, x, training=training, mask=x_mask)

        f, f_mask = cm(self.ffn,
                       x,
                       training=training,
                       mask=x_mask,
                       cache=cache.get('ffn', None))
        if f_mask is None:
            n_mask = f_mask
        elif res_mask is None:
            n_mask = x_mask
        else:
            n_mask = tf.math.logical_and(f_mask, res_mask)
        f = tfkl.Dropout(self.dropout_rate)(f, training=training)
        if self.prenorm:
            x = (ldf * f) + res
            mask = n_mask
        else:
            x, mask = cm(self.norm2, (ldf * f) + res,
                         training=training,
                         mask=n_mask)
        if attn_weights_ca is not None or attn_weights_sa is not None:
            return x, mask, {'sa': attn_weights_sa, 'ca': attn_weights_ca}
        return x, mask
예제 #7
0
    def call(self,
             inputs,
             training=None,
             mask=None,
             cache=None,
             decode_loop_step=None):
        if cache is None:
            cache = {}
        with cache_context.SubContext(self.name):
            x = inputs
            ldf = tf.cast(
                tf.random.uniform([], maxval=1.,
                                  dtype=inputs.dtype) > self.layerdrop_rate,
                inputs.dtype) if training else tf.convert_to_tensor(
                    1., dtype=inputs.dtype)
            if self.prenorm:
                x, x_mask = cm(self.norm0, x, training=training, mask=mask)
            else:
                x_mask = mask
            x, x_mask = cm(self.sa_layer,
                           x,
                           training=training,
                           mask=x_mask,
                           cache=cache.get('sa', None),
                           decode_loop_step=decode_loop_step)
            res, res_mask = cm(self.mha_skip_adapt,
                               inputs,
                               training=training,
                               mask=mask)
            if x_mask is None:
                n_mask = res_mask
            elif res_mask is None:
                n_mask = x_mask
            else:
                n_mask = tf.math.logical_and(x_mask, res_mask)
            x = tfkl.Dropout(self.dropout_rate)(x, training=training)
            if self.prenorm:
                x = (x * ldf) + res
                x_mask = n_mask
            else:
                x, x_mask = cm(self.norm0, (x * ldf) + res,
                               training=training,
                               mask=n_mask)

            res, res_mask = cm(self.ffn_skip_adapt,
                               x,
                               training=training,
                               mask=x_mask)
            if self.prenorm:
                x, x_mask = cm(self.norm1, x, training=training, mask=x_mask)
            f, f_mask = cm(self.ffn,
                           x,
                           training=training,
                           mask=x_mask,
                           cache=cache.get('ffn', None))
            if f_mask is None:
                n_mask = f_mask
            elif res_mask is None:
                n_mask = x_mask
            else:
                n_mask = tf.math.logical_and(f_mask, res_mask)
            f = tfkl.Dropout(self.dropout_rate)(f, training=training)
            if self.prenorm:
                x = (ldf * f) + res
            else:
                x, _ = cm(self.norm1, (ldf * f) + res,
                          training=training,
                          mask=n_mask)
            return x
예제 #8
0
    def call(self,
             inputs,
             training=None,
             mask=None,
             cache=None,
             decode_loop_step=None,
             pad_q_to_kv=False):
        x = inputs
        q, q_mask = cm(self.q_layer, x, training=training, mask=mask)
        k, k_mask = cm(self.k_layer, x, training=training, mask=mask)
        v, v_mask = cm(self.v_layer, x, training=training, mask=mask)
        if cache is not None:
            # Combine cached keys and values with new keys and values.
            if cache["k"] is not None:
                # Update cache
                if decode_loop_step is not None:

                    cache_k_shape = cache["k"].shape.as_list()
                    indices = tf.reshape(
                        tf.one_hot(decode_loop_step,
                                   cache_k_shape[1],
                                   dtype=k.dtype), [1, cache_k_shape[1], 1])
                    k = cache["k"] + k * indices
                    if mask is not None:
                        indices = tf.reshape(
                            tf.one_hot(decode_loop_step,
                                       cache_k_shape[1],
                                       dtype=tf.float16),
                            [1, cache_k_shape[1]])
                        k_mask = tf.logical_or(
                            cache["k_mask"],
                            (tf.cast(k_mask, tf.float16) * indices) > 0.)

                    cache_v_shape = cache["v"].shape.as_list()
                    indices = tf.reshape(
                        tf.one_hot(decode_loop_step,
                                   cache_v_shape[1],
                                   dtype=v.dtype), [1, cache_v_shape[1], 1])
                    v = cache["v"] + v * indices
                    if mask is not None:
                        indices = tf.reshape(
                            tf.one_hot(decode_loop_step,
                                       cache_v_shape[1],
                                       dtype=tf.float16),
                            [1, cache_v_shape[1]])
                        v_mask = tf.logical_or(
                            cache["v_mask"],
                            (tf.cast(v_mask, tf.float16) * indices) > 0.)
                else:
                    k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1)
                    v = tf.concat([tf.cast(cache["v"], v.dtype), v], axis=1)
                    if mask is not None:
                        k_mask = tf.concat(
                            [tf.cast(cache["k_mask"], k_mask.dtype), k_mask],
                            axis=1)
                        v_mask = tf.concat(
                            [tf.cast(cache["v_mask"], v_mask.dtype), v_mask],
                            axis=1)

            # Update cache
            cache["k"] = k
            cache["v"] = v
            if mask is not None:
                cache["k_mask"] = k_mask
                cache["v_mask"] = v_mask

        q_shape = t2t_common.shape_list(q)
        kv_shape = t2t_common.shape_list(k)

        if pad_q_to_kv:
            if q_shape[1] != kv_shape[1]:
                if decode_loop_step is not None:
                    q_prepad = decode_loop_step
                    q_postpad = (kv_shape[1] - q_shape[1]) - decode_loop_step

                else:
                    q_prepad = (kv_shape[1] - q_shape[1])
                    q_postpad = 0
                q = tf.pad(q, paddings=[[0, 0], [q_prepad, q_postpad], [0, 0]])
                if mask is not None:
                    q_mask = tf.pad(q_mask,
                                    paddings=[[0, 0], [q_prepad, q_postpad]])
            else:
                # This is just stupid autograph nonsense, ignore it
                if decode_loop_step is not None:
                    q_prepad = decode_loop_step
                else:
                    q_prepad = (kv_shape[1] - q_shape[1])
        else:
            # This is just stupid autograph nonsense, ignore it
            if decode_loop_step is not None:
                q_prepad = decode_loop_step
            else:
                q_prepad = (kv_shape[1] - q_shape[1])

        if mask is not None:
            mask = [q_mask, tf.logical_and(k_mask, v_mask)]
        x, weights = self.attention_layer([q, k, v],
                                          mask=mask,
                                          training=training)
        if not self.skip_out:
            x = self.out_layer(x, mask=mask, training=training)
        x_shape = t2t_common.shape_list(x)
        if pad_q_to_kv:
            if q_shape[1] != kv_shape[1]:
                if decode_loop_step is not None:
                    x = tf.slice(x, [0, q_prepad, 0],
                                 [x_shape[0], 1, x_shape[2]])
                else:
                    x = tf.slice(x, [0, q_prepad, 0],
                                 [x_shape[0], q_shape[1], x_shape[2]])
        if self.return_attn_weights:
            return x, weights
        return x
예제 #9
0
 def do_get_kv():
     k, k_mask = cm(self.k_layer, sx, training=training, mask=smask)
     v, v_mask = cm(self.v_layer, sx, training=training, mask=smask)
     return k, k_mask, v, v_mask