def __init__(self, name, type, value, strict, allow_downcast=None, container=None): super().__init__(type=type, name=name, owner=None, index=None) if container is not None: self.container = container if (value is not None) or (strict is not None): raise TypeError("value and strict are ignored if you pass " "a container here") else: self.container = Container( self, storage=[ type.filter(value, strict=strict, allow_downcast=allow_downcast) ], readonly=False, strict=strict, allow_downcast=allow_downcast, )
def test_container_deepcopy(): # This is a test to a work around a NumPy bug. t = theano.tensor.scalar() # It seam that numpy.asarray(0.).astype(floatX) can return a numpy # scalar with some NumPy Version. So we call numpy.asarray with # the dtype parameter. v = np.asarray(0.0, dtype=theano.config.floatX) assert isinstance(v, np.ndarray), type(v) for readonly in [True, False]: c = Container(t, [v], readonly=readonly) assert isinstance(c.storage[0], np.ndarray), (c.storage[0], type(c.storage[0])) assert c.storage[0].dtype == v.dtype, (c.storage[0].dtype, v.dtype) assert c.storage[0].dtype == c.type.dtype, (c.storage[0].dtype, c.type.dtype) d = deepcopy(c) assert isinstance(d.storage[0], np.ndarray), (d.storage[0], type(d.storage[0])) assert d.storage[0].dtype == v.dtype, (d.storage[0].dtype, v.dtype) assert d.storage[0].dtype == c.type.dtype, (d.storage[0].dtype, c.type.dtype)
def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph nodes = self.schedule(fgraph) no_recycling = self.no_recycling input_storage, output_storage, storage_map = map_storage( fgraph, nodes, input_storage, output_storage, storage_map) compute_map = {} for k in storage_map: compute_map[k] = [k.owner is None] try: # We need to create thunk functions that will populate the output # storage arrays with the JAX-computed values. thunks, nodes = self.create_jax_thunks(compute_map, storage_map) except NotImplementedError as e: if not self.allow_non_jax: raise warn(f"JaxLinker could not JAXify graph: {e}") thunks = [] for node in nodes: thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling, "py") thunk_inputs = [storage_map[v] for v in node.inputs] thunk_outputs = [storage_map[v] for v in node.outputs] thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs thunks.append(thunk) computed, last_user = gc_helper(nodes) if self.allow_gc: post_thunk_old_storage = [] for node in nodes: post_thunk_old_storage.append([ storage_map[input] for input in node.inputs if (input in computed) and (input not in fgraph.outputs) and (node == last_user[input]) ]) else: post_thunk_old_storage = None if no_recycling is True: no_recycling = list(storage_map.values()) no_recycling = utils.difference(no_recycling, input_storage) else: no_recycling = [ storage_map[r] for r in no_recycling if r not in fgraph.inputs ] fn = streamline(fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling) fn.allow_gc = self.allow_gc add_clear_storage(fn, computed, storage_map) fn.storage_map = storage_map return ( fn, [ Container(input, storage) for input, storage in zip(fgraph.inputs, input_storage) ], [ Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage) ], thunks, nodes, )