Ejemplo n.º 1
0
    def before_step(self, step, fetches, feeds, batch):
        """
        Steps taken before the training step.
        Parameters
        ----------
        step
            Training step.
        fetches
            Fetches for the next session.run call.
        feeds
            Feeds for the next session.run call.
        batch
            The batch to be iterated over.

        """

        def to_image(obj):
            if isinstance(obj, np.ndarray) and len(obj.shape) == 3:
                batches, height, width = obj.shape
                obj = obj.reshape(batches, height, width, 1)
            if isinstance(obj, np.ndarray) and len(obj.shape) == 4:
                return obj.transpose(0, 3, 1, 2)
            else:
                return obj

        walk(feeds, to_image, inplace=True)

        super().before_step(step, fetches, feeds, batch)
Ejemplo n.º 2
0
def plot_datum(
    nested_thing,
    savename="datum.png",
    heuristics=default_heuristic,
    plt_functions=PLOT_FUNCTIONS,
):
    """Plots all data in the nested_thing as best as can.

    If heuristics is given, this determines how each leaf datum is converted
    to something plottable.

    Parameters
    ---------
    nested_thing : dict or list
	Some nested object.
    savename : str
	``Path/to/the/plot.png``.
    heuristics : Callable
	If given this should produce a string specifying
        the kind of data of the leaf. If ``None`` determinde automatically.
        See :func:`default_heuristic`.
    plt_functions : dict of Callables
	Maps a ``kind`` to a function which
        can plot it. Each callable must be able to receive a the key, the
        leaf object and the Axes to plot it in.

    """

    class Plotter(object):
        def __init__(self, kind_fn, savename):
            self.kind_fn = kind_fn
            self.savename = savename
            self.buffer = []

        def __call__(self, key, obj):
            kind = self.kind_fn(key, obj)
            self.buffer += [[kind, key, obj]]

        def plot(self):
            n_pl = len(self.buffer)

            f = plt.figure(figsize=(5, 2 * n_pl))

            gs = gridspec.GridSpec(n_pl, 1)

            for i, [kind, key, obj] in enumerate(self.buffer):
                ax = f.add_subplot(gs[i])
                plt_functions[kind](key, obj, ax)

            f.savefig(self.savename)

        def __str__(self):
            self.plot()
            return "Saved Plot at {}".format(self.savename)

    P = Plotter(heuristics, savename)

    walk(nested_thing, P, pass_key=True)

    print(P)
Ejemplo n.º 3
0
    def __init__(self, root):
        """
        Parameters
        ----------
        root : str
            Where to look for all the data.
        """
        meta_path = os.path.join(root, "meta.yaml")
        self.meta = meta = yaml.safe_load(open(meta_path, "r"))

        labels = load_labels(os.path.join(root, "labels"))
        self.loaders, self.loader_kwargs = setup_loaders(labels, meta)
        self.labels = clean_keys(labels, self.loaders)

        class Lenner:
            def __init__(self):
                self.l = None
                self.visited = []

            def __call__(self, key, label):
                if self.l is None:
                    self.l = len(label)
                else:
                    if len(label) != self.l:
                        raise ValueError(f"Label {key} has a different length "
                                         "than the other labels.\n"
                                         f"Already seen: {self.visited}")
                self.visited += [key]

        L = Lenner()
        walk(self.labels, L, pass_key=True)

        self.num_examples = L.l

        self.append_labels = True
Ejemplo n.º 4
0
def convert_logs2numpy(logs):
    def conditional_convert2np(log_item):
        if isinstance(log_item, torch.Tensor):
            log_item = log_item.detach().cpu().numpy()
        return log_item

    # convert to numpy
    walk(logs, conditional_convert2np, inplace=True)
    return logs
Ejemplo n.º 5
0
def show_example(dset, idx):
    ex = dset[idx]
    st.header("Keys")
    walk(ex, display, pass_key=True)
    st.header("Summary")
    summary = pp2mkdtable(ex, jupyter_style=True)
    # print markdown summary on console for easy copy and pasting in readme etc
    print(summary)
    st.markdown(summary)
Ejemplo n.º 6
0
def test_walk_pass_key_inplace():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2}
    ref = {"a": [-1, -2], "b": {"c": {"d": -1}}, "e": -2}

    def fn(key, leaf):
        return -leaf

    walk(dol, fn, inplace=True, pass_key=True)

    assert dol == ref
Ejemplo n.º 7
0
    def get_example(self, idx):
        """Get the examples from the base dataset at defined at ``view[idx]``.
        """
        def get_view(view):
            return view[idx]

        view = walk(self.views, get_view)

        view_example = walk(view, self.base.__getitem__, walk_np_arrays=True)

        return view_example
