def generate(args):
    num_files = args.num_files
    validation = args.validation
    seq_len = args.seq_len

    vae = VAE()

    weights = "final.h5"

    try:
        vae.set_weights("./vae/" + weights)
    except:
        print("./vae/" + weights + " does not exist")
        raise FileNotFoundError

    for file_id in range(num_files):
        print("Generating file {}...".format(file_id))

        obs_data = []
        action_data = []

        for env_name in config.train_envs:
            try:
                if validation:
                    new_obs_data = np.load("./data/obs_valid_" + env_name + ".npz")["arr_0"]
                    new_action_data = np.load("./data/action_valid_" + env_name + ".npz")["arr_0"]
                else:
                    new_obs_data = np.load("./data/obs_data_" + env_name + "_" + str(file_id) + ".npz")["arr_0"]
                    new_action_data = np.load("./data/action_data_" + env_name + "_" + str(file_id) + ".npz")["arr_0"]

                index = 0
                for episode in new_obs_data:
                    # print(len(episode))
                    if len(episode) != seq_len:
                        new_obs_data = np.delete(new_obs_data, index)
                        new_action_data = np.delete(new_action_data, index)
                    else:
                        index += 1

                obs_data = np.append(obs_data, new_obs_data)
                action_data = np.append(action_data, new_action_data)

                print("Found {}...current data size = {} episodes".format(env_name, len(obs_data)))
            except Exception as e:
                print(e)
                pass

            if validation:
                rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data)
                np.savez_compressed("./data/rnn_input_" + env_name + "_valid", rnn_input)
                np.savez_compressed("./data/rnn_output_" + env_name + "_valid", rnn_output)
            else:
                rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data)
                np.savez_compressed("./data/rnn_input_" + env_name + "_" + str(file_id), rnn_input)
                np.savez_compressed("./data/rnn_output_" + env_name + "_" + str(file_id), rnn_output)

        if validation:
            break
Beispiel #2
0
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch

    vae = VAE()
    try:
        vae.set_weights('./vae/weights.h5')
    except:
        print(
            "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first"
        )
        raise

    for i in range(start_batch, max_batch + 1):
        print('Generating batch {}...'.format(i))
        vae.generate_rnn_data(i)
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch

    vae = VAE()

    try:
        vae.set_weights('./vae/weights.h5')
    except:
        print(
            "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first"
        )
        raise

    for batch_num in range(start_batch, max_batch + 1):
        first_item = True
        print('Generating batch {}...'.format(batch_num))

        for env_name in config.train_envs:
            try:
                new_obs_data = np.load('./data/obs_data_' + env_name + '_' +
                                       str(batch_num) + '.npy')
                new_action_data = np.load('./data/action_data_' + env_name +
                                          '_' + str(batch_num) + '.npy')
                if first_item:
                    obs_data = new_obs_data
                    action_data = new_action_data
                    first_item = False
                else:
                    obs_data = np.concatenate([obs_data, new_obs_data])
                    action_data = np.concatenate(
                        [action_data, new_action_data])
                print('Found {}...current data size = {} episodes'.format(
                    env_name, len(obs_data)))
            except:
                pass

        if first_item == False:
            rnn_input, rnn_output = vae.generate_rnn_data(
                obs_data, action_data)
            np.save('./data/rnn_input_' + str(batch_num), rnn_input)
            np.save('./data/rnn_output_' + str(batch_num), rnn_output)
        else:
            print('no data found for batch number {}'.format(batch_num))
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch

    vae = VAE()

    try:
      vae.set_weights('./vae/weights.h5')
    except:
      print("./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first")
      raise

    for batch_num in range(start_batch, max_batch + 1):
      first_item = True
      print('Generating batch {}...'.format(batch_num))

      for env_name in config.train_envs:
        try:
          new_obs_data = np.load('./data/obs_data_' + env_name + '_'  + str(batch_num) + '.npy') 
          new_action_data = np.load('./data/action_data_' + env_name + '_'  + str(batch_num) + '.npy')
          if first_item:
            obs_data = new_obs_data
            action_data = new_action_data
            first_item = False
          else:
            obs_data = np.concatenate([obs_data, new_obs_data])
            action_data = np.concatenate([action_data, new_action_data])
          print('Found {}...current data size = {} episodes'.format(env_name, len(obs_data)))
        except:
          pass
      
      if first_item == False:
        rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data)
        np.save('./data/rnn_input_' + str(batch_num), rnn_input)
        np.save('./data/rnn_output_' + str(batch_num), rnn_output)
      else:
        print('no data found for batch number {}'.format(batch_num))