Exemplo n.º 1
0
 def create_bindings(self):
     return {
         "mergeable_model":
         diag_execs.diagonal_mergeable_model_from_checkpoint,
         "model_merger":
         self.model_merger,
         #
         "checkpoint_to_fisher_matrix_uuid":
         self.get_checkpoint_to_fisher_matrix_uuid(),
         "weightings":
         create_pairwise_weightings(self.num_weightings),
         #
         "checkpoints":
         [m.model_checkpoint_uuid for m in self.models_to_merge],
         "checkpoint_tasks": [m.task for m in self.models_to_merge],
         "additional_model_bindings":
         [m.additional_model_bindings for m in self.models_to_merge],
         #
         "tasks": [m.task for m in self.models_to_merge],
         #
         "pretrained_model":
         self.pretrained_model,
         #
         "num_examples":
         self.validation_examples,
         "sequence_length":
         self.sequence_length,
         "batch_size":
         self.batch_size,
         #
         "normalize_fishers":
         self.normalize_fishers,
     }
Exemplo n.º 2
0
    def create_bindings(self):
        weightings = create_pairwise_weightings(self.num_weightings,
                                                self.min_target_weighting)

        return {
            "mergeable_model":
            diag_execs.diagonal_mergeable_model_from_checkpoint_or_pretrained,
            "model_merger":
            diag_execs.diagonal_model_merger,
            #
            "checkpoint_to_fisher_matrix_uuid":
            self.get_checkpoint_to_fisher_matrix_uuid(),
            "weightings":
            weightings,
            #
            "checkpoints":
            [m.model_checkpoint_uuid for m in self.models_to_merge],
            "checkpoint_tasks": [m.task for m in self.models_to_merge],
            #
            "tasks": [m.task for m in self.models_to_merge],
            #
            "pretrained_model":
            self.pretrained_model,
            #
            "num_examples":
            self.validation_examples,
            "image_size":
            self.image_size,
            "batch_size":
            self.batch_size,
            #
            "normalize_fishers":
            self.normalize_fishers,
        }
Exemplo n.º 3
0
 def create_bindings(self):
     return {
         "mergeable_model":
         diag_execs.diagonal_mergeable_model_from_checkpoint_or_pretrained,
         "model_merger":
         self.model_merger,
         #
         "checkpoint_to_fisher_matrix_uuid":
         self.get_checkpoint_to_fisher_matrix_uuid(),
         "weightings":
         create_pairwise_weightings(self.num_weightings),
         #
         "checkpoints":
         [m.model_checkpoint_uuid for m in self.models_to_merge],
         "checkpoint_tasks": [m.task for m in self.models_to_merge],
         #
         "tasks": [m.task for m in self.models_to_merge],
         #
         "pretrained_model":
         self.pretrained_model,
         #
         "num_examples":
         self.validation_examples,
         "sequence_length":
         self.sequence_length,
         "batch_size":
         self.batch_size,
         #
         "normalize_fishers":
         self.normalize_fishers,
         #
         #
         "hf_back_compat":
         False,
         "glue_label_map_overrides":
         defs.LABEL_MAP_OVERRIDES,
     }
Exemplo n.º 4
0
 def create_bindings(self):
     return {
         "mergeable_model":
         diag_execs.diagonal_mergeable_model_from_checkpoint_or_pretrained,
         "model_merger":
         diag_execs.diagonal_model_merger,
         #
         "checkpoint_to_fisher_matrix_uuid":
         self.get_checkpoint_to_fisher_matrix_uuid(),
         "weightings":
         create_pairwise_weightings(self.num_weightings),
         #
         "normalize_fishers":
         self.normalize_fishers,
         #
         "checkpoints":
         [m.model_checkpoint_uuid for m in self.models_to_merge],
         "checkpoint_tasks": [m.task for m in self.models_to_merge],
         #
         "task":
         self.models_to_merge[0].task,
         "tasks": [m.task for m in self.models_to_merge],
         #
         # NOTE: model_checkpoint_uuid is actually the name of the pretrained_model.
         # Note that this won't be the pretrained model for all models, but let's
         # hope that it isn't really used.
         "pretrained_model":
         self.models_to_merge[0].model_checkpoint_uuid,
         #
         "num_examples":
         self.validation_examples,
         "sequence_length":
         self.sequence_length,
         "batch_size":
         self.batch_size,
     }