Beispiel #1
0
    def _get_flow_r(self):
        self.z_T_f = self.weights
        # TODO:
        logdets_layers = []
        flow_r = lasagne.layers.InputLayer([None, self.num_params])
        if 1:  # we always use RNVP for this!
            if self.coupling:
                layer_temp = CoupledDenseLayer(flow_r, 200)
                flow_r = IndexLayer(layer_temp, 0)
                logdets_layers.append(IndexLayer(layer_temp, 1))
                for c in range(self.coupling - 1):
                    flow_r = PermuteLayer(flow_r, self.num_params)
                    layer_temp = CoupledDenseLayer(flow_r, 200)
                    flow_r = IndexLayer(layer_temp, 0)
                    logdets_layers.append(IndexLayer(layer_temp, 1))
        else:
            assert False

        self.flow_r = flow_r
        self.z_T_b = lasagne.layers.get_output(self.flow_r, self.z_T_f)
        # split z_T_b into the different layers:
        self.z_T_bs = []
        t = 0
        for ws in self.weight_shapes:
            self.z_T_bs.append(self.z_T_b[:, t:t + ws[0]])
            t += ws[0]
        # TODO
        self.logdets_z_T_b = sum(
            [get_output(ld, self.ep) for ld in logdets_layers])
    def _get_hyper_net(self):
        # inition random noise
        ep = self.srng.normal(size=(1, self.num_params), dtype=floatX)
        logdets_layers = []
        h_net = lasagne.layers.InputLayer([1, self.num_params])

        # mean and variation of the initial noise
        layer_temp = LinearFlowLayer(h_net)
        h_net = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        # split the noise: hnet1 for filters, hnet2 for WN params (DK)
        h_net = SplitLayer(h_net, self.num_params - self.num_classes, 1)
        h_net1 = IndexLayer(h_net, 0, (1, self.num_params - self.num_classes))
        # TODO: full h_net2
        h_net2 = IndexLayer(h_net, 1)

        h_net1 = lasagne.layers.ReshapeLayer(h_net1,
                                             (self.n_kernels,) + \
                                             (np.prod(self.kernel_shape),))
        if self.flow == 'RealNVP':
            if self.coupling:
                layer_temp = CoupledDenseLayer(h_net1, 100)
                h_net1 = IndexLayer(layer_temp, 0)
                logdets_layers.append(IndexLayer(layer_temp, 1))

                for c in range(self.coupling - 1):
                    h_net1 = PermuteLayer(h_net1, self.num_params)

                    layer_temp = CoupledDenseLayer(h_net1, 100)
                    h_net1 = IndexLayer(layer_temp, 0)
                    logdets_layers.append(IndexLayer(layer_temp, 1))
        elif self.flow == 'IAF':
            layer_temp = IAFDenseLayer(h_net1,
                                       200,
                                       1,
                                       L=self.coupling,
                                       cond_bias=False)
            layer = IndexLayer(layer_temp, 0)
            logdets_layers.append(IndexLayer(layer_temp, 1))
        else:
            assert False

        self.kernel_weights = lasagne.layers.get_output(h_net1, ep)
        h_net1 = lasagne.layers.ReshapeLayer(h_net1,
                                             (1, self.n_kernels * \
                                                 np.prod(self.kernel_shape) ) )
        h_net = lasagne.layers.ConcatLayer([h_net1, h_net2], 1)
        self.h_net = h_net
        self.weights = lasagne.layers.get_output(h_net, ep)
        self.logdets = sum([get_output(ld, ep) for ld in logdets_layers])
    def _get_hyper_net(self):
        # inition random noise
        ep = self.srng.normal(size=(self.wd1, self.num_params), dtype=floatX)
        logdets_layers = []
        h_net = lasagne.layers.InputLayer([None, self.num_params])

        # mean and variation of the initial noise
        layer_temp = LinearFlowLayer(h_net)
        h_net = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        if self.flow == 'RealNVP':
            if self.coupling:
                layer_temp = CoupledDenseLayer(h_net, 200)
                h_net = IndexLayer(layer_temp, 0)
                logdets_layers.append(IndexLayer(layer_temp, 1))
                for c in range(self.coupling - 1):
                    h_net = PermuteLayer(h_net, self.num_params)
                    layer_temp = CoupledDenseLayer(h_net, 200)
                    h_net = IndexLayer(layer_temp, 0)
                    logdets_layers.append(IndexLayer(layer_temp, 1))
        elif self.flow == 'IAF':
            layer_temp = IAFDenseLayer(h_net,
                                       200,
                                       1,
                                       L=self.coupling,
                                       cond_bias=False)
            layer = IndexLayer(layer_temp, 0)
            logdets_layers.append(IndexLayer(layer_temp, 1))
        else:
            assert False

        self.h_net = h_net
        self.weights = lasagne.layers.get_output(h_net, ep)
        self.logdets = sum([get_output(ld, ep) for ld in logdets_layers])
    def _get_hyper_net(self):
        # inition random noise
        self.ep = self.srng.normal(size=(self.wd1, self.num_params),
                                   dtype=floatX)
        logdets_layers = []
        h_net = lasagne.layers.InputLayer([None, self.num_params])

        # mean and variation of the initial noise
        #layer_temp = LinearFlowLayer(h_net, W=init.Normal(0.01,-7))
        layer_temp = LinearFlowLayer(h_net,
                                     b=init.Normal(.01),
                                     W=init.Normal(0.000000000001, -22))
        self.mean = layer_temp.b
        self.log_var = layer_temp.W
        self.delta = .001  # default value from modules.py
        self.h_net = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        self.weights = lasagne.layers.get_output(h_net, self.ep)
        self.logdets = sum([get_output(ld, self.ep) for ld in logdets_layers])
