Ejemplo n.º 1
0
def sample_vae2(args):
    """ For vae from https://github.com/hardmaru/WorldModelsExperiments.git
    """
    z_size = 32
    batch_size = args.count
    learning_rate = 0.0001
    kl_tolerance = 0.5
    model_path_name = "tf_vae"

    reset_graph()
    vae = ConvVAE(
        z_size=z_size,
        batch_size=batch_size,
        learning_rate=learning_rate,
        kl_tolerance=kl_tolerance,
        is_training=False,
        reuse=False,
        gpu_mode=False)  # use GPU on batchsize of 1000 -> much faster

    vae.load_json(os.path.join(model_path_name, 'vae.json'))

    z = np.random.normal(size=(args.count, z_size))
    samples = vae.decode(z)
    input_dim = samples.shape[1:]

    n = args.count
    plt.figure(figsize=(20, 4))
    plt.title('VAE samples')
    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(samples[i].reshape(input_dim[0], input_dim[1],
                                      input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    #plt.savefig( image_path )
    plt.show()
Ejemplo n.º 2
0

if __name__ == '__main__':
    arglist = parse_args()

    if not os.path.exists(arglist.series_dir):
        os.makedirs(arglist.series_dir)

    filelist = os.listdir(arglist.data_dir)
    filelist.sort()
    filelist = filelist[0:10000]

    dataset, action_dataset, oppo_action_dataset = load_raw_data_list(
        filelist, arglist)

    reset_graph()
    if arglist.use_vae:
        vae = ConvVAE(
            z_size=arglist.z_size,
            batch_size=arglist.batch_size,
            learning_rate=arglist.lr,
            kl_tolerance=arglist.kl_tolerance,
            is_training=False,
            reuse=False,
            gpu_mode=True)  # use GPU on batchsize of 1000 -> much faster

        vae.load_json(os.path.join(arglist.vae_path, 'vae.json'))

    mu_dataset = []
    logvar_dataset = []
    action_dataset_real = []
Ejemplo n.º 3
0
def main( dirs, z_size=32, batch_size=100, learning_rate=0.0001, kl_tolerance=0.5, epochs=100, save_model=False, verbose=True, optimizer="Adam" ):

    if save_model:
        model_save_path = "tf_vae"
        if not os.path.exists(model_save_path):
          os.makedirs(model_save_path)

    gen = DriveDataGenerator(dirs, image_size=(64,64), batch_size=batch_size, shuffle=True, max_load=10000, images_only=True )
        
    num_batches = len(gen)

    reset_graph()

    vae = ConvVAE(z_size=z_size,
                  batch_size=batch_size,
                  learning_rate=learning_rate,
                  kl_tolerance=kl_tolerance,
                  is_training=True,
                  reuse=False,
                  gpu_mode=True,
                  optimizer=optimizer)

    early = EarlyStopping(monitor='loss', min_delta=0.1, patience=5, verbose=verbose, mode='auto')
    early.set_model(vae)
    early.on_train_begin()

    best_loss = sys.maxsize

    if verbose:
        print("epoch\tstep\tloss\trecon_loss\tkl_loss")
    for epoch in range(epochs):
        for idx in range(num_batches):
            batch = gen[idx]

            obs = batch.astype(np.float)/255.0

            feed = {vae.x: obs,}

            (train_loss, r_loss, kl_loss, train_step, _) = vae.sess.run([
              vae.loss, vae.r_loss, vae.kl_loss, vae.global_step, vae.train_op
            ], feed)
            
            if train_loss < best_loss:
                best_loss = train_loss

            if save_model:
                if ((train_step+1) % 5000 == 0):
                  vae.save_json("tf_vae/vae.json")
        if verbose:
            print("{} of {}\t{}\t{:.2f}\t{:.2f}\t{:.2f}".format( epoch, epochs, (train_step+1), train_loss, r_loss, kl_loss) )
        gen.on_epoch_end()
        early.on_epoch_end(epoch, logs={"loss": train_loss})
        if vae.stop_training:
            break
    early.on_train_end()


# finished, final model:
    if save_model:
        vae.save_json("tf_vae/vae.json")

    return best_loss
Ejemplo n.º 4
0
def sample_vae2(args):
    """ For vae from https://github.com/hardmaru/WorldModelsExperiments.git
    """
    z_size = 64  # This needs to match the size of the trained vae
    batch_size = args.count
    learning_rate = 0.0001
    kl_tolerance = 0.5
    model_path_name = "tf_vae"

    reset_graph()
    vae = ConvVAE(
        z_size=z_size,
        batch_size=batch_size,
        learning_rate=learning_rate,
        kl_tolerance=kl_tolerance,
        is_training=False,
        reuse=False,
        gpu_mode=False)  # use GPU on batchsize of 1000 -> much faster

    vae.load_json(os.path.join(model_path_name, 'vae.json'))

    z = np.random.normal(size=(args.count, z_size))
    samples = vae.decode(z)
    input_dim = samples.shape[1:]

    gen = DriveDataGenerator(args.dirs,
                             image_size=(64, 64),
                             batch_size=args.count,
                             shuffle=True,
                             max_load=10000,
                             images_only=True)
    orig = gen[0].astype(np.float) / 255.0
    #mu, logvar = vae.encode_mu_logvar(orig)
    #recon = vae.decode( mu )
    recon = vae.decode(vae.encode(orig))

    n = args.count
    plt.figure(figsize=(20, 6), tight_layout=False)
    plt.title('VAE samples')
    for i in range(n):
        ax = plt.subplot(3, n, i + 1)
        plt.imshow(samples[i].reshape(input_dim[0], input_dim[1],
                                      input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if 0 == i:
            ax.set_title("Random")

    for i in range(n):
        ax = plt.subplot(3, n, n + i + 1)
        plt.imshow(orig[i].reshape(input_dim[0], input_dim[1], input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if 0 == i:
            ax.set_title("Real")

        ax = plt.subplot(3, n, (2 * n) + i + 1)
        plt.imshow(recon[i].reshape(input_dim[0], input_dim[1], input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if 0 == i:
            ax.set_title("Reconstructed")

    plt.savefig("samples_vae.png")
    plt.show()