def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     # wrap save and restore into a LambdaCheckpointHook
     self.ckpthook = LambdaCheckpointHook(
         root_path=ProjectManager.checkpoints,
         global_step_getter=self.get_global_step,
         global_step_setter=self.set_global_step,
         save=self.save,
         restore=self.restore,
         interval=set_default(self.config, "ckpt_freq", None),
     )
     if not self.config.get("test_mode", False):
         # in training, excute train ops and add logginghook
         self._train_ops = set_default(
             self.config, "train_ops", ["step_ops/train_op"]
         )
         self._log_ops = set_default(self.config, "log_ops", ["step_ops/log_op"])
         # logging
         self.loghook = LoggingHook(
             paths=self._log_ops, root_path=ProjectManager.train, interval=1
         )
         # wrap it in interval hook
         self.ihook = IntervalHook(
             [self.loghook],
             interval=set_default(self.config, "start_log_freq", 1),
             modify_each=1,
             max_interval=set_default(self.config, "log_freq", 1000),
             get_step=self.get_global_step,
         )
         self.hooks.append(self.ihook)
         # write checkpoints after epoch or when interrupted
         self.hooks.append(self.ckpthook)
     else:
         # evaluate
         self._eval_op = set_default(
             self.config, "eval_hook/eval_op", "step_ops/eval_op"
         )
         self._eval_callbacks = set_default(
             self.config, "eval_hook/eval_callbacks", list()
         )
         if not isinstance(self._eval_callbacks, list):
             self._eval_callbacks = [self._eval_callbacks]
         self._eval_callbacks = [
             get_obj_from_str(name) for name in self._eval_callbacks
         ]
         label_key = set_default(
             self.config, "eval_hook/label_key", "step_ops/eval_op/labels"
         )
         self.evalhook = TemplateEvalHook(
             dataset=self.dataset,
             step_getter=self.get_global_step,
             keypath=self._eval_op,
             meta=self.config,
             callbacks=self._eval_callbacks,
             label_key=label_key,
         )
         self.hooks.append(self.evalhook)
         self._train_ops = []
         self._log_ops = []
Exemple #2
0
def test_set_default_key_not_contained():
    dol = {"b": {"c": {"d": 1}}, "e": 2}
    ref = {"a": "new", "b": {"c": {"d": 1}}, "e": 2}

    val = set_default(dol, "a", "new")

    assert dol == ref
    assert val == "new"
Exemple #3
0
def default_repose_eval(root, data_in, data_out, config):

    # Set data_out to be data_in
    debug_mode = os.environ.get('DEBUG_MODE', 'False') == 'True'
    print('DEBUG', debug_mode)

    LOGGER.info("Setting up repose eval...")

    repose_config = config.get('repose_kwargs', {})
    if debug_mode:
        data_out = data_in
        koim = set_default(repose_config, 'data_out_im_key', 'target')
    else:
        koim = set_default(repose_config, 'data_out_im_key', 'frame_gen')
    set_value(repose_config, 'data_in_kp_key', 'target_keypoints_rel')

    # Only use pck for now
    set_value(repose_config, 'metrics', ['pck'])
    threshs = set_default(repose_config, 'metrics_kwargs/pck/thresholds',
                          PCK_THRESH)

    # For scaling the keypoints from relative to absolute
    gen_size = data_out[0][koim]
    if isinstance(gen_size, list):
        gen_size = gen_size[0]
    gen_size = np.array(gen_size.shape[:2])

    rp_eval = RePoseEval(**repose_config)

    LOGGER.info("Running repose eval...")

    print(len(data_in))
    print(len(data_out))
    print(repose_config)

    rp_eval(root, data_in, data_out, config)

    LOGGER.info("repose eval finished!")
