コード例 #1
0
def _deep_lod2dol_v3(list_of_nested_things):
    """Turns a list of nested dictionaries into a nested dictionary of lists.
    This function takes care that all leafs of the nested dictionaries are 
    considered as full keys, not only the top level keys.

    .. Note::

        The difference to :func:`deep_lod2dol` is, that the correct type is
        never checked.

    Parameters
    ----------
    list_of_nested_things : list(dict(anything))
        A list of deep dictionaries

    Returns
    -------
    out : dict(anything(list))
        A dict containing lists of leaf entries.
    """

    leaf_keypaths = get_leaf_names(list_of_nested_things[0])

    out = {}
    for key in leaf_keypaths:
        stacked_entry = np.stack(
            [retrieve(d, key) for d in list_of_nested_things])
        set_value(out, key, stacked_entry)

    return out
コード例 #2
0
def test_append_to_list():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}, "e": 2}}
    set_value(dol, "a/2", 3)
    ref = {"a": [1, 2, 3], "b": {"c": {"d": 1}, "e": 2}}
    assert dol == ref

    set_value(dol, "a/5", 6)
    ref = {"a": [1, 2, 3, None, None, 6], "b": {"c": {"d": 1}, "e": 2}}
    assert dol == ref
コード例 #3
0
def test_add_key():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}, "e": 2}}
    set_value(dol, "f", 3)
    ref = {"a": [1, 2], "b": {"c": {"d": 1}, "e": 2}, "f": 3}
    assert dol == ref

    set_value(dol, "b/1", 3)
    ref = {"a": [1, 2], "b": {"c": {"d": 1}, "e": 2, 1: 3}, "f": 3}
    assert dol == ref
コード例 #4
0
 def after_epoch(self, *args, **kwargs):
     if self._active:
         cb_results = super().after_epoch(*args, **kwargs)
         if self.cb_handler is not None:
             results = dict()
             set_value(results, self.keypath, cb_results)
             paths = [self.keypath + "/" + cb for cb in self.cb_names]
             self.cb_handler(results=results, paths=paths)
         self._active = False
