예제 #1
0
    def __init__(
            self, config,
            **kwargs):  # tasks, num_labels_per_task, mask_id, encoder_class):
        super(BertMultiTask, self).__init__(config)
        # encoder_class only BertModel or RobertaModel
        # some arguments specific to BertMultiTask could be passed in config or in kwargs
        #self.pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
        encoder_class = kwargs["encoder"]
        tasks = kwargs["tasks"] if "tasks" in kwargs else config.tasks
        num_labels_per_task = kwargs[
            "num_labels_per_task"] if "num_labels_per_task" in kwargs else config.tasks
        mask_id = kwargs["mask_id"] if "mask_id" in kwargs else config.tasks
        config.dropout_classifier = kwargs.get("dropout_classifier", 0.1)

        if "parsing" in tasks:
            config.graph_head_hidden_size_mlp_arc = kwargs.get(
                "graph_head_hidden_size_mlp_arc", 500)
            config.graph_head_hidden_size_mlp_rel = kwargs.get(
                "graph_head_hidden_size_mlp_rel", 200)

        if "AutoModel" in str(encoder_class):
            self.encoder = encoder_class.from_config(config)
        else:
            self.encoder = encoder_class(config)

        self.config = config
        assert isinstance(num_labels_per_task, dict)
        assert isinstance(tasks, list) and len(
            tasks) >= 1, "config.tasks should be a list of len >=1"
        self.head = nn.ModuleDict()
        self.mask_index_bert = mask_id
        self.tasks = tasks  #
        self.tasks_available = tasks  # all tasks available in the model (not only the one we want to use at a given run (self.tasks))
        self.task_parameters = TASKS_PARAMETER
        self.layer_wise_attention = None
        self.labels_supported = [
            label for task in tasks
            for label in self.task_parameters[task]["label"]
        ]
        self.sanity_checking_num_labels_per_task(num_labels_per_task, tasks,
                                                 self.task_parameters)
        self.num_labels_dic = num_labels_per_task

        for task in TASKS_PARAMETER:
            if task in tasks:
                num_label = get_key_name_num_label(task, self.task_parameters)
                if not self.task_parameters[task]["num_labels_mandatory"]:
                    # in this case we need to define and load MLM head of the model
                    self.head[task] = eval(self.task_parameters[task]["head"])(
                        config
                    )  #, self.encoder.embeddings.word_embeddings.weight)
                else:
                    self.head[task] = eval(self.task_parameters[task]["head"])(
                        config, num_labels=self.num_labels_dic[num_label])
            else:
                # we define empty heads for downstream use
                self.head[task] = None
예제 #2
0
    def append_extra_heads_model(self, downstream_tasks, num_labels_dic_new):

        self.labels_supported.extend([
            label for task in downstream_tasks
            for label in self.task_parameters[task]["label"]
        ])
        self.sanity_check_new_num_labels_per_task(
            num_labels_new=num_labels_dic_new,
            num_labels_original=self.num_labels_dic)
        self.num_labels_dic.update(num_labels_dic_new)
        for new_task in downstream_tasks:
            if new_task not in self.tasks:
                num_label = get_key_name_num_label(new_task,
                                                   self.task_parameters)
                self.head[new_task] = eval(
                    self.task_parameters[new_task]["head"])(
                        self.config, num_labels=num_labels_dic_new[num_label])

        # we update the tasks attributes
        self.tasks_available = list(set(self.tasks + downstream_tasks))
        self.tasks = downstream_tasks  # tasks to be used at prediction time (+ possibly train)