Ejemplo n.º 1
0
class OnmtRobertaEncoder(EncoderBase):
    '''
    Returns:
        (torch.FloatTensor, torch.FloatTensor):

        * embeddings ``(src_len, batch_size, model_dim)``
        * memory_bank ``(src_len, batch_size, model_dim)``
    '''

    def __init__(self, model_path, padding_idx, vocab_size):
        super(OnmtRobertaEncoder, self).__init__()
        

        self.roberta_encoder = TransformerSentenceEncoder(
            padding_idx=padding_idx,
            vocab_size=vocab_size,
            num_encoder_layers=args.encoder_layers,
            embedding_dim=args.encoder_embed_dim,
            ffn_embedding_dim=args.encoder_ffn_embed_dim,
            num_attention_heads=args.encoder_attention_heads,
            dropout=args.dropout,
            attention_dropout=args.attention_dropout,
            activation_dropout=args.activation_dropout,
            max_seq_len=args.max_positions,
            num_segments=0,
            encoder_normalize_before=True,
            apply_bert_init=True,
            activation_fn=args.activation_fn,
        )
        print(self.roberta_encoder)
        print("defined the roberta network!")
        model_ckpt_file=os.path.join(model_path, "model.pt")
        if os.path.exists(model_ckpt_file):
            ckpt = torch.load(model_ckpt_file, map_location='cpu')
            args = ckpt["args"]
            model_dict = {}
            for k, v in ckpt["model"].items():
                if "decoder.sentence_encoder." in k:
                    k = k.replace("decoder.sentence_encoder.", "")
                    if k not in self.roberta_encoder.state_dict().keys():
                        print("skip", k)
                        continue
                    model_dict[k] = v
                    print("{}:{}".format(k, v.size()))

            self.roberta_encoder.load_state_dict(model_dict)
            print("loaded {}/{} weights".format(len(model_dict.keys()), len(self.roberta_encoder.state_dict().keys())))

        self.roberta_encoder.embed_tokens=expandEmbeddingByN(self.roberta_encoder.embed_tokens, 4 )
        print("*"*50)


    def forward(self, src, lengths=None):
        """See :func:`EncoderBase.forward()`"""
        self._check_args(src, lengths)
        src=src.squeeze(2).transpose(0,1).contiguous()

        #outs, sent_out=self.roberta_encoder(src)
        emb, outs, sent_out=self.forwad1(self.roberta_encoder,src)

        #emb=outs[0]

        out=outs[-1]
        #print("src--> outs", src.size(), out.size(), emb.size())
        #return emb.transpose(0,1).contiguous(), out.transpose(0, 1).contiguous(), lengths
        return emb, out, lengths
