Exemplo n.º 1
0
 def __init__(self,
              name: str,
              num_heads: int,
              residual_dropout: float = 0,
              attention_dropout: float = 0,
              activation: Optional[Union[str, Callable]] = 'gelu',
              compression_window_size: int = None,
              use_masking: bool = True,
              vanilla_wiring=False,
              **kwargs):
     self.attention_layer = MultiHeadSelfAttention(
         num_heads,
         use_masking=use_masking,
         dropout=attention_dropout,
         compression_window_size=compression_window_size,
         name=f'{name}_self_attention')
     self.norm1_layer = LayerNormalization(name=f'{name}_normalization1')
     self.dropout_layer = (Dropout(residual_dropout, name=f'{name}_dropout')
                           if residual_dropout > 0 else lambda x: x)
     self.norm2_layer = LayerNormalization(name=f'{name}_normalization2')
     self.transition_layer = TransformerTransition(
         name=f'{name}_transition', activation=activation)
     self.addition_layer = Add(name=f'{name}_add')
     self.vanilla_wiring = vanilla_wiring
     super().__init__(**kwargs)
Exemplo n.º 2
0
 def __init__(self, name: str, d_model: int, num_heads: int,
              transition_type = 'dot',
              residual_dropout: float = 0, attention_dropout: float = 0,
              activation: Optional[Union[str, Callable]] = 'gelu',
              compression_window_size: int = None, size_multiplier : int = 4,
              use_masking: bool = True, local_masking: int = None,
              vanilla_wiring=False):
     self.size_multiplier = size_multiplier
     self.name = name
     self.activation = activation
     self.attention_layer = MultiHeadSelfAttention(
         d_model, num_heads, use_masking=use_masking, dropout=attention_dropout,
         compression_window_size=compression_window_size, local_masking=local_masking,
         name=f'{name}_self_attention')
     self.norm1_layer = LayerNormalization(name=f'{name}_normalization1')
     self.dropout_layer = (
         Dropout(residual_dropout, name=f'{name}_dropout')
         if residual_dropout > 0
         else lambda x: x)
     self.norm2_layer = LayerNormalization(name=f'{name}_normalization2')
     if transition_type == 'dot':
         self.transition_type = 'dot'
         self.transition_layer = TransformerTransition(
             name=f'{name}_transition', activation=activation, size_multiplier=size_multiplier)
     elif transition_type == 'cnn':
         self.transition_type = 'cnn'
         self.transition_layer = None
     else:
         raise NotImplementedError("Transformer transition {} is not implemented.".format(transition_type))
     self.addition_layer = Add(name=f'{name}_add')
     self.vanilla_wiring = vanilla_wiring
Exemplo n.º 3
0
 def __init__(self,
              name: str,
              num_heads: int,
              residual_dropout: float = 0,
              attention_dropout: float = 0,
              activation: Optional[Union[str, Callable]] = 'gelu',
              compression_window_size: int = None,
              use_masking: bool = True,
              vanilla_wiring=False,
              agglomerative_attention: bool = False,
              dropout_cls: Type[Layer] = Dropout):
     if agglomerative_attention:
         assert compression_window_size is None, 'compression not supported for agglomerative attention'
         self.attention_layer = MultiHeadAgglomerativeSelfAttention(
             num_heads,
             use_masking=use_masking,
             dropout=attention_dropout,
             name=f'{name}_self_attention')
     else:
         self.attention_layer = MultiHeadSelfAttention(
             num_heads,
             use_masking=use_masking,
             dropout=attention_dropout,
             compression_window_size=compression_window_size,
             name=f'{name}_self_attention')
     self.norm1_layer = LayerNormalization(name=f'{name}_normalization1')
     self.dropout_layer = (dropout_cls(residual_dropout,
                                       name=f'{name}_dropout')
                           if residual_dropout > 0 else lambda x: x)
     self.norm2_layer = LayerNormalization(name=f'{name}_normalization2')
     self.transition_layer = TransformerTransition(
         name=f'{name}_transition', activation=activation)
     self.addition_layer = Add(name=f'{name}_add')
     self.vanilla_wiring = vanilla_wiring