Ejemplo n.º 8
0
    def after_step(self, step, results):
        def convert(var_or_tens):
            if hasattr(var_or_tens, 'cpu'):
                var_or_tens = var_or_tens.cpu()

            if isinstance(var_or_tens, torch.autograd.Variable):
                return var_or_tens.data.numpy()
            elif isinstance(var_or_tens, torch.Tensor):
                return var_or_tens.numpy()
            else:
                return var_or_tens

        walk(results, convert, inplace=True)
Ejemplo n.º 9
0
def example_callbacks(example, dset_handler):
    '''Creates all interactive connections to update and visualize the content
    of examples.'''

    Cooon = Connector(dset_handler)
    walk(example, Cooon, pass_key=True)

    callbacks = Cooon.callbacks
    callbacks['toggle_all'] = {
        'args': Cooon.toggle_all_ins_and_outs,
        'callback': ToggleAllCallback(len(callbacks), True)
    }
    return callbacks
Ejemplo n.º 10
0
    def get_example(self, idx):
        """Get the examples from the base dataset at defined at ``view[idx]``. Load loaders if applicable.
        """
        def get_view(view):
            return view[idx]

        view = walk(self.views, get_view)

        view_example = walk(view, self.base.__getitem__, walk_np_arrays=True)

        if len(self.loaders) > 0:
            loaders_example = super().get_example(idx)
            view_example.update(loaders_example)

        return view_example
Ejemplo n.º 11
0
def update_config(config, additional_kwargs):
    """additional_kwargs are added in order of the keys' length, e.g. 'a'
    is overriden by 'a/b'."""
    keys = sorted(additional_kwargs.keys())
    for k in keys:
        set_value(config, k, additional_kwargs[k])

    def replace(k):
        if isinstance(k, str) and k[0] == "{" and k[-1] == "}":
            k_ = k[1:-1].strip()
            return retrieve(config, k_, default=k)
        else:
            return k

    walk(config, replace, inplace=True)
Ejemplo n.º 12
0
    def before_step(self, step, fetches, feeds, batch):
        def convert(obj):
            if isinstance(obj, np.ndarray):
                try:
                    obj = torch.tensor(obj)
                    obj = obj.to(self.dtype)
                    if self.use_gpu:
                        obj = obj.cuda()
                    return obj
                except Exception:
                    return obj
            else:
                return obj

        walk(feeds, convert, inplace=True)
Ejemplo n.º 13
0
def test_meta_view_dset():
    N = 100
    V = 25
    try:
        super_root, base_root, view_root = _setup(".", N, V)

        M = MetaViewDataset(view_root)
        M.expand = True
        M.append_labels = False
        M.show()

        assert len(M) == V

        for kk in ["simple1", "simple", "complex"]:
            assert kk in M.labels
            if kk == "complex":
                for i in range(2):
                    for k in ["attr1", "attr2", "image_", "keypoints"]:
                        assert k in M.labels[kk][i]
                        assert len(M.labels[kk][i][k]) == V
            else:
                for k in ["attr1", "attr2", "image_", "keypoints"]:
                    assert k in M.labels[kk]
                    assert len(M.labels[kk][k]) == V

        d = M[0]
        # For ex 0 this is the same for both complex and simple
        single_ref = {"image": np.ones(shape=(64, 64, 3)), "index_": 0}

        ref_simple = single_ref
        ref_complex = [[single_ref] * 3] * 20

        ref = {
            "simple1": ref_simple,
            "simple": ref_simple,
            "complex": [ref_complex, ref_simple],
            "index_": 0,
        }

        def tester(key, val):
            assert np.all(val == retrieve(ref, key))

        walk(d, tester, pass_key=True)

        assert hasattr(M, "meta")

    finally:
        _teardown(super_root)
Ejemplo n.º 14
0
    def _maybe_append_labels(self, datum, index):
        if self.append_labels:

            def label_getter(labels):
                return labels[index]

            labels = walk(self.labels, label_getter)
            update(datum, {"labels_": labels})
Ejemplo n.º 15
0
    def prepare_inputs_inplace(self, inputs):
        '''Casts all input to torch Tensor and pushes them to the gpu.'''
        before = time.time()

        inputs = walk(inputs, np2pt, inplace=True)

        if retrieve(self.config, "debug_timing", default=False):
            self.logger.info("prepare of data needed {} s".format(time.time() - before))
Ejemplo n.º 16
0
def test_walk():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2}
    ref = {"a": [-1, -2], "b": {"c": {"d": -1}}, "e": -2}

    def fn(leaf):
        return -leaf

    val = walk(dol, fn)

    assert val == ref
Ejemplo n.º 17
0
def test_walk_pass_key():
    dol = {"a": [1, 2], "b": {"c": {"d": 1}}, "e": 2}
    ref = {"a": [-1, -2], "b": {"c": {"d": -1}}, "e": -2}

    def fn(key, leaf):
        return -leaf

    val = walk(dol, fn, pass_key=True)

    assert val == ref
