コード例 #1
0
def main(_):
    """
    Launches this whole zingamajinga.
    """
    data_path = params.local_data_dir + params.this_data_dir
    rlt_dir = params.local_rlt_dir + params.this_data_dir + addDateTime() + '/'
    if params.mode == 'generate':
        generate_fake_data(lat_mod_class=params.lat_mod_class,
                           gen_mod_class=params.gen_mod_class,
                           params=params,
                           data_path=data_path,
                           save_data_file=params.save_data_file,
                           Nsamps=params.genNsamps,
                           NTbins=params.genNTbins,
                           write_params_file=True,
                           draw_quiver=True,
                           draw_heat_maps=True,
                           savefigs=True)
    if params.mode == 'train':
        graph = tf.Graph()
        with graph.as_default():
            sess = tf.Session(graph=graph)
            with sess.as_default():
                with open(data_path + params.save_data_file, 'rb+') as f:
                    # Set encoding='latin1' for python 2 pickled data
                    datadict = pickle.load(
                        f,
                        encoding='latin1') if params.is_py2 else pickle.load(f)
                    Ytrain = datadict['Ytrain']
                    Yvalid = datadict['Yvalid']

                params.yDim = Ytrain.shape[-1]

                if not os.path.exists(rlt_dir): os.makedirs(rlt_dir)
                write_option_file(rlt_dir)

                opt = Optimizer_TS(params)

                sess.run(tf.global_variables_initializer())
                opt.train(sess, rlt_dir, Ytrain, Yvalid)
コード例 #2
0
class OptimizerTest(tf.test.TestCase):
    """
    """
    graph = tf.Graph()
    with graph.as_default():
        sess = tf.Session(graph=graph)
        with sess.as_default():
            opt = Optimizer_TS(params)
            sess.run(tf.global_variables_initializer())

    def test_train(self):
        sess = self.sess
        with open(DATA_FILE, 'rb+') as f:
            #             datadict = pickle.load(f, encoding='latin1')
            datadict = pickle.load(f, encoding='latin1')
            Ydata = datadict['Ytrain']
            Yvalid = datadict['Yvalid']
        with sess.as_default():
            X = sess.run(self.opt.mrec.Mu_NxTxd, feed_dict={'VAEC/Y:0': Ydata})
            Xv = sess.run(self.opt.mrec.Mu_NxTxd,
                          feed_dict={'VAEC/Y:0': Yvalid})
            print(len(X), len(Xv))
            pX = sess.run(self.opt.mrec.postX,
                          feed_dict={
                              'VAEC/Y:0': Ydata,
                              'VAEC/X:0': X
                          })
            pXv = sess.run(self.opt.mrec.postX,
                           feed_dict={
                               'VAEC/Y:0': Yvalid,
                               'VAEC/X:0': Xv
                           })
            print(len(pX), len(pXv))
            new_valid_cost = sess.run(self.opt.cost,
                                      feed_dict={
                                          'VAEC/X:0': Xv,
                                          'VAEC/Y:0': Yvalid
                                      })
            print(new_valid_cost)
            self.opt.train(sess, RLT_DIR, Ydata)
            print(
                sess.run(self.opt.cost,
                         feed_dict={
                             'VAEC/X:0': Xv,
                             'VAEC/Y:0': Yvalid
                         }))