Beispiel #5
0
    def _get_hyper_net(self):
        # inition random noise
        ep = self.srng.normal(size=(self.wd1, self.num_params), dtype=floatX)
        logdets_layers = []
        h_net = lasagne.layers.InputLayer([None, self.num_params])

        # mean and variation of the initial noise
        layer_temp = LinearFlowLayer(h_net)
        h_net = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        if self.coupling:
            # add more to introduce more correlation if needed
            layer_temp = CoupledDenseLayer(h_net, 200)
            h_net = IndexLayer(layer_temp, 0)
            logdets_layers.append(IndexLayer(layer_temp, 1))

            h_net = PermuteLayer(h_net, self.num_params)

            layer_temp = CoupledDenseLayer(h_net, 200)
            h_net = IndexLayer(layer_temp, 0)
            logdets_layers.append(IndexLayer(layer_temp, 1))

        self.h_net = h_net
        self.weights = lasagne.layers.get_output(h_net, ep)
        self.logdets = sum([get_output(ld, ep) for ld in logdets_layers])
Beispiel #6
0
    def _get_hyper_net(self):
        # inition random noise
        ep = self.srng.normal(size=(1, self.num_params), dtype=floatX)
        logdets_layers = []
        h_net = lasagne.layers.InputLayer([1, self.num_params])

        # mean and variation of the initial noise
        layer_temp = LinearFlowLayer(h_net)
        h_net = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        h_net = SplitLayer(h_net, self.num_params - self.num_classes, 1)
        h_net1 = IndexLayer(h_net, 0, (1, self.num_params - self.num_classes))
        h_net2 = IndexLayer(h_net, 1)


        h_net1 = lasagne.layers.ReshapeLayer(h_net1,
                                             (self.n_kernels,) + \
                                             (np.prod(self.kernel_shape),))
        if self.coupling:
            # add more to introduce more correlation if needed
            layer_temp = CoupledDenseLayer(h_net1, 100)
            h_net1 = IndexLayer(layer_temp, 0)
            logdets_layers.append(IndexLayer(layer_temp, 1))

            h_net1 = PermuteLayer(h_net1, self.num_params)

            layer_temp = CoupledDenseLayer(h_net1, 100)
            h_net1 = IndexLayer(layer_temp, 0)
            logdets_layers.append(IndexLayer(layer_temp, 1))

        self.kernel_weights = lasagne.layers.get_output(h_net1, ep)
        h_net1 = lasagne.layers.ReshapeLayer(h_net1,
                                             (1, self.n_kernels * \
                                                 np.prod(self.kernel_shape) ) )
        h_net = lasagne.layers.ConcatLayer([h_net1, h_net2], 1)
        self.h_net = h_net
        self.weights = lasagne.layers.get_output(h_net, ep)
        self.logdets = sum([get_output(ld, ep) for ld in logdets_layers])
 def _get_hyper_net(self):
     # inition random noise
     if self.noise_distribution == 'spherical_gaussian':
         self.ep = self.srng.normal(size=(self.wd1,
                                 self.num_params),dtype=floatX)
     elif self.noise_distribution == 'exponential_MoG':
         self.ep = self.srng.normal(size=(self.wd1, self.num_params), dtype=floatX)
         self.ep += 2 * self.srng.binomial(size=(self.wd1, self.num_params), dtype=floatX) - 1
     logdets_layers = []
     h_net = lasagne.layers.InputLayer([None,self.num_params])
     
     # mean and variation of the initial noise
     layer_temp = LinearFlowLayer(h_net)
     h_net = IndexLayer(layer_temp,0)
     logdets_layers.append(IndexLayer(layer_temp,1))
     
     if self.coupling:
         layer_temp = CoupledDenseLayer(h_net,200)
         h_net = IndexLayer(layer_temp,0)
         logdets_layers.append(IndexLayer(layer_temp,1))
         
         for c in range(self.coupling-1):
             h_net = PermuteLayer(h_net,self.num_params)
             
             layer_temp = CoupledDenseLayer(h_net,200)
             h_net = IndexLayer(layer_temp,0)
             logdets_layers.append(IndexLayer(layer_temp,1))
     
     self.h_net = h_net
     self.logits = lasagne.layers.get_output(h_net,self.ep)
     self.drop_probs = T.nnet.sigmoid(self.logits)
     self.logdets = sum([lasagne.layers.get_output(ld,self.ep) for ld in logdets_layers])
     # TODO: test this!
     self.logdets += T.log(T.grad(T.sum(self.drop_probs), self.logits)).sum()
     self.logqw = - self.logdets
     # TODO: we should multiply this by #units if we don't output them independently...
     self.logpw = (self.alpha-1)*T.log(self.drop_probs).sum() + (self.beta-1)*T.log(1 - self.drop_probs).sum() # - np.log(self.denom) #<--- this term is constant
     # we'll compute the whole KL term right here...
     self.kl = (self.logqw - self.logpw).mean()
