def make_step(self, signals, dt, rng): print(self.A) A = signals[self.A] X = signals[self.X] Y = signals[self.Y] # check broadcasting shapes Ashape = npext.broadcast_shape(A.shape, 2) Xshape = npext.broadcast_shape(X.shape, 2) Yshape = npext.broadcast_shape(Y.shape, 2) assert all(len(s) == 2 for s in [Ashape, Xshape, Yshape]) for da, dx, dy in zip(Ashape, Xshape, Yshape): if not (da in [1, dy] and dx in [1, dy] and max(da, dx) == dy): raise ValueError("Incompatible shapes in ElementwiseInc: " "Trying to do %s += %s * %s" % (Yshape, Ashape, Xshape)) def step(): # print(Y[...]) # print('a: ') # print(type(A)) # print(A) # print('X: ') # print(type(X)) # print(X) Z = A * X Y[...] = npext.castDecimal(Y[...]) + Z return step
def make_step(self, signals, dt, rng): A = signals[self.A] X = signals[self.X] Y = signals[self.Y] # check broadcasting shapes Ashape = npext.broadcast_shape(A.shape, 2) Xshape = npext.broadcast_shape(X.shape, 2) Yshape = npext.broadcast_shape(Y.shape, 2) assert all(len(s) == 2 for s in [Ashape, Xshape, Yshape]) for da, dx, dy in zip(Ashape, Xshape, Yshape): assert da in [1, dy] and dx in [1, dy] and max(da, dx) == dy def step(): Y[...] += A * X return step
def make_step(self, signals, dt, rng): A = signals[self.A] X = signals[self.X] Y = signals[self.Y] # check broadcasting shapes Ashape = npext.broadcast_shape(A.shape, 2) Xshape = npext.broadcast_shape(X.shape, 2) Yshape = npext.broadcast_shape(Y.shape, 2) assert all(len(s) == 2 for s in [Ashape, Xshape, Yshape]) for da, dx, dy in zip(Ashape, Xshape, Yshape): if not (da in [1, dy] and dx in [1, dy] and max(da, dx) == dy): raise ValueError("Incompatible shapes in ElementwiseInc: " "Trying to do %s += %s * %s" % (Yshape, Ashape, Xshape)) def step(): Y[...] += A * X return step
def make_step(self, signals, dt, rng): A = signals[self.A] X = signals[self.X] Y = signals[self.Y] # check broadcasting shapes Ashape = npext.broadcast_shape(A.shape, 2) Xshape = npext.broadcast_shape(X.shape, 2) Yshape = npext.broadcast_shape(Y.shape, 2) assert all(len(s) == 2 for s in [Ashape, Xshape, Yshape]) for da, dx, dy in zip(Ashape, Xshape, Yshape): if not (da in [1, dy] and dx in [1, dy] and max(da, dx) == dy): raise BuildError("Incompatible shapes in ElementwiseInc: " "Trying to do %s += %s * %s" % (Yshape, Ashape, Xshape)) def step_elementwiseinc(): Y[...] += A * X return step_elementwiseinc
def test_brodcast_shape(): assert broadcast_shape(shape=(3, 2), length=3) == (1, 3, 2) assert broadcast_shape(shape=(3, 2), length=4) == (1, 1, 3, 2) assert broadcast_shape(shape=(3, 2), length=2) == (3, 2)