Ejemplo n.º 18
0
Archivo: meta.py Proyecto: jhaux/edflow
def load_labels(root):
    """
    Parameters
    ----------
    root : str
        Where to look for the labels.

    Returns
    -------
    labels : dict
        All labels as ``np.memmap`` s.
    """

    regex = re.compile(r".*-\*-.*-\*-.*\.npy")

    label_files = _get_label_files(root)

    class Loader:
        def __init__(self):
            self.labels = {}

        def __call__(self, key_path, path):
            if isinstance(path, str) and regex.match(path):
                f = os.path.basename(path)
                f_ = f[: -len(".npy")]
                key_, shape, dtype = f_.split("-*-")
                shape = tuple([int(s) for s in shape.split("x")])

                key_path = key_path.split("/")
                if len(key_path) == 1:
                    key = key_
                else:
                    key = "/".join(key_path[:-1] + [key_])

                mmap = np.memmap(path, mode="c", shape=shape, dtype=dtype)

                set_value(self.labels, key, mmap)

    L = Loader()
    walk(label_files, L, pass_key=True)

    return L.labels
Ejemplo n.º 19
0
    def before_step(self, *args, **kwargs):
        """Checks if something changed and if yes runs the callback."""

        try:
            updates = yaml.full_load(open(self.ufile, "r"))

            if self.last_updates is not None:
                changes = {}

                def is_changed(key, val, changes=changes):
                    if contains_key(key, updates):
                        other_val = retrieve(key, updates)

                        change = np.any(val != other_val)
                    else:
                        # This key is new -> Changes did happen!
                        change = True
                    changes[key] = change

                self.logger.debug("Pre  CHANGES: {}".format(changes))
                walk(self.last_updates, is_changed, pass_key=True)
                self.logger.debug("Post CHANGES: {}".format(changes))

                if np.any(list(changes.values())):
                    self.callback(updates)

                    self.logger.debug("Runtime inputs received.")
                    self.logger.debug("{}".format(updates))

                    self.last_updates = updates
            else:
                if updates is not None:
                    self.callback(updates)

                    self.logger.info("Runtime inputs received.")
                    self.logger.debug("{}".format(updates))

                    self.last_updates = updates
        except Exception as e:
            self.logger.error("Something bad happend :(")
            self.logger.error("{}".format(e))
            self.logger.error(traceback.format_exc())
Ejemplo n.º 20
0
def display_controls(datasource_1_value):
    '''Makes sure, that the slider changes the displayed example.
    '''

    # Get example
    ex = Ht[int(datasource_1_value)]

    # display leaf variables
    de = DisplayElements()
    walk(ex, de, pass_key=True)
    content = de.elements

    connector = Connector()
    walk(ex, connector, pass_key=True)

    for key, connection in connector.callbacks.items():
        print(key, connection)
        app.callback(**connection['args'])(connection['callback'])

    return html.Div(content, )
Ejemplo n.º 21
0
    def after_step(self, step, results):
        """
        Steps taken after the training step.
        :param step: Training step.
        :param results: Result of the session.
        :return:
        """
        super().after_step(step, results)

        def to_image(k, obj):
            if (
                "weights" not in k
                and isinstance(obj, np.ndarray)
                and len(obj.shape) == 4
            ):
                return obj.transpose(0, 2, 3, 1)
            else:
                return obj

        walk(results, to_image, inplace=True, pass_key=True)
Ejemplo n.º 22
0
Archivo: meta.py Proyecto: jhaux/edflow
def clean_keys(labels, loaders):
    """Removes all loader information from the keys.

    Parameters
    ----------
    labels : dict(str, numpy.memmap)
        Labels contain all load-easy dataset relevant data. 
    
    Returns
    -------
    labels : dict(str, numpy.memmap)
        The original labels, with keys without the ``:loader`` part.
    """

    class Cleaner:
        def __init__(self):
            self.to_delete = []
            self.to_set = []

        def __call__(self, key, val):
            k, l = loader_from_key(key)
            if l is not None:
                self.to_set += [[k + "_", retrieve(labels, key)]]
                self.to_delete += [key]

    C = Cleaner()
    walk(labels, C, pass_key=True)

    for key, val in C.to_set:
        set_value(labels, key, val)

    for key in C.to_delete:
        pop_keypath(labels, key)

    for k_ in list(loaders.keys()):
        if k_ in labels:
            k = k_ + "_"
            labels[k] = labels[k_]
            del labels[k_]

    return labels