Exemple #4
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # wrap save and restore into a LambdaCheckpointHook
        self.ckpthook = LambdaCheckpointHook(
            root_path=ProjectManager.checkpoints,
            global_step_getter=self.get_global_step,
            global_step_setter=self.set_global_step,
            save=self.save,
            restore=self.restore,
            interval=set_default(self.config, "ckpt_freq", None),
            ckpt_zero=set_default(self.config, "ckpt_zero", False),
        )
        # write checkpoints after epoch or when interrupted during training
        if not self.config.get("test_mode", False):
            self.hooks.append(self.ckpthook)

        ## hooks - disabled unless -t is specified

        # execute train ops
        self._train_ops = set_default(self.config, "train_ops",
                                      ["train/train_op"])
        train_hook = ExpandHook(paths=self._train_ops, interval=1)
        self.hooks.append(train_hook)

        # log train/step_ops/log_ops in increasing intervals
        self._log_ops = set_default(self.config, "log_ops",
                                    ["train/log_op", "validation/log_op"])
        self.loghook = LoggingHook(paths=self._log_ops,
                                   root_path=ProjectManager.train,
                                   interval=1)
        self.ihook = IntervalHook(
            [self.loghook],
            interval=set_default(self.config, "start_log_freq", 1),
            modify_each=1,
            max_interval=set_default(self.config, "log_freq", 1000),
            get_step=self.get_global_step,
        )
        self.hooks.append(self.ihook)

        # setup logging integrations
        if not self.config.get("test_mode", False):
            default_wandb_logging = {
                "active": False,
                "handlers": ["scalars", "images"]
            }
            wandb_logging = set_default(self.config, "integrations/wandb",
                                        default_wandb_logging)
            if wandb_logging["active"]:
                import wandb
                from edflow.hooks.logging_hooks.wandb_handler import (
                    log_wandb,
                    log_wandb_images,
                )

                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = ProjectManager.root.strip(
                    "/").replace("/", "-")
                wandb_project = set_default(self.config,
                                            "integrations/wandb/project", None)
                wandb_entity = set_default(self.config,
                                           "integrations/wandb/entity", None)
                wandb.init(
                    name=ProjectManager.root,
                    config=self.config,
                    project=wandb_project,
                    entity=wandb_entity,
                )

                handlers = set_default(
                    self.config,
                    "integrations/wandb/handlers",
                    default_wandb_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(log_wandb)
                if "images" in handlers:
                    self.loghook.handlers["images"].append(log_wandb_images)

            default_tensorboard_logging = {
                "active": False,
                "handlers": ["scalars", "images", "figures"],
            }
            tensorboard_logging = set_default(self.config,
                                              "integrations/tensorboard",
                                              default_tensorboard_logging)
            if tensorboard_logging["active"]:
                try:
                    from torch.utils.tensorboard import SummaryWriter
                except:
                    from tensorboardX import SummaryWriter

                from edflow.hooks.logging_hooks.tensorboard_handler import (
                    log_tensorboard_config,
                    log_tensorboard_scalars,
                    log_tensorboard_images,
                    log_tensorboard_figures,
                )

                self.tensorboard_writer = SummaryWriter(ProjectManager.root)
                log_tensorboard_config(self.tensorboard_writer, self.config,
                                       self.get_global_step())
                handlers = set_default(
                    self.config,
                    "integrations/tensorboard/handlers",
                    default_tensorboard_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(
                        lambda *args, **kwargs: log_tensorboard_scalars(
                            self.tensorboard_writer, *args, **kwargs))
                if "images" in handlers:
                    self.loghook.handlers["images"].append(
                        lambda *args, **kwargs: log_tensorboard_images(
                            self.tensorboard_writer, *args, **kwargs))
                if "figures" in handlers:
                    self.loghook.handlers["figures"].append(
                        lambda *args, **kwargs: log_tensorboard_figures(
                            self.tensorboard_writer, *args, **kwargs))
        ## epoch hooks

        # evaluate validation/step_ops/eval_op after each epoch
        self._eval_op = set_default(self.config, "eval_hook/eval_op",
                                    "validation/eval_op")
        _eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks",
                                      dict())
        if not isinstance(_eval_callbacks, dict):
            _eval_callbacks = {"cb": _eval_callbacks}
        eval_callbacks = dict()
        for k in _eval_callbacks:
            eval_callbacks[k] = _eval_callbacks[k]
        if hasattr(self, "callbacks"):
            iterator_callbacks = retrieve(self.callbacks,
                                          "eval_op",
                                          default=dict())
            for k in iterator_callbacks:
                import_path = get_str_from_obj(iterator_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        if hasattr(self.model, "callbacks"):
            model_callbacks = retrieve(self.model.callbacks,
                                       "eval_op",
                                       default=dict())
            for k in model_callbacks:
                import_path = get_str_from_obj(model_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        callback_handler = None
        if not self.config.get("test_mode", False):
            callback_handler = lambda results, paths: self.loghook(
                results=results,
                step=self.get_global_step(),
                paths=paths,
            )

        # offer option to run eval functor:
        # overwrite step op to only include the evaluation of the functor and
        # overwrite callbacks to only include the callbacks of the functor
        if self.config.get("test_mode",
                           False) and "eval_functor" in self.config:
            # offer option to use eval functor for evaluation
            eval_functor = get_obj_from_str(
                self.config["eval_functor"])(config=self.config)
            self.step_ops = lambda: {"eval_op": eval_functor}
            eval_callbacks = dict()
            if hasattr(eval_functor, "callbacks"):
                for k in eval_functor.callbacks:
                    eval_callbacks[k] = get_str_from_obj(
                        eval_functor.callbacks[k])
            set_value(self.config, "eval_hook/eval_callbacks", eval_callbacks)
        self.evalhook = TemplateEvalHook(
            datasets=self.datasets,
            step_getter=self.get_global_step,
            keypath=self._eval_op,
            config=self.config,
            callbacks=eval_callbacks,
            callback_handler=callback_handler,
        )
        self.epoch_hooks.append(self.evalhook)
Exemple #5
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # wrap save and restore into a LambdaCheckpointHook
        self.ckpthook = LambdaCheckpointHook(
            root_path=ProjectManager.checkpoints,
            global_step_getter=self.get_global_step,
            global_step_setter=self.set_global_step,
            save=self.save,
            restore=self.restore,
            interval=set_default(self.config, "ckpt_freq", None),
        )
        if not self.config.get("test_mode", False):
            # in training, excute train ops and add logginghook for train and
            # validation batches
            self._train_ops = set_default(self.config, "train_ops",
                                          ["step_ops/train_op"])
            self._log_ops = set_default(self.config, "log_ops",
                                        ["step_ops/log_op"])
            # logging
            self.loghook = LoggingHook(
                paths=self._log_ops,
                root_path=ProjectManager.train,
                interval=1,
                name="train",
            )
            # wrap it in interval hook
            self.ihook = IntervalHook(
                [self.loghook],
                interval=set_default(self.config, "start_log_freq", 1),
                modify_each=1,
                max_interval=set_default(self.config, "log_freq", 1000),
                get_step=self.get_global_step,
            )
            self.hooks.append(self.ihook)
            # validation logging
            self._validation_log_ops = set_default(self.config,
                                                   "validation_log_ops",
                                                   ["validation_ops/log_op"])
            self._validation_root = os.path.join(ProjectManager.train,
                                                 "validation")
            os.makedirs(self._validation_root, exist_ok=True)
            # logging
            self.validation_loghook = LoggingHook(
                paths=self._validation_log_ops,
                root_path=self._validation_root,
                interval=1,
                name="validation",
            )
            self.hooks.append(self.validation_loghook)
            # write checkpoints after epoch or when interrupted
            self.hooks.append(self.ckpthook)
            wandb_logging = set_default(self.config, "integrations/wandb",
                                        False)
            if wandb_logging:
                import wandb
                from edflow.hooks.logging_hooks.wandb_handler import log_wandb

                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = ProjectManager.root.replace(
                    "/", "-")
                wandb.init(name=ProjectManager.root, config=self.config)
                self.loghook.handlers["scalars"].append(log_wandb)
                self.validation_loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_wandb(
                        *args, **kwargs, prefix="validation"))
            tensorboardX_logging = set_default(self.config,
                                               "integrations/tensorboardX",
                                               False)
            if tensorboardX_logging:
                from tensorboardX import SummaryWriter
                from edflow.hooks.logging_hooks.tensorboardX_handler import (
                    log_tensorboard_config,
                    log_tensorboard_scalars,
                )

                self.tensorboardX_writer = SummaryWriter(ProjectManager.root)
                log_tensorboard_config(self.tensorboardX_writer, self.config,
                                       self.get_global_step())
                self.loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_tensorboard_scalars(
                        self.tensorboardX_writer, *args, **kwargs))
                self.validation_loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_tensorboard_scalars(
                        self.tensorboardX_writer,
                        *args,
                        **kwargs,
                        prefix="validation"))

        else:
            # evaluate
            self._eval_op = set_default(self.config, "eval_hook/eval_op",
                                        "step_ops/eval_op")
            self._eval_callbacks = set_default(self.config,
                                               "eval_hook/eval_callbacks",
                                               dict())
            if not isinstance(self._eval_callbacks, dict):
                self._eval_callbacks = {"cb": self._eval_callbacks}
            for k in self._eval_callbacks:
                self._eval_callbacks[k] = get_obj_from_str(
                    self._eval_callbacks[k])
            label_key = set_default(self.config, "eval_hook/label_key",
                                    "step_ops/eval_op/labels")
            self.evalhook = TemplateEvalHook(
                dataset=self.dataset,
                step_getter=self.get_global_step,
                keypath=self._eval_op,
                config=self.config,
                callbacks=self._eval_callbacks,
                labels_key=label_key,
            )
            self.hooks.append(self.evalhook)
            self._train_ops = []
            self._log_ops = []