Exemplo n.º 1
0
    def __init__(self, config):
        super(VariableNormTransformerLayer, self).__init__()
        self.config = config

        if self.config.norm_type == 'layer':
            self.attention_norm = nn.LayerNorm(config.hidden_size,
                                               eps=config.layer_norm_eps)
        elif self.config.norm_type == 'adanorm':
            self.attention_norm = AdaNorm(0.3, config.layer_norm_eps)
        elif self.config.norm_type == 'scalenorm':
            self.attention_norm = ScaleNorm(config.hidden_size**0.5)

        self.self_attention = BertSelfAttention(config)
        self.self_out = nn.Linear(config.hidden_size, config.hidden_size)
        self.self_dropout = nn.Dropout(config.hidden_dropout_prob)

        if self.config.norm_type == 'layer':
            self.ff_norm = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
        elif self.config.norm_type == 'adanorm':
            self.ff_norm = AdaNorm(0.3, config.layer_norm_eps)
        elif self.config.norm_type == 'scalenorm':
            self.ff_norm = ScaleNorm(config.hidden_size**0.5)

        self.ff1 = BertIntermediate(config)
        self.ff2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.ff_dropout = nn.Dropout(config.hidden_dropout_prob)
Exemplo n.º 2
0
 def __init__(self, config):
     check_vit_in_transformers()
     super().__init__(config)
     # We need to support attention masks for vision language input
     # ViTAttention from transformers doesn't currently support attention masks,
     # for versions without attention_mask support we use these clones of ViT modules
     # that use BertSelfAttention to enable masking.
     self.attention = BertSelfAttention(config)
Exemplo n.º 3
0
    def __init__(self, config):
        super(CBOW, self).__init__()

        self.embeddings = BertEmbeddings(config)
        self.attention = BertSelfAttention(config)
        self.act_fn = nn.ReLU()
        self.linear_1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.linear_2 = nn.Linear(config.hidden_size, config.hidden_size)
Exemplo n.º 4
0
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.embeddings = MyEmbeddings(config)

        intermediate_size = config.hidden_size * 4

        self.causal = False
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
        self.layer_norms_1, self.layer_norms_2 = nn.ModuleList(
        ), nn.ModuleList()

        for _ in range(config.num_layers):
            self.attentions.append(BertSelfAttention(config))
            self.feed_forwards.append(
                nn.Sequential(nn.Linear(config.hidden_size, intermediate_size),
                              nn.ReLU(),
                              nn.Linear(intermediate_size,
                                        config.hidden_size)))
            self.layer_norms_1.append(
                nn.LayerNorm(config.hidden_size, eps=1e-12))
            self.layer_norms_2.append(
                nn.LayerNorm(config.hidden_size, eps=1e-12))
Exemplo n.º 5
0
 def __init__(self, dim):
     super(SelfAttention, self).__init__()
     cfg = BertConfig(hidden_size=dim, num_hidden_layers=1)
     self.atten = BertSelfAttention(cfg)
Exemplo n.º 6
0
 def __init__(self, config):
     super().__init__()
     self.att = BertSelfAttention(config)
     self.output = BertSelfOutput(config)
Exemplo n.º 7
0
 def __init__(self, config, opt):
     super(SelfAttention, self).__init__()
     self.opt = opt
     self.config = config
     self.SA = BertSelfAttention(config)
     self.tanh = torch.nn.Tanh()
Exemplo n.º 8
0
class VariableNormTransformerLayer(nn.Module):
    def __init__(self, config):
        super(VariableNormTransformerLayer, self).__init__()
        self.config = config

        if self.config.norm_type == 'layer':
            self.attention_norm = nn.LayerNorm(config.hidden_size,
                                               eps=config.layer_norm_eps)
        elif self.config.norm_type == 'adanorm':
            self.attention_norm = AdaNorm(0.3, config.layer_norm_eps)
        elif self.config.norm_type == 'scalenorm':
            self.attention_norm = ScaleNorm(config.hidden_size**0.5)

        self.self_attention = BertSelfAttention(config)
        self.self_out = nn.Linear(config.hidden_size, config.hidden_size)
        self.self_dropout = nn.Dropout(config.hidden_dropout_prob)

        if self.config.norm_type == 'layer':
            self.ff_norm = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
        elif self.config.norm_type == 'adanorm':
            self.ff_norm = AdaNorm(0.3, config.layer_norm_eps)
        elif self.config.norm_type == 'scalenorm':
            self.ff_norm = ScaleNorm(config.hidden_size**0.5)

        self.ff1 = BertIntermediate(config)
        self.ff2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.ff_dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, attention_mask=None, *args, **kwargs):
        residual = hidden_states
        if self.config.prenorm:
            hidden_states = self.attention_norm(hidden_states)
        # Self-attention sublayers
        if not attention_mask is None:
            if attention_mask.ndim == 2:
                attention_mask = attention_mask[:, None, None, :]
        hidden_states, attentions = self.self_attention(
            hidden_states,
            attention_mask=attention_mask,
            output_attentions=self.config.output_attentions)
        hidden_states = self.self_out(hidden_states)
        hidden_states = self.self_dropout(hidden_states) + residual
        if not self.config.prenorm:
            hidden_states = self.attention_norm(hidden_states)

        residual = hidden_states
        if self.config.prenorm:
            hidden_states = self.ff_norm(hidden_states)
        # FF sublayer
        hidden_states = self.ff1(hidden_states)
        hidden_state = torch.nn.functional.gelu(hidden_states)
        hidden_states = self.ff2(hidden_states)
        hidden_states = self.ff_dropout(hidden_states) + residual
        if not self.config.prenorm:
            hidden_states = self.ff_norm(hidden_states)

        return hidden_states, attentions

    def load_from_bert(self, bert_layer):
        self.self_attention.load_state_dict(
            bert_layer.attention.self.state_dict())
        self.self_out.load_state_dict(
            bert_layer.attention.output.dense.state_dict())
        self.ff1.load_state_dict(bert_layer.intermediate.state_dict())
        self.ff2.load_state_dict(bert_layer.output.dense.state_dict())
        if self.config.norm_type == "layer":
            self.attention_norm.load_state_dict(
                bert_layer.attention.output.LayerNorm.state_dict())
            self.ff_norm.load_state_dict(
                bert_layer.output.LayerNorm.state_dict())