def FeedForward( inputs: str, outputs: str, units_inner: int, units_readout: int, dim: int, dropout_rate: float, ): """FeedForward Layer.""" if inputs == "_x": raise ValueError( "Cannot use name '_x' for inputs (used as intermediary node).") return Sequential( Select(inputs=inputs, outputs="_x"), Dropout(inputs="_x", outputs="_x", dropout_rate=dropout_rate), Conv1d(inputs="_x", outputs="_x", filters=units_inner, kernel_size=1, activation=tf.nn.relu, use_bias=True), Dropout(inputs="_x", outputs="_x", dropout_rate=dropout_rate), Conv1d(inputs="_x", outputs="_x", filters=units_readout, kernel_size=1, activation=None, use_bias=True), Dropout(inputs="_x", outputs="_x", dropout_rate=dropout_rate), Dense(inputs="_x", outputs="_x", units=dim), Add(inputs=(inputs, "_x"), outputs=outputs), )
def Transformer( dim: int, num_heads: int = 4, encoding_blocks: int = 2, dim_head: int = 128, residual_connection: bool = True, use_layer_normalization: bool = True, event_dropout_rate: float = 0.0, use_feedforward: bool = True, ff_dropout_rate: float = 0.0, ff_normalization: bool = False, scale: bool = False, use_positional_encoding: bool = True, trainable_positional_encoding: bool = True, use_look_ahead_mask: bool = True, inputs: Tuple[str, str] = ("inputEmbeddings", "inputMask"), outputs: str = "userEmbeddings", ) -> base.Layer: """Transformer Model.""" return Sequential( Select(n_in=2, inputs=inputs, outputs=("inputEmbeddings", "inputMask")), SpatialDropout1D(inputs="inputEmbeddings", outputs="inputEmbeddingsDropout", dropout_rate=event_dropout_rate), AttentionMask(inputs="inputMask", outputs="mask", use_look_ahead_mask=use_look_ahead_mask), (Scale(inputs="inputEmbeddingsDropout", outputs="inputEnc", multiplier=(num_heads * dim_head)**0.5) if scale else Select( inputs="inputEmbeddingsDropout", outputs="inputEnc")), (PositionalEncoding( inputs="inputEnc", outputs="inputEnc", trainable=trainable_positional_encoding, ) if use_positional_encoding else []), [ Scope( Sequential( SelfMultiheadAttention( inputs=("inputEnc", "mask"), outputs="inputEnc", dim_head=dim_head, num_heads=num_heads, residual_connection=residual_connection, ), (Scope( Normalization(inputs="inputEnc", outputs="inputEnc"), "attention_norm") if use_layer_normalization and not (not use_feedforward and block_id == encoding_blocks - 1) else []), (FeedForward( inputs="inputEnc", outputs="inputEnc", units_inner=(num_heads * dim_head), units_readout=(num_heads * dim_head), dim=dim, dropout_rate=ff_dropout_rate, ) if use_feedforward else []), (Scope( Normalization(inputs="inputEnc", outputs="inputEnc"), "ff_norm") if use_feedforward and ff_normalization and block_id != encoding_blocks - 1 else []), ), f"block_{block_id}", ) for block_id in range(encoding_blocks) ], SliceLastPadded(inputs=("inputEnc", "inputMask"), outputs=outputs), )