def call(self, inputs, inputs_padding, is_training=True): """ Encodes the inputs. Args: inputs: The embedded input, a float tensor with shape [batch_size, max_length, embedding_dim]. inputs_padding: A float tensor with shape [batch_size, max_length], indicating the padding positions, where 1.0 for padding and 0.0 for non-padding. is_training: A bool, whether in training mode or not. Returns: The encoded output with shape [batch_size, max_length, hidden_size] """ # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding all_layers = [] self_attention_bias = layer_utils.input_padding_to_bias(inputs_padding) if self.get_config()["attention_monotonic"]: self_attention_bias = tf.minimum( tf.expand_dims(tf.expand_dims(self_attention_bias, axis=1), axis=1), layer_utils.lower_triangle_attention_bias(tf.shape(inputs)[1])) x = inputs if is_training: x = tf.nn.dropout( x, rate=self.get_config()["layer_postprocess_dropout_rate"]) for idx, layer in enumerate(self._stacking_layers): x = layer(x, self_attention_bias, is_training=is_training) all_layers.append(x) if self.get_config()["post_normalize"]: if self._return_all_layers: return all_layers return x outputs = self.quant(self._output_norm_layer(x), name="output_ln") return outputs
def call(self, decoder_inputs, cache, is_training=True, decode_loop_step=None): """ Encodes the inputs. Args: decoder_inputs: The embedded decoder input, a float tensor with shape [batch_size, max_target_length, embedding_dim] or [batch_size, embedding_dim] for one decoding step. cache: A dictionary, generated from self.create_decoding_internal_cache. is_training: A bool, whether in training mode or not. decode_loop_step: An integer, step number of the decoding loop. Used only for autoregressive inference with static-shape cache. Returns: The decoder output with shape [batch_size, max_length, hidden_size] when `decoder_inputs` is a 3-d tensor or with shape [batch_size, hidden_size] when `decoder_inputs` is a 2-d tensor. """ ori_ndims = decoder_inputs.get_shape().ndims if ori_ndims == 2: decoder_inputs = tf.expand_dims(decoder_inputs, axis=1) # decoder self attention has shape [1, 1, max_target_len, max_target_len] decoder_self_attention_bias = layer_utils.lower_triangle_attention_bias( tf.shape(decoder_inputs)[1]) x = decoder_inputs if is_training: x = tf.nn.dropout( decoder_inputs, rate=self.get_config()["layer_postprocess_dropout_rate"]) for idx, layer in enumerate(self._stacking_layers): selfatt_layer = layer[0] encdecatt_layer = layer[1] ffn_layer = layer[2] layer_name = "layer_{}".format(idx) layer_cache = None if cache["decoding_states"] is None else cache["decoding_states"][layer_name] selfatt_cache = None if layer_cache is None else layer_cache["self_attention"] with tf.name_scope(layer_name): # self attention layer x = selfatt_layer( x, # x as query bias=decoder_self_attention_bias, cache=selfatt_cache, is_training=is_training, decode_loop_step=decode_loop_step) # enc-dec attention layer if encdecatt_layer is not None: x = encdecatt_layer( x, # x as query memory=cache["memory"], # None indicates self-attention memory_bias=cache["memory_bias"], is_training=is_training) # ffn x = ffn_layer(x, is_training=is_training) outputs = x if not self.get_config()["post_normalize"]: outputs = self.quant(self._output_norm_layer(x), name="output_ln") if ori_ndims == 2: outputs = tf.squeeze(outputs, axis=1) return outputs
def incremental_encode(self, inputs, cache, time=None): """ Encoding function for streaming input. Args: inputs: The embedded input at time t, a float tensor with shape [batch, embedding_dim] or [batch, length, embedding_dim] cache: A dict containing cached tensors. time: The start time of the inputs Returns: The incremented encoder output with shape [batch, t+1, dim], and the updated cache dict. """ params = self.get_config() assert params["attention_monotonic"], ( "function `incremental_encode` only available when attention_monotonic=True" ) if cache is None: cache = {} if cache is not None and len(cache) == 0: batch_size = tf.shape(inputs)[0] for lid in range(params["num_layers"]): cache[f"layer_{lid}"] = self._stacking_layers[ lid].create_internal_cache() cache = tf.nest.map_structure( lambda ts: layer_utils.tile_tensor(ts, batch_size, axis=0), cache) if inputs.get_shape().ndims == 2: x = tf.expand_dims(inputs, axis=1) x_bias = None else: x = inputs if time is None: time = 0 x_bias = layer_utils.lower_triangle_attention_bias( time + tf.shape(x)[1])[:, :, -tf.shape(x)[1]:] for idx, layer in enumerate(self._stacking_layers): layer_cache = None if cache is None else cache[f"layer_{idx}"] x = layer(x, x_bias, layer_cache, is_training=False) outputs = x if not params["post_normalize"]: outputs = self.quant(self._output_norm_layer(x), name="output_ln") return outputs, cache
def call(self, decoder_inputs, cache, decode_lagging=None, is_training=True, decode_loop_step=None): """ Encodes the inputs. Args: decoder_inputs: The embedded decoder input, a float tensor with shape [batch_size, max_target_length, embedding_dim] or [batch_size, embedding_dim] for one decoding step. cache: A dictionary, generated from self.create_decoding_internal_cache. decode_lagging: The waitk lagging for streaming input when training. During inference, it is the lagging for current step. is_training: A bool, whether in training mode or not. decode_loop_step: An integer, step number of the decoding loop. Used only for autoregressive inference with static-shape cache. Returns: The decoder output with shape [batch_size, max_length, hidden_size] when `decoder_inputs` is a 3-d tensor or with shape [batch_size, hidden_size] when `decoder_inputs` is a 2-d tensor. """ ori_ndims = decoder_inputs.get_shape().ndims if ori_ndims == 2: decoder_inputs = tf.expand_dims(decoder_inputs, axis=1) memory_bias = cache.get("memory_bias", None) # [batch, memory_length] if memory_bias is not None and decode_lagging is not None: if ori_ndims == 3: memory_bias = tf.minimum( tf.expand_dims(memory_bias, axis=1), tf.expand_dims(layer_utils.waitk_attention_bias( memory_length=tf.shape(memory_bias)[1], query_length=tf.shape(decoder_inputs)[1], waitk_lagging=decode_lagging), axis=0)) else: # ori_ndims == 2 memory_bias = tf.minimum( memory_bias, tf.expand_dims(layer_utils.waitk_attention_bias( memory_length=tf.shape(memory_bias)[1], waitk_lagging=decode_lagging), axis=0)) # decoder self attention has shape [1, 1, max_target_len, max_target_len] decoder_self_attention_bias = layer_utils.lower_triangle_attention_bias( tf.shape(decoder_inputs)[1]) x = decoder_inputs if is_training: x = tf.nn.dropout( decoder_inputs, rate=self.get_config()["layer_postprocess_dropout_rate"]) for idx, layer in enumerate(self._stacking_layers): layer_cache = (None if cache["decoding_states"] is None else cache["decoding_states"][f"layer_{idx}"]) x = self._stacking_layers[idx](x, decoder_self_attention_bias, layer_cache, memory=cache.get("memory", None), memory_bias=memory_bias, is_training=is_training, decode_loop_step=decode_loop_step) outputs = x if not self.get_config()["post_normalize"]: outputs = self.quant(self._output_norm_layer(x), name="output_ln") if ori_ndims == 2: outputs = tf.squeeze(outputs, axis=1) return outputs
def test_lower_triangle_attention_bias(): assert_equal_numpy(lower_triangle_attention_bias(5).numpy(), pt_lower_triangle_attention_bias(5).detach().numpy())