class BertRanker(BaseRanker):
    def __init__(self, args, task):
        super(BertRanker, self).__init__(args, task)

        init_model = getattr(args, "pretrained_model", "")
        self.joint_layers = nn.ModuleList()
        if os.path.isfile(init_model):
            print(f"initialize weight from {init_model}")

            from fairseq import hub_utils

            x = hub_utils.from_pretrained(
                os.path.dirname(init_model),
                checkpoint_file=os.path.basename(init_model),
            )

            in_state_dict = x["models"][0].state_dict()
            init_args = x["args"].model

            num_positional_emb = init_args.max_positions + task.dictionary.pad(
            ) + 1

            # follow the setup in roberta
            self.model = TransformerSentenceEncoder(
                padding_idx=task.dictionary.pad(),
                vocab_size=len(task.dictionary),
                num_encoder_layers=getattr(args, "encoder_layers",
                                           init_args.encoder_layers),
                embedding_dim=init_args.encoder_embed_dim,
                ffn_embedding_dim=init_args.encoder_ffn_embed_dim,
                num_attention_heads=init_args.encoder_attention_heads,
                dropout=init_args.dropout,
                attention_dropout=init_args.attention_dropout,
                activation_dropout=init_args.activation_dropout,
                num_segments=2,  # add language embeddings
                max_seq_len=num_positional_emb,
                offset_positions_by_padding=False,
                encoder_normalize_before=True,
                apply_bert_init=True,
                activation_fn=init_args.activation_fn,
                freeze_embeddings=args.freeze_embeddings,
                n_trans_layers_to_freeze=args.n_trans_layers_to_freeze,
            )

            # still need to learn segment embeddings as we added a second language embedding
            if args.freeze_embeddings:
                for p in self.model.segment_embeddings.parameters():
                    p.requires_grad = False

            update_init_roberta_model_state(in_state_dict)
            print("loading weights from the pretrained model")
            self.model.load_state_dict(
                in_state_dict,
                strict=False)  # ignore mismatch in language embeddings

            ffn_embedding_dim = init_args.encoder_ffn_embed_dim
            num_attention_heads = init_args.encoder_attention_heads
            dropout = init_args.dropout
            attention_dropout = init_args.attention_dropout
            activation_dropout = init_args.activation_dropout
            activation_fn = init_args.activation_fn

            classifier_embed_dim = getattr(args, "embed_dim",
                                           init_args.encoder_embed_dim)
            if classifier_embed_dim != init_args.encoder_embed_dim:
                self.transform_layer = nn.Linear(init_args.encoder_embed_dim,
                                                 classifier_embed_dim)
        else:
            self.model = TransformerSentenceEncoder(
                padding_idx=task.dictionary.pad(),
                vocab_size=len(task.dictionary),
                num_encoder_layers=args.encoder_layers,
                embedding_dim=args.embed_dim,
                ffn_embedding_dim=args.ffn_embed_dim,
                num_attention_heads=args.attention_heads,
                dropout=args.dropout,
                attention_dropout=args.attention_dropout,
                activation_dropout=args.activation_dropout,
                max_seq_len=task.max_positions()
                if task.max_positions() else args.tokens_per_sample,
                num_segments=2,
                offset_positions_by_padding=False,
                encoder_normalize_before=args.encoder_normalize_before,
                apply_bert_init=args.apply_bert_init,
                activation_fn=args.activation_fn,
            )

            classifier_embed_dim = args.embed_dim
            ffn_embedding_dim = args.ffn_embed_dim
            num_attention_heads = args.attention_heads
            dropout = args.dropout
            attention_dropout = args.attention_dropout
            activation_dropout = args.activation_dropout
            activation_fn = args.activation_fn

        self.joint_classification = args.joint_classification
        if args.joint_classification == "sent":
            if args.joint_normalize_before:
                self.joint_layer_norm = LayerNorm(classifier_embed_dim)
            else:
                self.joint_layer_norm = None

            self.joint_layers = nn.ModuleList([
                TransformerSentenceEncoderLayer(
                    embedding_dim=classifier_embed_dim,
                    ffn_embedding_dim=ffn_embedding_dim,
                    num_attention_heads=num_attention_heads,
                    dropout=dropout,
                    attention_dropout=attention_dropout,
                    activation_dropout=activation_dropout,
                    activation_fn=activation_fn,
                ) for _ in range(args.num_joint_layers)
            ])

        self.classifier = RobertaClassificationHead(
            classifier_embed_dim,
            classifier_embed_dim,
            1,  # num_classes
            "tanh",
            args.classifier_dropout,
        )

    def forward(self, src_tokens, src_lengths):
        segment_labels = self.get_segment_labels(src_tokens)
        positions = self.get_positions(src_tokens, segment_labels)

        inner_states, _ = self.model(
            tokens=src_tokens,
            segment_labels=segment_labels,
            last_state_only=True,
            positions=positions,
        )

        return inner_states[-1].transpose(0, 1)  # T x B x C -> B x T x C

    def sentence_forward(self,
                         encoder_out,
                         src_tokens=None,
                         sentence_rep="head"):
        # encoder_out: B x T x C
        if sentence_rep == "head":
            x = encoder_out[:, :1, :]
        else:  # 'meanpool', 'maxpool'
            assert src_tokens is not None, "meanpool requires src_tokens input"
            segment_labels = self.get_segment_labels(src_tokens)
            padding_mask = src_tokens.ne(self.padding_idx)
            encoder_mask = segment_labels * padding_mask.type_as(
                segment_labels)

            if sentence_rep == "meanpool":
                ntokens = torch.sum(encoder_mask, dim=1, keepdim=True)
                x = torch.sum(
                    encoder_out * encoder_mask.unsqueeze(2),
                    dim=1,
                    keepdim=True) / ntokens.unsqueeze(2).type_as(encoder_out)
            else:  # 'maxpool'
                encoder_out[(encoder_mask == 0).unsqueeze(2).repeat(
                    1, 1, encoder_out.shape[-1])] = -float("inf")
                x, _ = torch.max(encoder_out, dim=1, keepdim=True)

        if hasattr(self, "transform_layer"):
            x = self.transform_layer(x)

        return x  # B x 1 x C

    def joint_forward(self, x):
        # x: T x B x C
        if self.joint_layer_norm:
            x = self.joint_layer_norm(x.transpose(0, 1))
            x = x.transpose(0, 1)

        for layer in self.joint_layers:
            x, _ = layer(x, self_attn_padding_mask=None)
        return x

    def classification_forward(self, x):
        # x: B x T x C
        return self.classifier(x)