Пример #1
0
 def meta_eval(self, **kwargs):
     assert self.args.splits_to_eval == ["test"]
     meta_model_getter = getattr(self,
                                 "meta_" + self.args.meta_testing_method)
     self.models = {}
     self.record_keeper = self.meta_record_keeper
     self.pickler_and_csver = self.meta_pickler_and_csver
     self.pickler_and_csver.load_records()
     self.hooks = logging_presets.HookContainer(
         self.record_keeper,
         record_group_name_prefix=meta_model_getter.__name__)
     self.tester_obj = self.pytorch_getter.get("tester",
                                               self.args.testing_method,
                                               self.get_tester_kwargs())
     group_name = self.hooks.record_group_name(self.tester_obj, "test")
     curr_records = self.meta_record_keeper.get_record(group_name)
     iteration = len(list(curr_records.values())[0]) - 1 if len(
         curr_records) > 0 else 0  #this abomination is necessary
     for name, i in {"-1": -1, "best": iteration}.items():
         self.models["trunk"], self.models["embedder"] = meta_model_getter(
             name)
         self.set_transforms()
         self.eval_model(i,
                         name,
                         splits_to_eval=self.args.splits_to_eval,
                         load_model=False,
                         **kwargs)
     self.pickler_and_csver.save_records()
Пример #2
0
    def meta_eval(self):
        meta_model_getter = getattr(self, self.get_curr_meta_testing_method())
        self.models = {}
        self.record_keeper = self.meta_record_keeper
        self.hooks = logging_presets.HookContainer(self.record_keeper, record_group_name_prefix=meta_model_getter.__name__, primary_metric=self.args.eval_primary_metric)
        self.tester_obj = self.pytorch_getter.get("tester", 
                                                self.args.testing_method, 
                                                self.get_tester_kwargs())

        models_to_eval = []
        if self.args.check_untrained_accuracy: 
            models_to_eval.append(const.UNTRAINED_TRUNK)
            models_to_eval.append(const.UNTRAINED_TRUNK_AND_EMBEDDER)
        models_to_eval.append(const.TRAINED)

        group_names = [self.get_eval_record_name_dict(self.curr_meta_testing_method)[split_name] for split_name in self.args.splits_to_eval]

        for name in models_to_eval:
            self.models["trunk"], self.models["embedder"] = meta_model_getter(name)
            did_not_skip = self.eval_model(name, name, load_model=False, skip_eval_if_already_done=self.args.skip_meta_eval_if_already_done)
            if did_not_skip:
                for group_name in group_names:
                    len_of_existing_records = c_f.try_getting_db_count(self.meta_record_keeper, group_name) + 1
                    self.record_keeper.update_records({const.TRAINED_STATUS_COL_NAME: name}, global_iteration=len_of_existing_records, input_group_name_for_non_objects=group_name)
                    self.record_keeper.update_records({"timestamp": c_f.get_datetime()}, global_iteration=len_of_existing_records, input_group_name_for_non_objects=group_name)

                for irrelevant_key in ["best_epoch", "best_accuracy"]:
                    self.record_keeper.record_writer.records[group_name].pop(irrelevant_key, None)
                self.record_keeper.save_records()
    def meta_eval(self):
        meta_model_getter = getattr(self, "meta_"+self.args.meta_testing_method)
        self.models = {}
        self.record_keeper = self.meta_record_keeper
        self.hooks = logging_presets.HookContainer(self.record_keeper, record_group_name_prefix=meta_model_getter.__name__, primary_metric=self.args.eval_primary_metric)
        self.tester_obj = self.pytorch_getter.get("tester", 
                                                self.args.testing_method, 
                                                self.get_tester_kwargs())

        eval_dict = {"best": 1}
        if self.args.check_untrained_accuracy: eval_dict["-1"] = -1

        group_name = self.get_eval_record_name_dict("meta_ConcatenateEmbeddings")["test"]
        len_of_existing_records = c_f.try_getting_db_count(self.meta_record_keeper, group_name)

        for name, i in eval_dict.items():
            self.models["trunk"], self.models["embedder"] = meta_model_getter(name)
            self.set_transforms()
            did_not_skip = self.eval_model(i, name, splits_to_eval=self.args.splits_to_eval, load_model=False, skip_eval_if_already_done=self.args.skip_meta_eval_if_already_done)
            if did_not_skip:
                is_trained = int(i==1)
                global_iteration = len_of_existing_records + is_trained + 1
                self.record_keeper.update_records({"is_trained": int(i==1)}, global_iteration=global_iteration, input_group_name_for_non_objects=group_name)
                self.record_keeper.update_records({"timestamp": c_f.get_datetime()}, global_iteration=global_iteration, input_group_name_for_non_objects=group_name)

                for irrelevant_key in ["epoch", "best_epoch", "best_accuracy"]:
                    self.record_keeper.record_writer.records.pop(irrelevant_key, None)
                self.record_keeper.save_records()
