def main(args):

    new_model = args.new_model
    N = int(args.N)
    epochs = int(args.epochs)

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise

    try:
        data, N = import_data(N)
    except:
        print('NO DATA FOUND')
        raise

    print('DATA SHAPE = {}'.format(data.shape))

    for epoch in range(epochs):
        print('EPOCH ' + str(epoch))
        vae.train(data)
        vae.save_weights('./vae/weights.h5')
Beispiel #2
0
def train_on_drives_gen(args):
    input_dim = (64, 64, 3)

    gen = DriveDataGenerator(args.dirs,
                             image_size=input_dim[0:2],
                             batch_size=100,
                             shuffle=True,
                             max_load=10000,
                             images_only=True)
    val = DriveDataGenerator(args.val,
                             image_size=input_dim[0:2],
                             batch_size=100,
                             shuffle=True,
                             max_load=10000,
                             images_only=True)
    vae = VAE(input_dim=input_dim)
    print("Train: {}".format(gen.count))
    print("Val  : {}".format(val.count))

    if not args.new_model:
        try:
            vae.set_weights('./vae/weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise

    vae.train_gen(gen, val, epochs=100)
def sample_vae(args):
    """ For vae from: https://github.com/AppliedDataSciencePartners/WorldModels.git
    """
    vae = VAE(input_dim=(120, 120, 3))

    try:
        vae.set_weights('./vae/weights.h5')
    except:
        print(
            "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first"
        )
        raise

    z = np.random.normal(size=(args.count, vae.z_dim))
    samples = vae.decoder.predict(z)
    input_dim = samples.shape[1:]

    n = args.count
    plt.figure(figsize=(20, 4))
    plt.title('VAE samples')
    for i in range(n):
        ax = plt.subplot(3, n, i + 1)
        plt.imshow(samples[i].reshape(input_dim[0], input_dim[1],
                                      input_dim[2]))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    #plt.savefig( image_path )
    plt.show()
def main(args):

    N = int(args.N)

    vae = VAE()

    try:
        vae.set_weights('./vae/weights.h5')
    except Exception as e:
        print(e)
        print(
            "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first"
        )
        raise

    filelist, N = get_filelist(N)

    file_count = 0

    initial_mus = []
    initial_log_vars = []

    for file in filelist:
        try:

            rollout_data = np.load(ROLLOUT_DIR_NAME + file)

            mu, log_var, action, reward, done, initial_mu, initial_log_var = encode_episode(
                vae, rollout_data)

            np.savez_compressed(SERIES_DIR_NAME + file,
                                mu=mu,
                                log_var=log_var,
                                action=action,
                                reward=reward,
                                done=done)
            initial_mus.append(initial_mu)
            initial_log_vars.append(initial_log_var)

            file_count += 1

            if file_count % 50 == 0:
                print('Encoded {} / {} episodes'.format(file_count, N))

        except Exception as e:
            print(e)
            print('Skipped {}...'.format(file))

    print('Encoded {} / {} episodes'.format(file_count, N))

    initial_mus = np.array(initial_mus)
    initial_log_vars = np.array(initial_log_vars)

    print('ONE MU SHAPE = {}'.format(mu.shape))
    print('INITIAL MU SHAPE = {}'.format(initial_mus.shape))

    np.savez_compressed(ROOT_DIR_NAME + 'initial_z.npz',
                        initial_mu=initial_mus,
                        initial_log_var=initial_log_vars)
def generate(args):
    num_files = args.num_files
    validation = args.validation
    seq_len = args.seq_len

    vae = VAE()

    weights = "final.h5"

    try:
        vae.set_weights("./vae/" + weights)
    except:
        print("./vae/" + weights + " does not exist")
        raise FileNotFoundError

    for file_id in range(num_files):
        print("Generating file {}...".format(file_id))

        obs_data = []
        action_data = []

        for env_name in config.train_envs:
            try:
                if validation:
                    new_obs_data = np.load("./data/obs_valid_" + env_name + ".npz")["arr_0"]
                    new_action_data = np.load("./data/action_valid_" + env_name + ".npz")["arr_0"]
                else:
                    new_obs_data = np.load("./data/obs_data_" + env_name + "_" + str(file_id) + ".npz")["arr_0"]
                    new_action_data = np.load("./data/action_data_" + env_name + "_" + str(file_id) + ".npz")["arr_0"]

                index = 0
                for episode in new_obs_data:
                    # print(len(episode))
                    if len(episode) != seq_len:
                        new_obs_data = np.delete(new_obs_data, index)
                        new_action_data = np.delete(new_action_data, index)
                    else:
                        index += 1

                obs_data = np.append(obs_data, new_obs_data)
                action_data = np.append(action_data, new_action_data)

                print("Found {}...current data size = {} episodes".format(env_name, len(obs_data)))
            except Exception as e:
                print(e)
                pass

            if validation:
                rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data)
                np.savez_compressed("./data/rnn_input_" + env_name + "_valid", rnn_input)
                np.savez_compressed("./data/rnn_output_" + env_name + "_valid", rnn_output)
            else:
                rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data)
                np.savez_compressed("./data/rnn_input_" + env_name + "_" + str(file_id), rnn_input)
                np.savez_compressed("./data/rnn_output_" + env_name + "_" + str(file_id), rnn_output)

        if validation:
            break
