def compute_mask(self, inputs, mask=None): with cache_context.SubContext(self.name): x = inputs for block in self.blocks: if block.supports_masking: mask = block.compute_mask(x, mask=mask) return mask
def get_kv(): def do_get_kv(): k, k_mask = cm(self.k_layer, sx, training=training, mask=smask) v, v_mask = cm(self.v_layer, sx, training=training, mask=smask) return k, k_mask, v, v_mask if self.cache_kv and cache_context.cache_context is not None: with cache_context.SubContext(self.name): if 'cached_kv' in cache_context.cache: k, k_mask, v, v_mask = cache_context.cache['cached_kv'] else: k, k_mask, v, v_mask = do_get_kv() cache_context.cache['cached_kv'] = k, k_mask, v, v_mask else: k, k_mask, v, v_mask = do_get_kv() return k, k_mask, v, v_mask
def compute_mask(self, inputs, mask=None): with cache_context.SubContext(self.name): x = inputs if self.sa_layer.supports_masking: x_mask = self.sa_layer.compute_mask(x, mask=mask) else: x_mask = mask if self.mha_skip_adapt.supports_masking: res_mask = self.mha_skip_adapt.compute_mask(inputs, mask=mask) else: res_mask = mask if x_mask is None: n_mask = res_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(x_mask, res_mask) if self.norm0.supports_masking: x_mask = self.norm0.compute_mask(x, mask=n_mask) else: x_mask = n_mask if self.ffn.supports_masking: f_mask = self.ffn.compute_mask(x, mask=x_mask) else: f_mask = x_mask if self.ffn_skip_adapt.supports_masking: res_mask = self.ffn_skip_adapt.compute_mask(x, mask=x_mask) else: res_mask = x_mask if f_mask is None: n_mask = f_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(f_mask, res_mask) if self.norm1.supports_masking: mask = self.norm1.compute_mask(x, mask=n_mask) else: mask = n_mask return mask
def call(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None, **kwargs): if cache is None: cache = {} with cache_context.SubContext(self.name): x = inputs for i, block in enumerate(self.blocks): if cache is not None: block_cache = cache.get(i, {}) else: block_cache = None x = block(x, training=training, mask=mask, cache=block_cache, decode_loop_step=decode_loop_step, **kwargs) return x
def call(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None): if cache is None: cache = {} with cache_context.SubContext(self.name): x = inputs ldf = tf.cast( tf.random.uniform([], maxval=1., dtype=inputs.dtype) > self.layerdrop_rate, inputs.dtype) if training else tf.convert_to_tensor( 1., dtype=inputs.dtype) if self.prenorm: x, x_mask = cm(self.norm0, x, training=training, mask=mask) else: x_mask = mask x, x_mask = cm(self.sa_layer, x, training=training, mask=x_mask, cache=cache.get('sa', None), decode_loop_step=decode_loop_step) res, res_mask = cm(self.mha_skip_adapt, inputs, training=training, mask=mask) if x_mask is None: n_mask = res_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(x_mask, res_mask) x = tfkl.Dropout(self.dropout_rate)(x, training=training) if self.prenorm: x = (x * ldf) + res x_mask = n_mask else: x, x_mask = cm(self.norm0, (x * ldf) + res, training=training, mask=n_mask) res, res_mask = cm(self.ffn_skip_adapt, x, training=training, mask=x_mask) if self.prenorm: x, x_mask = cm(self.norm1, x, training=training, mask=x_mask) f, f_mask = cm(self.ffn, x, training=training, mask=x_mask, cache=cache.get('ffn', None)) if f_mask is None: n_mask = f_mask elif res_mask is None: n_mask = x_mask else: n_mask = tf.math.logical_and(f_mask, res_mask) f = tfkl.Dropout(self.dropout_rate)(f, training=training) if self.prenorm: x = (ldf * f) + res else: x, _ = cm(self.norm1, (ldf * f) + res, training=training, mask=n_mask) return x