def check_dataset(name):
        try:
            x_dim, data_train, data_valid, data_test = datasets.get_streams(name, batch_size=10, small_batch_size=10)
        except IOError as e:
            raise SkipTest

        for ds in (data_train, data_valid, data_test):
            features, = next(ds.get_epoch_iterator())

            features = features.reshape([10, -1])
            assert features.shape == (10, x_dim)
Example #2
0
def main(args):
    """Run experiment. """
    lr_tag = float_tag(args.learning_rate)

    x_dim, train_stream, valid_stream, test_stream = datasets.get_streams(args.data, args.batch_size)

    #------------------------------------------------------------
    # Setup model
    deterministic_act = Tanh
    deterministic_size = 1.

    if args.method == 'vae':
        sizes_tag = args.layer_spec.replace(",", "-")
        layer_sizes = [int(i) for i in args.layer_spec.split(",")]
        layer_sizes, z_dim = layer_sizes[:-1], layer_sizes[-1]

        name = "%s-%s-%s-lr%s-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.n_samples, sizes_tag)

        if args.activation == "tanh":
            hidden_act = Tanh()
        elif args.activation == "logistic":
            hidden_act = Logistic()
        elif args.activation == "relu":
            hidden_act = Rectifier()
        else: 
            raise "Unknown hidden nonlinearity %s" % args.hidden_act

        model = VAE(x_dim=x_dim, hidden_layers=layer_sizes, hidden_act=hidden_act, z_dim=z_dim,
                    batch_norm=args.batch_normalization)
        model.initialize()
    elif args.method == 'rws':
        sizes_tag = args.layer_spec.replace(",", "-")
        name = "%s-%s-%s-lr%s-dl%d-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.deterministic_layers, args.n_samples, sizes_tag)

        p_layers, q_layers = create_layers(
                                args.layer_spec, x_dim,
                                args.deterministic_layers,
                                deterministic_act, deterministic_size)

        model = ReweightedWakeSleep(
                p_layers,
                q_layers,
            )
        model.initialize()
    elif args.method == 'bihm':
        sizes_tag = args.layer_spec.replace(",", "-")
        name = "%s-%s-%s-lr%s-dl%d-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.deterministic_layers, args.n_samples, sizes_tag)

        p_layers, q_layers = create_layers(
                                args.layer_spec, x_dim,
                                args.deterministic_layers,
                                deterministic_act, deterministic_size)

        model = BiHM(
                p_layers,
                q_layers,
                l1reg=args.l1reg,
                l2reg=args.l2reg,
            )
        model.initialize()
    elif args.method == 'continue':
        import cPickle as pickle
        from os.path import basename, splitext


        with open(args.model_file, 'rb') as f:
            m = pickle.load(f)

        if isinstance(m, MainLoop):
            m = m.model

        model = m.get_top_bricks()[0]
        while len(model.parents) > 0:
            model = model.parents[0]

        assert isinstance(model, (BiHM, ReweightedWakeSleep, VAE))

        mname, _, _ = basename(args.model_file).rpartition("_model.pkl")
        name = "%s-cont-%s-lr%s-spl%s" % (mname, args.name, lr_tag, args.n_samples)
    else:
        raise ValueError("Unknown training method '%s'" % args.method)

    #------------------------------------------------------------

    x = tensor.matrix('features')

    #------------------------------------------------------------
    # Testset monitoring

    train_monitors = []
    valid_monitors = []
    test_monitors = []
    for s in [1, 10, 100, 1000,]:
        log_p, log_ph = model.log_likelihood(x, s)
        log_p  = -log_p.mean()
        log_ph = -log_ph.mean()
        log_p.name  = "log_p_%d" % s
        log_ph.name = "log_ph_%d" % s

        #valid_monitors += [log_p, log_ph]
        test_monitors += [log_p, log_ph]

    #------------------------------------------------------------
    # Z estimation
    #for s in [100000]:
    #    z2 = tensor.exp(model.estimate_log_z2(s)) / s
    #    z2.name = "z2_%d" % s
    #
    #    valid_monitors += [z2]
    #    test_monitors += [z2]


    #------------------------------------------------------------
    # Gradient and training monitoring

    if args.method in ['vae', 'dvae']:
        log_p_bound = model.log_likelihood_bound(x, args.n_samples)
        gradients = None
        log_p_bound  = -log_p_bound.mean()
        log_p_bound.name  = "log_p_bound"
        cost = log_p_bound

        train_monitors += [log_p_bound, named(model.kl_term.mean(), 'kl_term'), named(model.recons_term.mean(), 'recons_term')]
        valid_monitors += [log_p_bound, named(model.kl_term.mean(), 'kl_term'), named(model.recons_term.mean(), 'recons_term')]
        test_monitors  += [log_p_bound, named(model.kl_term.mean(), 'kl_term'), named(model.recons_term.mean(), 'recons_term')]
    else:
        log_p, log_ph, gradients = model.get_gradients(x, args.n_samples)
        log_p_bound = named( -model.log_p_bound.mean(), "log_p_bound")
        log_p  = named( -log_p.mean(), "log_p")
        log_ph = named( -log_ph.mean(), "log_ph")
        cost = log_p

        train_monitors += [log_p_bound, log_p, log_ph]
        valid_monitors += [log_p_bound, log_p, log_ph]


    #------------------------------------------------------------
    cg = ComputationGraph([cost])

    if args.step_rule == "momentum":
        step_rule = Momentum(args.learning_rate, 0.95)
    elif args.step_rule == "rmsprop":
        step_rule = RMSProp(args.learning_rate)
    elif args.step_rule == "adam":
        step_rule = Adam(args.learning_rate)
    else:
        raise "Unknown step_rule %s" % args.step_rule

    parameters = cg.parameters

    algorithm = GradientDescent(
        cost=cost,
        parameters=parameters,
        gradients=gradients,
        step_rule=CompositeRule([
            step_rule,
        ])
    )

    #------------------------------------------------------------

    train_monitors += [aggregation.mean(algorithm.total_gradient_norm),
                       aggregation.mean(algorithm.total_step_norm)]

    #------------------------------------------------------------

    # Live plotting?
    plotting_extensions = []
    if args.live_plotting:
        plotting_extensions = [
            PlotManager(
                name,
                [Plotter(channels=[
                        ["valid_%s" % cost.name, "valid_log_p"],
                        ["train_total_gradient_norm", "train_total_step_norm"]],
                    titles=[
                        "validation cost",
                        "norm of training gradient and step"
                    ]),
                DisplayImage([
                    WeightDisplay(
                        model.p_layers[0].mlp.linear_transformations[0].W,
                        n_weights=100, image_shape=(28, 28))]
                    #ImageDataStreamDisplay(test_stream, image_shape=(28,28))]
                )]
            )
        ]

    main_loop = MainLoop(
        model=Model(cost),
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[Timing(),
                    ProgressBar(),
                    TrainingDataMonitoring(
                        train_monitors,
                        prefix="train",
                        after_epoch=False,
                        after_batch=True),
                    DataStreamMonitoring(
                        valid_monitors,
                        data_stream=valid_stream,
                        prefix="valid"),
                    DataStreamMonitoring(
                        test_monitors,
                        data_stream=test_stream,
                        prefix="test",
                        after_epoch=False,
                        after_training=True,
                        every_n_epochs=10),
                    TrackTheBest('valid_%s' % cost.name),
                    Checkpoint(name+".pkl", save_separately=['log', 'model']),
                    FinishIfNoImprovementAfter('valid_%s_best_so_far' % cost.name, epochs=args.patience),
                    FinishAfter(after_n_epochs=args.max_epochs),
                    Printing()] + plotting_extensions)
    main_loop.run()
