def __init__(self, rbm, units_list, dimensions_list, W, name=None): super(AdvancedProdParameters, self).__init__(rbm, units_list, name=name) assert len(units_list) == 2 self.var = W self.variables = [self.var] self.vu = units_list[0] self.hu = units_list[1] self.vd = dimensions_list[0] self.hd = dimensions_list[1] self.vard = self.vd + self.hd # there are vd visible dimensions and hd hidden dimensions, meaning that the weight matrix has # vd + hd = Wd dimensions. # the hiddens and visibles have hd+1 and vd+1 dimensions respectively, because the first dimension # is reserved for minibatches! self.terms[self.vu] = lambda vmap: tensordot(vmap[self.hu], W, axes=(range(1,self.hd+1),range(self.vd, self.vard))) self.terms[self.hu] = lambda vmap: tensordot(vmap[self.vu], W, axes=(range(1,self.vd+1),range(0, self.vd))) def gradient(vmap): v_indices = range(0, self.vd + 1) + (['x'] * self.hd) h_indices = [0] + (['x'] * self.vd) + range(1, self.hd + 1) v_reshaped = vmap[self.vu].dimshuffle(v_indices) h_reshaped = vmap[self.hu].dimshuffle(h_indices) return v_reshaped * h_reshaped self.energy_gradients[self.var] = gradient self.energy_gradient_sums[self.var] = lambda vmap: tensordot(vmap[self.vu], vmap[self.hu], axes=([0],[0]))
def energy_term(self, vmap): # b_padded = T.shape_padright(self.var, self.sd) # return - T.sum(tensordot(vmap[self.u], b_padded, axes=(range(1, self.ud+1), range(0, self.ud))), axis=0) # this does not work because tensordot cannot handle broadcastable dimensions. # instead, the dimensions of b_padded which are broadcastable should be summed out afterwards. # this comes down to the same thing. so: t = tensordot(vmap[self.u], self.var, axes=(range(1, self.nd+1), range(0, self.nd))) # now sum t over its trailing shared dimensions, which mimics broadcast + tensordot behaviour. axes = range(t.ndim - self.sd, t.ndim) return - T.sum(t, axis=axes)
def energy_term(self, vmap): # v_part = tensordot(vmap[self.vu], self.var, axes=(range(1, self.vd+1), range(0, self.vd))) v_part = self.terms[self.hu](vmap) neg_energy = tensordot(v_part, vmap[self.hu], axes=(range(1, self.hd + 1), range(1, self.hd + 1))) # we do not sum over the first dimension, which is reserved for minibatches! return -neg_energy # don't forget to flip the sign!
def energy_term(self, vmap): # v_part = tensordot(vmap[self.vu], self.var, axes=(range(1, self.vd+1), range(0, self.vd))) v_part = self.terms[self.hu](vmap) neg_energy = tensordot(v_part, vmap[self.hu], axes=(range(1, self.hd+1), range(1, self.hd+1))) # we do not sum over the first dimension, which is reserved for minibatches! return - neg_energy # don't forget to flip the sign!
def term_u2(vmap): p = tensordot(vmap[self.u0], W, axes=([1],[0])) # (mb, u1, u2) return T.sum(p * vmap[self.u1].dimshuffle(0, 1, 'x'), axis=1) # (mb, u2)
def term_u1(vmap): p = tensordot(vmap[self.u0], W, axes=([1],[0])) # (mb, u1, u2) return T.sum(p * vmap[self.u2].dimshuffle(0, 'x', 1), axis=2) # (mb, u1)
def energy_term(self, vmap): return - tensordot(vmap[self.u], self.var, axes=(range(1, self.ud+1), range(0, self.ud)))
def term_u2(vmap): p = tensordot(vmap[self.u0], W, axes=([1], [0])) # (mb, u1, u2) return T.sum(p * vmap[self.u1].dimshuffle(0, 1, 'x'), axis=1) # (mb, u2)
def term_u1(vmap): p = tensordot(vmap[self.u0], W, axes=([1], [0])) # (mb, u1, u2) return T.sum(p * vmap[self.u2].dimshuffle(0, 'x', 1), axis=2) # (mb, u1)