input_dim = (28 / reduction)**2 # input, hidden and output dims hidden_dims = 32 * [32] dims = [input_dim, hidden_dims[0]] + hidden_dims + [n_outputs] # allocate parameters fs = [activation.tanh for _ in dims[1:-1]] + [activation.logsoftmax] # layer input means cs = [util.shared_floatx((m,), initialization.constant(0)) for m in dims[:-1]] # layer input whitening matrices Us = [util.shared_floatx((m, m), initialization.identity()) for m in dims[:-1]] # weight matrices Ws = [util.shared_floatx((m, n), initialization.orthogonal()) for m, n in util.safezip(dims[:-1], dims[1:])] # batch normalization diagonal scales gammas = [util.shared_floatx((n, ), initialization.constant(1)) for n in dims[1:]] # biases or betas bs = [util.shared_floatx((n, ), initialization.constant(0)) for n in dims[1:]] # reparametrization updates updates = [] # theano graphs with assertions & breakpoints, to be evaluated after # performing the updates checks = [] parameters_by_layer = []
import theano.tensor as T import util, activation, initialization, steprules, whitening, mnist learning_rate = 1e-3 # use batch normalization in addition to PRONG (i.e. PRONG+) batch_normalize = False data = mnist.get_data() n_outputs = 10 dims = [784, 500, 300, 100, n_outputs] layers = [ dict(f=activation.tanh, c=util.shared_floatx((m,), initialization.constant(0)), # input mean U=util.shared_floatx((m, m), initialization.identity()), # input whitening matrix W=util.shared_floatx((m, n), initialization.orthogonal()), # weight matrix g=util.shared_floatx((n,), initialization.constant(1)), # gammas (for batch normalization) b=util.shared_floatx((n,), initialization.constant(0))) # bias for m, n in util.safezip(dims[:-1], dims[1:])] layers[-1]["f"] = activation.logsoftmax features, targets = T.matrix("features"), T.ivector("targets") #theano.config.compute_test_value = "warn" #features.tag.test_value = data["valid"]["features"][:11] #targets.tag.test_value = data["valid"]["targets"][:11] # reparametrization updates reparameterization_updates = [] # theano graphs with assertions & breakpoints, to be evaluated after # performing the updates