Exemple #1
0
 def __init__(self, embedding_dim, head_count, model_dim, drop_prob, dropout):
     super(Encoder, self).__init__()
     self.head_count = head_count
     self.model_dim = model_dim
     self.drop_prob = drop_prob
     self.dropout = dropout
     self.title_lstm = RNN.LSTM(
         self.embedding_dim, self.model_dim, True, self.drop_prob)
     self.abstract_lstm = RNN.LSTM(
         self.embedding_dim, self.model_dim, True, self.drop_prob)
     self.title_linear = nn.Dense(self.model_dim, flatten = False, in_units= 2*self.model_dim)
     self.abstract_linear = nn.Dense(self.model_dim, flatten = False, in_units= 2*self.model_dim)
     self.final_linear = nn.Dense(2*self.model_dim, flatten = False, in_units= self.model_dim)
     
    
     
     self.ta_mutal = MultiHeadAttentionCell(base_cell=base_cell, 
                                            query_units= 2*self.model_dim, use_bias=True,
                                       key_units = 2*self.model_dim, value_units= 2*self.model_dim, num_heads=self.head_count, weight_initializer= 'Xavier')
     self.at_mutal = MultiHeadAttentionCell(base_cell=base_cell, 
                                            query_units= 2*self.model_dim, use_bias=True,
                                       key_units = 2*self.model_dim, value_units= 2*self.model_dim, num_heads=self.head_count, weight_initializer= 'Xavier')
     self.self_attn = MultiHeadAttentionCell(base_cell=base_cell, 
                                            query_units= 2*self.model_dim, use_bias=True,
                                       key_units = 2*self.model_dim, value_units= 2*self.model_dim, num_heads=self.head_count, weight_initializer= 'Xavier')
     self.ffn1 = Resblock(2*self.model_dim, self.dropout)
     self.ffn2 = Resblock(2*self.model_dim, self.dropout)
     self.W_G = nn.Dense(1, flatten = False, in_units= 4*self.model_dim)
     self.ffn3 = Resblock(2*self.model_dim)
Exemple #2
0
def _get_attention_cell(attention_cell, units=None):
    """

    Parameters
    ----------
    attention_cell : AttentionCell or str
    units : int or None

    Returns
    -------
    attention_cell : AttentionCell
    """
    # import pdb; pdb.set_trace()
    if isinstance(attention_cell, str):
        if attention_cell == 'scaled_luong':
            return DotProductAttentionCell(units=units,
                                           scaled=True,
                                           normalized=False,
                                           luong_style=True)
        elif attention_cell == 'scaled_dot':
            return DotProductAttentionCell(units=None,
                                           scaled=True,
                                           normalized=False,
                                           luong_style=False)
        elif attention_cell == 'dot':
            return DotProductAttentionCell(units=None,
                                           scaled=False,
                                           normalized=False,
                                           luong_style=False)
        elif attention_cell == 'cosine':
            return DotProductAttentionCell(units=units,
                                           scaled=False,
                                           normalized=True)
        elif attention_cell == 'mlp':
            return MLPAttentionCell(units=units, normalized=False)
        elif attention_cell == 'normed_mlp':
            return MLPAttentionCell(units=units, normalized=True)
        elif attention_cell == 'MultiHeadAttentionCell':
            attention_cell = MLPAttentionCell(units=units, normalized=False)
            # return MultiHeadAttentionCell(base_cell=attention_cell, query_units=units, key_units=units, value_units=units, num_heads=4)
            return MultiHeadAttentionCell(base_cell=attention_cell,
                                          query_units=units,
                                          key_units=units,
                                          value_units=units,
                                          num_heads=4)

        else:
            raise NotImplementedError
    else:
        assert isinstance(attention_cell, AttentionCell),\
            'attention_cell must be either string or AttentionCell. Received attention_cell={}'\
                .format(attention_cell)
        return attention_cell