コード例 #5
0
def test_top_is_list():
    dol = [{"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2}, 2, 3]

    set_value(dol, "0/k", 4)
    ref = [{"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2, "k": 4}, 2, 3]
    assert ref == dol

    set_value(dol, "0", 1)
    ref = [1, 2, 3]
    assert ref == dol
コード例 #6
0
def test_fancy_overwriting():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2}

    set_value(dol, "e/f", 3)
    ref = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": {"f": 3}}
    assert ref == dol

    set_value(dol, "e/f/1/g", 3)
    ref = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": {"f": [None, {"g": 3}]}}
    assert ref == dol
コード例 #7
0
def test_extra_labels():
    D = DebugDataset(size=10)
    D.append_labels = True

    E = ExtraLabelsDataset(D, lambda dset, idx: {"new": idx})

    de = E[0]
    ref = D[0]
    set_value(ref, "labels_/new", 0)

    assert de == ref

    assert "new" in E.labels
    assert np.all(E.labels["new"] == np.arange(10))
コード例 #8
0
ファイル: commandline_kwargs.py プロジェクト: mritv/edflow
def update_config(config, additional_kwargs):
    """additional_kwargs are added in order of the keys' length, e.g. 'a'
    is overriden by 'a/b'."""
    keys = sorted(additional_kwargs.keys())
    for k in keys:
        set_value(config, k, additional_kwargs[k])

    def replace(k):
        if isinstance(k, str) and k[0] == "{" and k[-1] == "}":
            k_ = k[1:-1].strip()
            return retrieve(config, k_, default=k)
        else:
            return k

    walk(config, replace, inplace=True)
コード例 #9
0
ファイル: meta.py プロジェクト: jhaux/edflow
        def __call__(self, key_path, path):
            if isinstance(path, str) and regex.match(path):
                f = os.path.basename(path)
                f_ = f[: -len(".npy")]
                key_, shape, dtype = f_.split("-*-")
                shape = tuple([int(s) for s in shape.split("x")])

                key_path = key_path.split("/")
                if len(key_path) == 1:
                    key = key_
                else:
                    key = "/".join(key_path[:-1] + [key_])

                mmap = np.memmap(path, mode="c", shape=shape, dtype=dtype)

                set_value(self.labels, key, mmap)
コード例 #10
0
def _deep_lod2dol(list_of_nested_things):
    """Turns a list of nested dictionaries into a nested dictionary of lists.
    This function takes care that all leafs of the nested dictionaries are 
    considered as full keys, not only the top level keys.

    Parameters
    ----------
    list_of_nested_things : list
        A list of deep nested dictionaries.

    Returns
    -------
    out : dict
        A dict containing lists of leaf entries.

    Raises
    ------
    ValueError
        Raised if the passed object is not a ``list`` or if its values are not ``dict`` s.
    """

    # Put custom exceptions in try excepts so that we do not check everytime
    # the type, only when an error occurs
    try:
        leaf_keypaths = get_leaf_names(list_of_nested_things[0])
    except Exception as e:
        if not isinstance(list_of_nested_things, list):
            raise TypeError("Expected `list` but got "
                            "{}".format(type(list_of_nested_things)))
        else:
            raise e

    try:
        out = {}
        for key in leaf_keypaths:
            stacked_entry = np.stack(
                [retrieve(d, key) for d in list_of_nested_things])
            set_value(out, key, stacked_entry)
    except Exception as e:
        for v in list_of_nested_things:
            if not isinstance(v, dict):
                raise TypeError("Entries must be `dict` but got "
                                "{}".format(type(v)))
        raise e

    return out
コード例 #11
0
def _deep_lod2dol_v2(list_of_nested_things):
    """Turns a list of nested dictionaries into a nested dictionary of lists.
    This function takes care that all leafs of the nested dictionaries are 
    considered as full keys, not only the top level keys.

    .. Note::

        The difference to :func:`deep_lod2dol` is, that the correct type is
        always checked not only at exceptions.

    Parameters
    ----------
    list_of_nested_things : list
        A list of deep dictionaries

    Returns
    -------
    out : dict
        A dict containing lists of leaf entries.

    Raises
    ------
    ValueError
        Raised if the passed object is not a ``list`` or if its values are not
        ``dict`` s.
    """

    if not isinstance(list_of_nested_things, list):
        raise TypeError("Expected `list` but got "
                        "{}".format(type(list_of_nested_things)))
    leaf_keypaths = get_leaf_names(list_of_nested_things[0])

    for v in list_of_nested_things:
        if not isinstance(v, dict):
            raise TypeError("Entries must be `dict` but got "
                            "{}".format(type(v)))
    out = {}
    for key in leaf_keypaths:
        stacked_entry = np.stack(
            [retrieve(d, key) for d in list_of_nested_things])
        set_value(out, key, stacked_entry)

    return out
コード例 #12
0
ファイル: meta.py プロジェクト: jhaux/edflow
def clean_keys(labels, loaders):
    """Removes all loader information from the keys.

    Parameters
    ----------
    labels : dict(str, numpy.memmap)
        Labels contain all load-easy dataset relevant data. 
    
    Returns
    -------
    labels : dict(str, numpy.memmap)
        The original labels, with keys without the ``:loader`` part.
    """

    class Cleaner:
        def __init__(self):
            self.to_delete = []
            self.to_set = []

        def __call__(self, key, val):
            k, l = loader_from_key(key)
            if l is not None:
                self.to_set += [[k + "_", retrieve(labels, key)]]
                self.to_delete += [key]

    C = Cleaner()
    walk(labels, C, pass_key=True)

    for key, val in C.to_set:
        set_value(labels, key, val)

    for key in C.to_delete:
        pop_keypath(labels, key)

    for k_ in list(loaders.keys()):
        if k_ in labels:
            k = k_ + "_"
            labels[k] = labels[k_]
            del labels[k_]

    return labels
コード例 #13
0
def default_repose_eval(root, data_in, data_out, config):

    # Set data_out to be data_in
    debug_mode = os.environ.get('DEBUG_MODE', 'False') == 'True'
    print('DEBUG', debug_mode)

    LOGGER.info("Setting up repose eval...")

    repose_config = config.get('repose_kwargs', {})
    if debug_mode:
        data_out = data_in
        koim = set_default(repose_config, 'data_out_im_key', 'target')
    else:
        koim = set_default(repose_config, 'data_out_im_key', 'frame_gen')
    set_value(repose_config, 'data_in_kp_key', 'target_keypoints_rel')

    # Only use pck for now
    set_value(repose_config, 'metrics', ['pck'])
    threshs = set_default(repose_config, 'metrics_kwargs/pck/thresholds',
                          PCK_THRESH)

    # For scaling the keypoints from relative to absolute
    gen_size = data_out[0][koim]
    if isinstance(gen_size, list):
        gen_size = gen_size[0]
    gen_size = np.array(gen_size.shape[:2])

    rp_eval = RePoseEval(**repose_config)

    LOGGER.info("Running repose eval...")

    print(len(data_in))
    print(len(data_out))
    print(repose_config)

    rp_eval(root, data_in, data_out, config)

    LOGGER.info("repose eval finished!")
コード例 #14
0
def test_set_value():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}, "e": 2}}
    ref = {"a": [3, 2], "b": {"c": {"d": 1}, "e": 2}}

    set_value(dol, "a/0", 3)
    assert dol == ref

    ref = {"a": [3, 2], "b": {"c": {"d": 1}, "e": 3}}

    set_value(dol, "b/e", 3)
    assert dol == ref

    set_value(dol, "a/1/f", 3)

    ref = {"a": [3, {"f": 3}], "b": {"c": {"d": 1}, "e": 3}}
    assert dol == ref
