示例#1
0
 def get_representation(sub_model: LongformerEncoder, ids: T, attn_mask: T, global_attn_mask: T,
                        fix_encoder: bool = False) -> (T, T, T):
     sequence_output = None
     if ids is not None:
         if fix_encoder:
             with torch.no_grad():
                 sequence_output, _, _ = sub_model.forward(input_ids=ids, attention_mask=attn_mask,
                                                           global_attention_mask=global_attn_mask)
             if sub_model.training:
                 sequence_output.requires_grad_(requires_grad=True)
         else:
             sequence_output, _, _ = sub_model.forward(input_ids=ids, attention_mask=attn_mask,
                                                       global_attention_mask=global_attn_mask)
     return sequence_output
示例#2
0
 def __init__(self, args: Namespace, fix_encoder=False):
     super().__init__()
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.tokenizer = get_hotpotqa_longformer_tokenizer(model_name=args.pretrained_cfg_name)
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     longEncoder = LongformerEncoder.init_encoder(cfg_name=args.pretrained_cfg_name, projection_dim=args.project_dim,
                                                  hidden_dropout=args.input_drop, attn_dropout=args.attn_drop,
                                                  seq_project=args.seq_project)
     longEncoder.resize_token_embeddings(len(self.tokenizer))
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     if args.frozen_layer_num > 0:
         modules = [longEncoder.embeddings, *longEncoder.encoder.layer[:args.frozen_layer_num]]
         for module in modules:
             for param in module.parameters():
                 param.requires_grad = False
         logging.info('Frozen the first {} layers'.format(args.frozen_layer_num))
     # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.longformer = longEncoder #### LongFormer encoder
     self.hidden_size = longEncoder.get_out_size()
     self.doc_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support document prediction
     self.sent_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support sentence prediction
     self.fix_encoder = fix_encoder
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.hparams = args
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.graph_training = self.hparams.with_graph_training == 1
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.with_graph = self.hparams.with_graph == 1
     if self.with_graph:
         self.graph_encoder = TransformerModule(layer_num=self.hparams.layer_number, d_model=self.hidden_size,
                                                heads=self.hparams.heads)
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.mask_value = MASK_VALUE
示例#3
0
 def __init__(self, args: Namespace, fix_encoder=False):
     super().__init__()
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.tokenizer = get_hotpotqa_longformer_tokenizer(
         model_name=args.pretrained_cfg_name)
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     longEncoder = LongformerEncoder.init_encoder(
         cfg_name=args.pretrained_cfg_name,
         projection_dim=args.project_dim,
         hidden_dropout=args.input_drop,
         attn_dropout=args.attn_drop,
         seq_project=args.seq_project)
     longEncoder.resize_token_embeddings(len(self.tokenizer))
     # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     if args.frozen_layer_num > 0:
         modules = [
             longEncoder.embeddings,
             *longEncoder.encoder.layer[:args.frozen_layer_num]
         ]
         for module in modules:
             for param in module.parameters():
                 param.requires_grad = False
         logging.info('Frozen the first {} layers'.format(
             args.frozen_layer_num))
     # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.longformer = longEncoder  #### LongFormer encoder
     self.hidden_size = longEncoder.get_out_size()
     self.answer_type_outputs = MLP(
         d_input=self.hidden_size, d_mid=4 * self.hidden_size,
         d_out=3)  ## yes, no, span question score
     self.answer_span_outputs = MLP(d_input=self.hidden_size,
                                    d_mid=4 * self.hidden_size,
                                    d_out=2)  ## span prediction score
     self.doc_mlp = MLP(d_input=self.hidden_size,
                        d_mid=4 * self.hidden_size,
                        d_out=1)  ## support document prediction
     self.sent_mlp = MLP(d_input=self.hidden_size,
                         d_mid=4 * self.hidden_size,
                         d_out=1)  ## support sentence prediction
     self.fix_encoder = fix_encoder
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.hparams = args
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.hop_model_name = self.hparams.hop_model_name  ## triple score
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.graph_training = (self.hparams.with_graph_training == 1)
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     if self.hop_model_name not in ['DotProduct', 'BiLinear']:
         self.hop_model_name = None
     else:
         self.hop_doc_dotproduct = DotProduct(
             args=self.hparams
         ) if self.hop_model_name == 'DotProduct' else None
         self.hop_doc_bilinear = BiLinear(
             args=self.hparams, project_dim=self.hidden_size
         ) if self.hop_model_name == 'BiLinear' else None
     ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
     self.mask_value = MASK_VALUE