def from_torch(attention: TorchDistilMultiHeadSelfAttention,
                   layernorm: nn.LayerNorm):
        params = {k: v for k, v in attention.named_parameters()}
        layernorm_params = {k: v for k, v in layernorm.named_parameters()}

        with torch.no_grad():
            # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight
            qkv_weight = torch.clone(
                torch.t(
                    torch.cat((params['q_lin.weight'], params['k_lin.weight'],
                               params['v_lin.weight']),
                              0).contiguous()).contiguous())
            qkv_bias = torch.cat((params['q_lin.bias'], params['k_lin.bias'],
                                  params['v_lin.bias']), 0).contiguous()

            output_weight = torch.clone(
                torch.t(params['out_lin.weight']).contiguous())
            att = DistillBertAttention(
                convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias),
                convert2tt_tensor(output_weight),
                convert2tt_tensor(params['out_lin.bias']),
                convert2tt_tensor(layernorm_params['weight']),
                convert2tt_tensor(layernorm_params['bias']), attention.n_heads)

            return att
Exemple #2
0
 def from_onmt(multi_headed_attn: OnmtMultiHeadedAttention,
               layer_norm: TorchLayerNorm,
               is_trans_weight: bool = False):
     ln_params = {k: v for k, v in layer_norm.named_parameters()}
     with torch.no_grad():
         att = MultiHeadedAttention(
             *(MultiHeadedAttention.pack_parameter(multi_headed_attn,
                                                   is_trans_weight)),
             convert2tt_tensor(ln_params['weight']),
             convert2tt_tensor(ln_params['bias']),
             multi_headed_attn.head_count)
         return att
    def from_torch(ffn: TorchDistilFFN,
                   layernorm: nn.LayerNorm,
                   is_trans_weight: Optional[bool] = True):
        ffn_params = {k: v for k, v in ffn.named_parameters()}
        layernorm_params = {k: v for k, v in layernorm.named_parameters()}

        # Note that torch's weights of linear layer is transposed
        if is_trans_weight:
            w_1 = convert2tt_tensor(ffn_params['lin1.weight'])
            w_2 = convert2tt_tensor(ffn_params['lin2.weight'])
        else:
            w_1 = convert2tt_tensor(
                torch.clone(torch.t(ffn_params['lin1.weight']).contiguous()))
            w_2 = convert2tt_tensor(
                torch.clone(torch.t(ffn_params['lin2.weight']).contiguous()))

        with torch.no_grad():
            ffn = DistrillFFN(w_1, convert2tt_tensor(ffn_params['lin1.bias']),
                              w_2, convert2tt_tensor(ffn_params['lin2.bias']),
                              convert2tt_tensor(layernorm_params['weight']),
                              convert2tt_tensor(layernorm_params['bias']))
            return ffn