def call(self, inputs, inputs_padding, is_training=True): """ Encodes the inputs. Args: inputs: The embedded input, a float tensor with shape [batch_size, max_length, embedding_dim]. inputs_padding: A float tensor with shape [batch_size, max_length], indicating the padding positions, where 1.0 for padding and 0.0 for non-padding. is_training: A bool, whether in training mode or not. Returns: The encoded output with shape [batch_size, max_length, hidden_size] """ # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding all_layers = [] self_attention_bias = layer_utils.input_padding_to_bias(inputs_padding) if self.get_config()["attention_monotonic"]: self_attention_bias = tf.minimum( tf.expand_dims(tf.expand_dims(self_attention_bias, axis=1), axis=1), layer_utils.lower_triangle_attention_bias(tf.shape(inputs)[1])) x = inputs if is_training: x = tf.nn.dropout( x, rate=self.get_config()["layer_postprocess_dropout_rate"]) for idx, layer in enumerate(self._stacking_layers): x = layer(x, self_attention_bias, is_training=is_training) all_layers.append(x) if self.get_config()["post_normalize"]: if self._return_all_layers: return all_layers return x outputs = self.quant(self._output_norm_layer(x), name="output_ln") return outputs
def call(self, inputs, inputs_padding, is_training=True): """ Encodes the inputs. Args: inputs: The embedded input, a float tensor with shape [batch_size, max_length, embedding_dim]. inputs_padding: A float tensor with shape [batch_size, max_length], indicating the padding positions, where 1.0 for padding and 0.0 for non-padding. is_training: A bool, whether in training mode or not. Returns: The encoded output with shape [batch_size, max_length, hidden_size] """ # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding self_attention_bias = layer_utils.input_padding_to_bias(inputs_padding) x = inputs if is_training: x = tf.nn.dropout( x, rate=self.get_config()["layer_postprocess_dropout_rate"]) for idx, layer in enumerate(self._stacking_layers): self_attention_layer = layer[0] ffn_layer = layer[1] with tf.name_scope("layer_{}".format(idx)): # self attention layer x = self_attention_layer( x, # x as query bias=self_attention_bias, is_training=is_training) # ffn x = ffn_layer(x, is_training=is_training) return self._output_norm_layer(x)
def create_decoding_internal_cache(self, encoder_outputs, encoder_inputs_padding, is_inference=False, decode_padded_length=None): """ Creates internal cache for decoding. Args: encoder_outputs: The output tensor from encoder with shape [batch_size, max_input_length, hidden_size]. encoder_inputs_padding: A float tensor with shape [batch_size, max_length], indicating the padding positions of `encoder_output`, where 1.0 for padding and 0.0 for non-padding. is_inference: A boolean scalar, whether in inference mode or not. decode_padded_length: The maximum decoding length when inference, for creating static-shape cache. Returns: `cache`, a dictionary containing static(e.g. encoder hidden states for attention) and dynamic(e.g. transformer decoding cache) tensors used during decoding and will be passed to `call()`. Note that, the dynamic tensors must store in cache["decoding_states"] for beam search use. """ # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding if is_inference: params = self.get_config() decoding_states = {} batch_size = tf.shape(encoder_outputs)[0] num_heads = params["num_attention_heads"] num_units_per_head = params["hidden_size"] // num_heads # initialize decoder self attention keys/values for lid in range(params["num_layers"]): # Ensure shape invariance for tf.while_loop. decoding_states["layer_{}".format(lid)] = { "self_attention": { "keys": tf.zeros([ batch_size, decode_padded_length or 0, num_heads, num_units_per_head ], dtype=compat.CUSTOM_GLOBAL_FLOATX), "values": tf.zeros([ batch_size, decode_padded_length or 0, num_heads, num_units_per_head ], dtype=compat.CUSTOM_GLOBAL_FLOATX) }, } else: decoding_states = None cache = dict(decoding_states=decoding_states) if self._with_encoder_decoder_attention: cache["memory"] = encoder_outputs cache["memory_bias"] = layer_utils.input_padding_to_bias( encoder_inputs_padding) return cache
def create_decoding_internal_cache(self, encoder_outputs, encoder_inputs_padding, is_inference=False, decode_padded_length=None): """ Creates internal cache for decoding. Args: encoder_outputs: The output tensor from encoder with shape [batch_size, max_input_length, hidden_size]. encoder_inputs_padding: A float tensor with shape [batch_size, max_length], indicating the padding positions of `encoder_output`, where 1.0 for padding and 0.0 for non-padding. is_inference: A boolean scalar, whether in inference mode or not. decode_padded_length: The maximum decoding length when inference, for creating static-shape cache. Returns: `cache`, a dictionary containing static(e.g. encoder hidden states for attention) and dynamic(e.g. transformer decoding cache) tensors used during decoding and will be passed to `call()`. Note that, the dynamic tensors must store in cache["decoding_states"] for beam search use. """ # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding enc_dec_attention_bias = layer_utils.input_padding_to_bias( encoder_inputs_padding) if is_inference: params = self.get_config() decoding_states = {} batch_size = tf.shape(encoder_outputs)[0] # initialize decoder conv hidden states for lid in range(params["num_layers"]): # Ensure shape invariance for tf.while_loop. if decode_padded_length is None: init_len = params["conv_kernel_size_list"][lid] - 1 else: init_len = params["conv_kernel_size_list"][ lid] - 1 + decode_padded_length decoding_states["layer_{}".format(lid)] = { "light_conv": { "conv": tf.zeros( [batch_size, init_len, params["conv_hidden_size"]], dtype=compat.CUSTOM_GLOBAL_FLOATX) }, } else: decoding_states = None cache = dict(decoding_states=decoding_states, memory=encoder_outputs, memory_bias=enc_dec_attention_bias) return cache
def create_decoding_internal_cache(self, encoder_outputs, encoder_inputs_padding, is_inference=False, decode_padded_length=None): """ Creates internal cache for decoding. Args: encoder_outputs: The output tensor from encoder with shape [batch_size, max_input_length, hidden_size]. encoder_inputs_padding: A float tensor with shape [batch_size, max_length], indicating the padding positions of `encoder_output`, where 1.0 for padding and 0.0 for non-padding. is_inference: A boolean scalar, whether in inference mode or not. decode_padded_length: The maximum decoding length when inference, for creating static-shape cache. Returns: `cache`, a dictionary containing static(e.g. encoder hidden states for attention) and dynamic(e.g. transformer decoding cache) tensors used during decoding and will be passed to `call()`. Note that, the dynamic tensors must store in cache["decoding_states"] for beam search use. """ # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding if is_inference: decoding_states = {} batch_size = tf.shape(encoder_outputs)[0] # initialize decoder self attention keys/values for lid, layer in enumerate(self._stacking_layers): # Ensure shape invariance for tf.while_loop. decoding_states[ f"layer_{lid}"] = layer.create_decoding_internal_cache( decode_padded_length) decoding_states = tf.nest.map_structure( lambda ts: tile_tensor(ts, batch_size, axis=0), decoding_states) for lid, layer in enumerate(self._stacking_layers): decoding_states[f"layer_{lid}"].update( layer.memorize_memory(encoder_outputs)) else: decoding_states = None cache = dict(decoding_states=decoding_states) if encoder_inputs_padding is not None: cache["memory"] = encoder_outputs cache["memory_bias"] = layer_utils.input_padding_to_bias( encoder_inputs_padding) return cache