コード例 #3
0
class DataTests(tf.test.TestCase):
    """
    """
    with open(DATA_FILE, 'rb+') as f:
        datadict = pickle.load(f, encoding='latin1') # `encoding='latin1'` for python 2 pickled data 
        Ydata = 2*datadict['Ytrain']
        Yvalid = 2*datadict['Yvalid']
        
    yDim = Ydata.shape[2]
    
    graph = tf.Graph()
    with graph.as_default():
        sess = tf.Session(graph=graph)
        with sess.as_default():
            opt = Optimizer_TS(params)
            mrec = opt.mrec
            mlat = opt.lat_ev_model
            mgen = opt.mgen
            
            sess.run(tf.global_variables_initializer())
            
    def test_logdensity(self):
        sess = self.sess
        with sess.as_default():
            MuX = sess.run(self.mrec.Mu_NxTxd, feed_dict={'VAEC/Y:0' : self.Ydata})
            cost = sess.run(self.opt.cost, feed_dict={'VAEC/Y:0' : self.Ydata,
                                                      'VAEC/X:0' : MuX})
            print('cost:', cost)
            postX = sess.run(self.mrec.postX_NxTxd, feed_dict={'VAEC/Y:0' : self.Ydata,
                                                         'VAEC/X:0' : MuX})
            cost = sess.run(self.opt.cost, feed_dict={'VAEC/Y:0' : self.Ydata,
                                                      'VAEC/X:0' : postX})
            checks = sess.run(self.opt.checks1, feed_dict={'VAEC/Y:0' : self.Ydata,
                                                          'VAEC/X:0' : postX})
            print('cost:', cost)
            print('checks1', checks)
            
    def test_postX(self):
        if params.gen_mod_class == 'Gaussian':
            sess = self.sess
            with sess.as_default():
                MuX = sess.run(self.mrec.Mu_NxTxd, feed_dict={'VAEC/Y:0' : self.Ydata})
                MuX_valid = sess.run(self.mrec.Mu_NxTxd, feed_dict={'VAEC/Y:0' : self.Yvalid})
                postX = sess.run(self.mrec.postX_NxTxd, feed_dict={'VAEC/Y:0' : self.Ydata,
                                                             'VAEC/X:0' : MuX})
                postX_valid = sess.run(self.mrec.postX_NxTxd, feed_dict={'VAEC/Y:0' : self.Yvalid,
                                                                   'VAEC/X:0' : MuX_valid})
                Yprime = sess.run(self.opt.mgen.MuY_NxTxD, feed_dict={'VAEC/X:0' : postX})
                SigmaY = sess.run(self.mgen.SigmaInvY_DxD)
                
                cost = sess.run(self.opt.cost, feed_dict={'VAEC/X:0' : MuX,
                                                          'VAEC/Y:0' : self.Ydata})
                checks = sess.run(self.opt.checks1, feed_dict={'VAEC/X:0' : MuX,
                                                              'VAEC/Y:0' : self.Ydata})
                new_valid_cost = sess.run(self.opt.cost, feed_dict={'VAEC/X:0' : MuX_valid,
                                                                'VAEC/Y:0' : self.Yvalid})
                checks_v = sess.run(self.opt.checks1, feed_dict={'VAEC/X:0' : MuX_valid,
                                                              'VAEC/Y:0' : self.Yvalid})
    #             checks2 = sess.run(self.mrec.checks1, feed_dict={'VAEC/X:0' : MuX_valid,
    #                                                           'VAEC/Y:0' : self.Yvalid})
                
                print('Yprime (mean, std)', np.mean(Yprime), np.std(Yprime))
                mins, maxs = np.min(Yprime, axis=(0,1)), np.max(Yprime, axis=(0,1))
                print('Yprime ranges', list(zip(mins, maxs)))
                print('\nSigmaY', SigmaY[0,0])
                print('SigmaY', np.linalg.det(SigmaY))
                
                print('\npostX (mean, std)', np.mean(postX), np.std(postX))
                
                mins, maxs = np.min(postX, axis=(0,1)), np.max(postX, axis=(0,1))
                print('postX ranges', list(zip(mins, maxs)))
                print('postX_valid (mean, std)', np.mean(postX_valid), np.std(postX_valid))
                mins, maxs = np.min(postX_valid, axis=(0,1)), np.max(postX_valid, axis=(0,1))
                print('postX_valid ranges', list(zip(mins, maxs)))
                
                print('cost', cost, checks)
                print('valid cost', new_valid_cost, checks_v)
                print('')

    def test_data(self):
        print('Y (mean, std):', np.mean(self.Ydata), np.std(self.Ydata))
        print('Y range:', np.min(self.Ydata), np.max(self.Ydata))
        print('')
            
    def test_inferredX_range(self):
        """
        This computes the initial ranges for the values of the latent-space
        variables inferred by the recognition network for your data. Reasonable
        values per dimension are -30 <~ min(MuX) <~ max(MuX) < 30. 
        """
        sess = self.sess
        with sess.as_default():
            MuX = sess.run(self.mrec.Mu_NxTxd, feed_dict={'VAEC/Y:0' : self.Ydata})
            print('MuX (mean, std)', np.mean(MuX), np.std(MuX))
            mins, maxs = np.min(MuX, axis=(0,1)), np.max(MuX, axis=(0,1))
            print('MuX ranges', list(zip(mins, maxs)))
            print('')
    
    def test_inferredLambdaX_range(self):
        """
        This computes the initial ranges for the values of the latent-space
        precision as yielded by the recognition network for your data. Reasonable
        values per LambdaX entry L are -3 < min(L) < max(L) < 3
        """
        sess = self.sess
        with sess.as_default():
            LambdaX = sess.run(self.mrec.Lambda_NxTxdxd, feed_dict={'VAEC/Y:0' : self.Ydata})
            print('LambdaX (mean, std)', np.mean(LambdaX), np.std(LambdaX))
            mins, maxs = np.min(LambdaX, axis=(0,1)).flatten(), np.max(LambdaX, axis=(0,1)).flatten()
            print('LambdaX ranges', list(zip(mins, maxs)))
            print('')
        
    def test_nonlinearity_range(self):
        """
        The average and max values of the nonlinearity alpha*B should be <~
        o(10^-1). That is smaller than 1, yet sizable. This depends on the
        nonlinearity network and on the range of the values in the latent space
        that the recognition network yields.
        """
        sess = self.sess
        with sess.as_default():
            MuX = sess.run(self.mrec.Mu_NxTxd, feed_dict={'VAEC/Y:0' : self.Ydata})
            alphaB = sess.run(self.mlat.alpha*self.mlat.B_NxTxdxd,
                              feed_dict={'VAEC/X:0' : MuX})
            print('alphaB (mean, std)', np.mean(alphaB), np.std(alphaB))
            mins, maxs = np.min(alphaB, axis=(0,1)).flatten(), np.max(alphaB, axis=(0,1)).flatten()
            print('alphaB ranges', list(zip(mins, maxs)))