def logdet(self): z = self.z0 # sxd u = self.u_ # d w = self.w_ # d b = self.b # . deriv = self.h.deriv # f' if not self.batched: # f'(sxd \dot d + .) * -xd = sxd phi = deriv(z.dot(w) + b).dimshuffle(0, "x") * w.dimshuffle("x", 0) # \abs(. + sxd \dot d) = s det = aet.abs_(1.0 + phi.dot(u)) return aet.log(det) else: z = z.swapaxes(0, 1) b = b.dimshuffle(0, "x") # z bxsxd # u bxd # w bxd # b bx-x- # f'(bxsxd \bdot bxd + bx-x-) * bx-xd = bxsxd phi = deriv(aet.batched_dot(z, w) + b).dimshuffle( 0, 1, "x") * w.dimshuffle(0, "x", 1) # \abs(. + bxsxd \bdot bxd) = bxs det = aet.abs_(1.0 + aet.batched_dot(phi, u)) # bxs return aet.log(det).sum(0) # s
def forward(self): z = self.z0 # sxd H = self.H # dxd if self.batched: return aet.batched_dot(z.swapaxes(0, 1), H).swapaxes(0, 1) else: return z.dot(H)
def forward(self): z = self.z0 # sxd u = self.u_ # d w = self.w_ # d b = self.b # . h = self.h # f # h(sxd \dot d + .) = s if not self.batched: hwz = h(z.dot(w) + b) # s # sxd + (s \outer d) = sxd z1 = z + aet.outer(hwz, u) # sxd return z1 else: z = z.swapaxes(0, 1) # z bxsxd # u bxd # w bxd b = b.dimshuffle(0, "x") # b bx- hwz = h(aet.batched_dot(z, w) + b) # bxs # bxsxd + (bxsx- * bx-xd) = bxsxd hwz = hwz.dimshuffle(0, 1, "x") # bxsx- u = u.dimshuffle(0, "x", 1) # bx-xd z1 = z + hwz * u # bxsxd return z1.swapaxes(0, 1) # sxbxd
def L_op(self, inputs, outputs, output_grads): # Gradients computed by Op assert self.compute_grad and len(outputs) == 2 gradients = outputs[1] assert gradients is not None # Gradients of original function, to compose chain rule grad_op = output_grads[0] grad_shuffle = GpuDimShuffle( input_broadcastable=( False, False, False, ), new_order=(1, 0, 2), )(gradients) grad_bdot = tt.batched_dot(grad_op, grad_shuffle) grad_shuffle_reverse = GpuDimShuffle( input_broadcastable=( False, False, False, ), new_order=(1, 0, 2), )(grad_bdot) return [ grad_shuffle_reverse, grad_undefined(self, 1, inputs[1]), grad_undefined(self, 2, inputs[2]), ]
def symbolic_random(self): initial = self.symbolic_initial L = self.L mu = self.mean if self.batched: # initial: bxsxd # L: bxdxd initial = initial.swapaxes(0, 1) return at.batched_dot(initial, L.swapaxes(1, 2)).swapaxes(0, 1) + mu else: return initial.dot(L.T) + mu
def test_basic(self): # Reported in https://github.com/Aesara/Aesara/issues/5730 x = tensor.tensor3() y = tensor.tensor3() z = tensor.batched_dot(x, y[:, 0, :, np.newaxis]) f = aesara.function([x, y], z, mode=mode_with_gpu) x_num = np.arange(32 * 19 * 600, dtype=config.floatX).reshape( (32, 19, 600)) y_num = np.arange(7 * 32 * 600, dtype=config.floatX).reshape( (32, 7, 600)) f(x_num, y_num) assert f.maker.fgraph.toposort()[-2].op.inplace
def L_op(self, inputs, outputs, output_grads): assert self.compute_grad and len(outputs) == 2 gradients = outputs[1] assert gradients is not None grad_op = output_grads[0] total_grad = tt.batched_dot(grad_op, gradients.dimshuffle(1, 0, 2)).dimshuffle( 1, 0, 2 ) return [ total_grad, grad_undefined(self, 1, inputs[1]), grad_undefined(self, 2, inputs[2]), ]
def cov(self): L = self.L if self.batched: return at.batched_dot(L, L.swapaxes(-1, -2)) else: return L.dot(L.T)