コード例 #1
0
    def add_args(parser):
        """Add model-specific arguments to the parser."""

        parser.add_argument('--triplet_type', type=str, default=None,
                            help='type of triplet model to use for inference')

        RobertaWrapper.add_args(parser)
コード例 #2
0
 def build_model(cls, args, task):
     encoder = RobertaWrapper.build_model(args, task)
     model_dict = nn.ModuleDict()
     for task_name, sub_task in task.tasks.items():
         task_override_args = args.tasks[task_name]
         if 'arch' in task_override_args:
             model_dict[task_name] = ARCH_MODEL_REGISTRY[
                 task_override_args['arch']].build_model(sub_task.args,
                                                         sub_task,
                                                         encoder=encoder)
         else:
             model_dict[task_name] = encoder
     return cls(args, encoder, model_dict)
コード例 #3
0
ファイル: encoder_mlm.py プロジェクト: Michiel29/graphqa
 def build_model(cls, args, task, encoder=None):
     if encoder is None:
         encoder = RobertaWrapper.build_model(args, task)
     return cls(args, encoder)
コード例 #4
0
ファイル: encoder_mlm.py プロジェクト: Michiel29/graphqa
 def add_args(parser):
     """Add model-specific arguments to the parser."""
     RobertaWrapper.add_args(parser)
コード例 #5
0
 def build_model(cls, args, task, encoder=None):
     if encoder is None:
         encoder = RobertaWrapper.build_model(args, task)
     n_entities = len(task.entity_dictionary)
     return cls(args, encoder, n_entities)
コード例 #6
0
ファイル: encoder_triplet.py プロジェクト: Michiel29/graphqa
 def build_model(cls, args, task, encoder=None):
     if encoder is None:
         encoder = RobertaWrapper.build_model(args, task)
     triplet_model = triplet_dict[args.triplet_type](args)
     n_entities = len(task.entity_dictionary)
     return cls(args, encoder, triplet_model, n_entities)