def __init__(self,
                 embedding_size,
                 hidden_size,
                 num_layers,
                 num_heads,
                 total_key_depth,
                 total_value_depth,
                 filter_size,
                 max_length=100,
                 input_dropout=0.0,
                 layer_dropout=0.0,
                 attention_dropout=0.0,
                 relu_dropout=0.0,
                 use_mask=False,
                 act=False):
        """
        Parameters:
            embedding_size: Size of embeddings
            hidden_size: Hidden size
            num_layers: Total layers in the Encoder
            num_heads: Number of attention heads
            total_key_depth: Size of last dimension of keys. Must be divisible by num_head
            total_value_depth: Size of last dimension of values. Must be divisible by num_head
            output_depth: Size last dimension of the final output
            filter_size: Hidden size of the middle layer in FFN
            max_length: Max sequence length (required for timing signal)
            input_dropout: Dropout just after embedding
            layer_dropout: Dropout for each layer
            attention_dropout: Dropout probability after attention (Should be non-zero only during training)
            relu_dropout: Dropout probability after relu in FFN (Should be non-zero only during training)
            use_mask: Set to True to turn on future value masking
        """

        super(Encoder, self).__init__()

        self.timing_signal = _gen_timing_signal(max_length, hidden_size)
        ## for t
        self.position_signal = _gen_timing_signal(num_layers, hidden_size)

        self.num_layers = num_layers
        self.act = act
        params = (hidden_size, total_key_depth
                  or hidden_size, total_value_depth
                  or hidden_size, filter_size, num_heads,
                  _gen_bias_mask(max_length) if use_mask else None,
                  layer_dropout, attention_dropout, relu_dropout)

        self.proj_flag = False
        if (embedding_size == hidden_size):
            self.embedding_proj = nn.Linear(embedding_size,
                                            hidden_size,
                                            bias=False)
            self.proj_flag = True

        self.enc = EncoderLayer(*params)

        self.layer_norm = LayerNorm(hidden_size)
        self.input_dropout = nn.Dropout(input_dropout)
        if (self.act):
            self.act_fn = ACT_basic(hidden_size)
Esempio n. 2
0
    def __init__(self,
                 embedding_size,
                 hidden_size,
                 num_layers,
                 num_heads,
                 total_key_depth,
                 total_value_depth,
                 filter_size,
                 max_length=1000,
                 input_dropout=0.0,
                 layer_dropout=0.0,
                 attention_dropout=0.0,
                 relu_dropout=0.0,
                 use_mask=False,
                 universal=False):
        super(Encoder, self).__init__()
        self.universal = universal
        self.num_layers = num_layers
        self.timing_signal = _gen_timing_signal(max_length, hidden_size)

        if self.universal:
            self.position_signal = _gen_timing_signal(num_layers, hidden_size)

        params = (hidden_size, total_key_depth
                  or hidden_size, total_value_depth
                  or hidden_size, filter_size, num_heads,
                  _gen_bias_mask(max_length) if use_mask else None,
                  layer_dropout, attention_dropout, relu_dropout)

        self.embedding_proj = nn.Linear(embedding_size,
                                        hidden_size,
                                        bias=False)

        if self.universal:
            self.enc = EncoderLayer(*params)
        else:
            self.enc = nn.ModuleList(
                [EncoderLayer(*params) for _ in range(num_layers)])

        self.layer_norm = LayerNorm(hidden_size)
        self.input_dropout = nn.Dropout(input_dropout)

        self.act_fn = ACT_basic(hidden_size)
        self.remainders = None
        self.n_updates = None