def __init__(self, name: str, num_heads: int, residual_dropout: float = 0, attention_dropout: float = 0, activation: Optional[Union[str, Callable]] = 'gelu', compression_window_size: int = None, use_masking: bool = True, vanilla_wiring=False, **kwargs): self.attention_layer = MultiHeadSelfAttention( num_heads, use_masking=use_masking, dropout=attention_dropout, compression_window_size=compression_window_size, name=f'{name}_self_attention') self.norm1_layer = LayerNormalization(name=f'{name}_normalization1') self.dropout_layer = (Dropout(residual_dropout, name=f'{name}_dropout') if residual_dropout > 0 else lambda x: x) self.norm2_layer = LayerNormalization(name=f'{name}_normalization2') self.transition_layer = TransformerTransition( name=f'{name}_transition', activation=activation) self.addition_layer = Add(name=f'{name}_add') self.vanilla_wiring = vanilla_wiring super().__init__(**kwargs)
def __init__(self, name: str, d_model: int, num_heads: int, transition_type = 'dot', residual_dropout: float = 0, attention_dropout: float = 0, activation: Optional[Union[str, Callable]] = 'gelu', compression_window_size: int = None, size_multiplier : int = 4, use_masking: bool = True, local_masking: int = None, vanilla_wiring=False): self.size_multiplier = size_multiplier self.name = name self.activation = activation self.attention_layer = MultiHeadSelfAttention( d_model, num_heads, use_masking=use_masking, dropout=attention_dropout, compression_window_size=compression_window_size, local_masking=local_masking, name=f'{name}_self_attention') self.norm1_layer = LayerNormalization(name=f'{name}_normalization1') self.dropout_layer = ( Dropout(residual_dropout, name=f'{name}_dropout') if residual_dropout > 0 else lambda x: x) self.norm2_layer = LayerNormalization(name=f'{name}_normalization2') if transition_type == 'dot': self.transition_type = 'dot' self.transition_layer = TransformerTransition( name=f'{name}_transition', activation=activation, size_multiplier=size_multiplier) elif transition_type == 'cnn': self.transition_type = 'cnn' self.transition_layer = None else: raise NotImplementedError("Transformer transition {} is not implemented.".format(transition_type)) self.addition_layer = Add(name=f'{name}_add') self.vanilla_wiring = vanilla_wiring
def __init__(self, name: str, num_heads: int, residual_dropout: float = 0, attention_dropout: float = 0, activation: Optional[Union[str, Callable]] = 'gelu', compression_window_size: int = None, use_masking: bool = True, vanilla_wiring=False, agglomerative_attention: bool = False, dropout_cls: Type[Layer] = Dropout): if agglomerative_attention: assert compression_window_size is None, 'compression not supported for agglomerative attention' self.attention_layer = MultiHeadAgglomerativeSelfAttention( num_heads, use_masking=use_masking, dropout=attention_dropout, name=f'{name}_self_attention') else: self.attention_layer = MultiHeadSelfAttention( num_heads, use_masking=use_masking, dropout=attention_dropout, compression_window_size=compression_window_size, name=f'{name}_self_attention') self.norm1_layer = LayerNormalization(name=f'{name}_normalization1') self.dropout_layer = (dropout_cls(residual_dropout, name=f'{name}_dropout') if residual_dropout > 0 else lambda x: x) self.norm2_layer = LayerNormalization(name=f'{name}_normalization2') self.transition_layer = TransformerTransition( name=f'{name}_transition', activation=activation) self.addition_layer = Add(name=f'{name}_add') self.vanilla_wiring = vanilla_wiring
def __init__(self, num_heads: int, residual_dropout: float = 0, attention_dropout: float = 0, activation: Optional[Union[str, Callable]] = 'gelu', compression_window_size: int = None, use_masking: bool = True, vanilla_wiring=False, name='TransformerBlock'): super().__init__() self.attention_layer = MultiHeadSelfAttention( num_heads, use_masking=use_masking, dropout=attention_dropout, compression_window_size=compression_window_size, name='self_attention') self.norm1_layer = LayerNormalization(name='normalization1') # Use this instead of lambda to avoid autograph issues def identity(x): return x self.dropout_layer = ( Dropout(residual_dropout, name='dropout') if residual_dropout > 0 else identity) self.norm2_layer = LayerNormalization(name='normalization2') self.transition_layer = TransformerTransition( name='transition', activation=activation) self.addition_layer = Add(name='add') self.vanilla_wiring = vanilla_wiring
class TransformerBlock(Layer): """ A pseudo-layer combining together all nuts and bolts to assemble a complete section of both the Transformer and the Universal Transformer models, following description from the "Universal Transformers" paper. Each such block is, essentially: - Multi-head self-attention (masked or unmasked, with attention dropout, but without input dropout) - Residual connection, - Dropout - Layer normalization - Transition function - Residual connection - Dropout - Layer normalization Also check TransformerACT class if you need support for ACT (Adaptive Computation Time). IMPORTANT: The older Transformer 2017 model ("Attention is all you need") uses slightly different order of operations. A quote from the paper: "We apply dropout [33] to the output of each sub-layer, before it is added to the sub-layer input and normalized" while the Universal Transformer paper puts dropout one step *after* the sub-layers's output was added to its input (Figure 4 in the paper). In this code the order from the Universal Transformer is used, as arguably more reasonable. You can use classical Transformer's (2017) way of connecting the pieces by passing vanilla_wiring=True to the constructor. """ def __init__(self, name: str, num_heads: int, residual_dropout: float = 0, attention_dropout: float = 0, activation: Optional[Union[str, Callable]] = 'gelu', compression_window_size: int = None, use_masking: bool = True, vanilla_wiring=False, **kwargs): self.attention_layer = MultiHeadSelfAttention( num_heads, use_masking=use_masking, dropout=attention_dropout, compression_window_size=compression_window_size, name=f'{name}_self_attention') self.norm1_layer = LayerNormalization(name=f'{name}_normalization1') self.dropout_layer = (Dropout(residual_dropout, name=f'{name}_dropout') if residual_dropout > 0 else lambda x: x) self.norm2_layer = LayerNormalization(name=f'{name}_normalization2') self.transition_layer = TransformerTransition( name=f'{name}_transition', activation=activation) self.addition_layer = Add(name=f'{name}_add') self.vanilla_wiring = vanilla_wiring super().__init__(**kwargs) def build(self, input_shape): self.attention_layer.build(input_shape) self.norm1_layer.build(input_shape) self.norm2_layer.build(input_shape) self.transition_layer.build(input_shape) def call(self, _input): output = self.attention_layer(_input) post_residual1 = (self.addition_layer([ _input, self.dropout_layer(output) ]) if self.vanilla_wiring else self.dropout_layer( self.addition_layer([_input, output]))) norm1_output = self.norm1_layer(post_residual1) output = self.transition_layer(norm1_output) post_residual2 = (self.addition_layer([ norm1_output, self.dropout_layer(output) ]) if self.vanilla_wiring else self.dropout_layer( self.addition_layer([norm1_output, output]))) output = self.norm2_layer(post_residual2) return output