def test_autoencoder_logistic_linear_tied(): data = np.random.randn(10, 5).astype(config.floatX) ae = Autoencoder(5, 7, act_enc='sigmoid', act_dec='linear', tied_weights=True) w = ae.weights.get_value() ae.hidbias.set_value(np.random.randn(7).astype(config.floatX)) hb = ae.hidbias.get_value() ae.visbias.set_value(np.random.randn(5).astype(config.floatX)) vb = ae.visbias.get_value() d = tensor.matrix() result = np.dot(1. / (1 + np.exp(-hb - np.dot(data, w))), w.T) + vb ff = theano.function([d], ae.reconstruct(d)) assert _allclose(ff(data), result)
def test_autoencoder_tanh_cos_untied(): data = np.random.randn(10, 5).astype(config.floatX) ae = Autoencoder(5, 7, act_enc='tanh', act_dec='cos', tied_weights=False) w = ae.weights.get_value() w_prime = ae.w_prime.get_value() ae.hidbias.set_value(np.random.randn(7).astype(config.floatX)) hb = ae.hidbias.get_value() ae.visbias.set_value(np.random.randn(5).astype(config.floatX)) vb = ae.visbias.get_value() d = tensor.matrix() result = np.cos(np.dot(np.tanh(hb + np.dot(data, w)), w_prime) + vb) ff = theano.function([d], ae.reconstruct(d)) assert _allclose(ff(data), result)
def create_ae(conf, layer, data, model=None): """ This function basically train an autoencoder according to the parameters in conf, and save the learned model """ savedir = utils.getboth(layer, conf, 'savedir') clsname = layer['autoenc_class'] # Guess the filename if model is not None: if model.endswith('.pkl'): filename = os.path.join(savedir, model) else: filename = os.path.join(savedir, model + '.pkl') else: filename = os.path.join(savedir, layer['name'] + '.pkl') # Try to load the model if model is not None: print '... loading layer:', clsname try: return Autoencoder.load(filename) except Exception, e: print 'Warning: error while loading %s:' % clsname, e.args[0] print 'Switching back to training mode.'
def get_autoencoder(structure): n_input, n_output = structure config = { 'nhid': n_output, 'nvis': n_input, 'tied_weights': True, 'act_enc': 'tanh', 'act_dec': 'sigmoid', 'irange': 0.001, } return Autoencoder(**config)