コード例 #1
0
ファイル: sharedvalue.py プロジェクト: rpgoldman/Theano-PyMC
    def __init__(self,
                 name,
                 type,
                 value,
                 strict,
                 allow_downcast=None,
                 container=None):
        super().__init__(type=type, name=name, owner=None, index=None)

        if container is not None:
            self.container = container
            if (value is not None) or (strict is not None):
                raise TypeError("value and strict are ignored if you pass "
                                "a container here")
        else:
            self.container = Container(
                self,
                storage=[
                    type.filter(value,
                                strict=strict,
                                allow_downcast=allow_downcast)
                ],
                readonly=False,
                strict=strict,
                allow_downcast=allow_downcast,
            )
コード例 #2
0
ファイル: test_link.py プロジェクト: rpgoldman/Theano-PyMC
def test_container_deepcopy():
    # This is a test to a work around a NumPy bug.

    t = theano.tensor.scalar()
    # It seam that numpy.asarray(0.).astype(floatX) can return a numpy
    # scalar with some NumPy Version. So we call numpy.asarray with
    # the dtype parameter.
    v = np.asarray(0.0, dtype=theano.config.floatX)
    assert isinstance(v, np.ndarray), type(v)
    for readonly in [True, False]:
        c = Container(t, [v], readonly=readonly)
        assert isinstance(c.storage[0], np.ndarray), (c.storage[0], type(c.storage[0]))
        assert c.storage[0].dtype == v.dtype, (c.storage[0].dtype, v.dtype)
        assert c.storage[0].dtype == c.type.dtype, (c.storage[0].dtype, c.type.dtype)
        d = deepcopy(c)
        assert isinstance(d.storage[0], np.ndarray), (d.storage[0], type(d.storage[0]))
        assert d.storage[0].dtype == v.dtype, (d.storage[0].dtype, v.dtype)
        assert d.storage[0].dtype == c.type.dtype, (d.storage[0].dtype, c.type.dtype)
コード例 #3
0
ファイル: jax_linker.py プロジェクト: rpgoldman/Theano-PyMC
    def make_all(self,
                 input_storage=None,
                 output_storage=None,
                 storage_map=None):
        fgraph = self.fgraph
        nodes = self.schedule(fgraph)
        no_recycling = self.no_recycling

        input_storage, output_storage, storage_map = map_storage(
            fgraph, nodes, input_storage, output_storage, storage_map)

        compute_map = {}
        for k in storage_map:
            compute_map[k] = [k.owner is None]

        try:
            # We need to create thunk functions that will populate the output
            # storage arrays with the JAX-computed values.
            thunks, nodes = self.create_jax_thunks(compute_map, storage_map)

        except NotImplementedError as e:
            if not self.allow_non_jax:
                raise

            warn(f"JaxLinker could not JAXify graph: {e}")

            thunks = []
            for node in nodes:
                thunk = node.op.make_thunk(node, storage_map, compute_map,
                                           no_recycling, "py")
                thunk_inputs = [storage_map[v] for v in node.inputs]
                thunk_outputs = [storage_map[v] for v in node.outputs]

                thunk.inputs = thunk_inputs
                thunk.outputs = thunk_outputs

                thunks.append(thunk)

        computed, last_user = gc_helper(nodes)

        if self.allow_gc:
            post_thunk_old_storage = []

            for node in nodes:
                post_thunk_old_storage.append([
                    storage_map[input] for input in node.inputs
                    if (input in computed) and (input not in fgraph.outputs)
                    and (node == last_user[input])
                ])
        else:
            post_thunk_old_storage = None

        if no_recycling is True:
            no_recycling = list(storage_map.values())
            no_recycling = utils.difference(no_recycling, input_storage)
        else:
            no_recycling = [
                storage_map[r] for r in no_recycling if r not in fgraph.inputs
            ]

        fn = streamline(fgraph,
                        thunks,
                        nodes,
                        post_thunk_old_storage,
                        no_recycling=no_recycling)

        fn.allow_gc = self.allow_gc
        add_clear_storage(fn, computed, storage_map)
        fn.storage_map = storage_map

        return (
            fn,
            [
                Container(input, storage)
                for input, storage in zip(fgraph.inputs, input_storage)
            ],
            [
                Container(output, storage, True)
                for output, storage in zip(fgraph.outputs, output_storage)
            ],
            thunks,
            nodes,
        )