Esempio n. 1
0
    def doStep(w):

        grad = ndict.cloneZeros(v)
        gw = ndict.cloneZeros(w)

        for l in range(n_batch):
            n_tot = x.itervalues().next().shape[1]
            idx_minibatch = np.random.randint(0, n_tot, n_subbatch)
            x_minibatch = {i: x[i][:, idx_minibatch] for i in x}
            if convertImgs:
                x_minibatch = {i: x_minibatch[i] / 256. for i in x_minibatch}

            # Use z ~ q(z|x) to compute d[LB]/d[gw]
            _, z, _ = model_q.gen_xz(v, x_minibatch, {}, n_subbatch)
            _, logpz_q = model_q.logpxz(v, x_minibatch, z)
            logpx_p, logpz_p, _gw, gz_p = model_p.dlogpxz_dwz(
                w, x_minibatch, z)
            for i in _gw:
                gw[i] += _gw[i]

            # Compute d[LB]/d[gv]  where gv = v (variational params)
            _, _, gv, _ = model_q.dlogpxz_dwz(v, x_minibatch, z)
            weight = np.sum(logpx_p) + np.sum(logpz_p) - np.sum(logpz_q)

            for i in v:
                f = gv[i] * weight
                h = gv[i]
                cv_cov[i] = cv_cov[i] + cv_lr * (f * h - cv_cov[i])
                cv_var[i] = cv_var[i] + cv_lr * (h**2 - cv_var[i])
                grad[i] += f - (cv_cov[i] / (cv_var[i] + 1e-8)) * h

        _, gwprior = model_p.dlogpw_dw(w)
        for i in gw:
            gw[i] += float(n_subbatch * n_batch) / n_tot * gwprior[i]

        def optimize(_w, _gw, gw_ss, stepsize):
            reg = 1e-8
            for i in _gw:
                gw_ss[i] += _gw[i]**2
                if nsteps[0] > warmup:
                    _w[i] += stepsize / np.sqrt(gw_ss[i] + reg) * _gw[i]

        optimize(w, gw, gw_ss, ada_stepsize)
        optimize(v, grad, gv_ss, ada_stepsize)

        nsteps[0] += 1

        if ndict.hasNaN(grad):
            raise Exception()
        if ndict.hasNaN(v):
            raise Exception()

        return z.copy(), logpx_p + logpz_p - logpz_q
    def checknan(self, v, w, gv, gw):

        if ndict.hasNaN(gv) or ndict.hasNaN(gw):
            raise Exception("dL_dw(): NaN found in gradients")
            print('v:')
            ndict.p(v)
            print('w:')
            ndict.p(w)
            print('gv:')
            ndict.p(gv)
            print('gw:')
            ndict.p(gw)
            raise Exception("dL_dw(): NaN found in gradients")
Esempio n. 3
0
    def doStep(w):
        
        grad = ndict.cloneZeros(v)
        gw = ndict.cloneZeros(w)

        for l in range(n_batch):
            n_tot = x.itervalues().next().shape[1]
            idx_minibatch = np.random.randint(0, n_tot, n_subbatch)
            x_minibatch = {i:x[i][:,idx_minibatch] for i in x}
            if convertImgs: x_minibatch = {i:x_minibatch[i]/256. for i in x_minibatch}
            
            # Use z ~ q(z|x) to compute d[LB]/d[gw]
            _, z, _  = model_q.gen_xz(v, x_minibatch, {}, n_subbatch)
            _, logpz_q = model_q.logpxz(v, x_minibatch, z)
            logpx_p, logpz_p, _gw, gz_p = model_p.dlogpxz_dwz(w, x_minibatch, z)
            for i in _gw: gw[i] += _gw[i]
            
            # Compute d[LB]/d[gv]  where gv = v (variational params)
            _, _, gv, _ = model_q.dlogpxz_dwz(v, x_minibatch, z)
            weight = np.sum(logpx_p) + np.sum(logpz_p) - np.sum(logpz_q)
            
            for i in v:
                f = gv[i] * weight
                h = gv[i]
                cv_cov[i] = cv_cov[i] + cv_lr * (f * h - cv_cov[i])
                cv_var[i] = cv_var[i] + cv_lr * (h**2 - cv_var[i])
                grad[i] += f - (cv_cov[i]/(cv_var[i] + 1e-8)) * h
        
        _, gwprior = model_p.dlogpw_dw(w)
        for i in gw: gw[i] += float(n_subbatch*n_batch)/n_tot * gwprior[i]

        def optimize(_w, _gw, gw_ss, stepsize):
            reg=1e-8
            for i in _gw:
                gw_ss[i] += _gw[i]**2
                if nsteps[0] > warmup:
                    _w[i] += stepsize / np.sqrt(gw_ss[i]+reg) * _gw[i]

        optimize(w, gw, gw_ss, ada_stepsize)
        optimize(v, grad, gv_ss, ada_stepsize)
        
        nsteps[0] += 1
        
        if ndict.hasNaN(grad):
            raise Exception()
        if ndict.hasNaN(v):
            raise Exception()
        
        return z.copy(), logpx_p + logpz_p - logpz_q