Exemplo n.º 4
0
 def __init__(self, num_heads: int,
              residual_dropout: float = 0, attention_dropout: float = 0,
              activation: Optional[Union[str, Callable]] = 'gelu',
              compression_window_size: int = None,
              use_masking: bool = True,
              vanilla_wiring=False, name='TransformerBlock'):
     super().__init__()
     self.attention_layer = MultiHeadSelfAttention(
         num_heads, use_masking=use_masking, dropout=attention_dropout,
         compression_window_size=compression_window_size,
         name='self_attention')
     self.norm1_layer = LayerNormalization(name='normalization1')
     # Use this instead of lambda to avoid autograph issues
     def identity(x):
         return x
     self.dropout_layer = (
         Dropout(residual_dropout, name='dropout')
         if residual_dropout > 0
         else identity)
     self.norm2_layer = LayerNormalization(name='normalization2')
     self.transition_layer = TransformerTransition(
         name='transition', activation=activation)
     self.addition_layer = Add(name='add')
     self.vanilla_wiring = vanilla_wiring
Exemplo n.º 5
0
class TransformerBlock(Layer):
    """
    A pseudo-layer combining together all nuts and bolts to assemble
    a complete section of both the Transformer and the Universal Transformer
    models, following description from the "Universal Transformers" paper.
    Each such block is, essentially:

    - Multi-head self-attention (masked or unmasked, with attention dropout,
      but without input dropout)
    - Residual connection,
    - Dropout
    - Layer normalization
    - Transition function
    - Residual connection
    - Dropout
    - Layer normalization

    Also check TransformerACT class if you need support for ACT (Adaptive
    Computation Time).

    IMPORTANT: The older Transformer 2017 model ("Attention is all you need")
    uses slightly different order of operations. A quote from the paper:

        "We apply dropout [33] to the output of each sub-layer,
         before it is added to the sub-layer input and normalized"

    while the Universal Transformer paper puts dropout one step *after*
    the sub-layers's output was added to its input (Figure 4 in the paper).

    In this code the order from the Universal Transformer is used, as arguably
    more reasonable. You can use classical Transformer's (2017) way of
    connecting the pieces by passing vanilla_wiring=True to the constructor.
    """
    def __init__(self,
                 name: str,
                 num_heads: int,
                 residual_dropout: float = 0,
                 attention_dropout: float = 0,
                 activation: Optional[Union[str, Callable]] = 'gelu',
                 compression_window_size: int = None,
                 use_masking: bool = True,
                 vanilla_wiring=False,
                 **kwargs):
        self.attention_layer = MultiHeadSelfAttention(
            num_heads,
            use_masking=use_masking,
            dropout=attention_dropout,
            compression_window_size=compression_window_size,
            name=f'{name}_self_attention')
        self.norm1_layer = LayerNormalization(name=f'{name}_normalization1')
        self.dropout_layer = (Dropout(residual_dropout, name=f'{name}_dropout')
                              if residual_dropout > 0 else lambda x: x)
        self.norm2_layer = LayerNormalization(name=f'{name}_normalization2')
        self.transition_layer = TransformerTransition(
            name=f'{name}_transition', activation=activation)
        self.addition_layer = Add(name=f'{name}_add')
        self.vanilla_wiring = vanilla_wiring
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.attention_layer.build(input_shape)
        self.norm1_layer.build(input_shape)
        self.norm2_layer.build(input_shape)
        self.transition_layer.build(input_shape)

    def call(self, _input):
        output = self.attention_layer(_input)
        post_residual1 = (self.addition_layer([
            _input, self.dropout_layer(output)
        ]) if self.vanilla_wiring else self.dropout_layer(
            self.addition_layer([_input, output])))
        norm1_output = self.norm1_layer(post_residual1)
        output = self.transition_layer(norm1_output)
        post_residual2 = (self.addition_layer([
            norm1_output, self.dropout_layer(output)
        ]) if self.vanilla_wiring else self.dropout_layer(
            self.addition_layer([norm1_output, output])))
        output = self.norm2_layer(post_residual2)
        return output