コード例 #1
0
    def provide(self, request):

        roi = request[self.graph_key].roi

        random_points = self.random_point_generator.get_random_points(roi)

        batch = gp.Batch()
        batch[self.graph_key] = gp.Graph(
            [gp.Node(id=i, location=l) for i, l in random_points.items()], [],
            gp.GraphSpec(roi=roi, directed=False))

        return batch
コード例 #2
0
ファイル: utils.py プロジェクト: funkelab/contraband
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = gp.Batch()

        # If a Array is requested then we will randomly choose
        # the number of requested points
        if isinstance(self.points, gp.ArrayKey):
            points = np.random.choice(self.data.shape[0], self.num_points)
            data = self.data[points][np.newaxis]
            if self.scale is not None:
                data = data * self.scale
            if self.label_data is not None:
                labels = self.label_data[points]
            batch[self.points] = gp.Array(data, self.spec[self.points])

        else:
            # If a graph is request we must select points within the
            # request ROI

            min_bb = request[self.points].roi.get_begin()
            max_bb = request[self.points].roi.get_end()

            logger.debug("Points source got request for %s",
                         request[self.points].roi)

            point_filter = np.ones((self.data.shape[0], ), dtype=np.bool)
            for d in range(self.ndims):
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] >= min_bb[d])
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] < max_bb[d])

            points_data, labels = self._get_points(point_filter)
            logger.debug(f"Found {len(points_data)} points")
            points_spec = gp.GraphSpec(roi=request[self.points].roi.copy())
            batch.graphs[self.points] = gp.Graph(points_data, [], points_spec)

        # Labels will always be an Array
        if self.label_data is not None:
            batch[self.labels] = gp.Array(labels, self.spec[self.labels])

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
コード例 #3
0
ファイル: snapshot_source.py プロジェクト: pattonw/neurolight
    def graph_from_path(self, graph_key, data, path):
        saved_ids = data[f"{path}-ids"]
        saved_edges = data[f"{path}-edges"]
        saved_locations = data[f"{path}-locations"]
        node_attrs = [
            (attr, data[f"{path}/node_attrs/{attr}"])
            for attr in self.node_attrs.get(graph_key, [])
        ]
        attrs = [attr for attr, _ in node_attrs]
        attr_values = zip(
            *[values for _, values in node_attrs], (None,) * len(saved_locations)
        )
        nodes = [
            gp.Node(
                node_id,
                location=location,
                attrs={attr: value for attr, value in zip(attrs, values)},
            )
            for node_id, location, values in zip(
                saved_ids, saved_locations, attr_values
            )
        ]

        edge_attrs = [
            (attr, data[f"{path}/edge_attrs/{attr}"])
            for attr in self.edge_attrs.get(graph_key, [])
        ]
        attrs = [attr for attr, _ in edge_attrs]
        attr_values = zip(
            *[values for _, values in edge_attrs], (None,) * len(saved_edges)
        )
        edges = [
            gp.Edge(u, v, attrs={attr: value for attr, value in zip(attrs, values)})
            for (u, v), values in zip(saved_edges, attr_values)
        ]
        return gp.Graph(
            nodes,
            edges,
            gp.GraphSpec(
                gp.Roi(
                    (None,) * len(saved_locations[0]), (None,) * len(saved_locations[0])
                ),
                directed=self.directed.get(graph_key),
            ),
        )
コード例 #4
0
ファイル: utils.py プロジェクト: funkelab/contraband
    def process(self, batch, request):
        if self.key not in batch:
            return

        if isinstance(self.key, gp.ArrayKey):
            data = batch[self.key].data
            shape = data.shape
            roi = batch[self.key].spec.roi
            assert shape[-roi.dims()] == 1, "Channel to delete must be size 1," \
                                           "but given shape " + str(shape)

            shape = self.__remove_dim(shape, len(shape) - roi.dims())
            batch[self.key].data = data.reshape(shape)
            batch[self.key].spec.roi = gp.Roi(
                self.__remove_dim(roi.get_begin()),
                self.__remove_dim(roi.get_shape()))
            batch[self.key].spec.voxel_size = \
                self.__remove_dim(batch[self.key].spec.voxel_size)

        if isinstance(self.key, gp.GraphKey):
            roi = batch[self.key].spec.roi

            batch[self.key].spec.roi = gp.Roi(
                self.__remove_dim(roi.get_begin()),
                self.__remove_dim(roi.get_shape()))

            graph = gp.Graph([], [], spec=batch[self.key].spec)
            for node in batch[self.key].nodes:
                print(node)
                new_node = gp.Node(node.id,
                                   node.location[1:],
                                   temporary=node.temporary,
                                   attrs=node.attrs)
                graph.add_node(new_node)
                print(node)
            print(graph.spec.roi)
            print(list(graph.nodes))
            batch[self.key] = graph
コード例 #5
0
 def process(self, batch, request):
     graph = gp.Graph([], [], self.spec[self.graph_key])
     batch[self.graph_key] = graph
コード例 #6
0
 def process(self, batch, request):
     graph = gp.Graph([gp.Node(0, self.center)], [],
                      self.spec[self.graph_key])
     batch[self.graph_key] = graph