Example #3
0
    layer_kl = [
        tensor.sum(lq - lp, axis=1) / n_samples
        for lp, lq in zip(log_p[:], log_q[:])
    ]

    do_kl = theano.function([x, n_samples], [log_px, total_kl] + layer_kl,
                            name="do_kl",
                            allow_input_downcast=True)

    #----------------------------------------------------------------------
    logger.info("Loading dataset...")

    n_samples = args.nsamples
    batch_size = max(1, 10000 // args.nsamples)

    x_dim, stream_train, stream_valid, stream_test = datasets.get_streams(
        args.data, batch_size)
    stream = stream_test

    log_px = np.array([])
    total_kl = np.array([])
    layer_kl = [np.array([]) for _ in xrange(n_layers)]
    for batch in stream.get_epoch_iterator():
        features = batch[0]
        ret = do_kl(features, n_samples)
        log_px_batch, total_kl_batch = ret[:2]
        layer_kl_batch = ret[2:]

        log_px = np.concatenate([log_px, log_px_batch])
        total_kl = np.concatenate([total_kl, total_kl_batch])
        for l, kl_batch in enumerate(layer_kl_batch):
            layer_kl[l] = np.concatenate([layer_kl[l], kl_batch])
Example #4
0
def main(args):
    """Run experiment. """
    lr_tag = float_tag(args.learning_rate)

    x_dim, train_stream, valid_stream, test_stream = datasets.get_streams(
        args.data, args.batch_size)

    #------------------------------------------------------------
    # Setup model
    deterministic_act = Tanh
    deterministic_size = 1.

    if args.method == 'vae':
        sizes_tag = args.layer_spec.replace(",", "-")
        layer_sizes = [int(i) for i in args.layer_spec.split(",")]
        layer_sizes, z_dim = layer_sizes[:-1], layer_sizes[-1]

        name = "%s-%s-%s-lr%s-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.n_samples, sizes_tag)

        if args.activation == "tanh":
            hidden_act = Tanh()
        elif args.activation == "logistic":
            hidden_act = Logistic()
        elif args.activation == "relu":
            hidden_act = Rectifier()
        else:
            raise "Unknown hidden nonlinearity %s" % args.hidden_act

        model = VAE(x_dim=x_dim,
                    hidden_layers=layer_sizes,
                    hidden_act=hidden_act,
                    z_dim=z_dim,
                    batch_norm=args.batch_normalization)
        model.initialize()
    elif args.method == 'dvae':
        sizes_tag = args.layer_spec.replace(",", "-")
        layer_sizes = [int(i) for i in args.layer_spec.split(",")]
        layer_sizes, z_dim = layer_sizes[:-1], layer_sizes[-1]

        name = "%s-%s-%s-lr%s-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.n_samples, sizes_tag)

        if args.activation == "tanh":
            hidden_act = Tanh()
        elif args.activation == "logistic":
            hidden_act = Logistic()
        elif args.activation == "relu":
            hidden_act = Rectifier()
        else:
            raise "Unknown hidden nonlinearity %s" % args.hidden_act

        model = DVAE(x_dim=x_dim,
                     hidden_layers=layer_sizes,
                     hidden_act=hidden_act,
                     z_dim=z_dim,
                     batch_norm=args.batch_normalization)
        model.initialize()
    elif args.method == 'rws':
        sizes_tag = args.layer_spec.replace(",", "-")
        qbase = "" if not args.no_qbaseline else "noqb-"

        name = "%s-%s-%s-%slr%s-dl%d-spl%d-%s" % \
            (args.data, args.method, args.name, qbase, lr_tag, args.deterministic_layers, args.n_samples, sizes_tag)

        p_layers, q_layers = create_layers(args.layer_spec, x_dim,
                                           args.deterministic_layers,
                                           deterministic_act,
                                           deterministic_size)

        model = ReweightedWakeSleep(
            p_layers,
            q_layers,
            qbaseline=(not args.no_qbaseline),
        )
        model.initialize()
    elif args.method == 'bihm-rws':
        sizes_tag = args.layer_spec.replace(",", "-")
        name = "%s-%s-%s-lr%s-dl%d-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.deterministic_layers, args.n_samples, sizes_tag)

        p_layers, q_layers = create_layers(args.layer_spec, x_dim,
                                           args.deterministic_layers,
                                           deterministic_act,
                                           deterministic_size)

        model = BiHM(
            p_layers,
            q_layers,
            l1reg=args.l1reg,
            l2reg=args.l2reg,
        )
        model.initialize()
    elif args.method == 'continue':
        import cPickle as pickle
        from os.path import basename, splitext

        with open(args.model_file, 'rb') as f:
            m = pickle.load(f)

        if isinstance(m, MainLoop):
            m = m.model

        model = m.get_top_bricks()[0]
        while len(model.parents) > 0:
            model = model.parents[0]

        assert isinstance(model, (BiHM, ReweightedWakeSleep, VAE))

        mname, _, _ = basename(args.model_file).rpartition("_model.pkl")
        name = "%s-cont-%s-lr%s-spl%s" % (mname, args.name, lr_tag,
                                          args.n_samples)
    else:
        raise ValueError("Unknown training method '%s'" % args.method)

    #------------------------------------------------------------

    x = tensor.matrix('features')

    #------------------------------------------------------------
    # Testset monitoring

    train_monitors = []
    valid_monitors = []
    test_monitors = []
    for s in [1, 10, 100, 1000]:
        log_p, log_ph = model.log_likelihood(x, s)
        log_p = -log_p.mean()
        log_ph = -log_ph.mean()
        log_p.name = "log_p_%d" % s
        log_ph.name = "log_ph_%d" % s

        #train_monitors += [log_p, log_ph]
        #valid_monitors += [log_p, log_ph]
        test_monitors += [log_p, log_ph]

    #------------------------------------------------------------
    # Z estimation
    #for s in [100000]:
    #    z2 = tensor.exp(model.estimate_log_z2(s)) / s
    #    z2.name = "z2_%d" % s
    #
    #    valid_monitors += [z2]
    #    test_monitors += [z2]

    #------------------------------------------------------------
    # Gradient and training monitoring

    if args.method in ['vae', 'dvae']:
        log_p_bound, gradients = model.get_gradients(x, args.n_samples)
        log_p_bound = -log_p_bound.mean()
        log_p_bound.name = "log_p_bound"
        cost = log_p_bound

        train_monitors += [
            log_p_bound,
            named(model.kl_term.mean(), 'kl_term'),
            named(model.recons_term.mean(), 'recons_term')
        ]
        valid_monitors += [
            log_p_bound,
            named(model.kl_term.mean(), 'kl_term'),
            named(model.recons_term.mean(), 'recons_term')
        ]
        test_monitors += [
            log_p_bound,
            named(model.kl_term.mean(), 'kl_term'),
            named(model.recons_term.mean(), 'recons_term')
        ]
    else:
        log_p, log_ph, gradients = model.get_gradients(x, args.n_samples)
        log_p = -log_p.mean()
        log_ph = -log_ph.mean()
        log_p.name = "log_p"
        log_ph.name = "log_ph"
        cost = log_ph

        train_monitors += [log_p, log_ph]
        valid_monitors += [log_p, log_ph]

    #------------------------------------------------------------
    # Detailed monitoring
    """
    n_layers = len(p_layers)

    log_px, w, log_p, log_q, samples = model.log_likelihood(x, n_samples)

    exp_samples = []
    for l in xrange(n_layers):
        e = (w.dimshuffle(0, 1, 'x')*samples[l]).sum(axis=1)
        e.name = "inference_h%d" % l
        e.tag.aggregation_scheme = aggregation.TakeLast(e)
        exp_samples.append(e)

    s1 = samples[1]
    sh1 = s1.shape
    s1_ = s1.reshape([sh1[0]*sh1[1], sh1[2]])
    s0, _ = model.p_layers[0].sample_expected(s1_)
    s0 = s0.reshape([sh1[0], sh1[1], s0.shape[1]])
    s0 = (w.dimshuffle(0, 1, 'x')*s0).sum(axis=1)
    s0.name = "inference_h0^"
    s0.tag.aggregation_scheme = aggregation.TakeLast(s0)
    exp_samples.append(s0)

    # Draw P-samples
    p_samples, _, _ = model.sample_p(100)
    #weights = model.importance_weights(samples)
    #weights = weights / weights.sum()

    for i, s in enumerate(p_samples):
        s.name = "psamples_h%d" % i
        s.tag.aggregation_scheme = aggregation.TakeLast(s)

    #
    samples = model.sample(100, oversample=100)

    for i, s in enumerate(samples):
        s.name = "samples_h%d" % i
        s.tag.aggregation_scheme = aggregation.TakeLast(s)
    """
    cg = ComputationGraph([cost])

    #------------------------------------------------------------

    if args.step_rule == "momentum":
        step_rule = Momentum(args.learning_rate, 0.95)
    elif args.step_rule == "rmsprop":
        step_rule = RMSProp(args.learning_rate)
    elif args.step_rule == "adam":
        step_rule = Adam(args.learning_rate)
    else:
        raise "Unknown step_rule %s" % args.step_rule

    #parameters = cg.parameters[:4] + cg.parameters[5:]
    parameters = cg.parameters

    algorithm = GradientDescent(
        cost=cost,
        parameters=parameters,
        gradients=gradients,
        step_rule=CompositeRule([
            #StepClipping(25),
            step_rule,
            #RemoveNotFinite(1.0),
        ]))

    #------------------------------------------------------------

    train_monitors += [
        aggregation.mean(algorithm.total_gradient_norm),
        aggregation.mean(algorithm.total_step_norm)
    ]

    #------------------------------------------------------------

    # Live plotting?
    plotting_extensions = []
    if args.live_plotting:
        plotting_extensions = [
            PlotManager(
                name,
                [
                    Plotter(channels=[[
                        "valid_%s" % cost.name, "valid_log_p"
                    ], ["train_total_gradient_norm", "train_total_step_norm"]],
                            titles=[
                                "validation cost",
                                "norm of training gradient and step"
                            ]),
                    DisplayImage(
                        [
                            WeightDisplay(model.p_layers[0].mlp.
                                          linear_transformations[0].W,
                                          n_weights=100,
                                          image_shape=(28, 28))
                        ]
                        #ImageDataStreamDisplay(test_stream, image_shape=(28,28))]
                    )
                ])
        ]

    main_loop = MainLoop(
        model=Model(cost),
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            Timing(),
            ProgressBar(),
            TrainingDataMonitoring(
                train_monitors, prefix="train", after_epoch=True),
            DataStreamMonitoring(
                valid_monitors, data_stream=valid_stream, prefix="valid"),
            DataStreamMonitoring(test_monitors,
                                 data_stream=test_stream,
                                 prefix="test",
                                 after_epoch=False,
                                 after_training=True,
                                 every_n_epochs=10),
            #SharedVariableModifier(
            #    algorithm.step_rule.components[0].learning_rate,
            #    half_lr_func,
            #    before_training=False,
            #    after_epoch=False,
            #    after_batch=False,
            #    every_n_epochs=half_lr),
            TrackTheBest('valid_%s' % cost.name),
            Checkpoint(name + ".pkl", save_separately=['log', 'model']),
            FinishIfNoImprovementAfter('valid_%s_best_so_far' % cost.name,
                                       epochs=args.patience),
            FinishAfter(after_n_epochs=args.max_epochs),
            Printing()
        ] + plotting_extensions)
    main_loop.run()
