def call(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None): qx, sx = inputs if mask is not None: qmask, smask = mask else: qmask, smask = (None, None) q, q_mask = cm(self.q_layer, qx, training=training, mask=qmask) # In cross-attention we may compute KV once and perform attention on it in subsequent steps def get_kv(): def do_get_kv(): k, k_mask = cm(self.k_layer, sx, training=training, mask=smask) v, v_mask = cm(self.v_layer, sx, training=training, mask=smask) return k, k_mask, v, v_mask if self.cache_kv and cache_context.cache_context is not None: with cache_context.SubContext(self.name): if 'cached_kv' in cache_context.cache: k, k_mask, v, v_mask = cache_context.cache['cached_kv'] else: k, k_mask, v, v_mask = do_get_kv() cache_context.cache['cached_kv'] = k, k_mask, v, v_mask else: k, k_mask, v, v_mask = do_get_kv() return k, k_mask, v, v_mask if cache is not None: if 'k' not in cache: k, k_mask, v, v_mask = get_kv() # Update cache cache["k"] = k cache["v"] = v if mask is not None: cache["k_mask"] = k_mask cache["v_mask"] = v_mask else: k = cache["k"] v = cache["v"] if mask is not None: k_mask = cache["k_mask"] v_mask = cache["v_mask"] else: k, k_mask, v, v_mask = get_kv() if mask is not None: mask = [q_mask, tf.logical_and(k_mask, v_mask)] x, weights = self.attention_layer([q, k, v], mask=mask, training=training) if not self.skip_out: x = self.out_layer(x, mask=mask, training=training) if self.return_attn_weights: return x, weights return x
def _call(self, inputs, training=None, cache=None, mask=None, **kwargs): max_req_len = max(self.first_kernel_size * self.dilation_rate, self.second_kernel_size * self.dilation_rate) if cache is not None: inputs = tf.concat([cache['input_queue'], inputs], axis=1) if mask is not None: mask = tf.concat([cache['mask_queue'], mask], axis=1) x, x_mask = cm(self.dff_layer, inputs, training=training, mask=mask) if self.inner_gate: x, x_mask = cm(self.gating_layer, x, training=training, mask=mask) x, x_mask = cm(self.out_layer, x, training=training, mask=x_mask) if cache is not None: cache['input_queue'] = inputs[:, -max_req_len:, :] if mask is not None: cache['mask_queue'] = mask[:, -max_req_len:] mask = mask[:, -1:] return x[:, -1:, :], mask return x, mask
def _call(self, inputs, mask=None, cache=None, training=None, **kwargs): if 'decode_loop_step' in kwargs: _ = kwargs.pop('decode_loop_step') x = inputs for i, block in enumerate(self.blocks): if cache is not None: block_cache = cache[i] else: block_cache = None x, mask = cm(block, x, training=training, mask=mask, cache=block_cache, **kwargs) return x
def _call(self, inputs, training=None, cache=None, mask=None, **kwargs): if cache is not None: inputs = tf.concat([cache['input_queue'], inputs], axis=1) if mask is not None: mask = tf.concat([cache['mask_queue'], mask], axis=1) x, _ = cm(self.processing_block, inputs=inputs, mask=mask, training=training, **kwargs) if cache is not None: cache['input_queue'] = inputs[:, -(self.kernel_size * self.dilation_rate):, :] if mask is not None: cache['mask_queue'] = mask[:, -(self.kernel_size * self.dilation_rate):] mask = mask[:, -1:] return x[:, -1:, :], mask return x, mask
def call_masked(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None, **kwargs): if cache is None: cache = {} x = inputs for i, block in enumerate(self.blocks): if cache is not None: block_cache = cache.get(i, {}) else: block_cache = None x, mask = cm(block, x, training=training, mask=mask, cache=block_cache, decode_loop_step=decode_loop_step, **kwargs) return x, mask
def _call(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None, pad_q_to_kv=False): if cache is None: cache = {} x, enc_in = inputs ldf = tf.cast( tf.random.uniform([], maxval=1., dtype=inputs[0].dtype) > self.layerdrop_rate, inputs[0].dtype) if training else tf.convert_to_tensor( 1., dtype=inputs[0].dtype) _x = x if mask is not None: x_mask, enc_mask = mask else: assert False x_mask = None enc_mask = None _x_mask = x_mask ## Self-attention if self.prenorm: x, x_mask = cm(self.norm0, x, training=training, mask=x_mask) x, x_mask = cm(self.sa_layer, x, training=training, mask=x_mask, cache=cache.get('sa', None), decode_loop_step=decode_loop_step, pad_q_to_kv=pad_q_to_kv) attn_weights_sa = None if type(x) in [list, tuple]: x, attn_weights_sa = x res, res_mask = cm(self.mha_skip_adapt, _x, training=training, mask=_x_mask) if x_mask is None: n_mask = res_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(x_mask, res_mask) x = tfkl.Dropout(self.dropout_rate)(x, training=training) if self.prenorm: x = (ldf * x) + res x_mask = n_mask else: x, x_mask = cm(self.norm0, (ldf * x) + res, training=training, mask=n_mask) _x = x ## Attend to encoding if mask is not None: ca_mask = x_mask, enc_mask else: assert False ca_mask = None res, res_mask = cm(self.mha_skip_adapt, _x, training=training, mask=x_mask) if self.prenorm: x, x_mask = cm(self.norm1, x, training=training, mask=x_mask) x, x_mask = cm(self.ca_layer, (x, enc_in), training=training, mask=ca_mask, cache=cache.get('ca', None)) attn_weights_ca = None if type(x) in [list, tuple]: x, attn_weights_ca = x if x_mask is None: n_mask = res_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(x_mask, res_mask) x = tfkl.Dropout(self.dropout_rate)(x, training=training) if self.prenorm: x = (ldf * x) + res x_mask = n_mask else: x, x_mask = cm(self.norm1, (ldf * x) + res, training=training, mask=n_mask) _x = x res, res_mask = cm(self.ffn_skip_adapt, _x, training=training, mask=x_mask) ## FF-net if self.prenorm: x, x_mask = cm(self.norm2, x, training=training, mask=x_mask) f, f_mask = cm(self.ffn, x, training=training, mask=x_mask, cache=cache.get('ffn', None)) if f_mask is None: n_mask = f_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(f_mask, res_mask) f = tfkl.Dropout(self.dropout_rate)(f, training=training) if self.prenorm: x = (ldf * f) + res mask = n_mask else: x, mask = cm(self.norm2, (ldf * f) + res, training=training, mask=n_mask) if attn_weights_ca is not None or attn_weights_sa is not None: return x, mask, {'sa': attn_weights_sa, 'ca': attn_weights_ca} return x, mask
def call(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None): if cache is None: cache = {} with cache_context.SubContext(self.name): x = inputs ldf = tf.cast( tf.random.uniform([], maxval=1., dtype=inputs.dtype) > self.layerdrop_rate, inputs.dtype) if training else tf.convert_to_tensor( 1., dtype=inputs.dtype) if self.prenorm: x, x_mask = cm(self.norm0, x, training=training, mask=mask) else: x_mask = mask x, x_mask = cm(self.sa_layer, x, training=training, mask=x_mask, cache=cache.get('sa', None), decode_loop_step=decode_loop_step) res, res_mask = cm(self.mha_skip_adapt, inputs, training=training, mask=mask) if x_mask is None: n_mask = res_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(x_mask, res_mask) x = tfkl.Dropout(self.dropout_rate)(x, training=training) if self.prenorm: x = (x * ldf) + res x_mask = n_mask else: x, x_mask = cm(self.norm0, (x * ldf) + res, training=training, mask=n_mask) res, res_mask = cm(self.ffn_skip_adapt, x, training=training, mask=x_mask) if self.prenorm: x, x_mask = cm(self.norm1, x, training=training, mask=x_mask) f, f_mask = cm(self.ffn, x, training=training, mask=x_mask, cache=cache.get('ffn', None)) if f_mask is None: n_mask = f_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(f_mask, res_mask) f = tfkl.Dropout(self.dropout_rate)(f, training=training) if self.prenorm: x = (ldf * f) + res else: x, _ = cm(self.norm1, (ldf * f) + res, training=training, mask=n_mask) return x
def call(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None, pad_q_to_kv=False): x = inputs q, q_mask = cm(self.q_layer, x, training=training, mask=mask) k, k_mask = cm(self.k_layer, x, training=training, mask=mask) v, v_mask = cm(self.v_layer, x, training=training, mask=mask) if cache is not None: # Combine cached keys and values with new keys and values. if cache["k"] is not None: # Update cache if decode_loop_step is not None: cache_k_shape = cache["k"].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=k.dtype), [1, cache_k_shape[1], 1]) k = cache["k"] + k * indices if mask is not None: indices = tf.reshape( tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=tf.float16), [1, cache_k_shape[1]]) k_mask = tf.logical_or( cache["k_mask"], (tf.cast(k_mask, tf.float16) * indices) > 0.) cache_v_shape = cache["v"].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=v.dtype), [1, cache_v_shape[1], 1]) v = cache["v"] + v * indices if mask is not None: indices = tf.reshape( tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=tf.float16), [1, cache_v_shape[1]]) v_mask = tf.logical_or( cache["v_mask"], (tf.cast(v_mask, tf.float16) * indices) > 0.) else: k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1) v = tf.concat([tf.cast(cache["v"], v.dtype), v], axis=1) if mask is not None: k_mask = tf.concat( [tf.cast(cache["k_mask"], k_mask.dtype), k_mask], axis=1) v_mask = tf.concat( [tf.cast(cache["v_mask"], v_mask.dtype), v_mask], axis=1) # Update cache cache["k"] = k cache["v"] = v if mask is not None: cache["k_mask"] = k_mask cache["v_mask"] = v_mask q_shape = t2t_common.shape_list(q) kv_shape = t2t_common.shape_list(k) if pad_q_to_kv: if q_shape[1] != kv_shape[1]: if decode_loop_step is not None: q_prepad = decode_loop_step q_postpad = (kv_shape[1] - q_shape[1]) - decode_loop_step else: q_prepad = (kv_shape[1] - q_shape[1]) q_postpad = 0 q = tf.pad(q, paddings=[[0, 0], [q_prepad, q_postpad], [0, 0]]) if mask is not None: q_mask = tf.pad(q_mask, paddings=[[0, 0], [q_prepad, q_postpad]]) else: # This is just stupid autograph nonsense, ignore it if decode_loop_step is not None: q_prepad = decode_loop_step else: q_prepad = (kv_shape[1] - q_shape[1]) else: # This is just stupid autograph nonsense, ignore it if decode_loop_step is not None: q_prepad = decode_loop_step else: q_prepad = (kv_shape[1] - q_shape[1]) if mask is not None: mask = [q_mask, tf.logical_and(k_mask, v_mask)] x, weights = self.attention_layer([q, k, v], mask=mask, training=training) if not self.skip_out: x = self.out_layer(x, mask=mask, training=training) x_shape = t2t_common.shape_list(x) if pad_q_to_kv: if q_shape[1] != kv_shape[1]: if decode_loop_step is not None: x = tf.slice(x, [0, q_prepad, 0], [x_shape[0], 1, x_shape[2]]) else: x = tf.slice(x, [0, q_prepad, 0], [x_shape[0], q_shape[1], x_shape[2]]) if self.return_attn_weights: return x, weights return x
def do_get_kv(): k, k_mask = cm(self.k_layer, sx, training=training, mask=smask) v, v_mask = cm(self.v_layer, sx, training=training, mask=smask) return k, k_mask, v, v_mask