def get_representation(sub_model: LongformerEncoder, ids: T, attn_mask: T, global_attn_mask: T, fix_encoder: bool = False) -> (T, T, T): sequence_output = None if ids is not None: if fix_encoder: with torch.no_grad(): sequence_output, _, _ = sub_model.forward(input_ids=ids, attention_mask=attn_mask, global_attention_mask=global_attn_mask) if sub_model.training: sequence_output.requires_grad_(requires_grad=True) else: sequence_output, _, _ = sub_model.forward(input_ids=ids, attention_mask=attn_mask, global_attention_mask=global_attn_mask) return sequence_output
def __init__(self, args: Namespace, fix_encoder=False): super().__init__() # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.tokenizer = get_hotpotqa_longformer_tokenizer(model_name=args.pretrained_cfg_name) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ longEncoder = LongformerEncoder.init_encoder(cfg_name=args.pretrained_cfg_name, projection_dim=args.project_dim, hidden_dropout=args.input_drop, attn_dropout=args.attn_drop, seq_project=args.seq_project) longEncoder.resize_token_embeddings(len(self.tokenizer)) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if args.frozen_layer_num > 0: modules = [longEncoder.embeddings, *longEncoder.encoder.layer[:args.frozen_layer_num]] for module in modules: for param in module.parameters(): param.requires_grad = False logging.info('Frozen the first {} layers'.format(args.frozen_layer_num)) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.longformer = longEncoder #### LongFormer encoder self.hidden_size = longEncoder.get_out_size() self.doc_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support document prediction self.sent_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support sentence prediction self.fix_encoder = fix_encoder ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.hparams = args ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.graph_training = self.hparams.with_graph_training == 1 ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.with_graph = self.hparams.with_graph == 1 if self.with_graph: self.graph_encoder = TransformerModule(layer_num=self.hparams.layer_number, d_model=self.hidden_size, heads=self.hparams.heads) ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.mask_value = MASK_VALUE
def __init__(self, args: Namespace, fix_encoder=False): super().__init__() # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.tokenizer = get_hotpotqa_longformer_tokenizer( model_name=args.pretrained_cfg_name) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ longEncoder = LongformerEncoder.init_encoder( cfg_name=args.pretrained_cfg_name, projection_dim=args.project_dim, hidden_dropout=args.input_drop, attn_dropout=args.attn_drop, seq_project=args.seq_project) longEncoder.resize_token_embeddings(len(self.tokenizer)) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if args.frozen_layer_num > 0: modules = [ longEncoder.embeddings, *longEncoder.encoder.layer[:args.frozen_layer_num] ] for module in modules: for param in module.parameters(): param.requires_grad = False logging.info('Frozen the first {} layers'.format( args.frozen_layer_num)) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.longformer = longEncoder #### LongFormer encoder self.hidden_size = longEncoder.get_out_size() self.answer_type_outputs = MLP( d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=3) ## yes, no, span question score self.answer_span_outputs = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=2) ## span prediction score self.doc_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support document prediction self.sent_mlp = MLP(d_input=self.hidden_size, d_mid=4 * self.hidden_size, d_out=1) ## support sentence prediction self.fix_encoder = fix_encoder ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.hparams = args ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.hop_model_name = self.hparams.hop_model_name ## triple score ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.graph_training = (self.hparams.with_graph_training == 1) ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ if self.hop_model_name not in ['DotProduct', 'BiLinear']: self.hop_model_name = None else: self.hop_doc_dotproduct = DotProduct( args=self.hparams ) if self.hop_model_name == 'DotProduct' else None self.hop_doc_bilinear = BiLinear( args=self.hparams, project_dim=self.hidden_size ) if self.hop_model_name == 'BiLinear' else None ####+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ self.mask_value = MASK_VALUE