Esempio n. 4
0
    def dfd_dw(self, w, x, z, gz2):
        x, z = self.xz_to_theano(x, z)
        w, z, x, gz2 = ndict.ordereddicts((w, z, x, gz2))
        A = self.get_A(x)
        r = self.f_dfd_dw(*(list(w.values()) + list(x.values()) +
                            list(z.values()) + [A] + list(gz2.values())))
        logpx, logpz, fd, gw = r[0], r[1], r[2], dict(
            list(zip(list(w.keys()), r[3:3 + len(w)])))

        if ndict.hasNaN(gw):
            if True:
                print('NaN detected in gradients')
                raise Exception()
                for i in gw:
                    gw[i][np.isnan(gw[i])] = 0
            else:

                print('fd: ', fd)
                print('Values:')
                ndict.p(w)
                ndict.p(z)
                print('Gradients:')
                ndict.p(gw)
                raise Exception("dfd_dw(): NaN found in gradients")

        gw, _ = self.gwgz_to_numpy(gw, {})
        return logpx, logpz, fd, gw
Esempio n. 5
0
	def gen_xz(self, w, x, z, n_batch):
		A = np.ones((1, n_batch))
		
		if not x.has_key('x'):
			x['x'] = np.random.binomial(1, 0.5, size=(self.n_units[0], n_batch))
		
		def f_sigmoid(x): return 1./(1.+np.exp(-x))
		def f_softplus(x): return np.log(np.exp(x) + 1)# - np.log(2)
		nonlinear = {'tanh': np.tanh, 'sigmoid': f_sigmoid,'softplus': f_softplus}[self.nonlinear]

		hidden = []
		hidden.append(x['x'])
		for i in range(1, len(self.n_units)-1):
			hidden.append(nonlinear(np.dot(w['w_%i'%i], hidden[i-1]) + np.dot(w['b_%i'%i], A)))
		
		mean = np.dot(w['w_mean'], hidden[-1]) + self.logmeanb_factor * np.dot(w['b_mean'], A)
		logvar = self.logvar_const + self.logvar_factor * np.dot(w['w_logvar'], hidden[-1]) + np.dot(w['b_logvar'], A)
		
		if not z.has_key('eps0'):
			z['eps0'] = np.random.normal(mean, np.exp(logvar/2))
		
		if ndict.hasNaN(z):
			raise Exception("NaN detected")
		
		_z = {'mean':mean, 'logvar':logvar}
		return x, z, _z
Esempio n. 6
0
 def checknan(self, v, w, gv, gw):
     
     if ndict.hasNaN(gv) or ndict.hasNaN(gw):
             raise Exception("dL_dw(): NaN found in gradients")
             #print 'logpx: ', logpx
             #print 'logpz: ', logpz
             #print 'logqz: ', logqz
             print 'v:'
             ndict.p(v)
             print 'w:'
             ndict.p(w)
             print 'gv:'
             ndict.p(gv)
             print 'gw:'
             ndict.p(gw)
             raise Exception("dL_dw(): NaN found in gradients")
Esempio n. 7
0
    def checknan(self, v, w, gv, gw):

        if ndict.hasNaN(gv) or ndict.hasNaN(gw):
            raise Exception("dL_dw(): NaN found in gradients")
            #print 'logpx: ', logpx
            #print 'logpz: ', logpz
            #print 'logqz: ', logqz
            print 'v:'
            ndict.p(v)
            print 'w:'
            ndict.p(w)
            print 'gv:'
            ndict.p(gv)
            print 'gw:'
            ndict.p(gw)
            raise Exception("dL_dw(): NaN found in gradients")
