def _h5files(path: str, sino_fn='analytical_phantom_sinogram.h5', recon_fn='recons.h5', recon_ms_fn='recon_multi_scale_merged.h5'): from dxpy.core.path import Path return str(Path(path) / sino_fn), str(Path(path) / recon_fn), str( Path(path) / recon_ms_fn)
class TensorSpec: def __init__(self, path='data', shape=None, dtype=None, chunks=None): self._path = Path(path) self.shape = shape self.dtype = dtype self.chunks = chunks @property def dataset(self): return self._path.basename @property def groups(self): return self._path.parts()[:-1] def get_group(self, h5py_file, *, restrict=False): ds = h5py_file for g in self.groups: if restrict: ds = ds[g] else: ds = ds.require_group(g) return ds def create_dateset(self, h5py_file, data=None): gp = self.get_group(h5py_file) if data is None: return gp.create_dateset(self.dataset, shape=self.shape, dtype=self.dtype, chunks=chunks) else: return gp.create_dataset(self.dataset, data=data) def get_dataset(self, h5py_file): gp = self.get_group(h5py_file) return gp[self.basename]
def _update_statistics(dataset_config_name, low_dose_ratio): from dxpy.learn.config import config from dxpy.core.path import Path if isinstance(dataset_config_name, str): dataset_config_name = Path(dataset_config_name).parts() c = config for k in dataset_config_name: c = c[k] mean, std = c['mean'], c['std'] mean /= low_dose_ratio std /= low_dose_ratio c['mean'], c['std'] = mean, std return mean, std
def __enter__(self): global default_engine import matlab.engine import os from dxpy.core.path import Path self.eng = matlab.engine.start_matlab() path_dxmat = os.environ['PATH_DXL_DXMAT'] path_gen = Path(path_dxmat) if self.sub_path is not None: path_gen = path_gen / self.sub_path self.eng.addpath(str(path_gen)) self.pre = default_engine default_engine = self.eng return self.eng
def test_mid_tilde_support(self): p = Path('a/x~y') self.assertEqual(str(p), 'a/x~y')
def test_tilde_support(self): p = Path('~/x') import os self.assertEqual(str(p), os.environ['HOME'] + '/x')
def test_div(self): p = Path('/tmp') p = p / 'sub' self.assertEqual(p.abs, '/tmp/sub')
def test_copy_init(self): p = Path('/tmp/file') p2 = Path(p) self.assertEqual(p.abs, p2.abs)
def test_name_dir(self): p = Path('/tmp/base/') self.assertEqual(p.basename, 'base')
def test_parent(self): p = Path('/tmp/base') self.assertEqual(p.parent(), Path('/tmp'))
class ConfigsView: def __init__(self, dct, base_path=''): self.data = dct self.base = Path(base_path) def __unified_keys(self, path_or_keys): if isinstance(path_or_keys, (list, tuple)): return list(self.base.parts()), list(path_or_keys) else: return list(self.base.parts()), list(Path(path_or_keys).parts()) def _get_value_raw(self, keys): result = self.data for k in keys: if not isinstance(result, (dict, ConfigsView)): return None result = result.get(k) return result def _form_path(self, keys): return '/'.join(keys) def _query(self, base_keys, local_keys): # if len(local_keys) <= 1: result = self._get_value_raw(base_keys + local_keys) # else: # path = self.base / local_keys[0] # key_path = self._form_path(local_keys[1:]) # result = ConfigsView(self.data, path)._query(key_path) return result def _search(self, key): base_keys, local_keys = self.__unified_keys(key) result = self._query(base_keys, local_keys) path = self._form_path(base_keys + local_keys) if result is None and len(base_keys) > 0: new_path = self._form_path((base_keys + local_keys)[:-2]) result, _ = ConfigsView(self.data, new_path)._search(local_keys[-1]) return result, path def _post_processing(self, result, path, default, restrict): if isinstance(result, dict) or (result is None and default is None and not restrict): result = ConfigsView(self.data, path) elif result is None: result = default return result def to_dict(self): result, path = self._search(str(self.base)) if isinstance(result, dict): return result else: raise ValueError("Configs view on base path is not dict.") def get(self, key, default=None, *, restrict=False): result, path = self._search(key) return self._post_processing(result, path, default, restrict) def __getitem__(self, key): return self.get(key, restrict=True) def __iter__(self): dct = self._get_value_raw(self.base.parts()) if dct is None: return list().__iter__() else: return dct.__iter__()
def test_str_root(self): p = Path('/') self.assertEqual(p.abs, '/')
def __init__(self, path='data', shape=None, dtype=None, chunks=None): self._path = Path(path) self.shape = shape self.dtype = dtype self.chunks = chunks
def __init__(self, dct, base_path=''): self.data = dct self.base = Path(base_path)
class Graph: """ Base class of components. A `Graph` is an dict like collections of nodes and edges. Nodes of `Graph` can be a tensor another sub `Graph`. Edges representing relations of sub graphs/tensors, flow of information. A `Graph` is an generalization of `tf.Graph`, which is designed for following features: 1. An unified interface of `tf.Graph` and general compute graph or operations/procedures; 2. Seperate config and implementation, use TreeDict for configs, and supports multiple ways of config; 3. An easy-to-use way of seperate/reuse subgraphs; 4. Supports an warp of sessions.run/normal python function. Please add member method for tasks, and register them to tasks Config module: Configs module is designed to support externel hierachy configs, thus with name of graph, graph can load/search configs. This is designed to reduce complecity of config complex networks. To achieve this, use more configurable than directly pass arguments to init functions. For sub-models, in most cases do not pass arguments by its parent graph, directly use config system. There are some exceptions to this design, for some simple/frequently used blocks, some simple arguments may be passed to simplify config procedure. Another simulation is the case with lots of child models and they all share simple but same configurations, you may pass the shared arguments by its parent Graph. In most cases, Graph should communicate(connect) with others via dict of Tensors/Graph. Methods: - as_tensor(self): return self.nodes['main'], which is designed for sub_graphs. - get_feed_dict(self, task=None): A method which returns a feed_dict, which can be used to update parent (in most cases, the graph which called subgraph.get_feed_dict()) graph's get_feed_dict() or run task. Which is used to garantee output nodes (if is Tensor) to be valid under certain tasks, if task is None, a feed_dict should be provided so that all nodes are valid. - run(self, task_name, feeds): Call a registered function. # TODO: make it different from directly call function, provide default feeds / unified feeds, etc. Properties: name: Path, name is used for: 1. loading configs from global config object; 2. its basename sometimes used for namescope/variable name in TensorFlow; """ def __init__(self, name, **config): self.name = Path(name) self.c = self.__load_config(config) self.nodes = dict() self.tasks = dict() # Methods to be overrided: @classmethod def _default_config(cls): """ Override this method to provide default configs. """ return dict() def __hash__(self): return self.name.__hash__() def keys(self): return self.nodes.keys() def values(self): return self.nodes.values() def items(self): return self.nodes.items() def __iter__(self): return self.nodes.__iter__() def _externel_feeds(self): return dict() def get_feed_dict(self, feeds=None, task=None): """ Return feed dict for this graph to work for specific tasks. In most cases, it works as an translator for feeds to feed_dict, thus replacing name of feeds to tf.Tensor. """ from dxpy.collections.dicts import combine_dicts result = self._externel_feeds() if feeds is None: return result for n in self.nodes: if n in feeds: result[self.tensor(n)] = feeds[n] if self.nodes[n] in feeds: result[self.tensor(n)] = feeds[self.nodes[n]] return result def _print_config_kernel(self, recursive, indent, fout): title = "{ind}>{cls}:{name}({fullname})".format(ind=" " * indent, cls=__class__, name=self.parse_name()) indent_sub = indent + 4 for k in self.nodes: if isinstance(self.nodes[k], tf.Tensor): print('{ind}tf.Tensor:{name}({sp})'.format( ind=" " * indent_sub, name=k, sp=self.nodes[k].shape), file=fout) elif isinstance(self.nodes[k], Graph): if recursive: self.nodes[k].print_config(recursive, fout, indent_sub) else: print('{ind}Graph:{name}({sub_name})'.format( ind=" " * indent_sub, name=k, sub_name=self.nodes[k].name), file=fout) @property def basename(self): """ Get the base name of graph name. Useful for variable_scope or name_scope of graph. """ return self.name.basename def register_node(self, name=None, tensor_or_subgraph=None): from .utils import refined_tensor_or_graph_name if tensor_or_subgraph is None: tensor_or_subgraph = name name = refined_tensor_or_graph_name(tensor_or_subgraph) self.nodes[name] = tensor_or_subgraph def register_task(self, name, func): self.tasks[name] = func def register_main_node(self, tensor_or_subgraph): self.register_node(NodeKeys.MAIN, tensor_or_subgraph) def register_main_task(self, func): self.register_task(NodeKeys.MAIN, func) def create_variable_node(self, dtype, shape, name, *, trainable=False, init_value=None): if init_value is not None: initer = tf.constant_initializer(init_value) else: initer = None self.register_node( name, tf.get_variable(name, shape, dtype, trainable=trainable, initializer=initer)) return self.nodes[name] def create_placeholder_node(self, dtype, shape, name): self.register_node(name, tf.placeholder(dtype, shape, name)) return self.nodes[name] def create_sub_graph_node(self, name, create_func): self.register_node(name, create_func()) return self.nodes[name] def param(self, key, feeds=None, *, default=None, raise_key_error=True): """ Best practice: always use param instead of directly using self.c """ if isinstance(feeds, dict) and key in feeds: return feeds[key] result = self.c.get(key, default) if result is None and raise_key_error: raise KeyError(key) return result def tensor(self, key=None): if key is None: return self.as_tensor() if not key in self.nodes: if isinstance(key, tf.Tensor): for n in self.nodes: if self.nodes[n] == key: return self.nodes[n] if isinstance(self.nodes[key], tf.Tensor): return self.nodes[key] elif isinstance(self.nodes[key], Graph): return self.nodes[key].as_tensor() raise KeyError("Key {} not found in graph {}.".format(key, self.name)) def graph(self, key=None): if key is None: key = NodeKeys.MAIN if key in self.nodes: if isinstance(self.nodes[key], Graph): return self.nodes[key] else: raise TypeError( "Type of self.nodes[{}] of in graph {} is {}, required Graph." .format(key, self.name, type(self.nodes[key]))) raise KeyError("Key {} not found in graph {}.".format(key, self.name)) def print_config(self, fout=None, recursive=False, indent=0): self._print_config_kernel(recursive, indent, fout) def __getitem__(self, name): if name in self.nodes: return self.nodes[name] elif name in self.tasks: return self.tasks[name] else: return None def __call__(self, feeds=None): return self.run(NodeKeys.MAIN, feeds) def run(self, task_name, feeds=None): return self.tasks[task_name](feeds) def as_tensor(self): return self.nodes.get(NodeKeys.MAIN) def __load_config(self, config_direct): from .config import config as config_global from dxpy.collections.dicts import combine_dicts return combine_dicts(config_direct, config_global, self._default_config())
def __init__(self, name, **config): self.name = Path(name) self.c = self.__load_config(config) self.nodes = dict() self.tasks = dict()
def __unified_keys(self, path_or_keys): if isinstance(path_or_keys, (list, tuple)): return list(self.base.parts()), list(path_or_keys) else: return list(self.base.parts()), list(Path(path_or_keys).parts())
def test_str_basic(self): p = Path('/tmp') self.assertEqual(p.abs, '/tmp')
def test_parts2(self): p = Path('tmp/base') self.assertEqual(p.parts(), ('tmp', 'base'))
def _h5_file(path: str, file_name=DEFAULT_FILE_NAME): from dxpy.core.path import Path return str(Path(path) / file_name)
def h5filename(): from ..config import config return str(Path(config['PATH_DATASETS']) / DEFAULT_FILE_NAME)