Example #5
0
                             allow_input_downcast=True)

    #----------------------------------------------------------------------
    logger.info("Loading dataset...")

    x_dim, _, _, data_test = datasets.get_data(args.data)

    num_examples = data_test.num_examples
    n_samples = (int(s) for s in args.nsamples.split(","))

    dict_p = {}
    dict_ps = {}

    for K in n_samples:
        batch_size = max(args.max_batch // K, 1)
        x_dim, _, _, stream = datasets.get_streams(args.data, batch_size)

        log_p = np.asarray([])
        log_ps = np.asarray([])
        for batch in stream.get_epoch_iterator(as_dict=True):
            log_p_, log_ps_ = do_nll(batch['features'], K)

            log_p = np.concatenate((log_p, log_p_))
            log_ps = np.concatenate((log_ps, log_ps_))

        log_p_ = stats.sem(log_p)
        log_p = np.mean(log_p)
        log_ps_ = stats.sem(log_ps)
        log_ps = np.mean(log_ps)

        dict_p[K] = log_p
Example #6
0
File: train.py Project: afcarl/bihm
def main(args):
    """Run experiment. """
    lr_tag = float_tag(args.learning_rate)

    x_dim, train_stream, valid_stream, test_stream = datasets.get_streams(
        args.data, args.batch_size)

    #------------------------------------------------------------
    # Setup model
    deterministic_act = Tanh
    deterministic_size = 1.

    if args.method == 'vae':
        sizes_tag = args.layer_spec.replace(",", "-")
        layer_sizes = [int(i) for i in args.layer_spec.split(",")]
        layer_sizes, z_dim = layer_sizes[:-1], layer_sizes[-1]

        name = "%s-%s-%s-lr%s-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.n_samples, sizes_tag)

        if args.activation == "tanh":
            hidden_act = Tanh()
        elif args.activation == "logistic":
            hidden_act = Logistic()
        elif args.activation == "relu":
            hidden_act = Rectifier()
        else:
            raise "Unknown hidden nonlinearity %s" % args.hidden_act

        model = VAE(x_dim=x_dim,
                    hidden_layers=layer_sizes,
                    hidden_act=hidden_act,
                    z_dim=z_dim,
                    batch_norm=args.batch_normalization)
        model.initialize()
    elif args.method == 'rws':
        sizes_tag = args.layer_spec.replace(",", "-")
        name = "%s-%s-%s-lr%s-dl%d-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.deterministic_layers, args.n_samples, sizes_tag)

        p_layers, q_layers = create_layers(args.layer_spec, x_dim,
                                           args.deterministic_layers,
                                           deterministic_act,
                                           deterministic_size)

        model = ReweightedWakeSleep(
            p_layers,
            q_layers,
        )
        model.initialize()
    elif args.method == 'bihm':
        sizes_tag = args.layer_spec.replace(",", "-")
        name = "%s-%s-%s-lr%s-dl%d-spl%d-%s" % \
            (args.data, args.method, args.name, lr_tag, args.deterministic_layers, args.n_samples, sizes_tag)

        p_layers, q_layers = create_layers(args.layer_spec, x_dim,
                                           args.deterministic_layers,
                                           deterministic_act,
                                           deterministic_size)

        model = BiHM(
            p_layers,
            q_layers,
            l1reg=args.l1reg,
            l2reg=args.l2reg,
        )
        model.initialize()
    elif args.method == 'continue':
        import cPickle as pickle
        from os.path import basename, splitext

        with open(args.model_file, 'rb') as f:
            m = pickle.load(f)

        if isinstance(m, MainLoop):
            m = m.model

        model = m.get_top_bricks()[0]
        while len(model.parents) > 0:
            model = model.parents[0]

        assert isinstance(model, (BiHM, ReweightedWakeSleep, VAE))

        mname, _, _ = basename(args.model_file).rpartition("_model.pkl")
        name = "%s-cont-%s-lr%s-spl%s" % (mname, args.name, lr_tag,
                                          args.n_samples)
    else:
        raise ValueError("Unknown training method '%s'" % args.method)

    #------------------------------------------------------------

    x = tensor.matrix('features')

    #------------------------------------------------------------
    # Testset monitoring

    train_monitors = []
    valid_monitors = []
    test_monitors = []
    for s in [
            1,
            10,
            100,
            1000,
    ]:
        log_p, log_ph = model.log_likelihood(x, s)
        log_p = -log_p.mean()
        log_ph = -log_ph.mean()
        log_p.name = "log_p_%d" % s
        log_ph.name = "log_ph_%d" % s

        #valid_monitors += [log_p, log_ph]
        test_monitors += [log_p, log_ph]

    #------------------------------------------------------------
    # Z estimation
    #for s in [100000]:
    #    z2 = tensor.exp(model.estimate_log_z2(s)) / s
    #    z2.name = "z2_%d" % s
    #
    #    valid_monitors += [z2]
    #    test_monitors += [z2]

    #------------------------------------------------------------
    # Gradient and training monitoring

    if args.method in ['vae', 'dvae']:
        log_p_bound = model.log_likelihood_bound(x, args.n_samples)
        gradients = None
        log_p_bound = -log_p_bound.mean()
        log_p_bound.name = "log_p_bound"
        cost = log_p_bound

        train_monitors += [
            log_p_bound,
            named(model.kl_term.mean(), 'kl_term'),
            named(model.recons_term.mean(), 'recons_term')
        ]
        valid_monitors += [
            log_p_bound,
            named(model.kl_term.mean(), 'kl_term'),
            named(model.recons_term.mean(), 'recons_term')
        ]
        test_monitors += [
            log_p_bound,
            named(model.kl_term.mean(), 'kl_term'),
            named(model.recons_term.mean(), 'recons_term')
        ]
    else:
        log_p, log_ph, gradients = model.get_gradients(x, args.n_samples)
        log_p_bound = named(-model.log_p_bound.mean(), "log_p_bound")
        log_p = named(-log_p.mean(), "log_p")
        log_ph = named(-log_ph.mean(), "log_ph")
        cost = log_p

        train_monitors += [log_p_bound, log_p, log_ph]
        valid_monitors += [log_p_bound, log_p, log_ph]

    #------------------------------------------------------------
    cg = ComputationGraph([cost])

    if args.step_rule == "momentum":
        step_rule = Momentum(args.learning_rate, 0.95)
    elif args.step_rule == "rmsprop":
        step_rule = RMSProp(args.learning_rate)
    elif args.step_rule == "adam":
        step_rule = Adam(args.learning_rate)
    else:
        raise "Unknown step_rule %s" % args.step_rule

    parameters = cg.parameters

    algorithm = GradientDescent(cost=cost,
                                parameters=parameters,
                                gradients=gradients,
                                step_rule=CompositeRule([
                                    step_rule,
                                ]))

    #------------------------------------------------------------

    train_monitors += [
        aggregation.mean(algorithm.total_gradient_norm),
        aggregation.mean(algorithm.total_step_norm)
    ]

    #------------------------------------------------------------

    # Live plotting?
    plotting_extensions = []
    if args.live_plotting:
        plotting_extensions = [
            PlotManager(
                name,
                [
                    Plotter(channels=[[
                        "valid_%s" % cost.name, "valid_log_p"
                    ], ["train_total_gradient_norm", "train_total_step_norm"]],
                            titles=[
                                "validation cost",
                                "norm of training gradient and step"
                            ]),
                    DisplayImage(
                        [
                            WeightDisplay(model.p_layers[0].mlp.
                                          linear_transformations[0].W,
                                          n_weights=100,
                                          image_shape=(28, 28))
                        ]
                        #ImageDataStreamDisplay(test_stream, image_shape=(28,28))]
                    )
                ])
        ]

    main_loop = MainLoop(
        model=Model(cost),
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            Timing(),
            ProgressBar(),
            TrainingDataMonitoring(train_monitors,
                                   prefix="train",
                                   after_epoch=False,
                                   after_batch=True),
            DataStreamMonitoring(
                valid_monitors, data_stream=valid_stream, prefix="valid"),
            DataStreamMonitoring(test_monitors,
                                 data_stream=test_stream,
                                 prefix="test",
                                 after_epoch=False,
                                 after_training=True,
                                 every_n_epochs=10),
            TrackTheBest('valid_%s' % cost.name),
            Checkpoint(name + ".pkl", save_separately=['log', 'model']),
            FinishIfNoImprovementAfter('valid_%s_best_so_far' % cost.name,
                                       epochs=args.patience),
            FinishAfter(after_n_epochs=args.max_epochs),
            Printing()
        ] + plotting_extensions)
    main_loop.run()