Esempio n. 8
0
    def dlogpxz_dwz(self, w, x, z):

        x, z = self.xz_to_theano(x, z)
        w, z, x = ndict.ordereddicts((w, z, x))
        A = self.get_A(x)
        allvars = list(w.values()) + list(x.values()) + list(z.values()) + [A]

        # Check if keys are correct
        keys = list(w.keys()) + list(x.keys()) + list(z.keys()) + ['A']
        for i in range(len(keys)):
            if keys[i] != self.allvars_keys[i]:
                "Input values are incorrect!"
                print('Input:', keys)
                print('Should be:', self.allvars_keys)
                raise Exception()

        r = self.f_dlogpxz_dwz(*allvars)
        logpx, logpz, gw, gz = r[0], r[1], dict(
            list(zip(list(w.keys()), r[2:2 + len(w)]))), dict(
                list(zip(list(z.keys()), r[2 + len(w):])))

        if ndict.hasNaN(gw) or ndict.hasNaN(gz):
            if True:
                print('NaN detected in gradients')
                raise Exception()
                for i in gw:
                    gw[i][np.isnan(gw[i])] = 0
                for i in gz:
                    gz[i][np.isnan(gz[i])] = 0
            else:
                print('logpx: ', logpx)
                print('logpz: ', logpz)
                print('Values:')
                ndict.p(w)
                ndict.p(z)
                print('Gradients:')
                ndict.p(gw)
                ndict.p(gz)
                raise Exception("dlogpxz_dwz(): NaN found in gradients")

        gw, gz = self.gwgz_to_numpy(gw, gz)
        return logpx, logpz, gw, gz
Esempio n. 9
0
 def dlogpxz_dwz(self, w, x, z):
     
     x, z = self.xz_to_theano(x, z)
     w, z, x = ndict.ordereddicts((w, z, x))
     A = self.get_A(x)
     allvars = w.values() + x.values() + z.values() + [A]
     
     # Check if keys are correct
     keys = w.keys() + x.keys() + z.keys() + ['A']
     for i in range(len(keys)):
         if keys[i] != self.allvars_keys[i]:
             "Input values are incorrect!"
             print 'Input:', keys
             print 'Should be:', self.allvars_keys
             raise Exception()
         
     r = self.f_dlogpxz_dwz(*allvars)
     logpx, logpz, gw, gz = r[0], r[1], dict(zip(w.keys(), r[2:2+len(w)])), dict(zip(z.keys(), r[2+len(w):]))
     
     if ndict.hasNaN(gw) or ndict.hasNaN(gz):
         if True:
             print 'NaN detected in gradients'
             raise Exception()
             for i in gw: gw[i][np.isnan(gw[i])] = 0
             for i in gz: gz[i][np.isnan(gz[i])] = 0
         else:
             print 'logpx: ', logpx
             print 'logpz: ', logpz
             print 'Values:'
             ndict.p(w)
             ndict.p(z)
             print 'Gradients:'
             ndict.p(gw)
             ndict.p(gz)
             raise Exception("dlogpxz_dwz(): NaN found in gradients")
     
     gw, gz = self.gwgz_to_numpy(gw, gz)
     return logpx, logpz, gw, gz
Esempio n. 10
0
 def dfd_dw(self, w, x, z, gz2):
     x, z = self.xz_to_theano(x, z)
     w, z, x, gz2 = ndict.ordereddicts((w, z, x, gz2))
     A = self.get_A(x)
     r = self.f_dfd_dw(*(w.values() + x.values() + z.values() + [A] + gz2.values()))
     logpx, logpz, fd, gw = r[0], r[1], r[2], dict(zip(w.keys(), r[3:3+len(w)]))
     
     if ndict.hasNaN(gw):
         if True:
             print 'NaN detected in gradients'
             raise Exception()
             for i in gw: gw[i][np.isnan(gw[i])] = 0
         else:
             
             print 'fd: ', fd
             print 'Values:'
             ndict.p(w)
             ndict.p(z)
             print 'Gradients:'
             ndict.p(gw)
             raise Exception("dfd_dw(): NaN found in gradients")
     
     gw, _ = self.gwgz_to_numpy(gw, {})
     return logpx, logpz, fd, gw