Ejemplo n.º 23
0
    def __init__(self, root):
        super().__init__(root)

        base_import = retrieve(self.meta, "base_dset")
        base_kwargs = retrieve(self.meta, "base_kwargs")
        self.base = get_obj_from_str(base_import)(**base_kwargs)
        self.base.append_labels = False

        views = retrieve(self.meta, "views", default="view")

        def get_label(key):
            return retrieve(self.labels, key)

        self.views = walk(views, get_label)

        if not os.path.exists(os.path.join(root, ".constructed.txt")):

            def constructor(name, view):
                folder_name = name
                savefolder = os.path.join(root, "labels", folder_name)

                os.makedirs(savefolder, exist_ok=True)

                for key, label in tqdm(self.base.labels.items(),
                                       desc=f"Exporting View {name}"):

                    savepath = os.path.join(root, "labels", name)
                    label_view = np.take(label, view, axis=0)
                    store_label_mmap(label_view, savepath, key)

            walk(self.views, constructor, pass_key=True)

            with open(os.path.join(root, ".constructed.txt"), "w+") as cf:
                cf.write("Do not delete, this reduces loading times.\n"
                         "If you need to re-render the view, you can safely "
                         "delete this file.")

            # Re-initialize as we need to load the labels again.
            super().__init__(root)
Ejemplo n.º 24
0
    def prepare_logs(self, inputs, predictions, losses, model, granularity):
        '''Logs need to be differentiated into ``images`` and ``scalars``. This
        function casts everything we want to log to numpy and stores it
        correctly in the output log dictionary.
        '''
        # sample variational part
        output_sample = model(inputs["pt"], mode="sample_appearance")
        output_sample = {'image': output_sample}
        output_sample.update(model.saved_tensors)

        losses_sample = self.criterion(inputs["pt"], output_sample)
        sample_images = {
            "images_prediction_sample": pt2np(output_sample["image"]),
        }

        # concatenate logs
        logs = {
            "images": {
                "appearance": inputs["np"]["appearance"],
                "target": inputs["np"]["target"],
                "pose": inputs["np"]["pose"],
                "images_prediction": pt2np(predictions['image']),
            },
            "scalars": {
                **losses[granularity]
            },
        }

        # convert to numpy
        def conditional_convert2np(log_item):
            if isinstance(log_item, torch.Tensor):
                log_item = log_item.detach().cpu().numpy()
            return log_item

        walk(logs, conditional_convert2np, inplace=True)

        return logs
Ejemplo n.º 25
0
def test_meta_dset():
    N = 100
    try:
        root = _setup(".", N)

        M = MetaDataset(root)
        M.expand = True
        M.show()

        assert len(M) == N

        for k in ["attr1", "attr2", "image_", "keypoints"]:
            assert k in M.labels
            assert len(M.labels[k]) == N

        d = M[0]
        ref = {
            "image": np.ones(shape=(64, 64, 3)),
            "index_": 0,
            "labels_": {
                "image_": os.path.join(root, "images", "000.png"),
                "attr1": 0,
                "attr2": np.zeros((2)),
                "keypoints": np.ones((17, 2)),
            },
        }

        def tester(key, val):
            assert np.all(val == retrieve(ref, key))

        walk(d, tester, pass_key=True)

        assert hasattr(M, "meta")

    finally:
        _teardown(root)
Ejemplo n.º 26
0
def test_walk_npon():
    import numpy as np

    dol = {"a": np.array([1, 2]), "b": {"c": {"d": 1}}, "e": 2}
    ref = {"a": np.array([-1, -2]), "b": {"c": {"d": -1}}, "e": -2}

    def fn(leaf):
        return -leaf

    val = walk(dol, fn, walk_np_arrays=True)

    assert np.all(val["a"] == ref["a"])
    del val["a"]
    del ref["a"]
    assert val == ref
Ejemplo n.º 27
0
    def run(self, fetches, feed_dict):
        """Runs all fetch ops and stores the results.

        Args:
            fetches (dict): name: Callable pairs.
            feed_dict (dict): Passed as kwargs to all fetch ops

        Returns:
            dict: name: results pairs.
        """
        def fn(fetch_fn):
            return fetch_fn(self.model, **feed_dict)

        results = walk(fetches, fn)

        return results
Ejemplo n.º 28
0
    def run(self, fetches, feed_dict):
        """Runs all fetch ops and stores the results.

        Parameters
        ----------
        fetches : dict
	    name: Callable pairs.
        feed_dict : dict
	    Passed as kwargs to all fetch ops

        Returns
        -------
        dict
            name: results pairs.
        """
        def fn(fetch_fn):
            return fetch_fn(self.model, **feed_dict)

        results = walk(fetches, fn)

        return results
Ejemplo n.º 29
0
 def make_feeds(self, batch):
     # copy of batches
     feeds = walk(batch, lambda val: val)
     return feeds
Ejemplo n.º 30
0
def sizes(t):
    return walk(t, sizes_)