def main():
    # inputs
    input_shape = (1, 84, 84)
    filters = 32
    kernel_size = 6
    epochs = 20
    batch_size = 1

    # log directory
    run = 'cvae_atari_pong_no_batch_norm_16_May_17_15_27_batch_size_1_beta_1_epochs_20_filters_32_kernel_size_6_loss_vae_loss_lr_0.0001_optimizer_adam'
    log_dir = './summaries/' + experiment + '/' + run + '/'

    # define model
    vae = PongEntangledConvolutionalLatentNoBatchNormVAE(
        input_shape, log_dir, filters=filters, kernel_size=kernel_size)

    # load weights
    vae.load_model()

    # extract models
    model = vae.get_model()
    decoder = vae.get_decoder()
    encoder = vae.get_encoder()

    # load testing data
    test_directory = './atari_agents/record/test/'
    test_generator = utils.atari_generator(test_directory, batch_size=1)
    X_test_size = 100
    X_test = np.asarray(
        [next(test_generator)[0][0] for i in range(X_test_size)])

    # show original and reconstruction
    sampling.encode_decode_sample(X_test, model)

    # plot filters
    sampling.show_convolutional_layers(X_test, encoder, 8, 8)
def train_space_invaders_network_no_batch_norm(beta):
    # inputs
    input_shape = (1, 84, 84)
    filters = 32
    kernel_size = 6
    epochs = 10
    batch_size = 1
    lr = 1e-4

    # define filename
    name = 'cvae_atari_space_invaders_no_batch_norm'

    # builder hyperparameter dictionary
    hp_dictionary = {
        'epochs': epochs,
        'batch_size': batch_size,
        'beta': beta,
        'filters': filters,
        'kernel_size': kernel_size,
        'lr': lr,
        'loss': 'vae_loss',
        'optimizer': 'adam'
    }

    # define log directory
    log_dir = './summaries/' + experiment + '/' + utils.build_hyperparameter_string(
        name, hp_dictionary) + '/'

    # make VAE
    vae = PongEntangledConvolutionalLatentNoBatchNormVAE(
        input_shape,
        log_dir,
        filters=filters,
        kernel_size=kernel_size,
        beta=beta)

    # compile VAE
    from keras import optimizers
    optimizer = optimizers.Adam(lr=lr)
    vae.compile(optimizer=optimizer)

    # get dataset
    train_directory = './atari_agents/record/train/'
    test_directory = './atari_agents/record/test/'
    train_generator = utils.atari_generator(train_directory,
                                            batch_size=batch_size)
    test_generator = utils.atari_generator(test_directory,
                                           batch_size=batch_size)
    train_size = utils.count_images(train_directory)
    test_size = utils.count_images(test_directory)

    # print summaries
    vae.print_model_summaries()

    # fit VAE
    steps_per_epoch = int(train_size / batch_size)
    validation_steps = int(test_size / batch_size)
    vae.fit_generator(train_generator,
                      epochs=epochs,
                      steps_per_epoch=steps_per_epoch,
                      validation_data=test_generator,
                      validation_steps=validation_steps)