Beispiel #6
0
def main(args):

    new_model = args.new_model
    S = int(args.S)
    N = int(args.N)
    batch = int(args.batch)
    model_name = str(args.model_name)
    print(args.alpha)
    alpha = float(args.alpha)
    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/' + model_name + '/' + model_name +
                            '_weights.h5')

        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise
    else:
        if os.path.isdir('./vae/' + model_name):
            print("A model with this name already exists")
        else:
            os.mkdir('./vae/' + model_name)
            os.mkdir('./vae/' + model_name + '/log/')

    filelist = os.listdir(DIR_NAME)
    filelist = [x for x in filelist if x != '.DS_Store' and x != '.gitignore']
    filelist.sort()

    for i in range(round(float(N - S) / batch)):
        data = import_data(S + i * batch, S + (i + 1) * batch, filelist)
        dataS = []
        dataB = []
        for d in data:
            beta = alpha + np.random.rand() * (1 - alpha)
            dataS.append(
                cv2.resize(crop(d, alpha * beta),
                           dsize=(SCREEN_SIZE_X, SCREEN_SIZE_Y),
                           interpolation=cv2.INTER_CUBIC))
            dataB.append(
                cv2.resize(crop(d, beta),
                           dsize=(SCREEN_SIZE_X, SCREEN_SIZE_Y),
                           interpolation=cv2.INTER_CUBIC))

        dataS = np.asarray(dataS)
        dataB = np.asarray(dataB)

        vae.train(dataS, dataB, model_name
                  )  # uncomment this to train augmenting VAE, simple RNN (2)
        #vae.train(np.vstack([dataS, dataB]), np.vstack([dataS, dataB]), model_name) # uncomment this to train simple VAE, RNN (1)

        vae.save_weights('./vae/' + model_name + '/' + model_name +
                         '_weights.h5')

        print('Imported {} / {}'.format(S + (i + 1) * batch, N))
Beispiel #7
0
def make_model():
    vae = VAE()
    vae.set_weights('./vae/weights.h5')

    rnn = RNN()
    rnn.set_weights('./rnn/weights.h5')

    controller = Controller()

    model = Model(controller, vae, rnn)
    return model
Beispiel #8
0
def make_model():

  vae = VAE()
  vae.set_weights('./vae/weights.h5')

  rnn = RNN()
  rnn.set_weights('./rnn/weights.h5')

  controller = Controller()

  model = Model(controller, vae, rnn)
  return model
def main(args):

    num_files = args.num_files
    load_model = args.load_model

    vae = VAE()

    if not load_model=="None":
        try:
            print("Loading model " + load_model)
            vae.set_weights(load_model)
        except:
            print("Either don't set --load_model or ensure " + load_model + " exists")
            raise

    
    vae.train(num_files)
Beispiel #10
0
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch

    vae = VAE()
    try:
        vae.set_weights('./vae/weights.h5')
    except:
        print(
            "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first"
        )
        raise

    for i in range(start_batch, max_batch + 1):
        print('Generating batch {}...'.format(i))
        vae.generate_rnn_data(i)
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch

    vae = VAE()

    try:
        vae.set_weights('./vae/weights.h5')
    except:
        print(
            "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first"
        )
        raise

    for batch_num in range(start_batch, max_batch + 1):
        first_item = True
        print('Generating batch {}...'.format(batch_num))

        for env_name in config.train_envs:
            try:
                new_obs_data = np.load('./data/obs_data_' + env_name + '_' +
                                       str(batch_num) + '.npy')
                new_action_data = np.load('./data/action_data_' + env_name +
                                          '_' + str(batch_num) + '.npy')
                if first_item:
                    obs_data = new_obs_data
                    action_data = new_action_data
                    first_item = False
                else:
                    obs_data = np.concatenate([obs_data, new_obs_data])
                    action_data = np.concatenate(
                        [action_data, new_action_data])
                print('Found {}...current data size = {} episodes'.format(
                    env_name, len(obs_data)))
            except:
                pass

        if first_item == False:
            rnn_input, rnn_output = vae.generate_rnn_data(
                obs_data, action_data)
            np.save('./data/rnn_input_' + str(batch_num), rnn_input)
            np.save('./data/rnn_output_' + str(batch_num), rnn_output)
        else:
            print('no data found for batch number {}'.format(batch_num))
