def meta_eval(self, **kwargs): assert self.args.splits_to_eval == ["test"] meta_model_getter = getattr(self, "meta_" + self.args.meta_testing_method) self.models = {} self.record_keeper = self.meta_record_keeper self.pickler_and_csver = self.meta_pickler_and_csver self.pickler_and_csver.load_records() self.hooks = logging_presets.HookContainer( self.record_keeper, record_group_name_prefix=meta_model_getter.__name__) self.tester_obj = self.pytorch_getter.get("tester", self.args.testing_method, self.get_tester_kwargs()) group_name = self.hooks.record_group_name(self.tester_obj, "test") curr_records = self.meta_record_keeper.get_record(group_name) iteration = len(list(curr_records.values())[0]) - 1 if len( curr_records) > 0 else 0 #this abomination is necessary for name, i in {"-1": -1, "best": iteration}.items(): self.models["trunk"], self.models["embedder"] = meta_model_getter( name) self.set_transforms() self.eval_model(i, name, splits_to_eval=self.args.splits_to_eval, load_model=False, **kwargs) self.pickler_and_csver.save_records()
def meta_eval(self): meta_model_getter = getattr(self, self.get_curr_meta_testing_method()) self.models = {} self.record_keeper = self.meta_record_keeper self.hooks = logging_presets.HookContainer(self.record_keeper, record_group_name_prefix=meta_model_getter.__name__, primary_metric=self.args.eval_primary_metric) self.tester_obj = self.pytorch_getter.get("tester", self.args.testing_method, self.get_tester_kwargs()) models_to_eval = [] if self.args.check_untrained_accuracy: models_to_eval.append(const.UNTRAINED_TRUNK) models_to_eval.append(const.UNTRAINED_TRUNK_AND_EMBEDDER) models_to_eval.append(const.TRAINED) group_names = [self.get_eval_record_name_dict(self.curr_meta_testing_method)[split_name] for split_name in self.args.splits_to_eval] for name in models_to_eval: self.models["trunk"], self.models["embedder"] = meta_model_getter(name) did_not_skip = self.eval_model(name, name, load_model=False, skip_eval_if_already_done=self.args.skip_meta_eval_if_already_done) if did_not_skip: for group_name in group_names: len_of_existing_records = c_f.try_getting_db_count(self.meta_record_keeper, group_name) + 1 self.record_keeper.update_records({const.TRAINED_STATUS_COL_NAME: name}, global_iteration=len_of_existing_records, input_group_name_for_non_objects=group_name) self.record_keeper.update_records({"timestamp": c_f.get_datetime()}, global_iteration=len_of_existing_records, input_group_name_for_non_objects=group_name) for irrelevant_key in ["best_epoch", "best_accuracy"]: self.record_keeper.record_writer.records[group_name].pop(irrelevant_key, None) self.record_keeper.save_records()
def meta_eval(self): meta_model_getter = getattr(self, "meta_"+self.args.meta_testing_method) self.models = {} self.record_keeper = self.meta_record_keeper self.hooks = logging_presets.HookContainer(self.record_keeper, record_group_name_prefix=meta_model_getter.__name__, primary_metric=self.args.eval_primary_metric) self.tester_obj = self.pytorch_getter.get("tester", self.args.testing_method, self.get_tester_kwargs()) eval_dict = {"best": 1} if self.args.check_untrained_accuracy: eval_dict["-1"] = -1 group_name = self.get_eval_record_name_dict("meta_ConcatenateEmbeddings")["test"] len_of_existing_records = c_f.try_getting_db_count(self.meta_record_keeper, group_name) for name, i in eval_dict.items(): self.models["trunk"], self.models["embedder"] = meta_model_getter(name) self.set_transforms() did_not_skip = self.eval_model(i, name, splits_to_eval=self.args.splits_to_eval, load_model=False, skip_eval_if_already_done=self.args.skip_meta_eval_if_already_done) if did_not_skip: is_trained = int(i==1) global_iteration = len_of_existing_records + is_trained + 1 self.record_keeper.update_records({"is_trained": int(i==1)}, global_iteration=global_iteration, input_group_name_for_non_objects=group_name) self.record_keeper.update_records({"timestamp": c_f.get_datetime()}, global_iteration=global_iteration, input_group_name_for_non_objects=group_name) for irrelevant_key in ["epoch", "best_epoch", "best_accuracy"]: self.record_keeper.record_writer.records.pop(irrelevant_key, None) self.record_keeper.save_records()
def set_models_optimizers_losses(self): self.set_model() self.set_sampler() self.set_loss_function() self.set_mining_function() self.set_optimizers() self.set_record_keeper() self.hooks = logging_presets.HookContainer(self.record_keeper, primary_metric=self.args.eval_primary_metric, validation_split_name="val") self.tester_obj = self.pytorch_getter.get("tester", self.args.testing_method, self.get_tester_kwargs()) self.trainer = self.pytorch_getter.get("trainer", self.args.training_method, self.get_trainer_kwargs()) if self.is_training(): self.epoch = self.maybe_load_latest_saved_models() self.set_dataparallel() self.set_devices()
def set_models_optimizers_losses(self): self.set_model() self.set_transforms() self.set_sampler() self.set_dataparallel() self.set_loss_function() self.set_mining_function() self.set_optimizers() self.set_record_keeper() self.hooks = logging_presets.HookContainer(self.record_keeper, end_of_epoch_test=False) self.tester_obj = self.pytorch_getter.get("tester", self.args.testing_method, self.get_tester_kwargs()) self.epoch = self.maybe_load_models_and_records() + 1
def get_eval_record_name_dict(self, eval_type=const.NON_META, return_all=False, return_base_record_group_name=False): if not getattr(self, "hooks", None): self.hooks = logging_presets.HookContainer( None, primary_metric=self.args.eval_primary_metric) if not getattr(self, "tester_obj", None): if not getattr(self, "split_manager", None): self.split_manager = self.get_split_manager() self.tester_obj = self.pytorch_getter.get("tester", self.args.testing_method, self.get_tester_kwargs()) prefix = self.hooks.record_group_name_prefix self.hooks.record_group_name_prefix = "" #temporary if return_base_record_group_name: non_meta = { "base_record_group_name": self.hooks.base_record_group_name(self.tester_obj) } else: non_meta = { k: self.hooks.record_group_name(self.tester_obj, k) for k in self.split_manager.split_names } meta_separate = { k: "{}_{}".format(const.META_SEPARATE_EMBEDDINGS, v) for k, v in non_meta.items() } meta_concatenate = { k: "{}_{}".format(const.META_CONCATENATE_EMBEDDINGS, v) for k, v in non_meta.items() } self.hooks.record_group_name_prefix = prefix name_dict = { const.NON_META: non_meta, const.META_SEPARATE_EMBEDDINGS: meta_separate, const.META_CONCATENATE_EMBEDDINGS: meta_concatenate } if return_all: return name_dict return name_dict[eval_type]