input_dim = (28 / reduction)**2

# input, hidden and output dims
hidden_dims = 32 * [32]
dims = [input_dim, hidden_dims[0]] + hidden_dims + [n_outputs]

# allocate parameters
fs = [activation.tanh for _ in dims[1:-1]] + [activation.logsoftmax]
# layer input means
cs = [util.shared_floatx((m,), initialization.constant(0))
        for m in dims[:-1]]
# layer input whitening matrices
Us = [util.shared_floatx((m, m), initialization.identity())
        for m in dims[:-1]]
# weight matrices
Ws = [util.shared_floatx((m, n), initialization.orthogonal())
      for m, n in util.safezip(dims[:-1], dims[1:])]
# batch normalization diagonal scales
gammas = [util.shared_floatx((n, ), initialization.constant(1))
            for n in dims[1:]]
# biases or betas
bs = [util.shared_floatx((n, ), initialization.constant(0))
      for n in dims[1:]]

# reparametrization updates
updates = []
# theano graphs with assertions & breakpoints, to be evaluated after
# performing the updates
checks = []

parameters_by_layer = []
import theano.tensor as T
import util, activation, initialization, steprules, whitening, mnist

learning_rate = 1e-3
# use batch normalization in addition to PRONG (i.e. PRONG+)
batch_normalize = False

data = mnist.get_data()
n_outputs = 10

dims = [784, 500, 300, 100, n_outputs]
layers = [
    dict(f=activation.tanh,
         c=util.shared_floatx((m,),   initialization.constant(0)),   # input mean
         U=util.shared_floatx((m, m), initialization.identity()),    # input whitening matrix
         W=util.shared_floatx((m, n), initialization.orthogonal()),  # weight matrix
         g=util.shared_floatx((n,),   initialization.constant(1)),   # gammas (for batch normalization)
         b=util.shared_floatx((n,),   initialization.constant(0)))   # bias
    for m, n in util.safezip(dims[:-1], dims[1:])]
layers[-1]["f"] = activation.logsoftmax

features, targets = T.matrix("features"), T.ivector("targets")

#theano.config.compute_test_value = "warn"
#features.tag.test_value = data["valid"]["features"][:11]
#targets.tag.test_value = data["valid"]["targets"][:11]

# reparametrization updates
reparameterization_updates = []
# theano graphs with assertions & breakpoints, to be evaluated after
# performing the updates