Beispiel #12
0
def train_on_drives(args):
    vae = None

    for data, y, cat in loadDataBatches(args.dirs,
                                        skip_actions=True,
                                        max_batch=10000):
        if vae is None:
            input_dim = data[0].shape
            print("Data shape: {}".format(data.shape))
            vae = VAE(input_dim=input_dim)

            if not args.new_model:
                try:
                    vae.set_weights('./vae/weights.h5')
                except:
                    print(
                        "Either set --new_model or ensure ./vae/weights.h5 exists"
                    )
                    raise

        vae.train(data, epochs=100)
Beispiel #13
0
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch
    new_model = args.new_model
    # epochs = args.epochs

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise

    for batch_num in range(start_batch, max_batch + 1):
        print('Building batch {}...'.format(batch_num))
        data = np.load('./data/obs_data_' + str(batch_num) + '.npy')
        data = np.array([item for obs in data for item in obs])
        vae.train(data)
Beispiel #14
0
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch
    new_model = args.new_model
    epochs = args.epochs

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise

    for i in range(epochs):
        for batch_num in range(start_batch, max_batch + 1):
            print('Building batch {}...'.format(batch_num))
            first_item = True

            for env_name in config.train_envs:
                try:
                    new_data = np.load('./data/obs_data_' + env_name + '_' +
                                       str(batch_num) + '.npy')
                    if first_item:
                        data = new_data
                        first_item = False
                    else:
                        data = np.concatenate([data, new_data])
                    print('Found {}...current data size = {} episodes'.format(
                        env_name, len(data)))
                except:
                    pass

            if first_item == False:  # i.e. data has been found for this batch number
                data = np.array([item for obs in data for item in obs])
                vae.train(data)
            else:
                print('no data found for batch number {}'.format(batch_num))
def main(args):

    new_model = args.new_model
    S = int(args.S)
    N = int(args.N)
    batch = int(args.batch)
    epochs = int(args.epochs)
    model_name = str(args.model_name)

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/' + model_name + '_weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/" + model_name +
                  "_weights.h5 exists")
            raise
    elif not os.path.isdir('./vae/' + model_name):
        os.mkdir('./vae/' + model_name)
        os.mkdir('./vae/' + model_name + '/log/')

    filelist = os.listdir(DIR_NAME + model_name)
    filelist = [x for x in filelist if x != '.DS_Store' and x != '.gitignore']
    filelist.sort()
    N = max(N, len(filelist))

    for i in range(int(round(float(N - S) / batch) + 1)):
        dataS, dataB = import_data(S + i * batch, S + (i + 1) * batch,
                                   filelist, model_name)
        for epoch in range(epochs):
            vae.train(
                dataS, dataB, model_name
            )  # uncomment this to train augmenting VAE, simple RNN (2)
            #vae.train(np.vstack([dataS, dataB]), np.vstack([dataS, dataB]), model_name) # uncomment this to train simple VAE, RNN (1)
        vae.save_weights('./vae/' + model_name + '/' + model_name +
                         '_weights.h5')

        print('Imported {} / {}'.format(S + (i + 1) * batch, N))
Beispiel #16
0
def main(args):

  start_batch = args.start_batch
  max_batch = args.max_batch
  new_model = args.new_model

  vae = VAE()

  

  if not new_model:
    try:
      vae.set_weights('./vae/weights.h5')
    except:
      print("Either set --new_model or ensure ./vae/weights.h5 exists")
      raise

  for batch_num in range(start_batch, max_batch + 1):
    print('Building batch {}...'.format(batch_num))
    first_item = True

    for env_name in config.train_envs:
      try:
        new_data = np.load('./data/obs_data_' + env_name + '_' + str(batch_num) + '.npy')
        if first_item:
          data = new_data
          first_item = False
        else:
          data = np.concatenate([data, new_data])
        print('Found {}...current data size = {} episodes'.format(env_name, len(data)))
      except:
        pass

    if first_item == False: # i.e. data has been found for this batch number
      data = np.array([item for obs in data for item in obs])
      vae.train(data)
    else:
      print('no data found for batch number {}'.format(batch_num))
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch

    vae = VAE()

    try:
      vae.set_weights('./vae/weights.h5')
    except:
      print("./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first")
      raise

    for batch_num in range(start_batch, max_batch + 1):
      first_item = True
      print('Generating batch {}...'.format(batch_num))

      for env_name in config.train_envs:
        try:
          new_obs_data = np.load('./data/obs_data_' + env_name + '_'  + str(batch_num) + '.npy') 
          new_action_data = np.load('./data/action_data_' + env_name + '_'  + str(batch_num) + '.npy')
          if first_item:
            obs_data = new_obs_data
            action_data = new_action_data
            first_item = False
          else:
            obs_data = np.concatenate([obs_data, new_obs_data])
            action_data = np.concatenate([action_data, new_action_data])
          print('Found {}...current data size = {} episodes'.format(env_name, len(obs_data)))
        except:
          pass
      
      if first_item == False:
        rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data)
        np.save('./data/rnn_input_' + str(batch_num), rnn_input)
        np.save('./data/rnn_output_' + str(batch_num), rnn_output)
      else:
        print('no data found for batch number {}'.format(batch_num))
