Exemplo n.º 1
0
class HuggingFaceEncoder(torch.nn.Module):
    def __init__(self, layer_num, head_num, head_size, weights=None):
        super().__init__()
        hidden_dim = head_num * head_size
        conf = BertConfig(hidden_size=hidden_dim,
                          intermediate_size=4 * hidden_dim,
                          num_attention_heads=head_num,
                          num_hidden_layers=layer_num)
        self.encoder = BertEncoder(conf)
        w = {}
        for k, v in weights.weights.items():
            if k.startswith('bert.encoder') and not k.endswith('_amax'):
                w[k[13:]] = weights.weights[k]
        self.encoder.load_state_dict(w)
        self.head_mask = [None] * layer_num

    def forward(self, hidden_states, attention_mask):
        extended_attention_mask = (1.0 - attention_mask) * -10000.0
        output = self.encoder(hidden_states, extended_attention_mask,
                              self.head_mask)
        return output
Exemplo n.º 2
0
class HuggingFaceEncoder(torch.nn.Module):
    def __init__(self, layer_num, head_num, head_size, weights=None):
        super().__init__()
        hidden_dim = head_num * head_size
        conf = BertConfig(hidden_size=hidden_dim, intermediate_size=4*hidden_dim, num_attention_heads=head_num, num_hidden_layers=layer_num)
        self.encoder = BertEncoder(conf)
        if isinstance(weights, dict):
            w = {}
            for k, v in weights.items():
                if k.startswith('bert.encoder'):
                    w[k[13:]] = weights[k]
            self.encoder.load_state_dict(w)
        else:
            for i in range(layer_num):
                self.encoder.layer[i].attention.self.query.weight.data = weights.w[i][0].transpose(-1, -2).contiguous()
                self.encoder.layer[i].attention.self.query.bias.data = weights.w[i][1]
                self.encoder.layer[i].attention.self.key.weight.data = weights.w[i][2].transpose(-1, -2).contiguous()
                self.encoder.layer[i].attention.self.key.bias.data = weights.w[i][3]
                self.encoder.layer[i].attention.self.value.weight.data = weights.w[i][4].transpose(-1, -2).contiguous()
                self.encoder.layer[i].attention.self.value.bias.data = weights.w[i][5]
                self.encoder.layer[i].attention.output.dense.weight.data = weights.w[i][6].transpose(-1, -2).contiguous()
                self.encoder.layer[i].attention.output.dense.bias.data = weights.w[i][7]
                self.encoder.layer[i].attention.output.LayerNorm.weight.data = weights.w[i][8]
                self.encoder.layer[i].attention.output.LayerNorm.bias.data = weights.w[i][9]
                self.encoder.layer[i].intermediate.dense.weight.data = weights.w[i][10].transpose(-1, -2).contiguous()
                self.encoder.layer[i].intermediate.dense.bias.data = weights.w[i][11]
                self.encoder.layer[i].output.dense.weight.data = weights.w[i][12].transpose(-1, -2).contiguous()
                self.encoder.layer[i].output.dense.bias.data = weights.w[i][13]
                self.encoder.layer[i].output.LayerNorm.weight.data = weights.w[i][14]
                self.encoder.layer[i].output.LayerNorm.bias.data = weights.w[i][15]
        self.head_mask = [None] * layer_num

    def forward(self, hidden_states, attention_mask):
        extended_attention_mask = (1.0 - attention_mask) * -10000.0
        output = self.encoder(hidden_states, extended_attention_mask, self.head_mask)
        return output