コード例 #15
0
def test_top_is_dict():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2}

    set_value(dol, "h", 4)
    ref = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2, "h": 4}
    assert ref == dol

    set_value(dol, "i/j/k", 4)
    ref = {
        "a": [1, 2],
        "b": {
            "c": {
                "d": 1
            }
        },
        "e": 2,
        "h": 4,
        "i": {
            "j": {
                "k": 4
            }
        }
    }
    assert ref == dol

    set_value(dol, "j/0/k", 4)
    ref = {
        "a": [1, 2],
        "b": {
            "c": {
                "d": 1
            }
        },
        "e": 2,
        "h": 4,
        "i": {
            "j": {
                "k": 4
            }
        },
        "j": [{
            "k": 4
        }],
    }
    assert ref == dol
コード例 #16
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # wrap save and restore into a LambdaCheckpointHook
        self.ckpthook = LambdaCheckpointHook(
            root_path=ProjectManager.checkpoints,
            global_step_getter=self.get_global_step,
            global_step_setter=self.set_global_step,
            save=self.save,
            restore=self.restore,
            interval=set_default(self.config, "ckpt_freq", None),
            ckpt_zero=set_default(self.config, "ckpt_zero", False),
        )
        # write checkpoints after epoch or when interrupted during training
        if not self.config.get("test_mode", False):
            self.hooks.append(self.ckpthook)

        ## hooks - disabled unless -t is specified

        # execute train ops
        self._train_ops = set_default(self.config, "train_ops",
                                      ["train/train_op"])
        train_hook = ExpandHook(paths=self._train_ops, interval=1)
        self.hooks.append(train_hook)

        # log train/step_ops/log_ops in increasing intervals
        self._log_ops = set_default(self.config, "log_ops",
                                    ["train/log_op", "validation/log_op"])
        self.loghook = LoggingHook(paths=self._log_ops,
                                   root_path=ProjectManager.train,
                                   interval=1)
        self.ihook = IntervalHook(
            [self.loghook],
            interval=set_default(self.config, "start_log_freq", 1),
            modify_each=1,
            max_interval=set_default(self.config, "log_freq", 1000),
            get_step=self.get_global_step,
        )
        self.hooks.append(self.ihook)

        # setup logging integrations
        if not self.config.get("test_mode", False):
            default_wandb_logging = {
                "active": False,
                "handlers": ["scalars", "images"]
            }
            wandb_logging = set_default(self.config, "integrations/wandb",
                                        default_wandb_logging)
            if wandb_logging["active"]:
                import wandb
                from edflow.hooks.logging_hooks.wandb_handler import (
                    log_wandb,
                    log_wandb_images,
                )

                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = ProjectManager.root.strip(
                    "/").replace("/", "-")
                wandb_project = set_default(self.config,
                                            "integrations/wandb/project", None)
                wandb_entity = set_default(self.config,
                                           "integrations/wandb/entity", None)
                wandb.init(
                    name=ProjectManager.root,
                    config=self.config,
                    project=wandb_project,
                    entity=wandb_entity,
                )

                handlers = set_default(
                    self.config,
                    "integrations/wandb/handlers",
                    default_wandb_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(log_wandb)
                if "images" in handlers:
                    self.loghook.handlers["images"].append(log_wandb_images)

            default_tensorboard_logging = {
                "active": False,
                "handlers": ["scalars", "images", "figures"],
            }
            tensorboard_logging = set_default(self.config,
                                              "integrations/tensorboard",
                                              default_tensorboard_logging)
            if tensorboard_logging["active"]:
                try:
                    from torch.utils.tensorboard import SummaryWriter
                except:
                    from tensorboardX import SummaryWriter

                from edflow.hooks.logging_hooks.tensorboard_handler import (
                    log_tensorboard_config,
                    log_tensorboard_scalars,
                    log_tensorboard_images,
                    log_tensorboard_figures,
                )

                self.tensorboard_writer = SummaryWriter(ProjectManager.root)
                log_tensorboard_config(self.tensorboard_writer, self.config,
                                       self.get_global_step())
                handlers = set_default(
                    self.config,
                    "integrations/tensorboard/handlers",
                    default_tensorboard_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(
                        lambda *args, **kwargs: log_tensorboard_scalars(
                            self.tensorboard_writer, *args, **kwargs))
                if "images" in handlers:
                    self.loghook.handlers["images"].append(
                        lambda *args, **kwargs: log_tensorboard_images(
                            self.tensorboard_writer, *args, **kwargs))
                if "figures" in handlers:
                    self.loghook.handlers["figures"].append(
                        lambda *args, **kwargs: log_tensorboard_figures(
                            self.tensorboard_writer, *args, **kwargs))
        ## epoch hooks

        # evaluate validation/step_ops/eval_op after each epoch
        self._eval_op = set_default(self.config, "eval_hook/eval_op",
                                    "validation/eval_op")
        _eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks",
                                      dict())
        if not isinstance(_eval_callbacks, dict):
            _eval_callbacks = {"cb": _eval_callbacks}
        eval_callbacks = dict()
        for k in _eval_callbacks:
            eval_callbacks[k] = _eval_callbacks[k]
        if hasattr(self, "callbacks"):
            iterator_callbacks = retrieve(self.callbacks,
                                          "eval_op",
                                          default=dict())
            for k in iterator_callbacks:
                import_path = get_str_from_obj(iterator_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        if hasattr(self.model, "callbacks"):
            model_callbacks = retrieve(self.model.callbacks,
                                       "eval_op",
                                       default=dict())
            for k in model_callbacks:
                import_path = get_str_from_obj(model_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        callback_handler = None
        if not self.config.get("test_mode", False):
            callback_handler = lambda results, paths: self.loghook(
                results=results,
                step=self.get_global_step(),
                paths=paths,
            )

        # offer option to run eval functor:
        # overwrite step op to only include the evaluation of the functor and
        # overwrite callbacks to only include the callbacks of the functor
        if self.config.get("test_mode",
                           False) and "eval_functor" in self.config:
            # offer option to use eval functor for evaluation
            eval_functor = get_obj_from_str(
                self.config["eval_functor"])(config=self.config)
            self.step_ops = lambda: {"eval_op": eval_functor}
            eval_callbacks = dict()
            if hasattr(eval_functor, "callbacks"):
                for k in eval_functor.callbacks:
                    eval_callbacks[k] = get_str_from_obj(
                        eval_functor.callbacks[k])
            set_value(self.config, "eval_hook/eval_callbacks", eval_callbacks)
        self.evalhook = TemplateEvalHook(
            datasets=self.datasets,
            step_getter=self.get_global_step,
            keypath=self._eval_op,
            config=self.config,
            callbacks=eval_callbacks,
            callback_handler=callback_handler,
        )
        self.epoch_hooks.append(self.evalhook)
コード例 #17
0
def test_set_value_fail():
    with pytest.raises(Exception):
        dol = {"a": [1, 2], "b": {"c": {"d": 1}, "e": 2}}
        set_value(dol, "a/g", 3)  # should raise