Beispiel #8
0
        assert False # TODO
        num_params = sum(np.prod(ws[1]) for ws in weight_shapes)

    if perdatapoint:
        wd1 = input_var.shape[0]
    else:
        wd1 = 1

    ###########################
    # hypernet graph
    ep = srng.normal(size=(wd1,num_params), dtype=floatX)
    logdets_layers = []
    h_layer = lasagne.layers.InputLayer([None,num_params])
    
    layer_temp = LinearFlowLayer(h_layer)
    h_layer = IndexLayer(layer_temp,0)
    logdets_layers.append(IndexLayer(layer_temp,1))
    
    if coupling == 'conv':
        if fix_sigma: assert False # not implemented
        layer_temp = CoupledConv1DLayer(h_layer,16,5)
        h_layer = IndexLayer(layer_temp,0)
        logdets_layers.append(IndexLayer(layer_temp,1))
        
        h_layer = PermuteLayer(h_layer,num_params)
        
        layer_temp = CoupledConv1DLayer(h_layer,16,5)
        h_layer = IndexLayer(layer_temp,0)
        logdets_layers.append(IndexLayer(layer_temp,1))
    
    elif coupling == 'dense':
def main():
    """
    MNIST example
    weight norm reparameterized MLP with prior on rescaling parameters
    """

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--perdatapoint', action='store_true')
    parser.add_argument('--coupling', action='store_true')
    parser.add_argument('--size', default=10000, type=int)
    parser.add_argument('--lrdecay', action='store_true')
    parser.add_argument('--lr0', default=0.1, type=float)
    parser.add_argument('--lbda', default=0.01, type=float)
    parser.add_argument('--bs', default=50, type=int)
    args = parser.parse_args()
    print args

    perdatapoint = args.perdatapoint
    coupling = 1  #args.coupling
    lr0 = args.lr0
    lrdecay = args.lrdecay
    lbda = np.cast[floatX](args.lbda)
    bs = args.bs
    size = max(10, min(50000, args.size))
    clip_grad = 100
    max_norm = 100

    # load dataset
    filename = '/data/lisa/data/mnist.pkl.gz'
    train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist(filename)

    input_var = T.matrix('input_var')
    target_var = T.matrix('target_var')
    dataset_size = T.scalar('dataset_size')
    lr = T.scalar('lr')

    # 784 -> 20 -> 10
    weight_shapes = [(784, 200), (200, 10)]

    num_params = sum(ws[1] for ws in weight_shapes)
    if perdatapoint:
        wd1 = input_var.shape[0]
    else:
        wd1 = 1

    # stochastic hypernet
    ep = srng.normal(std=0.01, size=(wd1, num_params), dtype=floatX)
    logdets_layers = []
    h_layer = lasagne.layers.InputLayer([None, num_params])

    layer_temp = LinearFlowLayer(h_layer)
    h_layer = IndexLayer(layer_temp, 0)
    logdets_layers.append(IndexLayer(layer_temp, 1))

    if coupling:
        layer_temp = CoupledDenseLayer(h_layer, 200)
        h_layer = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        h_layer = PermuteLayer(h_layer, num_params)

        layer_temp = CoupledDenseLayer(h_layer, 200)
        h_layer = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

    weights = lasagne.layers.get_output(h_layer, ep)

    # primary net
    t = np.cast['int32'](0)
    layer = lasagne.layers.InputLayer([None, 784])
    inputs = {layer: input_var}
    for ws in weight_shapes:
        num_param = ws[1]
        w_layer = lasagne.layers.InputLayer((None, ws[1]))
        weight = weights[:, t:t + num_param].reshape((wd1, ws[1]))
        inputs[w_layer] = weight
        layer = stochasticDenseLayer2([layer, w_layer], ws[1])
        print layer.output_shape
        t += num_param

    layer.nonlinearity = nonlinearities.softmax
    y = T.clip(get_output(layer, inputs), 0.001, 0.999)  # stability

    # loss terms
    logdets = sum([get_output(logdet, ep) for logdet in logdets_layers])
    logqw = -(0.5 *
              (ep**2).sum(1) + 0.5 * T.log(2 * np.pi) * num_params + logdets)
    #logpw = log_normal(weights,0.,-T.log(lbda)).sum(1)
    logpw = log_stdnormal(weights).sum(1)
    kl = (logqw - logpw).mean()
    logpyx = -cc(y, target_var).mean()
    loss = -(logpyx - kl / T.cast(dataset_size, floatX))

    params = lasagne.layers.get_all_params([h_layer, layer])
    grads = T.grad(loss, params)
    mgrads = lasagne.updates.total_norm_constraint(grads, max_norm=max_norm)
    cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
    updates = lasagne.updates.adam(cgrads, params, learning_rate=lr)

    train = theano.function([input_var, target_var, dataset_size, lr],
                            loss,
                            updates=updates)
    predict = theano.function([input_var], y.argmax(1))

    records = train_model(train, predict, train_x[:size], train_y[:size],
                          valid_x, valid_y, lr0, lrdecay, bs)