Exemple #3
0
 def __init__(self, num_heads, **kwargs):
     super(SelfAttention, self).__init__(**kwargs)
     with self.name_scope():
         self.attention = MultiHeadAttentionCell(
             num_heads=num_heads,
             base_cell=DotProductAttentionCell(scaled=True,
                                               dropout=opt.layers_dropout,
                                               use_bias=False),
             query_units=opt.emb_encoder_conv_channels,
             key_units=opt.emb_encoder_conv_channels,
             value_units=opt.emb_encoder_conv_channels,
             use_bias=False,
             weight_initializer=Xavier())
Exemple #4
0
 def __init__(self, num_heads, **kwargs):
     super(SelfAttention, self).__init__(**kwargs)
     with self.name_scope():
         self.attention = MultiHeadAttentionCell(
             num_heads=num_heads,
             base_cell=DotProductAttentionCell(
                 scaled=True,
                 dropout=0.1,
                 use_bias=False
             ),
             query_units=EMB_ENCODER_CONV_CHANNELS,
             key_units=EMB_ENCODER_CONV_CHANNELS,
             value_units=EMB_ENCODER_CONV_CHANNELS,
             use_bias=False,
             weight_initializer=Xavier()
         )
def _get_attention_cell(attention_cell, units=None,
                        scaled=True, num_heads=None,
                        use_bias=False, dropout=0.0):
    """

    Parameters
    ----------
    attention_cell : AttentionCell or str
    units : int or None

    Returns
    -------
    attention_cell : AttentionCell
    """
    if isinstance(attention_cell, str):
        if attention_cell == 'scaled_luong':
            return DotProductAttentionCell(units=units, scaled=True, normalized=False,
                                           use_bias=use_bias, dropout=dropout, luong_style=True)
        elif attention_cell == 'scaled_dot':
            return DotProductAttentionCell(units=units, scaled=True, normalized=False,
                                           use_bias=use_bias, dropout=dropout, luong_style=False)
        elif attention_cell == 'dot':
            return DotProductAttentionCell(units=units, scaled=False, normalized=False,
                                           use_bias=use_bias, dropout=dropout, luong_style=False)
        elif attention_cell == 'cosine':
            return DotProductAttentionCell(units=units, scaled=False, use_bias=use_bias,
                                           dropout=dropout, normalized=True)
        elif attention_cell == 'mlp':
            return MLPAttentionCell(units=units, normalized=False)
        elif attention_cell == 'normed_mlp':
            return MLPAttentionCell(units=units, normalized=True)
        elif attention_cell == 'multi_head':
            base_cell = DotProductAttentionCell(scaled=scaled, dropout=dropout)
            return MultiHeadAttentionCell(base_cell=base_cell, query_units=units, use_bias=use_bias,
                                          key_units=units, value_units=units, num_heads=num_heads)
        else:
            raise NotImplementedError
    else:
        assert isinstance(attention_cell, AttentionCell),\
            'attention_cell must be either string or AttentionCell. Received attention_cell={}'\
                .format(attention_cell)
        return attention_cell
Exemple #6
0
 def __init__(self, embedding_dim, model_dim, dropout, head_count, vocab_size, extended_size,gpu):
     super(Decoder,self).__init__()
     self.ctx = gpu
     self.model_dim = model_dim
     self.dropout = dropout
     self.head_count = head_count
     self.vocab_size = vocab_size
     self.extended.size = extended_size
     self.decoder_ltsm = rnn.LSTM(
         2*self.model_dim, layout='NTC', 
         input_size= self.embedding_dim, 
         i2h_weight_initializer= 'Orthogonal',
     h2h_weight_initializer = 'Orthogonal')
     self.self_attn = MultiHeadAttentionCell(base_cell=base_cell, 
                                            query_units= 2*self.model_dim, use_bias=True,
                                       key_units = 2*self.model_dim, value_units= 2*self.model_dim, num_heads=self.head_count, weight_initializer= 'Xavier')
     self.fnn = Resblock(2*self.model_dim, self.dropout)
     self.V1 = nn.Dense(2*self.model_dim, in_units= 3*self.model_dim)
     self.V2 = nn.Dense(self.vocab_size, in_units= 2*self.model_dim)
     self.W_c = nn.Dense(1)
     self.W_s = nn.Dense(1)
     self.W_x = nn.Dense(1)