def from_torch(attention: TorchDistilMultiHeadSelfAttention, layernorm: nn.LayerNorm): params = {k: v for k, v in attention.named_parameters()} layernorm_params = {k: v for k, v in layernorm.named_parameters()} with torch.no_grad(): # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight qkv_weight = torch.clone( torch.t( torch.cat((params['q_lin.weight'], params['k_lin.weight'], params['v_lin.weight']), 0).contiguous()).contiguous()) qkv_bias = torch.cat((params['q_lin.bias'], params['k_lin.bias'], params['v_lin.bias']), 0).contiguous() output_weight = torch.clone( torch.t(params['out_lin.weight']).contiguous()) att = DistillBertAttention( convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias), convert2tt_tensor(output_weight), convert2tt_tensor(params['out_lin.bias']), convert2tt_tensor(layernorm_params['weight']), convert2tt_tensor(layernorm_params['bias']), attention.n_heads) return att
def from_onmt(multi_headed_attn: OnmtMultiHeadedAttention, layer_norm: TorchLayerNorm, is_trans_weight: bool = False): ln_params = {k: v for k, v in layer_norm.named_parameters()} with torch.no_grad(): att = MultiHeadedAttention( *(MultiHeadedAttention.pack_parameter(multi_headed_attn, is_trans_weight)), convert2tt_tensor(ln_params['weight']), convert2tt_tensor(ln_params['bias']), multi_headed_attn.head_count) return att
def from_torch(ffn: TorchDistilFFN, layernorm: nn.LayerNorm, is_trans_weight: Optional[bool] = True): ffn_params = {k: v for k, v in ffn.named_parameters()} layernorm_params = {k: v for k, v in layernorm.named_parameters()} # Note that torch's weights of linear layer is transposed if is_trans_weight: w_1 = convert2tt_tensor(ffn_params['lin1.weight']) w_2 = convert2tt_tensor(ffn_params['lin2.weight']) else: w_1 = convert2tt_tensor( torch.clone(torch.t(ffn_params['lin1.weight']).contiguous())) w_2 = convert2tt_tensor( torch.clone(torch.t(ffn_params['lin2.weight']).contiguous())) with torch.no_grad(): ffn = DistrillFFN(w_1, convert2tt_tensor(ffn_params['lin1.bias']), w_2, convert2tt_tensor(ffn_params['lin2.bias']), convert2tt_tensor(layernorm_params['weight']), convert2tt_tensor(layernorm_params['bias'])) return ffn