Пример #1
0
def _h5files(path: str,
             sino_fn='analytical_phantom_sinogram.h5',
             recon_fn='recons.h5',
             recon_ms_fn='recon_multi_scale_merged.h5'):
    from dxpy.core.path import Path
    return str(Path(path) / sino_fn), str(Path(path) / recon_fn), str(
        Path(path) / recon_ms_fn)
Пример #2
0
class TensorSpec:
    def __init__(self, path='data', shape=None, dtype=None, chunks=None):
        self._path = Path(path)
        self.shape = shape
        self.dtype = dtype
        self.chunks = chunks

    @property
    def dataset(self):
        return self._path.basename

    @property
    def groups(self):
        return self._path.parts()[:-1]

    def get_group(self, h5py_file, *, restrict=False):
        ds = h5py_file
        for g in self.groups:
            if restrict:
                ds = ds[g]
            else:
                ds = ds.require_group(g)
        return ds

    def create_dateset(self, h5py_file, data=None):
        gp = self.get_group(h5py_file)
        if data is None:
            return gp.create_dateset(self.dataset, shape=self.shape, dtype=self.dtype, chunks=chunks)
        else:
            return gp.create_dataset(self.dataset, data=data)

    def get_dataset(self, h5py_file):
        gp = self.get_group(h5py_file)
        return gp[self.basename]
Пример #3
0
def _update_statistics(dataset_config_name, low_dose_ratio):
    from dxpy.learn.config import config
    from dxpy.core.path import Path
    if isinstance(dataset_config_name, str):
        dataset_config_name = Path(dataset_config_name).parts()
    c = config
    for k in dataset_config_name:
        c = c[k]
    mean, std = c['mean'], c['std']
    mean /= low_dose_ratio
    std /= low_dose_ratio
    c['mean'], c['std'] = mean, std
    return mean, std
Пример #4
0
 def __enter__(self):
     global default_engine
     import matlab.engine
     import os
     from dxpy.core.path import Path
     self.eng = matlab.engine.start_matlab()
     path_dxmat = os.environ['PATH_DXL_DXMAT']
     path_gen = Path(path_dxmat)
     if self.sub_path is not None:
         path_gen = path_gen / self.sub_path
     self.eng.addpath(str(path_gen))
     self.pre = default_engine
     default_engine = self.eng
     return self.eng
Пример #5
0
 def test_mid_tilde_support(self):
     p = Path('a/x~y')
     self.assertEqual(str(p), 'a/x~y')
Пример #6
0
 def test_tilde_support(self):
     p = Path('~/x')
     import os
     self.assertEqual(str(p), os.environ['HOME'] + '/x')
Пример #7
0
 def test_div(self):
     p = Path('/tmp')
     p = p / 'sub'
     self.assertEqual(p.abs, '/tmp/sub')
Пример #8
0
 def test_copy_init(self):
     p = Path('/tmp/file')
     p2 = Path(p)
     self.assertEqual(p.abs, p2.abs)
Пример #9
0
 def test_name_dir(self):
     p = Path('/tmp/base/')
     self.assertEqual(p.basename, 'base')
Пример #10
0
 def test_parent(self):
     p = Path('/tmp/base')
     self.assertEqual(p.parent(), Path('/tmp'))
Пример #11
0
class ConfigsView:
    def __init__(self, dct, base_path=''):
        self.data = dct
        self.base = Path(base_path)

    def __unified_keys(self, path_or_keys):
        if isinstance(path_or_keys, (list, tuple)):
            return list(self.base.parts()), list(path_or_keys)
        else:
            return list(self.base.parts()), list(Path(path_or_keys).parts())

    def _get_value_raw(self, keys):
        result = self.data
        for k in keys:
            if not isinstance(result, (dict, ConfigsView)):
                return None
            result = result.get(k)
        return result

    def _form_path(self, keys):
        return '/'.join(keys)

    def _query(self, base_keys, local_keys):
        # if len(local_keys) <= 1:
        result = self._get_value_raw(base_keys + local_keys)
        # else:
        # path = self.base / local_keys[0]
        # key_path = self._form_path(local_keys[1:])
        # result = ConfigsView(self.data, path)._query(key_path)
        return result

    def _search(self, key):
        base_keys, local_keys = self.__unified_keys(key)
        result = self._query(base_keys, local_keys)
        path = self._form_path(base_keys + local_keys)
        if result is None and len(base_keys) > 0:
            new_path = self._form_path((base_keys + local_keys)[:-2])
            result, _ = ConfigsView(self.data,
                                    new_path)._search(local_keys[-1])
        return result, path

    def _post_processing(self, result, path, default, restrict):
        if isinstance(result, dict) or (result is None and default is None
                                        and not restrict):
            result = ConfigsView(self.data, path)
        elif result is None:
            result = default
        return result

    def to_dict(self):
        result, path = self._search(str(self.base))
        if isinstance(result, dict):
            return result
        else:
            raise ValueError("Configs view on base path is not dict.")

    def get(self, key, default=None, *, restrict=False):
        result, path = self._search(key)
        return self._post_processing(result, path, default, restrict)

    def __getitem__(self, key):
        return self.get(key, restrict=True)

    def __iter__(self):
        dct = self._get_value_raw(self.base.parts())
        if dct is None:
            return list().__iter__()
        else:
            return dct.__iter__()
