예제 #1
0
파일: graph.py 프로젝트: zaf05/mayo
 def _add_layer(self, from_names, to_names, layer_name, layer_params,
                module_path):
     from_names = ensure_list(from_names)
     to_names = ensure_list(to_names)
     if layer_params is not None and layer_params['type'] == 'module':
         # add module
         return self._add_module(from_names, to_names, layer_name,
                                 layer_params, module_path)
     # inputs
     from_nodes = [TensorNode(n, module_path, self) for n in from_names]
     if len(from_nodes) == 1:
         join_node = from_nodes[0]
     else:
         # join input nodes
         join_node = JoinNode(from_names, module_path, self)
         for each_node in from_nodes:
             self.add_edge(each_node, join_node)
     # layer
     if layer_name is None:
         layer_node = join_node
     else:
         layer_node = LayerNode(layer_name, layer_params, module_path, self)
         self.add_edge(join_node, layer_node)
     # outputs
     to_nodes = [TensorNode(n, module_path, self) for n in to_names]
     if len(to_nodes) == 1:
         self.add_edge(layer_node, to_nodes[0])
     else:
         split_node = SplitNode(to_names, module_path, self)
         self.add_edge(layer_node, split_node)
         for each_node in to_nodes:
             self.add_edge(split_node, each_node)
예제 #2
0
파일: graph.py 프로젝트: zaf05/mayo
 def _add_module(self,
                 from_names,
                 to_names,
                 module_name,
                 module_params,
                 module_path=None):
     from_names = ensure_list(from_names)
     to_names = ensure_list(to_names)
     # replace kwargs in module params
     params = _replace_module_kwargs(module_params)
     # module path
     module_path = module_path or []
     submodule_path = list(module_path)
     if module_name is not None:
         submodule_path += [module_name]
     # add graph connections
     for connection in ensure_list(params['graph']):
         with_layers = ensure_list(connection.get('with') or [])
         edges = list(
             zip([connection['from']] + with_layers,
                 with_layers + [connection['to']], with_layers + [None]))
         for input_names, output_names, layer_name in edges:
             if input_names == output_names:
                 if layer_name is None:
                     continue
                 raise EdgeError(
                     'Input name {!r} collides with output name {!r} '
                     'for layer {!r}.'.format(input_names, output_names,
                                              layer_name))
             layer_params = None
             if layer_name is not None:
                 try:
                     layer_params = params['layers'][layer_name]
                 except KeyError:
                     raise KeyError(
                         'Layer named {!r} is not defined.'.format(
                             layer_name))
             self._add_layer(input_names, output_names, layer_name,
                             layer_params, submodule_path)
     # add interface IO
     from_nodes = []
     input_names = params.get('inputs', ['input'])
     for from_name, input_name in zip(from_names, input_names):
         from_node = TensorNode(from_name, module_path, self)
         from_nodes.append(from_node)
         input_node = TensorNode(input_name, submodule_path, self)
         self.add_edge(from_node, input_node)
     to_nodes = []
     output_names = params.get('outputs', ['output'])
     for output_name, to_name in zip(output_names, to_names):
         output_node = TensorNode(output_name, submodule_path, self)
         to_node = TensorNode(to_name, module_path, self)
         to_nodes.append(to_node)
         self.add_edge(output_node, to_node)
     # ensure connection
     self._ensure_connection(from_nodes, to_nodes)
예제 #3
0
파일: graph.py 프로젝트: zaf05/mayo
 def _ensure_connection(self, from_nodes, to_nodes):
     iterator = itertools.product(ensure_list(from_nodes),
                                  ensure_list(to_nodes))
     for i, o in iterator:
         if not any(nx.all_simple_paths(self.nx_graph, i, o)):
             undirected = self.nx_graph.to_undirected()
             subgraphs = pprint.pformat(
                 list(nx.connected_components(undirected)))
             raise GraphIOError(
                 'We expect the net to have a path from the inputs '
                 'to the outputs, a path does not exist between {} and {}. '
                 'Disjoint subgraph nodes:\n{}'.format(i, o, subgraphs))
예제 #4
0
파일: generate.py 프로젝트: randysuen/mayo
 def _actions(self, key):
     return ensure_list(self.actions.get(key) or [])