def main():
    """
    MNIST example
    """

    import argparse

    parser = argparse.ArgumentParser()
    parser = argparse.ArgumentParser()
    parser.add_argument('--perdatapoint', action='store_true')
    parser.add_argument('--coupling', action='store_true')
    parser.add_argument('--size', default=10000, type=int)
    parser.add_argument('--lrdecay', action='store_true')
    parser.add_argument('--lr0', default=0.1, type=float)
    parser.add_argument('--lbda', default=10, type=float)
    parser.add_argument('--bs', default=50, type=int)
    args = parser.parse_args()
    print args

    perdatapoint = args.perdatapoint
    coupling = args.coupling
    size = max(10, min(50000, args.size))
    clip_grad = 10
    max_norm = 1000

    # load dataset
    filename = '/data/lisa/data/mnist.pkl.gz'
    train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist(filename)

    input_var = T.matrix('input_var')
    target_var = T.matrix('target_var')
    dataset_size = T.scalar('dataset_size')
    lr = T.scalar('lr')

    # 784 -> 20 -> 10
    weight_shapes = [(784, 20), (20, 20), (20, 10)]

    num_params = sum(np.prod(ws) for ws in weight_shapes)
    if perdatapoint:
        wd1 = input_var.shape[0]
    else:
        wd1 = 1

    # stochastic hypernet
    ep = srng.normal(size=(wd1, num_params), dtype=floatX)
    logdets_layers = []
    h_layer = lasagne.layers.InputLayer([None, num_params])

    layer_temp = LinearFlowLayer(h_layer)
    h_layer = IndexLayer(layer_temp, 0)
    logdets_layers.append(IndexLayer(layer_temp, 1))

    if coupling:
        layer_temp = CoupledConv1DLayer(h_layer, 16, 5)
        h_layer = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        h_layer = PermuteLayer(h_layer, num_params)

        layer_temp = CoupledConv1DLayer(h_layer, 16, 5)
        h_layer = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

    weights = lasagne.layers.get_output(h_layer, ep)

    # primary net
    t = np.cast['int32'](0)
    layer = lasagne.layers.InputLayer([None, 784])
    inputs = {layer: input_var}
    for ws in weight_shapes:
        num_param = np.prod(ws)
        print t, t + num_param
        w_layer = lasagne.layers.InputLayer((None, ) + ws)
        weight = weights[:, t:t + num_param].reshape((wd1, ) + ws)
        inputs[w_layer] = weight
        layer = stochasticDenseLayer([layer, w_layer], ws[1])
        t += num_param

    layer.nonlinearity = nonlinearities.softmax
    y = T.clip(get_output(layer, inputs), 0.001, 0.999)  # stability

    # loss terms
    logdets = sum([get_output(logdet, ep) for logdet in logdets_layers])
    logqw = -(0.5 *
              (ep**2).sum(1) + 0.5 * T.log(2 * np.pi) * num_params + logdets)
    logpw = log_stdnormal(weights).sum(1)
    kl = (logqw - logpw).mean()
    logpyx = -cc(y, target_var).mean()
    loss = -(logpyx - kl / T.cast(dataset_size, floatX))

    params = lasagne.layers.get_all_params([h_layer, layer])
    grads = T.grad(loss, params)
    mgrads = lasagne.updates.total_norm_constraint(grads, max_norm=max_norm)
    cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
    updates = lasagne.updates.nesterov_momentum(cgrads,
                                                params,
                                                learning_rate=lr)

    train = theano.function([input_var, target_var, dataset_size, lr],
                            loss,
                            updates=updates)
    predict = theano.function([input_var], y.argmax(1))

    records = train_model(train, predict, train_x[:size], train_y[:size],
                          valid_x, valid_y)

    output_probs = theano.function([input_var], y)
    MCt = np.zeros((100, 1000, 10))
    MCv = np.zeros((100, 1000, 10))
    for i in range(100):
        MCt[i] = output_probs(train_x[:1000])
        MCv[i] = output_probs(valid_x[:1000])

    tr = np.equal(MCt.mean(0).argmax(-1), train_y[:1000].argmax(-1)).mean()
    va = np.equal(MCv.mean(0).argmax(-1), valid_y[:1000].argmax(-1)).mean()
    print "train perf=", tr
    print "valid perf=", va

    for ii in range(15):
        print np.round(MCt[ii][0] * 1000)
