for task in LOW_RESOURCE_TASKS ], fixed_params={ "batch_size": 2, "sequence_length": 256, # "y_samples": 8, }, key_fields={ "pretrained_model", "task", "num_examples", "y_samples", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", target_tasks.mlm_dataset), # scopes.ArgNameBindingSpec("compiled_model", mlm_execs.roberta_mlm_model), # scopes.ArgNameBindingSpec("tokenizer", bert_common.bert_tokenizer), scopes.ArgNameBindingSpec("initializer", mlm_execs.roberta_initializer), scopes.ArgNameBindingSpec("loader", mlm_execs.roberta_loader), scopes.ArgNameBindingSpec("builder", mlm_execs.roberta_builder), ], ) class FisherComputation_MlmTargetTask(ExperimentAbc): pass
"validation_examples": 4096, # NOTE: I should find a way to specific these cleaner. "examples_per_epoch": 128, "num_epochs": 4, # "learning_rate": 3e-5, }, key_fields={"pretrained_model", "task", "reg_strength", "reg_type"}, # # # ### Need to add these bindings: # bindings=[ # scopes.ArgNameBindingSpec("with_validation", False), scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), scopes.ArgNameBindingSpec("compiled_model", gc_exe.bert_finetuning_model), scopes.ArgNameBindingSpec("optimizer", optimizers.adam_optimizer), scopes.ArgNameBindingSpec("callbacks", checkpoints.checkpoint_saver_callback), ], ) class FinetuneGlueIsoExperiment_Tiny(object): def create_run_instance_config(self, params): return runs.RunInstanceConfig( global_binding_specs=params.create_binding_specs()) ###############################################################################
"reg_type": "iso", "reg_strength": 3e-4, # "batch_size": 16, "learning_rate": 1e-5, "sequence_length": 64, # "num_ckpt_saves": 10, }, "key_fields": { "trial_index", "task", "pretrained_model", }, "bindings": [ scopes.ArgNameBindingSpec("with_validation", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), # scopes.ArgNameBindingSpec("callbacks", ckpt_exec.checkpoint_saver_callback), scopes.ArgNameBindingSpec("compiled_model", gc_exe.bert_finetuning_model), # scopes.ArgNameBindingSpec("optimizer", optimizers.adam_optimizer), # scopes.ArgNameBindingSpec("pretrained_body_only", True), ], }
varying_params=functools.partial( create_varying_params, train_exp=Finetune_Dapt_LowResource_FOR_REAL, task_to_example_count=TRAIN_EXAMPLES, ), fixed_params={ "batch_size": 2, "sequence_length": 256, # "y_samples": None, }, key_fields={ "finetuned_ckpt_uuid", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("y_samples", None), # scopes.ArgNameBindingSpec("fisher_class_chunk_size", 4), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", target_tasks.finetuning_dataset), # scopes.ArgNameBindingSpec("hf_back_compat", False), scopes.ArgNameBindingSpec("pretrained_body_only", True), scopes.ArgNameBindingSpec("use_roberta_head", True), ], ) class FisherComputation_Dapt_TargetTask_FOR_REAL(ExperimentAbc): pass
), fixed_params={ "num_weightings": 101, # "validation_examples": 2048, "image_size": simclr.IMAGE_SIZE, "batch_size": 256, # "normalize_fishers": True, }, key_fields={ "trial_index", "models_to_merge", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("multitask_merge", False), # scopes.ArgNameBindingSpec("split", "validation"), scopes.ArgNameBindingSpec("shuffle", False), scopes.ArgNameBindingSpec("repeat", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec( "dataset", image_classification.simclr_finetuning_dataset), # scopes.ArgNameBindingSpec("evaluate_model", eval_execs.robust_evaluate_model), scopes.ArgNameBindingSpec( "robust_evaluate_dataset", image_classification.robust_evaluation_dataset),
def create_binding_specs(self): return [ scopes.ArgNameBindingSpec(name, binding) for name, binding in self.create_bindings().items() ]
def create_binding_specs(self): if self.fisher_type == "diagonal": fisher_bindings = [ scopes.ArgNameBindingSpec( "compiled_fisher_computer", diagonal_execs.diagonal_fisher_computer), scopes.ArgNameBindingSpec("fisher_type", self.fisher_type), # NOTE: When I have more scopes, I should probably try to bind # this more strongly to where it is used. scopes.ArgNameBindingSpec("y_samples", self.diagonal_y_samples), ] if self.fisher_class_chunk_size is not None: fisher_bindings.append( scopes.ArgNameBindingSpec("fisher_class_chunk_size", self.fisher_class_chunk_size)) else: raise ValueError(f"Invalid fisher_type {self.fisher_type}.") return [ scopes.ArgNameBindingSpec("pretrained_model", self.pretrained_model), scopes.ArgNameBindingSpec("tasks", [self.task]), # scopes.ArgNameBindingSpec("finetuned_exp_uuid", self.finetuned_exp_uuid), scopes.ArgNameBindingSpec("finetuned_run_uuid", self.finetuned_run_uuid), scopes.ArgNameBindingSpec("finetuned_ckpt_uuid", self.finetuned_ckpt_uuid), # scopes.ArgNameBindingSpec("num_examples", self.num_examples), scopes.ArgNameBindingSpec("sequence_length", self.sequence_length), scopes.ArgNameBindingSpec("batch_size", self.batch_size), ] + fisher_bindings
def create_binding_specs(self): steps_per_epoch = self.examples_per_epoch // self.batch_size if self.fisher_type == "variational_diagonal": fisher_bindings = [ scopes.ArgNameBindingSpec("fisher_type", self.fisher_type), scopes.ArgNameBindingSpec( "compiled_fisher_computer", vardiag_exes.variational_diag_fisher_computer, ), scopes.ArgNameBindingSpec("variational_fisher_beta", self.variational_fisher_beta), scopes.ArgNameBindingSpec("save_fisher_at_each_epoch", self.save_fisher_at_each_epoch), ] else: raise ValueError(f"Invalid fisher_type {self.fisher_type}.") return [ scopes.ArgNameBindingSpec("pretrained_model", self.pretrained_model), scopes.ArgNameBindingSpec("tasks", [self.task]), # scopes.ArgNameBindingSpec("finetuned_exp_uuid", self.finetuned_exp_uuid), scopes.ArgNameBindingSpec("finetuned_run_uuid", self.finetuned_run_uuid), scopes.ArgNameBindingSpec("finetuned_ckpt_uuid", self.finetuned_ckpt_uuid), # scopes.ArgNameBindingSpec("num_examples", self.num_examples), scopes.ArgNameBindingSpec("sequence_length", self.sequence_length), scopes.ArgNameBindingSpec("batch_size", self.batch_size), # scopes.ArgNameBindingSpec("learning_rate", self.learning_rate), # scopes.ArgNameBindingSpec("epochs", self.epochs), scopes.ArgNameBindingSpec("steps_per_epoch", steps_per_epoch), ] + fisher_bindings
"reg_strength": 3e-4, # "batch_size": 8, "learning_rate": 1e-5, "sequence_length": 256, # "num_task_epochs": LOW_RESOURCE_EPOCHS, "num_ckpt_saves": LOW_RESOURCE_EPOCHS, }, key_fields={ "trial_index", "task", "checkpoint", }, bindings=[ scopes.ArgNameBindingSpec("with_validation", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", target_tasks.finetuning_dataset), # scopes.ArgNameBindingSpec("callbacks", ckpt_exec.checkpoint_saver_callback), scopes.ArgNameBindingSpec("compiled_model", gc_exe.bert_finetuning_model), # scopes.ArgNameBindingSpec("optimizer", optimizers.adam_optimizer), # scopes.ArgNameBindingSpec("hf_back_compat", False), scopes.ArgNameBindingSpec("pretrained_body_only", True), scopes.ArgNameBindingSpec("use_roberta_head", True), #
} for task in TASKS for reg_str in REG_STRENGTHS], fixed_params={ "pretrained_model": "r50_1x", "batch_size": 32, "reg_type": "iso", "image_size": simclr.IMAGE_SIZE, "train_examples": None, "validation_examples": 4096, "train_steps": 80_000, "steps_per_epoch": 10_000, "learning_rate": 1e-3, }, key_fields={"pretrained_model", "task", "reg_strength", "reg_type"}, bindings=[ # For some reason, validation can cause the program to hang indefinitely. scopes.ArgNameBindingSpec("with_validation", False), scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec( "dataset", image_classification.simclr_finetuning_dataset), scopes.ArgNameBindingSpec("compiled_model", sc_exe.simclr_finetuning_model), scopes.ArgNameBindingSpec("optimizer", optimizers.adam_optimizer), scopes.ArgNameBindingSpec("callbacks", ckpt_exec.checkpoint_saver_callback), ], ) class FinetuneSimclrIso_r50_1x(object): def create_run_instance_config(self, params): return runs.RunInstanceConfig( global_binding_specs=params.create_binding_specs())
"examples_per_epoch": 4096, # "save_fisher_at_each_epoch": True, }, key_fields={ "finetuned_run_uuid", "finetuned_ckpt_uuid", # "fisher_type", # "variational_fisher_beta", "learning_rate", "num_examples", }, bindings=[ scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), scopes.ArgNameBindingSpec("optimizer", optimizers.adam_optimizer), ], ) class RteBestCkpt_Iso_0003_PhaseI(object): def create_run_instance_config(self, params): return runs.RunInstanceConfig( global_binding_specs=params.create_binding_specs()) def create_preload_blob_uuids(self, params): return params.create_preload_blob_uuids() @experiment.experiment( uuid="1db11bdaa6ce4ee7b8ddeb8e1829da0d",
varying_params=functools.partial( create_varying_params, train_exp=Finetune_ROBERTA_LowResource, task_to_example_count=TRAIN_EXAMPLES, ), fixed_params={ "batch_size": 2, "sequence_length": 256, # "y_samples": None, }, key_fields={ "finetuned_ckpt_uuid", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("y_samples", None), # scopes.ArgNameBindingSpec("fisher_class_chunk_size", 4), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", target_tasks.finetuning_dataset), # scopes.ArgNameBindingSpec("hf_back_compat", False), scopes.ArgNameBindingSpec("pretrained_body_only", True), scopes.ArgNameBindingSpec("use_roberta_head", True), ], ) class FisherComputation_ROBERTA_TargetTasks(ExperimentAbc): pass
params_cls=DirectFisherParams, executable_cls=fisher_execs.fisher_computation, varying_params=functools.partial( create_varying_params, train_exp=Finetune_Rte, task_to_example_count=TASK_TO_EXAMPLE_COUNT, ), fixed_params={ "batch_size": 4, "sequence_length": 64, }, key_fields={ "finetuned_ckpt_uuid", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("y_samples", None), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), # scopes.ArgNameBindingSpec("hf_back_compat", False), scopes.ArgNameBindingSpec("pretrained_body_only", True), ], ) class Fisher_Rte(ExperimentAbc): pass @experiment.experiment( uuid="84092b50c15a4bad8b2cafb1f43e6524",
"reg_type": "iso", "reg_strength": 3e-4, # "batch_size": 8, "learning_rate": 1e-5, "sequence_length": 64, # "num_ckpt_saves": 10, }, key_fields={ "trial_index", "task", "pretrained_model", }, bindings=[ scopes.ArgNameBindingSpec("with_validation", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), # scopes.ArgNameBindingSpec("callbacks", ckpt_exec.checkpoint_saver_callback), scopes.ArgNameBindingSpec("compiled_model", gc_exe.bert_finetuning_model), # scopes.ArgNameBindingSpec("optimizer", optimizers.adam_optimizer), # scopes.ArgNameBindingSpec("hf_back_compat", False), scopes.ArgNameBindingSpec("pretrained_body_only", True), scopes.ArgNameBindingSpec("glue_label_map_overrides", defs.LABEL_MAP_OVERRIDES),
varying_params=functools.partial( create_varying_eval_params, train_exp=GlueFinetune_BertBase_LrSrc_Sft, ), fixed_params={ "sequence_length": 64, "num_examples": 2048, "batch_size": 512, }, key_fields={ "trial_index", "task", "checkpoints_summary", }, bindings=[ scopes.ArgNameBindingSpec("split", "validation"), scopes.ArgNameBindingSpec("shuffle", False), scopes.ArgNameBindingSpec("repeat", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), # scopes.ArgNameBindingSpec("evaluate_model", eval_execs.robust_evaluate_model), scopes.ArgNameBindingSpec( "robust_evaluate_dataset", glue.glue_robust_evaluation_dataset ), scopes.ArgNameBindingSpec("metrics_for_tasks", metrics_exe.glue_robust_metrics), scopes.ArgNameBindingSpec("cache_validation_batches_as_lists", True), # scopes.ArgNameBindingSpec("compiled_model", gc_exe.bert_finetuning_model), #
task, "pretrained_model": defs.TASK_TO_CKPT_BERT_BASE[task], "num_examples": min(NUM_GLUE_TRAIN_EXAMPLES[task], MAX_EXAMPLES), } for task in defs.HIGH_RESOURCE_TASKS], fixed_params={ "batch_size": 4, "sequence_length": 64, }, key_fields={ "task", "pretrained_model", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("y_samples", None), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), # scopes.ArgNameBindingSpec("fisher_class_chunk_size", 3), # scopes.ArgNameBindingSpec("loader", gc_exe.bert_loader), # scopes.ArgNameBindingSpec("pretrained_body_only", False), ], ) class FisherComputation_BertBase_HighResource(ExperimentAbc): pass
), fixed_params={ "num_weightings": 76, # "validation_examples": 2048, "sequence_length": 256, "batch_size": 128, # "normalize_fishers": True, }, key_fields={ "trial_index", "models_to_merge", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), # scopes.ArgNameBindingSpec("split", "test"), scopes.ArgNameBindingSpec("shuffle", False), scopes.ArgNameBindingSpec("repeat", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", target_tasks.finetuning_dataset), # scopes.ArgNameBindingSpec("evaluate_model", eval_execs.robust_evaluate_model), scopes.ArgNameBindingSpec("robust_evaluate_dataset", target_tasks.robust_evaluation_dataset), scopes.ArgNameBindingSpec("metrics_for_tasks", metrics_exe.glue_robust_metrics), scopes.ArgNameBindingSpec("cache_validation_batches_as_lists", True),
params_cls=FisherParams, executable_cls=fisher_execs.fisher_computation, varying_params=functools.partial( create_varying_params_last_ckpt, train_exp=Finetune_Subsets, max_examples=MAX_EXAMPLES, ), fixed_params={ "batch_size": 4, "image_size": simclr.IMAGE_SIZE, }, key_fields={ "finetuned_ckpt_uuid", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("y_samples", None), scopes.ArgNameBindingSpec("fisher_class_chunk_size", 4), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec( "dataset", image_classification.simclr_finetuning_dataset ), # scopes.ArgNameBindingSpec("initializer", sc_exe.simclr_initializer), scopes.ArgNameBindingSpec("loader", sc_exe.simclr_loader), scopes.ArgNameBindingSpec("builder", sc_exe.simclr_builder), # scopes.ArgNameBindingSpec("pretrained_body_only", True), ], )
"num_examples": 4096, "fisher_type": "diagonal", # Compute y expectation exactly. "diagonal_y_samples": None, "sequence_length": 64, "fisher_class_chunk_size": 4, }, key_fields={ "finetuned_run_uuid", "finetuned_ckpt_uuid", "fisher_type", "diagonal_y_samples", "num_examples", }, bindings=[ scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), ], ) class GlueRegs_Fisher_BestCkpt(object): def create_run_instance_config(self, params): return runs.RunInstanceConfig( global_binding_specs=params.create_binding_specs()) @experiment.experiment( uuid="3c7279a4c70e4c5a91372dd28870fd84", group=BertMergingPrelimsGroup, params_cls=FisherParams, executable_cls=fisher_execs.fisher_computation, varying_params=functools.partial(
varying_params=[{ "task": task, "pretrained_model": TASK_TO_CKPT[task], "num_examples": min(NUM_GLUE_TRAIN_EXAMPLES[task], MAX_MNLI_RTE_FISHER_EXAMPLES), } for task in ["rte", "mnli"]], fixed_params={ "batch_size": 4, "sequence_length": 64, }, key_fields={ "pretrained_model", "num_examples", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("y_samples", None), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), # scopes.ArgNameBindingSpec("fisher_class_chunk_size", 3), # scopes.ArgNameBindingSpec("loader", gc_exe.bert_loader), ], ) class FisherComputation_Base_MnliRte(ExperimentAbc): pass
def create_binding_specs(self): if self.fisher_type == "diagonal": fisher_bindings = [ scopes.ArgNameBindingSpec( "mergeable_model", diagonal_execs.diagonal_mergeable_model_from_checkpoint, ), scopes.ArgNameBindingSpec( "model_merger", diagonal_execs.diagonal_model_merger ), ] else: raise ValueError(f"Invalid fisher_type {self.fisher_type}.") return [ scopes.ArgNameBindingSpec( "checkpoint_to_fisher_matrix_uuid", self.get_checkpoint_to_fisher_matrix_uuid(), ), scopes.ArgNameBindingSpec( "weightings", create_pair_weightings(self.num_weightings) ), # scopes.ArgNameBindingSpec( "checkpoints", [m.model_checkpoint_uuid for m in self.models_to_merge] ), scopes.ArgNameBindingSpec( "checkpoint_tasks", [m.task for m in self.models_to_merge] ), # scopes.ArgNameBindingSpec("task", self.models_to_merge[0].task), scopes.ArgNameBindingSpec("tasks", [m.task for m in self.models_to_merge]), # scopes.ArgNameBindingSpec("pretrained_model", self.pretrained_model), # scopes.ArgNameBindingSpec("num_examples", self.validation_examples), scopes.ArgNameBindingSpec("sequence_length", self.sequence_length), scopes.ArgNameBindingSpec("batch_size", self.batch_size), ] + fisher_bindings
"pretrained_model": "bert-base-uncased", # "num_weightings": 51, # "validation_examples": 2048, "sequence_length": 64, "batch_size": 512, # "normalize_fishers": True, }, key_fields={ "models_to_merge", "normalize_fishers", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), # scopes.ArgNameBindingSpec("split", "validation"), scopes.ArgNameBindingSpec("shuffle", False), scopes.ArgNameBindingSpec("repeat", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), # scopes.ArgNameBindingSpec("evaluate_model", eval_execs.robust_evaluate_model), scopes.ArgNameBindingSpec("robust_evaluate_dataset", glue.glue_robust_evaluation_dataset), scopes.ArgNameBindingSpec("metrics_for_tasks", metrics_exe.glue_robust_metrics), scopes.ArgNameBindingSpec("cache_validation_batches_as_lists", True),
} @experiment.experiment( uuid="acf06e101c2146168f251cb868d2c4e3", group=PaperExpGroup, params_cls=DirectFisherParams, executable_cls=fisher_execs.fisher_computation, varying_params=functools.partial( create_varying_params, train_exp=GlueFinetune, task_to_example_count=TASK_TO_EXAMPLE_COUNT, ), fixed_params={ "batch_size": 4, "sequence_length": 64, }, key_fields={ "finetuned_ckpt_uuid", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), scopes.ArgNameBindingSpec("y_samples", None), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", glue.glue_finetuning_dataset), ], ) class FisherComputation(ExperimentAbc): pass
"reg_strength": 3e-4, # "batch_size": 8, "learning_rate": 1e-5, "sequence_length": 256, # "num_task_epochs": LOW_RESOURCE_EPOCHS, "num_ckpt_saves": LOW_RESOURCE_EPOCHS, }, key_fields={ "trial_index", "task", "pretrained_model", }, bindings=[ scopes.ArgNameBindingSpec("with_validation", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", target_tasks.finetuning_dataset), # scopes.ArgNameBindingSpec("callbacks", ckpt_exec.checkpoint_saver_callback), scopes.ArgNameBindingSpec("compiled_model", gc_exe.bert_finetuning_model), # scopes.ArgNameBindingSpec("optimizer", optimizers.adam_optimizer), ], ) class Finetune_LowResource(ExperimentAbc): pass
def get_checkpoint_to_fisher_matrix_uuid(self): key = (lambda m: m.model_checkpoint_uuid if m.model_checkpoint_uuid else self.pretrained_mlm_model) return {key(m): m.fisher_matrix_uuid for m in self.models_to_merge} def create_preload_blob_uuids(self): dikt = self.get_checkpoint_to_fisher_matrix_uuid() return tuple((set(dikt.keys()) | set(dikt.values())) - {self.pretrained_mlm_model}) ############################################################################### COMMON_BINDINGS = [ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), # scopes.ArgNameBindingSpec("split", "validation"), scopes.ArgNameBindingSpec("shuffle", False), scopes.ArgNameBindingSpec("repeat", False), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", target_tasks.finetuning_dataset), # scopes.ArgNameBindingSpec("evaluate_model", eval_execs.robust_evaluate_model), scopes.ArgNameBindingSpec("robust_evaluate_dataset", target_tasks.robust_evaluation_dataset), scopes.ArgNameBindingSpec("metrics_for_tasks", metrics_exe.glue_robust_metrics), scopes.ArgNameBindingSpec("cache_validation_batches_as_lists", True),
), fixed_params={ "batch_size": 1, "num_examples": 4096, "fisher_type": "diagonal", # Compute y expectation exactly. "diagonal_y_samples": None, "image_size": simclr.IMAGE_SIZE, }, key_fields={ "finetuned_run_uuid", "finetuned_ckpt_uuid", "fisher_type", "diagonal_y_samples", "num_examples", }, bindings=[ scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec( "dataset", image_classification.simclr_finetuning_dataset), scopes.ArgNameBindingSpec("initializer", sc_exe.simclr_initializer), scopes.ArgNameBindingSpec("builder", sc_exe.simclr_builder), scopes.ArgNameBindingSpec( "metrics", model_execs.multitask_classification_metrics), ], ) class SimclrFisherIso__r50_1x__ckpt_20k(object): def create_run_instance_config(self, params): return runs.RunInstanceConfig( global_binding_specs=params.create_binding_specs())
def create_binding_specs(self): steps_per_epoch = round(self.examples_per_epoch / self.batch_size) if self.reg_type == "iso": reg_bindings = [ scopes.ArgNameBindingSpec( "regularizer", gc_exe.regularize_body_l2_from_initial), scopes.ArgNameBindingSpec("reg_strength", self.reg_strength), ] else: raise ValueError(f"Invalid reg_type {self.reg_type}.") return [ scopes.ArgNameBindingSpec("pretrained_model", self.pretrained_model), scopes.ArgNameBindingSpec("train_num_examples", self.train_examples), scopes.ArgNameBindingSpec("validation_num_examples", self.validation_examples), scopes.ArgNameBindingSpec("batch_size", self.batch_size), scopes.ArgNameBindingSpec("tasks", [self.task]), scopes.ArgNameBindingSpec("steps_per_epoch", steps_per_epoch), scopes.ArgNameBindingSpec("epochs", self.num_epochs), scopes.ArgNameBindingSpec("sequence_length", self.sequence_length), scopes.ArgNameBindingSpec("learning_rate", self.learning_rate), ] + reg_bindings
"sequence_length": 256, # "batch_size": 2, "y_samples": 1, }, key_fields={ "trial_index", # "checkpoint", "task", # "num_examples", "y_samples", }, bindings=[ scopes.ArgNameBindingSpec("fisher_type", "diagonal"), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", s2orc.mlm_dataset), # scopes.ArgNameBindingSpec("compiled_model", mlm_execs.roberta_mlm_model), # scopes.ArgNameBindingSpec("tokenizer", bert_common.bert_tokenizer), scopes.ArgNameBindingSpec("initializer", mlm_execs.roberta_initializer), scopes.ArgNameBindingSpec("loader", mlm_execs.roberta_loader), scopes.ArgNameBindingSpec("builder", mlm_execs.roberta_builder), # scopes.ArgNameBindingSpec("hf_back_compat", False), scopes.ArgNameBindingSpec("pretrained_body_only", False),
"learning_rate": 1e-5, "sequence_length": 256, # "num_ckpt_saves": 1, }, key_fields={ "trial_index", # "pretrained_model", "task", # "num_examples", "reg_strength", }, bindings=[ scopes.ArgNameBindingSpec("with_validation", False), # scopes.ArgNameBindingSpec("fisher_type", "diagonal"), # scopes.ArgNameBindingSpec("tfds_dataset", tfds_execs.gcp_tfds_dataset), scopes.ArgNameBindingSpec("dataset", s2orc.mlm_dataset), # scopes.ArgNameBindingSpec("compiled_model", mlm_execs.roberta_mlm_model), # scopes.ArgNameBindingSpec("tokenizer", bert_common.bert_tokenizer), scopes.ArgNameBindingSpec("initializer", mlm_execs.roberta_initializer), scopes.ArgNameBindingSpec("loader", mlm_execs.roberta_loader), scopes.ArgNameBindingSpec("builder", mlm_execs.roberta_builder), #
"fisher_type": "diagonal", "diagonal_y_samples": 8, # "sequence_length": 256, "batch_size": 1, }, key_fields={ "pretrained_model", # "dataset", "num_examples", # "fisher_type", "diagonal_y_samples", }, bindings=[ scopes.ArgNameBindingSpec("initializer", mlm_exe.bert_initializer), scopes.ArgNameBindingSpec("loader", mlm_exe.bert_loader), scopes.ArgNameBindingSpec("builder", mlm_exe.bert_builder), scopes.ArgNameBindingSpec("metrics", mlm_exe.bert_mlm_metrics), # scopes.ArgNameBindingSpec("split", "train"), scopes.ArgNameBindingSpec("shuffle", True), scopes.ArgNameBindingSpec("repeat", False), ], ) class MlmFisher_Bert(object): def create_run_instance_config(self, params): return runs.RunInstanceConfig( global_binding_specs=params.create_binding_specs())