def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    continue_training = CONTINUE_TRAINING
    inference_cpu_only = INFERENCE_CPU_ONLY
    use_batchnorm = USE_BATCHNORM

    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        elif k == 'continue_training':
            continue_training = kwargs[k]
        elif k == 'inference_cpu_only':
            inference_cpu_only = kwargs[k]
        elif k == 'use_batchnorm':
            use_batchnorm = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = MNISTData(config.dataset_path, use_one_hot=True)

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')

    logger.info('### Build, train and test network ...')

    train_net = setup_network(allow_plots,
                              continue_training,
                              inference_cpu_only,
                              use_batchnorm,
                              mode='train')

    train_net.train(num_iter=10001)

    test_net = setup_network(allow_plots,
                             continue_training,
                             inference_cpu_only,
                             use_batchnorm,
                             mode='inference')
    test_net.test()

    if allow_plots:
        # Example Test Samples
        sample_batch = data.next_test_batch(8)
        predictions = test_net.run(sample_batch[0])
        shared.data.plot_samples('Example MNIST Predictions',
                                 sample_batch[0],
                                 outputs=sample_batch[1],
                                 predictions=predictions,
                                 interactive=True)

    logger.info('### Build, train and test network ... Done')
def run(**kwargs):
    allow_plots = ALLOW_PLOTS
    
    for k in kwargs.keys():
        if k == 'allow_plots':
            allow_plots = kwargs[k]
        else:
            logger.warn('Keyword \'%s\' is unknown.' % k)

    logger.info('### Loading dataset ...')

    data = MNISTData(config.dataset_path)

    # Important! Let the network know, which dataset to use.
    shared.data = data

    logger.info('### Loading dataset ... Done')


    logger.info('### Build, train and test network ...')

    # Train the network
    train_net = SimpleAE(mode='train')
    train_net.allow_plots = allow_plots
    train_net.build()
    train_net.train()

    # Test the network
    test_net = SimpleAE(mode='inference')
    test_net.allow_plots = allow_plots
    test_net.build()
    test_net.test()

    if allow_plots:
        # Feed a random test sample through the network and display the output
        # for the user.
        sample = data.next_test_batch(1)
        net_out = test_net.run(sample[0])
    
        fig = plt.figure()
        plt.ion()
        plt.suptitle('Sample Image')
        ax = fig.add_subplot(1,2,1)
        ax.set_axis_off()
        ax.imshow(np.squeeze(sample[0].reshape(data.in_shape)),
                  vmin=-1.0, vmax=1.0)
        ax.set_title('Input')
        ax = fig.add_subplot(1,2,2)
        ax.set_axis_off()
        ax.imshow(np.squeeze(net_out.reshape(data.in_shape)),
                  vmin=-1.0, vmax=1.0)
        ax.set_title('Output')
        plt.show()

    logger.info('### Build, train and test network ... Done')