def main():
    """
    MNIST example
    weight norm reparameterized MLP with prior on rescaling parameters
    """

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--coupling', action='store_true')
    parser.add_argument('--size', default=10000, type=int)
    parser.add_argument('--lrdecay', action='store_true')
    parser.add_argument('--lr0', default=0.1, type=float)
    parser.add_argument('--lbda', default=0.01, type=float)
    parser.add_argument('--bs', default=50, type=int)
    args = parser.parse_args()
    print args

    coupling = args.coupling
    lr0 = args.lr0
    lrdecay = args.lrdecay
    lbda = np.cast[floatX](args.lbda)
    bs = args.bs
    size = max(10, min(50000, args.size))
    clip_grad = 5
    max_norm = 10

    # load dataset
    filename = '/data/lisa/data/mnist.pkl.gz'
    train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist(filename)
    train_x = train_x.reshape(50000, 1, 28, 28)
    valid_x = valid_x.reshape(10000, 1, 28, 28)
    test_x = test_x.reshape(10000, 1, 28, 28)

    input_var = T.tensor4('input_var')
    target_var = T.matrix('target_var')
    dataset_size = T.scalar('dataset_size')
    lr = T.scalar('lr')

    # 784 -> 20 -> 10
    weight_shapes = [
        (16, 1, 5, 5),  # -> (None, 16, 14, 14)
        (16, 16, 5, 5),  # -> (None, 16,  7,  7)
        (16, 16, 5, 5)
    ]  # -> (None, 16,  4,  4)

    num_params = sum(np.prod(ws) for ws in weight_shapes) + 10
    wd1 = 1

    # stochastic hypernet
    ep = srng.normal(std=0.01, size=(wd1, num_params), dtype=floatX)
    logdets_layers = []
    h_layer = lasagne.layers.InputLayer([None, num_params])

    layer_temp = LinearFlowLayer(h_layer)
    h_layer = IndexLayer(layer_temp, 0)
    logdets_layers.append(IndexLayer(layer_temp, 1))

    if coupling:
        layer_temp = CoupledDenseLayer(h_layer, 200)
        h_layer = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

        h_layer = PermuteLayer(h_layer, num_params)

        layer_temp = CoupledDenseLayer(h_layer, 200)
        h_layer = IndexLayer(layer_temp, 0)
        logdets_layers.append(IndexLayer(layer_temp, 1))

    weights = lasagne.layers.get_output(h_layer, ep)

    # primary net
    t = np.cast['int32'](0)
    layer = lasagne.layers.InputLayer([None, 1, 28, 28])
    inputs = {layer: input_var}
    for ws in weight_shapes:
        num_param = np.prod(ws)
        weight = weights[:, t:t + num_param].reshape(ws)
        num_filters = ws[0]
        filter_size = ws[2]
        stride = 2
        pad = 'same'
        layer = stochasticConv2DLayer([layer, weight], num_filters,
                                      filter_size, stride, pad)
        print layer.output_shape
        t += num_param

    w_layer = lasagne.layers.InputLayer((None, 10))
    weight = weights[:, t:t + 10].reshape((wd1, 10))
    inputs[w_layer] = weight
    layer = stochasticDenseLayer2([layer, w_layer],
                                  10,
                                  nonlinearity=nonlinearities.softmax)

    y = T.clip(get_output(layer, inputs), 0.001, 0.999)

    # loss terms
    logdets = sum([get_output(logdet, ep) for logdet in logdets_layers])
    logqw = -(0.5 *
              (ep**2).sum(1) + 0.5 * T.log(2 * np.pi) * num_params + logdets)
    logpw = log_normal(weights, 0., -T.log(lbda)).sum(1)
    #logpw = log_stdnormal(weights).sum(1)
    kl = (logqw - logpw).mean()
    logpyx = -cc(y, target_var).mean()
    loss = -(logpyx - kl / T.cast(dataset_size, floatX))

    params = lasagne.layers.get_all_params([layer])[1:]  # excluding rand state
    grads = T.grad(loss, params)

    mgrads = lasagne.updates.total_norm_constraint(grads, max_norm=max_norm)
    cgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
    updates = lasagne.updates.adam(cgrads, params, learning_rate=lr)

    train = theano.function([input_var, target_var, dataset_size, lr],
                            loss,
                            updates=updates)
    predict = theano.function([input_var], y.argmax(1))

    records = train_model(train, predict, train_x[:size], train_y[:size],
                          valid_x, valid_y, lr0, lrdecay, bs)

    output_probs = theano.function([input_var], y)
    MCt = np.zeros((100, 1000, 10))
    MCv = np.zeros((100, 1000, 10))
    for i in range(100):
        MCt[i] = output_probs(train_x[:1000])
        MCv[i] = output_probs(valid_x[:1000])

    tr = np.equal(MCt.mean(0).argmax(-1), train_y[:1000].argmax(-1)).mean()
    va = np.equal(MCv.mean(0).argmax(-1), valid_y[:1000].argmax(-1)).mean()
    print "train perf=", tr
    print "valid perf=", va

    for ii in range(15):
        print np.round(MCt[ii][0] * 1000)