class HuggingFaceEncoder(torch.nn.Module): def __init__(self, layer_num, head_num, head_size, weights=None): super().__init__() hidden_dim = head_num * head_size conf = BertConfig(hidden_size=hidden_dim, intermediate_size=4 * hidden_dim, num_attention_heads=head_num, num_hidden_layers=layer_num) self.encoder = BertEncoder(conf) w = {} for k, v in weights.weights.items(): if k.startswith('bert.encoder') and not k.endswith('_amax'): w[k[13:]] = weights.weights[k] self.encoder.load_state_dict(w) self.head_mask = [None] * layer_num def forward(self, hidden_states, attention_mask): extended_attention_mask = (1.0 - attention_mask) * -10000.0 output = self.encoder(hidden_states, extended_attention_mask, self.head_mask) return output
class HuggingFaceEncoder(torch.nn.Module): def __init__(self, layer_num, head_num, head_size, weights=None): super().__init__() hidden_dim = head_num * head_size conf = BertConfig(hidden_size=hidden_dim, intermediate_size=4*hidden_dim, num_attention_heads=head_num, num_hidden_layers=layer_num) self.encoder = BertEncoder(conf) if isinstance(weights, dict): w = {} for k, v in weights.items(): if k.startswith('bert.encoder'): w[k[13:]] = weights[k] self.encoder.load_state_dict(w) else: for i in range(layer_num): self.encoder.layer[i].attention.self.query.weight.data = weights.w[i][0].transpose(-1, -2).contiguous() self.encoder.layer[i].attention.self.query.bias.data = weights.w[i][1] self.encoder.layer[i].attention.self.key.weight.data = weights.w[i][2].transpose(-1, -2).contiguous() self.encoder.layer[i].attention.self.key.bias.data = weights.w[i][3] self.encoder.layer[i].attention.self.value.weight.data = weights.w[i][4].transpose(-1, -2).contiguous() self.encoder.layer[i].attention.self.value.bias.data = weights.w[i][5] self.encoder.layer[i].attention.output.dense.weight.data = weights.w[i][6].transpose(-1, -2).contiguous() self.encoder.layer[i].attention.output.dense.bias.data = weights.w[i][7] self.encoder.layer[i].attention.output.LayerNorm.weight.data = weights.w[i][8] self.encoder.layer[i].attention.output.LayerNorm.bias.data = weights.w[i][9] self.encoder.layer[i].intermediate.dense.weight.data = weights.w[i][10].transpose(-1, -2).contiguous() self.encoder.layer[i].intermediate.dense.bias.data = weights.w[i][11] self.encoder.layer[i].output.dense.weight.data = weights.w[i][12].transpose(-1, -2).contiguous() self.encoder.layer[i].output.dense.bias.data = weights.w[i][13] self.encoder.layer[i].output.LayerNorm.weight.data = weights.w[i][14] self.encoder.layer[i].output.LayerNorm.bias.data = weights.w[i][15] self.head_mask = [None] * layer_num def forward(self, hidden_states, attention_mask): extended_attention_mask = (1.0 - attention_mask) * -10000.0 output = self.encoder(hidden_states, extended_attention_mask, self.head_mask) return output