Пример #12
0
 def test_str_root(self):
     p = Path('/')
     self.assertEqual(p.abs, '/')
Пример #13
0
 def __init__(self, path='data', shape=None, dtype=None, chunks=None):
     self._path = Path(path)
     self.shape = shape
     self.dtype = dtype
     self.chunks = chunks
Пример #14
0
 def __init__(self, dct, base_path=''):
     self.data = dct
     self.base = Path(base_path)
Пример #15
0
class Graph:
    """
    Base class of components.

    A `Graph` is an dict like collections of nodes and edges. Nodes of `Graph` can be a tensor another sub `Graph`.
    Edges representing relations of sub graphs/tensors, flow of information.

    A `Graph` is an generalization of `tf.Graph`, which is designed for following features:
        1. An unified interface of `tf.Graph` and general compute graph or operations/procedures;
        2. Seperate config and implementation, use TreeDict for configs, and supports multiple ways of config;
        3. An easy-to-use way of seperate/reuse subgraphs;
        4. Supports an warp of sessions.run/normal python function.
            Please add member method for tasks, and register them to tasks

    Config module:
    Configs module is designed to support externel hierachy configs, thus with name of graph,
    graph can load/search configs. This is designed to reduce complecity of config complex networks.
    To achieve this, use more configurable than directly pass arguments to init functions.
    For sub-models, in most cases do not pass arguments by its parent graph, directly use config system.
    There are some exceptions to this design, for some simple/frequently used blocks, some simple arguments
    may be passed to simplify config procedure. Another simulation is the case with lots of child models and
    they all share simple but same configurations, you may pass the shared arguments by its parent Graph.

    In most cases, Graph should communicate(connect) with others via dict of Tensors/Graph.


    Methods:

    -   as_tensor(self):
        return self.nodes['main'], which is designed for sub_graphs.

    -   get_feed_dict(self, task=None):
        A method which returns a feed_dict, which can be used to update parent (in most cases, the graph which called 
        subgraph.get_feed_dict()) graph's get_feed_dict() or run task.
        Which is used to garantee output nodes (if is Tensor) to be valid under certain tasks, if task is None,
        a feed_dict should be provided so that all nodes are valid.

    -   run(self, task_name, feeds):
        Call a registered function. # TODO: make it different from directly call function, provide default feeds / unified feeds, etc.

    Properties:
        name: Path, name is used for:
            1. loading configs from global config object;
            2. its basename sometimes used for namescope/variable name in TensorFlow;
    """
    def __init__(self, name, **config):
        self.name = Path(name)
        self.c = self.__load_config(config)
        self.nodes = dict()
        self.tasks = dict()

    # Methods to be overrided:
    @classmethod
    def _default_config(cls):
        """ Override this method to provide default configs. """
        return dict()

    def __hash__(self):
        return self.name.__hash__()

    def keys(self):
        return self.nodes.keys()

    def values(self):
        return self.nodes.values()

    def items(self):
        return self.nodes.items()

    def __iter__(self):
        return self.nodes.__iter__()

    def _externel_feeds(self):
        return dict()

    def get_feed_dict(self, feeds=None, task=None):
        """
        Return feed dict for this graph to work for specific tasks.

        In most cases, it works as an translator for feeds to feed_dict,
        thus replacing name of feeds to tf.Tensor.
        """
        from dxpy.collections.dicts import combine_dicts
        result = self._externel_feeds()
        if feeds is None:
            return result
        for n in self.nodes:
            if n in feeds:
                result[self.tensor(n)] = feeds[n]
            if self.nodes[n] in feeds:
                result[self.tensor(n)] = feeds[self.nodes[n]]
        return result

    def _print_config_kernel(self, recursive, indent, fout):
        title = "{ind}>{cls}:{name}({fullname})".format(ind=" " * indent,
                                                        cls=__class__,
                                                        name=self.parse_name())
        indent_sub = indent + 4
        for k in self.nodes:
            if isinstance(self.nodes[k], tf.Tensor):
                print('{ind}tf.Tensor:{name}({sp})'.format(
                    ind=" " * indent_sub, name=k, sp=self.nodes[k].shape),
                      file=fout)
            elif isinstance(self.nodes[k], Graph):
                if recursive:
                    self.nodes[k].print_config(recursive, fout, indent_sub)
                else:
                    print('{ind}Graph:{name}({sub_name})'.format(
                        ind=" " * indent_sub,
                        name=k,
                        sub_name=self.nodes[k].name),
                          file=fout)

    @property
    def basename(self):
        """
        Get the base name of graph name. Useful for variable_scope or name_scope of graph.
        """
        return self.name.basename

    def register_node(self, name=None, tensor_or_subgraph=None):
        from .utils import refined_tensor_or_graph_name
        if tensor_or_subgraph is None:
            tensor_or_subgraph = name
            name = refined_tensor_or_graph_name(tensor_or_subgraph)
        self.nodes[name] = tensor_or_subgraph

    def register_task(self, name, func):
        self.tasks[name] = func

    def register_main_node(self, tensor_or_subgraph):
        self.register_node(NodeKeys.MAIN, tensor_or_subgraph)

    def register_main_task(self, func):
        self.register_task(NodeKeys.MAIN, func)

    def create_variable_node(self,
                             dtype,
                             shape,
                             name,
                             *,
                             trainable=False,
                             init_value=None):
        if init_value is not None:
            initer = tf.constant_initializer(init_value)
        else:
            initer = None
        self.register_node(
            name,
            tf.get_variable(name,
                            shape,
                            dtype,
                            trainable=trainable,
                            initializer=initer))
        return self.nodes[name]

    def create_placeholder_node(self, dtype, shape, name):
        self.register_node(name, tf.placeholder(dtype, shape, name))
        return self.nodes[name]

    def create_sub_graph_node(self, name, create_func):
        self.register_node(name, create_func())
        return self.nodes[name]

    def param(self, key, feeds=None, *, default=None, raise_key_error=True):
        """
        Best practice: always use param instead of directly using self.c
        """
        if isinstance(feeds, dict) and key in feeds:
            return feeds[key]
        result = self.c.get(key, default)
        if result is None and raise_key_error:
            raise KeyError(key)
        return result

    def tensor(self, key=None):
        if key is None:
            return self.as_tensor()
        if not key in self.nodes:
            if isinstance(key, tf.Tensor):
                for n in self.nodes:
                    if self.nodes[n] == key:
                        return self.nodes[n]
        if isinstance(self.nodes[key], tf.Tensor):
            return self.nodes[key]
        elif isinstance(self.nodes[key], Graph):
            return self.nodes[key].as_tensor()
        raise KeyError("Key {} not found in graph {}.".format(key, self.name))

    def graph(self, key=None):
        if key is None:
            key = NodeKeys.MAIN
        if key in self.nodes:
            if isinstance(self.nodes[key], Graph):
                return self.nodes[key]
            else:
                raise TypeError(
                    "Type of  self.nodes[{}] of in graph {} is {}, required Graph."
                    .format(key, self.name, type(self.nodes[key])))
        raise KeyError("Key {} not found in graph {}.".format(key, self.name))

    def print_config(self, fout=None, recursive=False, indent=0):
        self._print_config_kernel(recursive, indent, fout)

    def __getitem__(self, name):
        if name in self.nodes:
            return self.nodes[name]
        elif name in self.tasks:
            return self.tasks[name]
        else:
            return None

    def __call__(self, feeds=None):
        return self.run(NodeKeys.MAIN, feeds)

    def run(self, task_name, feeds=None):
        return self.tasks[task_name](feeds)

    def as_tensor(self):
        return self.nodes.get(NodeKeys.MAIN)

    def __load_config(self, config_direct):
        from .config import config as config_global
        from dxpy.collections.dicts import combine_dicts
        return combine_dicts(config_direct, config_global,
                             self._default_config())
Пример #16
0
 def __init__(self, name, **config):
     self.name = Path(name)
     self.c = self.__load_config(config)
     self.nodes = dict()
     self.tasks = dict()
Пример #17
0
 def __unified_keys(self, path_or_keys):
     if isinstance(path_or_keys, (list, tuple)):
         return list(self.base.parts()), list(path_or_keys)
     else:
         return list(self.base.parts()), list(Path(path_or_keys).parts())
Пример #18
0
 def test_str_basic(self):
     p = Path('/tmp')
     self.assertEqual(p.abs, '/tmp')
Пример #19
0
 def test_parts2(self):
     p = Path('tmp/base')
     self.assertEqual(p.parts(), ('tmp', 'base'))
Пример #20
0
def _h5_file(path: str, file_name=DEFAULT_FILE_NAME):
    from dxpy.core.path import Path
    return str(Path(path) / file_name)
Пример #21
0
def h5filename():
    from ..config import config
    return str(Path(config['PATH_DATASETS']) / DEFAULT_FILE_NAME)