Beispiel #18
0
def main(args):

    N = int(args.N)
    S = int(args.S)
    model_name = str(args.model_name)
    vae = VAE()

    try:
      vae.set_weights('./vae/'+model_name+'/'+model_name+'_weights.h5')
    except:
      print("./vae/"+model_name+"_weights.h5 does not exist - ensure you have run 02_train_vae.py first")
      raise


    filelist, N = get_filelist(S, N, model_name)

    file_count = 0

    initial_musS = []
    initial_log_varsS = []
    initial_musB = []
    initial_log_varsB = []

    for file in filelist:
      try:
      
        rollout_data = np.load(ROLLOUT_DIR_NAME + model_name + '/' + file)

        muS, log_varS, action, reward, done, initial_muS, initial_log_varS = \
                encode_episode(vae, rollout_data['obsS'], rollout_data['action'], rollout_data['reward'], rollout_data['done'])
        muB, log_varB, action, reward, done, initial_muB, initial_log_varB = \
                encode_episode(vae, rollout_data['obsB'], rollout_data['action'], rollout_data['reward'], rollout_data['done'])

        if not os.path.isdir(SERIES_DIR_NAME + model_name):
          os.mkdir(SERIES_DIR_NAME + model_name)
        np.savez_compressed(SERIES_DIR_NAME + model_name + "/" + file, muS=muS, log_varS=log_varS, muB=muB, log_varB=log_varB, action = action, reward = reward, done = done)
        initial_musS.append(initial_muS)
        initial_log_varsS.append(initial_log_varS)
        initial_musB.append(initial_muB)
        initial_log_varsB.append(initial_log_varB)

        file_count += 1

        if file_count%50==0:
          print('Encoded {} / {} episodes'.format(S+file_count, N))

      except:
        print('Skipped {}...'.format(file))

    print('Encoded {} / {} episodes'.format(S+file_count, N))

    initial_musS = np.array(initial_musS)
    initial_log_varsS = np.array(initial_log_varsS)
    initial_musB = np.array(initial_musB)
    initial_log_varsB = np.array(initial_log_varsB)

    print('ONE MU SHAPE = {}'.format(initial_musS.shape))
    print('INITIAL MU SHAPE = {}'.format(initial_musS.shape))

    np.savez_compressed(ROOT_DIR_NAME + 'initial_zS.npz', initial_muS=initial_musS, initial_log_varS=initial_log_varsS)
    np.savez_compressed(ROOT_DIR_NAME + 'initial_zB.npz', initial_muB=initial_musB, initial_log_varB=initial_log_varsB)
Beispiel #19
0
                img_cols=64,
                img_channels=3):
    # Visualize model by running the P-model forward from a grid of points in the latent space
    # ranging from -3 to +3.
    num_grid = 15
    grid_pts = np.linspace(-3, 3, num_grid)

    # Create a keras function that takes latent (x, y) as input and produces images as output
    renderer = K.function([sample_var], [reconstruction_var])
    latent_input = np.zeros((1, ndim))  # must have shape (1, ndim) not (ndim,)

    # Allocate space for the final (num_grid * rows, num_grid * cols) image
    final_image = np.zeros(
        (num_grid * img_rows, num_grid * img_cols, img_channels))

    # Populate each sub-image one at a time
    for i in range(num_grid):
        for j in range(num_grid):
            latent_input[0, :2] = (grid_pts[i], grid_pts[j])
            img = renderer([latent_input])[0]
            final_image[j * img_rows:(j + 1) * img_rows, i * img_cols:(i + 1) * img_cols, :] = \
                np.reshape(img, (img_rows, img_cols, img_channels))

    plt.imshow(final_image.squeeze(), cmap='gray', extent=(-3, 3, -3, 3))
    plt.show()


# Visualize results
vae = VAE()
vae.set_weights('./vae/weights.h5')
render_grid(vae.latent.sample, vae.reconstruction)