class OnmtRobertaEncoder(EncoderBase): ''' Returns: (torch.FloatTensor, torch.FloatTensor): * embeddings ``(src_len, batch_size, model_dim)`` * memory_bank ``(src_len, batch_size, model_dim)`` ''' def __init__(self, model_path, padding_idx, vocab_size): super(OnmtRobertaEncoder, self).__init__() self.roberta_encoder = TransformerSentenceEncoder( padding_idx=padding_idx, vocab_size=vocab_size, num_encoder_layers=args.encoder_layers, embedding_dim=args.encoder_embed_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, max_seq_len=args.max_positions, num_segments=0, encoder_normalize_before=True, apply_bert_init=True, activation_fn=args.activation_fn, ) print(self.roberta_encoder) print("defined the roberta network!") model_ckpt_file=os.path.join(model_path, "model.pt") if os.path.exists(model_ckpt_file): ckpt = torch.load(model_ckpt_file, map_location='cpu') args = ckpt["args"] model_dict = {} for k, v in ckpt["model"].items(): if "decoder.sentence_encoder." in k: k = k.replace("decoder.sentence_encoder.", "") if k not in self.roberta_encoder.state_dict().keys(): print("skip", k) continue model_dict[k] = v print("{}:{}".format(k, v.size())) self.roberta_encoder.load_state_dict(model_dict) print("loaded {}/{} weights".format(len(model_dict.keys()), len(self.roberta_encoder.state_dict().keys()))) self.roberta_encoder.embed_tokens=expandEmbeddingByN(self.roberta_encoder.embed_tokens, 4 ) print("*"*50) def forward(self, src, lengths=None): """See :func:`EncoderBase.forward()`""" self._check_args(src, lengths) src=src.squeeze(2).transpose(0,1).contiguous() #outs, sent_out=self.roberta_encoder(src) emb, outs, sent_out=self.forwad1(self.roberta_encoder,src) #emb=outs[0] out=outs[-1] #print("src--> outs", src.size(), out.size(), emb.size()) #return emb.transpose(0,1).contiguous(), out.transpose(0, 1).contiguous(), lengths return emb, out, lengths
class BertRanker(BaseRanker): def __init__(self, args, task): super(BertRanker, self).__init__(args, task) init_model = getattr(args, "pretrained_model", "") self.joint_layers = nn.ModuleList() if os.path.isfile(init_model): print(f"initialize weight from {init_model}") from fairseq import hub_utils x = hub_utils.from_pretrained( os.path.dirname(init_model), checkpoint_file=os.path.basename(init_model), ) in_state_dict = x["models"][0].state_dict() init_args = x["args"].model num_positional_emb = init_args.max_positions + task.dictionary.pad( ) + 1 # follow the setup in roberta self.model = TransformerSentenceEncoder( padding_idx=task.dictionary.pad(), vocab_size=len(task.dictionary), num_encoder_layers=getattr(args, "encoder_layers", init_args.encoder_layers), embedding_dim=init_args.encoder_embed_dim, ffn_embedding_dim=init_args.encoder_ffn_embed_dim, num_attention_heads=init_args.encoder_attention_heads, dropout=init_args.dropout, attention_dropout=init_args.attention_dropout, activation_dropout=init_args.activation_dropout, num_segments=2, # add language embeddings max_seq_len=num_positional_emb, offset_positions_by_padding=False, encoder_normalize_before=True, apply_bert_init=True, activation_fn=init_args.activation_fn, freeze_embeddings=args.freeze_embeddings, n_trans_layers_to_freeze=args.n_trans_layers_to_freeze, ) # still need to learn segment embeddings as we added a second language embedding if args.freeze_embeddings: for p in self.model.segment_embeddings.parameters(): p.requires_grad = False update_init_roberta_model_state(in_state_dict) print("loading weights from the pretrained model") self.model.load_state_dict( in_state_dict, strict=False) # ignore mismatch in language embeddings ffn_embedding_dim = init_args.encoder_ffn_embed_dim num_attention_heads = init_args.encoder_attention_heads dropout = init_args.dropout attention_dropout = init_args.attention_dropout activation_dropout = init_args.activation_dropout activation_fn = init_args.activation_fn classifier_embed_dim = getattr(args, "embed_dim", init_args.encoder_embed_dim) if classifier_embed_dim != init_args.encoder_embed_dim: self.transform_layer = nn.Linear(init_args.encoder_embed_dim, classifier_embed_dim) else: self.model = TransformerSentenceEncoder( padding_idx=task.dictionary.pad(), vocab_size=len(task.dictionary), num_encoder_layers=args.encoder_layers, embedding_dim=args.embed_dim, ffn_embedding_dim=args.ffn_embed_dim, num_attention_heads=args.attention_heads, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, max_seq_len=task.max_positions() if task.max_positions() else args.tokens_per_sample, num_segments=2, offset_positions_by_padding=False, encoder_normalize_before=args.encoder_normalize_before, apply_bert_init=args.apply_bert_init, activation_fn=args.activation_fn, ) classifier_embed_dim = args.embed_dim ffn_embedding_dim = args.ffn_embed_dim num_attention_heads = args.attention_heads dropout = args.dropout attention_dropout = args.attention_dropout activation_dropout = args.activation_dropout activation_fn = args.activation_fn self.joint_classification = args.joint_classification if args.joint_classification == "sent": if args.joint_normalize_before: self.joint_layer_norm = LayerNorm(classifier_embed_dim) else: self.joint_layer_norm = None self.joint_layers = nn.ModuleList([ TransformerSentenceEncoderLayer( embedding_dim=classifier_embed_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, ) for _ in range(args.num_joint_layers) ]) self.classifier = RobertaClassificationHead( classifier_embed_dim, classifier_embed_dim, 1, # num_classes "tanh", args.classifier_dropout, ) def forward(self, src_tokens, src_lengths): segment_labels = self.get_segment_labels(src_tokens) positions = self.get_positions(src_tokens, segment_labels) inner_states, _ = self.model( tokens=src_tokens, segment_labels=segment_labels, last_state_only=True, positions=positions, ) return inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"): # encoder_out: B x T x C if sentence_rep == "head": x = encoder_out[:, :1, :] else: # 'meanpool', 'maxpool' assert src_tokens is not None, "meanpool requires src_tokens input" segment_labels = self.get_segment_labels(src_tokens) padding_mask = src_tokens.ne(self.padding_idx) encoder_mask = segment_labels * padding_mask.type_as( segment_labels) if sentence_rep == "meanpool": ntokens = torch.sum(encoder_mask, dim=1, keepdim=True) x = torch.sum( encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True) / ntokens.unsqueeze(2).type_as(encoder_out) else: # 'maxpool' encoder_out[(encoder_mask == 0).unsqueeze(2).repeat( 1, 1, encoder_out.shape[-1])] = -float("inf") x, _ = torch.max(encoder_out, dim=1, keepdim=True) if hasattr(self, "transform_layer"): x = self.transform_layer(x) return x # B x 1 x C def joint_forward(self, x): # x: T x B x C if self.joint_layer_norm: x = self.joint_layer_norm(x.transpose(0, 1)) x = x.transpose(0, 1) for layer in self.joint_layers: x, _ = layer(x, self_attn_padding_mask=None) return x def classification_forward(self, x): # x: B x T x C return self.classifier(x)