コード例 #1
0
 def __init__(self, config):
     super(BertDELayer, self).__init__()
     self.attention = BertDEAttention(config)
     
     half_config = config
     half_config.hidden_size = int(config.hidden_size/2)
     half_config.intermediate_size = int(config.intermediate_size/2)
     self.intermediate1 = BertIntermediate(half_config)
     self.intermediate2 = BertIntermediate(half_config)
     
     self.output1 = BertOutput(half_config)
     self.output2 = BertOutput(half_config)
コード例 #2
0
    def __init__(self, config):
        super().__init__()

        # Lang self-att and FFN layer
        self.lang_self_att = BertAttention(config)
        self.lang_inter = BertIntermediate(config)
        self.lang_output = BertOutput(config)

        # Visn self-att and FFN layer
        self.visn_self_att = BertAttention(config)
        self.visn_inter = BertIntermediate(config)
        self.visn_output = BertOutput(config)

        # The cross attention layer
        self.visual_attention = BertXAttention(config)
コード例 #3
0
ファイル: span_attention_layer.py プロジェクト: zxlzr/kb
 def __init__(self, config):
     super(SpanAttentionLayer, self).__init__()
     self.attention = SpanAttention(config)
     self.intermediate = BertIntermediate(config)
     self.output = BertOutput(config)
     init_bert_weights(self.intermediate, config.initializer_range)
     init_bert_weights(self.output, config.initializer_range)
コード例 #4
0
    def __init__(self, config):

        super(BertLayer, self).__init__()

        self.attention = BertAttention(config)

        self.intermediate = BertIntermediate(config)

        self.output = BertOutput(config)
コード例 #5
0
def test_BertIntermediate():
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000,
                        hidden_size=768,
                        num_hidden_layers=12,
                        num_attention_heads=12,
                        intermediate_size=3072)
    embeddings = BertEmbeddings(config)
    model = BertIntermediate(config)

    embedding_output = embeddings(input_ids, token_type_ids)
    print(model(embedding_output))
コード例 #6
0
    def __init__(self, config: VerticalAttentionTableBertConfig):
        nn.Module.__init__(self)

        self.attention = BertVerticalAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
コード例 #7
0
 def __init__(self, config, layer_id):
     super(BertLayer, self).__init__()
     self.input_size = config.hidden_size
     self.attention = BertAttention(config, self.input_size)
     self.intermediate = BertIntermediate(config)
     self.output = BertOutput(config)