Пример #4
0
 def set_models_optimizers_losses(self):
     self.set_model()
     self.set_sampler()
     self.set_loss_function()
     self.set_mining_function()
     self.set_optimizers()
     self.set_record_keeper()
     self.hooks = logging_presets.HookContainer(self.record_keeper, primary_metric=self.args.eval_primary_metric, validation_split_name="val")
     self.tester_obj = self.pytorch_getter.get("tester", self.args.testing_method, self.get_tester_kwargs())
     self.trainer = self.pytorch_getter.get("trainer", self.args.training_method, self.get_trainer_kwargs())
     if self.is_training():
         self.epoch = self.maybe_load_latest_saved_models()
     self.set_dataparallel()
     self.set_devices()
Пример #5
0
 def set_models_optimizers_losses(self):
     self.set_model()
     self.set_transforms()
     self.set_sampler()
     self.set_dataparallel()
     self.set_loss_function()
     self.set_mining_function()
     self.set_optimizers()
     self.set_record_keeper()
     self.hooks = logging_presets.HookContainer(self.record_keeper,
                                                end_of_epoch_test=False)
     self.tester_obj = self.pytorch_getter.get("tester",
                                               self.args.testing_method,
                                               self.get_tester_kwargs())
     self.epoch = self.maybe_load_models_and_records() + 1
    def get_eval_record_name_dict(self,
                                  eval_type=const.NON_META,
                                  return_all=False,
                                  return_base_record_group_name=False):
        if not getattr(self, "hooks", None):
            self.hooks = logging_presets.HookContainer(
                None, primary_metric=self.args.eval_primary_metric)
        if not getattr(self, "tester_obj", None):
            if not getattr(self, "split_manager", None):
                self.split_manager = self.get_split_manager()
            self.tester_obj = self.pytorch_getter.get("tester",
                                                      self.args.testing_method,
                                                      self.get_tester_kwargs())
        prefix = self.hooks.record_group_name_prefix
        self.hooks.record_group_name_prefix = ""  #temporary
        if return_base_record_group_name:
            non_meta = {
                "base_record_group_name":
                self.hooks.base_record_group_name(self.tester_obj)
            }
        else:
            non_meta = {
                k: self.hooks.record_group_name(self.tester_obj, k)
                for k in self.split_manager.split_names
            }
        meta_separate = {
            k: "{}_{}".format(const.META_SEPARATE_EMBEDDINGS, v)
            for k, v in non_meta.items()
        }
        meta_concatenate = {
            k: "{}_{}".format(const.META_CONCATENATE_EMBEDDINGS, v)
            for k, v in non_meta.items()
        }
        self.hooks.record_group_name_prefix = prefix

        name_dict = {
            const.NON_META: non_meta,
            const.META_SEPARATE_EMBEDDINGS: meta_separate,
            const.META_CONCATENATE_EMBEDDINGS: meta_concatenate
        }

        if return_all:
            return name_dict
        return name_dict[eval_type]