def __init__( self, data, batch_size=128, dtype=None, broadcastable=None, name="Minibatch", random_seed=42, update_shared_f=None, in_memory_size=None, ): if dtype is None: data = pm.smartfloatX(np.asarray(data)) else: data = np.asarray(data, dtype) in_memory_slc = self.make_static_slices(in_memory_size) self.shared = aesara.shared(data[in_memory_slc]) self.update_shared_f = update_shared_f self.random_slc = self.make_random_slices(self.shared.shape, batch_size, random_seed) minibatch = self.shared[self.random_slc] if broadcastable is None: broadcastable = (False,) * minibatch.ndim minibatch = aet.patternbroadcast(minibatch, broadcastable) self.minibatch = minibatch super().__init__(self.minibatch.type, None, None, name=name) Apply(aesara.compile.view_op, inputs=[self.minibatch], outputs=[self]) self.tag.test_value = copy(self.minibatch.tag.test_value)
def convert_variable(self, var): vt = var.type if (isinstance(vt, type(self)) and self.typecode == vt.typecode and self.ndim == vt.ndim and self.context_name == vt.context_name and all( sb == ob or ob for sb, ob in zip(self.broadcastable, vt.broadcastable))): return at.patternbroadcast(var, self.broadcastable)
def test_local_dimshuffle_subtensor(): dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) x = tensor.dtensor4("x") x = tensor.patternbroadcast(x, (False, True, False, False)) i = tensor.iscalar("i") out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see". x = tensor.tensor(broadcastable=(False, True, False), dtype="float64") out = x[i].dimshuffle(1) g = FunctionGraph([x, i], [out]) dimshuffle_subtensor(g) topo = g.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) # Test dimshuffle remove dimensions the subtensor don't "see" but # have in between dimensions. x = tensor.tensor(broadcastable=(False, True, False, True), dtype="float64") out = x[i].dimshuffle(1) f = aesara.function([x, i], out) topo = f.maker.fgraph.toposort() assert any([not isinstance(x, DimShuffle) for x in topo]) assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4, ) # Test a corner case that had Aesara return a bug. x = tensor.dtensor4("x") x = tensor.patternbroadcast(x, (False, True, False, False)) assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval({ x: np.ones((5, 1, 6, 7)) }).shape == (5, 3, 7)