Exemple #1
0
def main():
    # load mnist data
    (x_train, y_train), (x_test, y_test) = \
        load_mnist(shape=[784], dtype=np.float32, normalize=True)

    # input placeholders
    input_x = tf.placeholder(
        dtype=tf.float32, shape=(None,) + x_train.shape[1:], name='input_x')
    input_y = tf.placeholder(
        dtype=tf.int32, shape=[None], name='input_y')
    is_training = tf.placeholder(
        dtype=tf.bool, shape=(), name='is_training')
    learning_rate = tf.placeholder(shape=(), dtype=tf.float32)
    learning_rate_var = AnnealingDynamicValue(config.initial_lr,
                                              config.lr_anneal_factor)

    # build the model
    optimizer = tf.train.AdamOptimizer(learning_rate)

    # derive the loss, output and accuracy
    logits = model(input_x, is_training=is_training)
    softmax_loss = softmax_classification_loss(logits, input_y)
    loss = softmax_loss + regularization_loss()
    y = softmax_classification_output(logits)
    acc = classification_accuracy(y, input_y)

    # derive the optimizer
    params = tf.trainable_variables()
    grads = optimizer.compute_gradients(loss, var_list=params)
    with tf.control_dependencies(
            tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.apply_gradients(grads)

    # prepare for training and testing data
    train_flow = DataFlow.arrays(
        [x_train, y_train], config.batch_size, shuffle=True,
        skip_incomplete=True
    )
    test_flow = DataFlow.arrays([x_test, y_test], config.batch_size)

    with create_session().as_default():
        # train the network
        with TrainLoop(params,
                       max_epoch=config.max_epoch,
                       summary_dir=results.make_dir('train_summary'),
                       summary_graph=tf.get_default_graph(),
                       summary_commit_freqs={'loss': 10, 'acc': 10},
                       early_stopping=False) as loop:
            trainer = Trainer(
                loop, train_op, [input_x, input_y], train_flow,
                feed_dict={learning_rate: learning_rate_var, is_training: True},
                metrics={'loss': loss, 'acc': acc}
            )
            anneal_after(
                trainer, learning_rate_var, epochs=config.lr_anneal_epoch_freq,
                steps=config.lr_anneal_step_freq
            )
            evaluator = Evaluator(
                loop,
                metrics={'test_acc': acc},
                inputs=[input_x, input_y],
                data_flow=test_flow,
                feed_dict={is_training: False},
                time_metric_name='test_time'
            )
            trainer.evaluate_after_epochs(evaluator, freq=5)
            trainer.log_after_epochs(freq=1)
            trainer.run()

        # save test result
        results.commit(evaluator.last_metrics_dict)
Exemple #2
0
def main():
    logging.basicConfig(
        level='INFO',
        format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')

    # load mnist data
    (x_train, y_train), (x_test, y_test) = \
        load_mnist(shape=[config.x_dim], dtype=np.float32, normalize=True)

    # input placeholders
    input_x = tf.placeholder(dtype=tf.int32,
                             shape=(None, ) + x_train.shape[1:],
                             name='input_x')
    is_training = tf.placeholder(dtype=tf.bool, shape=(), name='is_training')
    learning_rate = tf.placeholder(shape=(), dtype=tf.float32)
    learning_rate_var = AnnealingDynamicValue(config.initial_lr,
                                              config.lr_anneal_factor)
    multi_gpu = MultiGPU(disable_prebuild=False)

    # build the model
    grads = []
    losses = []
    lower_bounds = []
    test_nlls = []
    batch_size = get_batch_size(input_x)
    params = None
    optimizer = tf.train.AdamOptimizer(learning_rate)

    for dev, pre_build, [dev_input_x
                         ] in multi_gpu.data_parallel(batch_size, [input_x]):
        with tf.device(dev), multi_gpu.maybe_name_scope(dev):
            if pre_build:
                with arg_scope([p_net, q_net], is_training=is_training):
                    _ = q_net(dev_input_x).chain(p_net,
                                                 latent_names=['z'],
                                                 observed={'x': dev_input_x})

            else:
                with arg_scope([q_net, p_net], is_training=is_training):
                    # derive the loss and lower-bound for training
                    train_q_net = q_net(dev_input_x)
                    train_chain = train_q_net.chain(
                        p_net,
                        latent_names=['z'],
                        latent_axis=0,
                        observed={'x': dev_input_x})

                    dev_vae_loss = tf.reduce_mean(
                        train_chain.vi.training.sgvb())
                    dev_loss = dev_vae_loss + regularization_loss()
                    dev_lower_bound = -dev_vae_loss
                    losses.append(dev_loss)
                    lower_bounds.append(dev_lower_bound)

                    # derive the nll and logits output for testing
                    test_q_net = q_net(dev_input_x, n_z=config.test_n_z)
                    test_chain = test_q_net.chain(p_net,
                                                  latent_names=['z'],
                                                  latent_axis=0,
                                                  observed={'x': dev_input_x})
                    dev_test_nll = -tf.reduce_mean(
                        test_chain.vi.evaluation.is_loglikelihood())
                    test_nlls.append(dev_test_nll)

                    # derive the optimizer
                    params = tf.trainable_variables()
                    grads.append(
                        optimizer.compute_gradients(dev_loss, var_list=params))

    # merge multi-gpu outputs and operations
    [loss, lower_bound, test_nll] = \
        multi_gpu.average([losses, lower_bounds, test_nlls], batch_size)
    train_op = multi_gpu.apply_grads(grads=multi_gpu.average_grads(grads),
                                     optimizer=optimizer,
                                     control_inputs=tf.get_collection(
                                         tf.GraphKeys.UPDATE_OPS))

    # derive the plotting function
    work_dev = multi_gpu.work_devices[0]
    with tf.device(work_dev), tf.name_scope('plot_x'):
        plot_p_net = p_net(n_z=100, is_training=is_training)
        x = tf.cast(255 * tf.sigmoid(plot_p_net['x'].distribution.logits),
                    dtype=tf.uint8)
        x_plots = tf.reshape(x, [-1, 28, 28])

    def plot_samples(loop):
        with loop.timeit('plot_time'):
            images = session.run(x_plots, feed_dict={is_training: False})
            save_images_collection(images=images,
                                   filename=results.prepare_parent(
                                       'plotting/{}.png'.format(loop.epoch)),
                                   grid_size=(10, 10))

    # prepare for training and testing data
    def input_x_sampler(x):
        return session.run([sampled_x], feed_dict={sample_input_x: x})

    with tf.device('/device:CPU:0'):
        sample_input_x = tf.placeholder(dtype=tf.float32,
                                        shape=(None, config.x_dim),
                                        name='sample_input_x')
        sampled_x = sample_from_probs(sample_input_x)

    train_flow = DataFlow.arrays([x_train],
                                 config.batch_size,
                                 shuffle=True,
                                 skip_incomplete=True).map(input_x_sampler)
    test_flow = DataFlow.arrays([x_test], config.test_batch_size). \
        map(input_x_sampler)

    with create_session().as_default() as session, \
            train_flow.threaded(5) as train_flow:
        # fix the testing flow, reducing the testing time
        test_flow = test_flow.to_arrays_flow(batch_size=config.test_batch_size)

        # train the network
        with TrainLoop(params,
                       var_groups=['p_net', 'q_net', 'posterior_flow'],
                       max_epoch=config.max_epoch,
                       summary_dir=results.make_dir('train_summary'),
                       summary_graph=tf.get_default_graph(),
                       early_stopping=False) as loop:
            trainer = Trainer(loop,
                              train_op, [input_x],
                              train_flow,
                              feed_dict={
                                  learning_rate: learning_rate_var,
                                  is_training: True
                              },
                              metrics={'loss': loss})
            anneal_after(trainer,
                         learning_rate_var,
                         epochs=config.lr_anneal_epoch_freq,
                         steps=config.lr_anneal_step_freq)
            evaluator = Evaluator(loop,
                                  metrics={
                                      'test_nll': test_nll,
                                      'test_lb': lower_bound
                                  },
                                  inputs=[input_x],
                                  data_flow=test_flow,
                                  feed_dict={is_training: False},
                                  time_metric_name='test_time')
            evaluator.after_run.add_hook(
                lambda: results.commit(evaluator.last_metrics_dict))
            trainer.evaluate_after_epochs(evaluator, freq=10)
            trainer.evaluate_after_epochs(functools.partial(
                plot_samples, loop),
                                          freq=10)
            trainer.log_after_epochs(freq=1)
            trainer.run()

    # write the final test_nll and test_lb
    results.commit_and_print(evaluator.last_metrics_dict)
Exemple #3
0
def main():
    # load mnist data
    (x_train, y_train), (x_test, y_test) = \
        load_mnist(shape=[config.x_dim], dtype=np.float32, normalize=True)

    # input placeholders
    input_x = tf.placeholder(dtype=tf.int32,
                             shape=(None, ) + x_train.shape[1:],
                             name='input_x')
    is_training = tf.placeholder(dtype=tf.bool, shape=(), name='is_training')
    learning_rate = tf.placeholder(shape=(), dtype=tf.float32)
    learning_rate_var = AnnealingDynamicValue(config.initial_lr,
                                              config.lr_anneal_factor)
    multi_gpu = MultiGPU(disable_prebuild=False)

    # build the model
    vae = VAE(
        p_z=Bernoulli(tf.zeros([1, config.z_dim])),
        p_x_given_z=Bernoulli,
        q_z_given_x=Bernoulli,
        h_for_p_x=functools.partial(h_for_p_x, is_training=is_training),
        h_for_q_z=functools.partial(h_for_q_z, is_training=is_training),
    )

    grads = []
    losses = []
    lower_bounds = []
    test_nlls = []
    batch_size = get_batch_size(input_x)
    params = None
    optimizer = tf.train.AdamOptimizer(learning_rate)

    for dev, pre_build, [dev_input_x
                         ] in multi_gpu.data_parallel(batch_size, [input_x]):
        with tf.device(dev), multi_gpu.maybe_name_scope(dev):
            if pre_build:
                with arg_scope([h_for_q_z, h_for_p_x]):
                    _ = vae.chain(dev_input_x)

            else:
                # derive the loss and lower-bound for training
                train_chain = vae.chain(dev_input_x)
                dev_baseline = baseline_net(dev_input_x)
                dev_cost, dev_baseline_cost = \
                    train_chain.vi.training.reinforce(baseline=dev_baseline)
                dev_loss = regularization_loss() + \
                    tf.reduce_mean(dev_cost + dev_baseline_cost)
                dev_lower_bound = \
                    tf.reduce_mean(train_chain.vi.lower_bound.elbo())
                losses.append(dev_loss)
                lower_bounds.append(dev_lower_bound)

                # derive the nll and logits output for testing
                test_chain = vae.chain(dev_input_x, n_z=config.test_n_z)
                dev_test_nll = -tf.reduce_mean(
                    test_chain.vi.evaluation.is_loglikelihood())
                test_nlls.append(dev_test_nll)

                # derive the optimizer
                params = tf.trainable_variables()
                grads.append(
                    optimizer.compute_gradients(dev_loss, var_list=params))

    # merge multi-gpu outputs and operations
    [loss, lower_bound, test_nll] = \
        multi_gpu.average([losses, lower_bounds, test_nlls], batch_size)
    train_op = multi_gpu.apply_grads(grads=multi_gpu.average_grads(grads),
                                     optimizer=optimizer,
                                     control_inputs=tf.get_collection(
                                         tf.GraphKeys.UPDATE_OPS))

    # derive the plotting function
    work_dev = multi_gpu.work_devices[0]
    with tf.device(work_dev), tf.name_scope('plot_x'), \
            arg_scope([h_for_q_z, h_for_p_x],
                      channels_last=multi_gpu.channels_last(work_dev)):
        x_plots = tf.reshape(
            tf.cast(255 *
                    tf.sigmoid(vae.model(n_z=100)['x'].distribution.logits),
                    dtype=tf.uint8), [-1, 28, 28])

    def plot_samples(loop):
        with loop.timeit('plot_time'):
            session = get_default_session_or_error()
            images = session.run(x_plots, feed_dict={is_training: False})
            save_images_collection(images=images,
                                   filename=results.prepare_parent(
                                       'plotting/{}.png'.format(loop.epoch)),
                                   grid_size=(10, 10))

    # prepare for training and testing data
    def input_x_sampler(x):
        sess = get_default_session_or_error()
        return sess.run([sampled_x], feed_dict={sample_input_x: x})

    with tf.device('/device:CPU:0'):
        sample_input_x = tf.placeholder(dtype=tf.float32,
                                        shape=(None, config.x_dim),
                                        name='sample_input_x')
        sampled_x = sample_from_probs(sample_input_x)

    train_flow = DataFlow.arrays([x_train],
                                 config.batch_size,
                                 shuffle=True,
                                 skip_incomplete=True).map(input_x_sampler)
    test_flow = DataFlow.arrays([x_test], config.test_batch_size). \
        map(input_x_sampler)

    with create_session().as_default():
        # fix the testing flow, reducing the testing time
        test_flow = test_flow.to_arrays_flow(batch_size=config.test_batch_size)

        # train the network
        with TrainLoop(params,
                       max_epoch=config.max_epoch,
                       summary_dir=results.make_dir('train_summary'),
                       summary_graph=tf.get_default_graph(),
                       early_stopping=False) as loop:
            trainer = Trainer(loop,
                              train_op, [input_x],
                              train_flow,
                              feed_dict={
                                  learning_rate: learning_rate_var,
                                  is_training: True
                              },
                              metrics={'loss': loss})
            anneal_after(trainer,
                         learning_rate_var,
                         epochs=config.lr_anneal_epoch_freq,
                         steps=config.lr_anneal_step_freq)
            evaluator = Evaluator(loop,
                                  metrics={
                                      'test_nll': test_nll,
                                      'test_lb': lower_bound
                                  },
                                  inputs=[input_x],
                                  data_flow=test_flow,
                                  feed_dict={is_training: False},
                                  time_metric_name='test_time')
            trainer.evaluate_after_epochs(evaluator, freq=10)
            trainer.evaluate_after_epochs(functools.partial(
                plot_samples, loop),
                                          freq=10)
            trainer.log_after_epochs(freq=1)
            trainer.run()

    # write the final test_nll and test_lb
    results.commit(evaluator.last_metrics_dict)
def main():
    # load mnist data
    (x_train, y_train), (x_test, y_test) = \
        load_cifar10(dtype=np.float32, normalize=True)
    print(x_train.shape)

    # input placeholders
    input_x = tf.placeholder(
        dtype=tf.float32, shape=(None,) + x_train.shape[1:], name='input_x')
    input_y = tf.placeholder(
        dtype=tf.int32, shape=[None], name='input_y')
    is_training = tf.placeholder(
        dtype=tf.bool, shape=(), name='is_training')
    learning_rate = tf.placeholder(shape=(), dtype=tf.float32)
    learning_rate_var = AnnealingDynamicValue(config.initial_lr,
                                              config.lr_anneal_factor)
    multi_gpu = MultiGPU()

    # build the model
    grads = []
    losses = []
    y_list = []
    acc_list = []
    batch_size = get_batch_size(input_x)
    params = None
    optimizer = tf.train.AdamOptimizer(learning_rate)

    for dev, pre_build, [dev_input_x, dev_input_y] in multi_gpu.data_parallel(
            batch_size, [input_x, input_y]):
        with tf.device(dev), multi_gpu.maybe_name_scope(dev):
            if pre_build:
                _ = model(dev_input_x, is_training, channels_last=True)

            else:
                # derive the loss, output and accuracy
                dev_logits = model(
                    dev_input_x,
                    is_training=is_training,
                    channels_last=multi_gpu.channels_last(dev)
                )
                dev_softmax_loss = \
                    softmax_classification_loss(dev_logits, dev_input_y)
                dev_loss = dev_softmax_loss + regularization_loss()
                dev_y = softmax_classification_output(dev_logits)
                dev_acc = classification_accuracy(dev_y, dev_input_y)
                losses.append(dev_loss)
                y_list.append(dev_y)
                acc_list.append(dev_acc)

                # derive the optimizer
                params = tf.trainable_variables()
                grads.append(
                    optimizer.compute_gradients(dev_loss, var_list=params))

    # merge multi-gpu outputs and operations
    [loss, acc] = multi_gpu.average([losses, acc_list], batch_size)
    [y] = multi_gpu.concat([y_list])
    train_op = multi_gpu.apply_grads(
        grads=multi_gpu.average_grads(grads),
        optimizer=optimizer,
        control_inputs=tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    )

    # prepare for training and testing data
    train_flow = DataFlow.arrays(
        [x_train, y_train], config.batch_size, shuffle=True,
        skip_incomplete=True
    )
    test_flow = DataFlow.arrays([x_test, y_test], config.batch_size)

    with create_session().as_default():
        # train the network
        with TrainLoop(params,
                       max_epoch=config.max_epoch,
                       summary_dir=results.make_dir('train_summary'),
                       summary_graph=tf.get_default_graph(),
                       summary_commit_freqs={'loss': 10, 'acc': 10},
                       early_stopping=False) as loop:
            trainer = Trainer(
                loop, train_op, [input_x, input_y], train_flow,
                feed_dict={learning_rate: learning_rate_var, is_training: True},
                metrics={'loss': loss, 'acc': acc}
            )
            anneal_after(
                trainer, learning_rate_var, epochs=config.lr_anneal_epoch_freq,
                steps=config.lr_anneal_step_freq
            )
            evaluator = Evaluator(
                loop,
                metrics={'test_acc': acc},
                inputs=[input_x, input_y],
                data_flow=test_flow,
                feed_dict={is_training: False},
                time_metric_name='test_time'
            )
            evaluator.after_run.add_hook(
                lambda: results.commit(evaluator.last_metrics_dict))
            trainer.evaluate_after_epochs(evaluator, freq=5)
            trainer.log_after_epochs(freq=1)
            trainer.run()

        # save test result
        results.commit_and_print(evaluator.last_metrics_dict)
def main():
    logging.basicConfig(
        level='INFO',
        format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')

    # load mnist data
    (x_train, y_train), (x_test, y_test) = \
        load_mnist(shape=[config.x_dim], dtype=np.float32, normalize=True)

    # input placeholders
    input_x = tf.placeholder(dtype=tf.int32,
                             shape=(None, ) + x_train.shape[1:],
                             name='input_x')
    is_training = tf.placeholder(dtype=tf.bool, shape=(), name='is_training')
    learning_rate = tf.placeholder(shape=(),
                                   dtype=tf.float32,
                                   name='learning_rate')
    learning_rate_var = AnnealingDynamicValue(config.initial_lr,
                                              config.lr_anneal_factor)
    tau_p = tf.placeholder(shape=(), dtype=tf.float32, name='tau_p')
    tau_p_var = AnnealingDynamicValue(config.initial_tau_p,
                                      config.tau_p_anneal_factor,
                                      config.min_tau_p)
    tau_q = tf.placeholder(shape=(), dtype=tf.float32, name='tau_q')
    tau_q_var = AnnealingDynamicValue(config.initial_tau_q,
                                      config.tau_q_anneal_factor,
                                      config.min_tau_q)
    multi_gpu = MultiGPU(disable_prebuild=False)

    # build the model
    grads = []
    losses = []
    test_nlls = []
    y_given_x_list = []
    batch_size = get_batch_size(input_x)
    params = None
    optimizer = tf.train.AdamOptimizer(learning_rate)

    for dev, pre_build, [dev_input_x
                         ] in multi_gpu.data_parallel(batch_size, [input_x]):
        with tf.device(dev), multi_gpu.maybe_name_scope(dev):
            if pre_build:
                with arg_scope([q_net, p_net], is_training=is_training):
                    _ = q_net(dev_input_x).chain(p_net,
                                                 latent_names=['y', 'z'],
                                                 observed={'x': dev_input_x})

            else:
                with arg_scope([q_net, p_net], is_training=is_training):
                    # derive the loss and lower-bound for training
                    train_n_samples = (config.train_n_samples_for_concrete
                                       if config.use_concrete_distribution else
                                       config.train_n_samples)
                    train_q_net = q_net(dev_input_x,
                                        n_samples=train_n_samples,
                                        tau=tau_q)
                    train_chain = train_q_net.chain(
                        p_net,
                        latent_names=['y', 'z'],
                        latent_axis=0,
                        observed={'x': dev_input_x},
                        tau=tau_p)

                    if config.use_concrete_distribution:
                        if train_n_samples is None:
                            dev_vae_loss = tf.reduce_mean(
                                train_chain.vi.training.sgvb())
                        else:
                            dev_vae_loss = tf.reduce_mean(
                                train_chain.vi.training.iwae())
                    else:
                        if train_n_samples is None:
                            dev_baseline = reinforce_baseline_net(dev_input_x)
                            dev_vae_loss = tf.reduce_mean(
                                train_chain.vi.training.reinforce(
                                    baseline=dev_baseline))
                        else:
                            dev_vae_loss = tf.reduce_mean(
                                train_chain.vi.training.vimco())
                    dev_loss = dev_vae_loss + regularization_loss()
                    dev_loss = add_p_z_given_y_reg_loss(dev_loss)
                    losses.append(dev_loss)

                    # derive the nll and logits output for testing
                    test_q_net = q_net(dev_input_x,
                                       n_samples=config.test_n_samples)
                    test_chain = test_q_net.chain(p_net,
                                                  latent_names=['y', 'z'],
                                                  latent_axis=0,
                                                  observed={'x': dev_input_x})
                    dev_test_nll = -tf.reduce_mean(
                        test_chain.vi.evaluation.is_loglikelihood())
                    test_nlls.append(dev_test_nll)

                    # derive the classifier via q(y|x)
                    dev_q_y_given_x = tf.argmax(
                        test_q_net['y'].distribution.logits, axis=-1)
                    y_given_x_list.append(dev_q_y_given_x)

                    # derive the optimizer
                    params = tf.trainable_variables()
                    grads.append(
                        optimizer.compute_gradients(dev_loss, var_list=params))

    # merge multi-gpu outputs and operations
    [loss, test_nll] = \
        multi_gpu.average([losses, test_nlls], batch_size)
    [y_given_x] = multi_gpu.concat([y_given_x_list])

    train_op = multi_gpu.apply_grads(grads=multi_gpu.average_grads(grads),
                                     optimizer=optimizer,
                                     control_inputs=tf.get_collection(
                                         tf.GraphKeys.UPDATE_OPS))

    # derive the plotting function
    work_dev = multi_gpu.work_devices[0]
    with tf.device(work_dev), tf.name_scope('plot_x'):
        plot_p_net = p_net(
            observed={'y': tf.range(config.n_clusters, dtype=tf.int32)},
            n_z=10,
            is_training=is_training)
        x = tf.cast(255 * tf.sigmoid(plot_p_net['x'].distribution.logits),
                    dtype=tf.uint8)
        x_plots = tf.reshape(tf.transpose(x, [1, 0, 2]), [-1, 28, 28])

    def plot_samples(loop):
        with loop.timeit('plot_time'):
            images = session.run(x_plots, feed_dict={is_training: False})
            save_images_collection(images=images,
                                   filename=results.prepare_parent(
                                       'plotting/{}.png'.format(loop.epoch)),
                                   grid_size=(config.n_clusters, 10))

    # derive the final un-supervised classifier
    c_classifier = ClusteringClassifier(config.n_clusters, 10)
    test_metrics = {}

    def train_classifier(loop):
        df = DataFlow.arrays([x_train], batch_size=config.batch_size). \
            map(input_x_sampler)
        with loop.timeit('cls_train_time'):
            [c_pred] = collect_outputs(outputs=[y_given_x],
                                       inputs=[input_x],
                                       data_flow=df,
                                       feed_dict={is_training: False})
            c_classifier.fit(c_pred, y_train)
            print(c_classifier.describe())

    def evaluate_classifier(loop):
        with loop.timeit('cls_test_time'):
            [c_pred] = collect_outputs(outputs=[y_given_x],
                                       inputs=[input_x],
                                       data_flow=test_flow,
                                       feed_dict={is_training: False})
            y_pred = c_classifier.predict(c_pred)
            cls_metrics = {'test_acc': accuracy_score(y_test, y_pred)}
            loop.collect_metrics(cls_metrics)
            test_metrics.update(cls_metrics)

    # prepare for training and testing data
    def input_x_sampler(x):
        return session.run([sampled_x], feed_dict={sample_input_x: x})

    with tf.device('/device:CPU:0'):
        sample_input_x = tf.placeholder(dtype=tf.float32,
                                        shape=(None, config.x_dim),
                                        name='sample_input_x')
        sampled_x = sample_from_probs(sample_input_x)

    train_flow = DataFlow.arrays([x_train],
                                 config.batch_size,
                                 shuffle=True,
                                 skip_incomplete=True).map(input_x_sampler)
    test_flow = DataFlow.arrays([x_test], config.test_batch_size). \
        map(input_x_sampler)

    with create_session().as_default() as session, \
            train_flow.threaded(5) as train_flow:
        # fix the testing flow, reducing the testing time
        test_flow = test_flow.to_arrays_flow(batch_size=config.test_batch_size)

        # train the network
        with TrainLoop(params,
                       var_groups=['p_net', 'q_net', 'gaussian_mixture_prior'],
                       max_epoch=config.max_epoch,
                       summary_dir=results.make_dir('train_summary'),
                       summary_graph=tf.get_default_graph(),
                       summary_commit_freqs={'loss': 10},
                       early_stopping=False) as loop:
            trainer = Trainer(loop,
                              train_op, [input_x],
                              train_flow,
                              feed_dict={
                                  learning_rate: learning_rate_var,
                                  tau_p: tau_p_var,
                                  tau_q: tau_q_var,
                                  is_training: True
                              },
                              metrics={'loss': loss})
            anneal_after(trainer,
                         learning_rate_var,
                         epochs=config.lr_anneal_epoch_freq,
                         steps=config.lr_anneal_step_freq)
            anneal_after(trainer,
                         tau_p_var,
                         epochs=config.tau_p_anneal_epoch_freq,
                         steps=config.tau_p_anneal_step_freq)
            anneal_after(trainer,
                         tau_q_var,
                         epochs=config.tau_q_anneal_epoch_freq,
                         steps=config.tau_q_anneal_step_freq)
            evaluator = Evaluator(loop,
                                  metrics={'test_nll': test_nll},
                                  inputs=[input_x],
                                  data_flow=test_flow,
                                  feed_dict={is_training: False},
                                  time_metric_name='test_time')
            evaluator.after_run.add_hook(
                lambda: results.commit(evaluator.last_metrics_dict))
            trainer.evaluate_after_epochs(evaluator, freq=10)
            trainer.evaluate_after_epochs(functools.partial(
                plot_samples, loop),
                                          freq=10)
            trainer.evaluate_after_epochs(functools.partial(
                train_classifier, loop),
                                          freq=10)
            trainer.evaluate_after_epochs(functools.partial(
                evaluate_classifier, loop),
                                          freq=10)

            trainer.log_after_epochs(freq=1)
            trainer.run()

    # write the final results
    with codecs.open('cluster_classifier.txt', 'wb', 'utf-8') as f:
        f.write(c_classifier.describe())
    test_metrics.update(evaluator.last_metrics_dict)
    results.commit_and_print(test_metrics)