def make_rbm2(Q, args): if os.path.isdir(args.rbm2_dirpath): print "\nLoading RBM #2 ...\n\n" rbm2 = BernoulliRBM.load_model(args.rbm2_dirpath) else: print "\nTraining RBM #2 ...\n\n" epochs = args.epochs[1] n_every = args.increase_n_gibbs_steps_every n_gibbs_steps = np.arange(args.n_gibbs_steps[1], args.n_gibbs_steps[1] + epochs / n_every) learning_rate = args.lr[1] / np.arange(1, 1 + epochs / n_every) n_gibbs_steps = np.repeat(n_gibbs_steps, n_every) learning_rate = np.repeat(learning_rate, n_every) rbm2 = BernoulliRBM( n_visible=args.n_hiddens[0], n_hidden=args.n_hiddens[1], W_init=0.005, vb_init=0., hb_init=0., n_gibbs_steps=n_gibbs_steps, learning_rate=learning_rate, momentum=[0.5] * 5 + [0.9], max_epoch=max(args.epochs[1], n_every), batch_size=args.batch_size[1], l2=args.l2[1], sample_h_states=True, sample_v_states=True, sparsity_cost=0., dbm_last=True, # !!! metrics_config=dict( msre=True, pll=True, train_metrics_every_iter=500, ), verbose=True, display_filters=0, display_hidden_activations=24, random_seed=args.random_seed[1], dtype='float32', tf_saver_params=dict(max_to_keep=1), model_path=args.rbm2_dirpath) rbm2.fit(Q) return rbm2
def make_rbm(X_train, X_val, args): if os.path.isdir(args.model_dirpath): print("\nLoading model ...\n\n") rbm = BernoulliRBM.load_model(args.model_dirpath) else: print("\nTraining model ...\n\n") rbm = BernoulliRBM(n_visible=784, n_hidden=args.n_hidden, W_init=args.w_init, vb_init=logit_mean(X_train) if args.vb_init else 0., hb_init=args.hb_init, n_gibbs_steps=args.n_gibbs_steps, learning_rate=args.lr, momentum=np.geomspace(0.5, 0.9, 8), max_epoch=args.epochs, batch_size=args.batch_size, l2=args.l2, sample_v_states=args.sample_v_states, sample_h_states=True, dropout=args.dropout, sparsity_target=args.sparsity_target, sparsity_cost=args.sparsity_cost, sparsity_damping=args.sparsity_damping, metrics_config=dict( msre=True, pll=True, feg=True, train_metrics_every_iter=1000, val_metrics_every_epoch=2, feg_every_epoch=4, n_batches_for_feg=50, ), verbose=True, display_filters=30, display_hidden_activations=24, v_shape=(28, 28), random_seed=args.random_seed, dtype=args.dtype, tf_saver_params=dict(max_to_keep=1), model_path=args.model_dirpath) rbm.fit(X_train, X_val) return rbm
def test_consistency_val(self): rbm1 = BernoulliRBM(max_epoch=2, model_path='test_rbm_1/', **self.rbm_config) rbm2 = BernoulliRBM(max_epoch=2, model_path='test_rbm_2/', **self.rbm_config) rbm1.fit(self.X, self.X_val) rbm2.fit(self.X, self.X_val) self.compare_weights(rbm1, rbm2) self.compare_transforms(rbm1, rbm2) # cleanup self.cleanup()
def make_rbm1(X, args): if os.path.isdir(args.rbm1_dirpath): print "\nLoading RBM #1 ...\n\n" rbm1 = BernoulliRBM.load_model(args.rbm1_dirpath) else: print "\nTraining RBM #1 ...\n\n" rbm1 = BernoulliRBM( n_visible=784, n_hidden=args.n_hiddens[0], W_init=0.001, vb_init=0., hb_init=0., n_gibbs_steps=args.n_gibbs_steps[0], learning_rate=args.lr[0], momentum=[0.5] * 5 + [0.9], max_epoch=args.epochs[0], batch_size=args.batch_size[0], l2=args.l2[0], sample_h_states=True, sample_v_states=True, sparsity_cost=0., dbm_first=True, # !!! metrics_config=dict( msre=True, pll=True, train_metrics_every_iter=500, ), verbose=True, display_filters=30, display_hidden_activations=24, v_shape=(28, 28), random_seed=args.random_seed[0], dtype='float32', tf_saver_params=dict(max_to_keep=1), model_path=args.rbm1_dirpath) rbm1.fit(X) return rbm1
from bm.utils.dataset import load_mnist X, y = load_mnist(mode='train', path='../data/') X /= 255. X_test, y_test = load_mnist(mode='test', path='../data/') X_test /= 255. print(X.shape, y.shape, X_test.shape, y_test.shape) fig = plt.figure(figsize=(10, 10)) im_plot(X[:100], shape=(28, 28), title='Training examples', imshow_params={'cmap': plt.cm.gray}) plt.savefig('mnist.png', dpi=196, bbox_inches='tight') rbm1 = BernoulliRBM.load_model('../models/dbm_mnist_rbm1/') rbm1_W = rbm1.get_tf_params(scope='weights')['W'] fig = plt.figure(figsize=(10, 10)) im_plot(rbm1_W.T, shape=(28, 28), title='First 100 filters extracted by RBM #1', imshow_params={'cmap': plt.cm.gray}) plt.savefig('dbm_mnist_rbm1.png', dpi=196, bbox_inches='tight') rbm2 = BernoulliRBM.load_model('../models/dbm_mnist_rbm2/') rbm2_W = rbm2.get_tf_params(scope='weights')['W'] U = rbm1_W.dot(rbm2_W) fig = plt.figure(figsize=(10, 10))
shape=(28, 28), title='Training examples', imshow_params={'cmap': plt.cm.gray}) plt.savefig('mnist.png', dpi=196, bbox_inches='tight') rbm1 = GaussianRBM.load_model('../models/dbm_mnist_gauss_rbm1/') rbm1_W = rbm1.get_tf_params(scope='weights')['W'] fig = plt.figure(figsize=(10, 10)) im_plot(rbm1_W.T, shape=(28, 28), title='First 100 filters extracted by RBM #1', imshow_params={'cmap': plt.cm.gray}) plt.savefig('dbm_mnist_gauss_rbm1.png', dpi=196, bbox_inches='tight') rbm2 = BernoulliRBM.load_model('../models/dbm_mnist_gauss_rbm2/') rbm2_W = rbm2.get_tf_params(scope='weights')['W'] U = rbm1_W.dot(rbm2_W) fig = plt.figure(figsize=(10, 10)) im_plot(U.T, shape=(28, 28), title='First 100 (high-level) filters extracted by RBM #2', imshow_params={'cmap': plt.cm.gray}) plt.savefig('dbm_mnist_gauss_rbm2.png', dpi=196, bbox_inches='tight') dbm = DBM.load_model('../models/dbm_gauss_mnist/') dbm.load_rbms([rbm1, rbm2]) # !!! W1_joint = dbm.get_tf_params(scope='weights')['W']