def main(args): new_model = args.new_model N = int(args.N) epochs = int(args.epochs) vae = VAE() if not new_model: try: vae.set_weights('./vae/weights.h5') except: print("Either set --new_model or ensure ./vae/weights.h5 exists") raise try: data, N = import_data(N) except: print('NO DATA FOUND') raise print('DATA SHAPE = {}'.format(data.shape)) for epoch in range(epochs): print('EPOCH ' + str(epoch)) vae.train(data) vae.save_weights('./vae/weights.h5')
def train_on_drives_gen(args): input_dim = (64, 64, 3) gen = DriveDataGenerator(args.dirs, image_size=input_dim[0:2], batch_size=100, shuffle=True, max_load=10000, images_only=True) val = DriveDataGenerator(args.val, image_size=input_dim[0:2], batch_size=100, shuffle=True, max_load=10000, images_only=True) vae = VAE(input_dim=input_dim) print("Train: {}".format(gen.count)) print("Val : {}".format(val.count)) if not args.new_model: try: vae.set_weights('./vae/weights.h5') except: print("Either set --new_model or ensure ./vae/weights.h5 exists") raise vae.train_gen(gen, val, epochs=100)
def sample_vae(args): """ For vae from: https://github.com/AppliedDataSciencePartners/WorldModels.git """ vae = VAE(input_dim=(120, 120, 3)) try: vae.set_weights('./vae/weights.h5') except: print( "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first" ) raise z = np.random.normal(size=(args.count, vae.z_dim)) samples = vae.decoder.predict(z) input_dim = samples.shape[1:] n = args.count plt.figure(figsize=(20, 4)) plt.title('VAE samples') for i in range(n): ax = plt.subplot(3, n, i + 1) plt.imshow(samples[i].reshape(input_dim[0], input_dim[1], input_dim[2])) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) #plt.savefig( image_path ) plt.show()
def main(args): N = int(args.N) vae = VAE() try: vae.set_weights('./vae/weights.h5') except Exception as e: print(e) print( "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first" ) raise filelist, N = get_filelist(N) file_count = 0 initial_mus = [] initial_log_vars = [] for file in filelist: try: rollout_data = np.load(ROLLOUT_DIR_NAME + file) mu, log_var, action, reward, done, initial_mu, initial_log_var = encode_episode( vae, rollout_data) np.savez_compressed(SERIES_DIR_NAME + file, mu=mu, log_var=log_var, action=action, reward=reward, done=done) initial_mus.append(initial_mu) initial_log_vars.append(initial_log_var) file_count += 1 if file_count % 50 == 0: print('Encoded {} / {} episodes'.format(file_count, N)) except Exception as e: print(e) print('Skipped {}...'.format(file)) print('Encoded {} / {} episodes'.format(file_count, N)) initial_mus = np.array(initial_mus) initial_log_vars = np.array(initial_log_vars) print('ONE MU SHAPE = {}'.format(mu.shape)) print('INITIAL MU SHAPE = {}'.format(initial_mus.shape)) np.savez_compressed(ROOT_DIR_NAME + 'initial_z.npz', initial_mu=initial_mus, initial_log_var=initial_log_vars)
def generate(args): num_files = args.num_files validation = args.validation seq_len = args.seq_len vae = VAE() weights = "final.h5" try: vae.set_weights("./vae/" + weights) except: print("./vae/" + weights + " does not exist") raise FileNotFoundError for file_id in range(num_files): print("Generating file {}...".format(file_id)) obs_data = [] action_data = [] for env_name in config.train_envs: try: if validation: new_obs_data = np.load("./data/obs_valid_" + env_name + ".npz")["arr_0"] new_action_data = np.load("./data/action_valid_" + env_name + ".npz")["arr_0"] else: new_obs_data = np.load("./data/obs_data_" + env_name + "_" + str(file_id) + ".npz")["arr_0"] new_action_data = np.load("./data/action_data_" + env_name + "_" + str(file_id) + ".npz")["arr_0"] index = 0 for episode in new_obs_data: # print(len(episode)) if len(episode) != seq_len: new_obs_data = np.delete(new_obs_data, index) new_action_data = np.delete(new_action_data, index) else: index += 1 obs_data = np.append(obs_data, new_obs_data) action_data = np.append(action_data, new_action_data) print("Found {}...current data size = {} episodes".format(env_name, len(obs_data))) except Exception as e: print(e) pass if validation: rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data) np.savez_compressed("./data/rnn_input_" + env_name + "_valid", rnn_input) np.savez_compressed("./data/rnn_output_" + env_name + "_valid", rnn_output) else: rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data) np.savez_compressed("./data/rnn_input_" + env_name + "_" + str(file_id), rnn_input) np.savez_compressed("./data/rnn_output_" + env_name + "_" + str(file_id), rnn_output) if validation: break
def main(args): new_model = args.new_model S = int(args.S) N = int(args.N) batch = int(args.batch) model_name = str(args.model_name) print(args.alpha) alpha = float(args.alpha) vae = VAE() if not new_model: try: vae.set_weights('./vae/' + model_name + '/' + model_name + '_weights.h5') except: print("Either set --new_model or ensure ./vae/weights.h5 exists") raise else: if os.path.isdir('./vae/' + model_name): print("A model with this name already exists") else: os.mkdir('./vae/' + model_name) os.mkdir('./vae/' + model_name + '/log/') filelist = os.listdir(DIR_NAME) filelist = [x for x in filelist if x != '.DS_Store' and x != '.gitignore'] filelist.sort() for i in range(round(float(N - S) / batch)): data = import_data(S + i * batch, S + (i + 1) * batch, filelist) dataS = [] dataB = [] for d in data: beta = alpha + np.random.rand() * (1 - alpha) dataS.append( cv2.resize(crop(d, alpha * beta), dsize=(SCREEN_SIZE_X, SCREEN_SIZE_Y), interpolation=cv2.INTER_CUBIC)) dataB.append( cv2.resize(crop(d, beta), dsize=(SCREEN_SIZE_X, SCREEN_SIZE_Y), interpolation=cv2.INTER_CUBIC)) dataS = np.asarray(dataS) dataB = np.asarray(dataB) vae.train(dataS, dataB, model_name ) # uncomment this to train augmenting VAE, simple RNN (2) #vae.train(np.vstack([dataS, dataB]), np.vstack([dataS, dataB]), model_name) # uncomment this to train simple VAE, RNN (1) vae.save_weights('./vae/' + model_name + '/' + model_name + '_weights.h5') print('Imported {} / {}'.format(S + (i + 1) * batch, N))
def make_model(): vae = VAE() vae.set_weights('./vae/weights.h5') rnn = RNN() rnn.set_weights('./rnn/weights.h5') controller = Controller() model = Model(controller, vae, rnn) return model
def main(args): num_files = args.num_files load_model = args.load_model vae = VAE() if not load_model=="None": try: print("Loading model " + load_model) vae.set_weights(load_model) except: print("Either don't set --load_model or ensure " + load_model + " exists") raise vae.train(num_files)
def main(args): start_batch = args.start_batch max_batch = args.max_batch vae = VAE() try: vae.set_weights('./vae/weights.h5') except: print( "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first" ) raise for i in range(start_batch, max_batch + 1): print('Generating batch {}...'.format(i)) vae.generate_rnn_data(i)
def main(args): start_batch = args.start_batch max_batch = args.max_batch vae = VAE() try: vae.set_weights('./vae/weights.h5') except: print( "./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first" ) raise for batch_num in range(start_batch, max_batch + 1): first_item = True print('Generating batch {}...'.format(batch_num)) for env_name in config.train_envs: try: new_obs_data = np.load('./data/obs_data_' + env_name + '_' + str(batch_num) + '.npy') new_action_data = np.load('./data/action_data_' + env_name + '_' + str(batch_num) + '.npy') if first_item: obs_data = new_obs_data action_data = new_action_data first_item = False else: obs_data = np.concatenate([obs_data, new_obs_data]) action_data = np.concatenate( [action_data, new_action_data]) print('Found {}...current data size = {} episodes'.format( env_name, len(obs_data))) except: pass if first_item == False: rnn_input, rnn_output = vae.generate_rnn_data( obs_data, action_data) np.save('./data/rnn_input_' + str(batch_num), rnn_input) np.save('./data/rnn_output_' + str(batch_num), rnn_output) else: print('no data found for batch number {}'.format(batch_num))
def train_on_drives(args): vae = None for data, y, cat in loadDataBatches(args.dirs, skip_actions=True, max_batch=10000): if vae is None: input_dim = data[0].shape print("Data shape: {}".format(data.shape)) vae = VAE(input_dim=input_dim) if not args.new_model: try: vae.set_weights('./vae/weights.h5') except: print( "Either set --new_model or ensure ./vae/weights.h5 exists" ) raise vae.train(data, epochs=100)
def main(args): start_batch = args.start_batch max_batch = args.max_batch new_model = args.new_model # epochs = args.epochs vae = VAE() if not new_model: try: vae.set_weights('./vae/weights.h5') except: print("Either set --new_model or ensure ./vae/weights.h5 exists") raise for batch_num in range(start_batch, max_batch + 1): print('Building batch {}...'.format(batch_num)) data = np.load('./data/obs_data_' + str(batch_num) + '.npy') data = np.array([item for obs in data for item in obs]) vae.train(data)
def main(args): start_batch = args.start_batch max_batch = args.max_batch new_model = args.new_model epochs = args.epochs vae = VAE() if not new_model: try: vae.set_weights('./vae/weights.h5') except: print("Either set --new_model or ensure ./vae/weights.h5 exists") raise for i in range(epochs): for batch_num in range(start_batch, max_batch + 1): print('Building batch {}...'.format(batch_num)) first_item = True for env_name in config.train_envs: try: new_data = np.load('./data/obs_data_' + env_name + '_' + str(batch_num) + '.npy') if first_item: data = new_data first_item = False else: data = np.concatenate([data, new_data]) print('Found {}...current data size = {} episodes'.format( env_name, len(data))) except: pass if first_item == False: # i.e. data has been found for this batch number data = np.array([item for obs in data for item in obs]) vae.train(data) else: print('no data found for batch number {}'.format(batch_num))
def main(args): new_model = args.new_model S = int(args.S) N = int(args.N) batch = int(args.batch) epochs = int(args.epochs) model_name = str(args.model_name) vae = VAE() if not new_model: try: vae.set_weights('./vae/' + model_name + '_weights.h5') except: print("Either set --new_model or ensure ./vae/" + model_name + "_weights.h5 exists") raise elif not os.path.isdir('./vae/' + model_name): os.mkdir('./vae/' + model_name) os.mkdir('./vae/' + model_name + '/log/') filelist = os.listdir(DIR_NAME + model_name) filelist = [x for x in filelist if x != '.DS_Store' and x != '.gitignore'] filelist.sort() N = max(N, len(filelist)) for i in range(int(round(float(N - S) / batch) + 1)): dataS, dataB = import_data(S + i * batch, S + (i + 1) * batch, filelist, model_name) for epoch in range(epochs): vae.train( dataS, dataB, model_name ) # uncomment this to train augmenting VAE, simple RNN (2) #vae.train(np.vstack([dataS, dataB]), np.vstack([dataS, dataB]), model_name) # uncomment this to train simple VAE, RNN (1) vae.save_weights('./vae/' + model_name + '/' + model_name + '_weights.h5') print('Imported {} / {}'.format(S + (i + 1) * batch, N))
def main(args): start_batch = args.start_batch max_batch = args.max_batch new_model = args.new_model vae = VAE() if not new_model: try: vae.set_weights('./vae/weights.h5') except: print("Either set --new_model or ensure ./vae/weights.h5 exists") raise for batch_num in range(start_batch, max_batch + 1): print('Building batch {}...'.format(batch_num)) first_item = True for env_name in config.train_envs: try: new_data = np.load('./data/obs_data_' + env_name + '_' + str(batch_num) + '.npy') if first_item: data = new_data first_item = False else: data = np.concatenate([data, new_data]) print('Found {}...current data size = {} episodes'.format(env_name, len(data))) except: pass if first_item == False: # i.e. data has been found for this batch number data = np.array([item for obs in data for item in obs]) vae.train(data) else: print('no data found for batch number {}'.format(batch_num))
def main(args): start_batch = args.start_batch max_batch = args.max_batch vae = VAE() try: vae.set_weights('./vae/weights.h5') except: print("./vae/weights.h5 does not exist - ensure you have run 02_train_vae.py first") raise for batch_num in range(start_batch, max_batch + 1): first_item = True print('Generating batch {}...'.format(batch_num)) for env_name in config.train_envs: try: new_obs_data = np.load('./data/obs_data_' + env_name + '_' + str(batch_num) + '.npy') new_action_data = np.load('./data/action_data_' + env_name + '_' + str(batch_num) + '.npy') if first_item: obs_data = new_obs_data action_data = new_action_data first_item = False else: obs_data = np.concatenate([obs_data, new_obs_data]) action_data = np.concatenate([action_data, new_action_data]) print('Found {}...current data size = {} episodes'.format(env_name, len(obs_data))) except: pass if first_item == False: rnn_input, rnn_output = vae.generate_rnn_data(obs_data, action_data) np.save('./data/rnn_input_' + str(batch_num), rnn_input) np.save('./data/rnn_output_' + str(batch_num), rnn_output) else: print('no data found for batch number {}'.format(batch_num))
def main(args): N = int(args.N) S = int(args.S) model_name = str(args.model_name) vae = VAE() try: vae.set_weights('./vae/'+model_name+'/'+model_name+'_weights.h5') except: print("./vae/"+model_name+"_weights.h5 does not exist - ensure you have run 02_train_vae.py first") raise filelist, N = get_filelist(S, N, model_name) file_count = 0 initial_musS = [] initial_log_varsS = [] initial_musB = [] initial_log_varsB = [] for file in filelist: try: rollout_data = np.load(ROLLOUT_DIR_NAME + model_name + '/' + file) muS, log_varS, action, reward, done, initial_muS, initial_log_varS = \ encode_episode(vae, rollout_data['obsS'], rollout_data['action'], rollout_data['reward'], rollout_data['done']) muB, log_varB, action, reward, done, initial_muB, initial_log_varB = \ encode_episode(vae, rollout_data['obsB'], rollout_data['action'], rollout_data['reward'], rollout_data['done']) if not os.path.isdir(SERIES_DIR_NAME + model_name): os.mkdir(SERIES_DIR_NAME + model_name) np.savez_compressed(SERIES_DIR_NAME + model_name + "/" + file, muS=muS, log_varS=log_varS, muB=muB, log_varB=log_varB, action = action, reward = reward, done = done) initial_musS.append(initial_muS) initial_log_varsS.append(initial_log_varS) initial_musB.append(initial_muB) initial_log_varsB.append(initial_log_varB) file_count += 1 if file_count%50==0: print('Encoded {} / {} episodes'.format(S+file_count, N)) except: print('Skipped {}...'.format(file)) print('Encoded {} / {} episodes'.format(S+file_count, N)) initial_musS = np.array(initial_musS) initial_log_varsS = np.array(initial_log_varsS) initial_musB = np.array(initial_musB) initial_log_varsB = np.array(initial_log_varsB) print('ONE MU SHAPE = {}'.format(initial_musS.shape)) print('INITIAL MU SHAPE = {}'.format(initial_musS.shape)) np.savez_compressed(ROOT_DIR_NAME + 'initial_zS.npz', initial_muS=initial_musS, initial_log_varS=initial_log_varsS) np.savez_compressed(ROOT_DIR_NAME + 'initial_zB.npz', initial_muB=initial_musB, initial_log_varB=initial_log_varsB)
img_cols=64, img_channels=3): # Visualize model by running the P-model forward from a grid of points in the latent space # ranging from -3 to +3. num_grid = 15 grid_pts = np.linspace(-3, 3, num_grid) # Create a keras function that takes latent (x, y) as input and produces images as output renderer = K.function([sample_var], [reconstruction_var]) latent_input = np.zeros((1, ndim)) # must have shape (1, ndim) not (ndim,) # Allocate space for the final (num_grid * rows, num_grid * cols) image final_image = np.zeros( (num_grid * img_rows, num_grid * img_cols, img_channels)) # Populate each sub-image one at a time for i in range(num_grid): for j in range(num_grid): latent_input[0, :2] = (grid_pts[i], grid_pts[j]) img = renderer([latent_input])[0] final_image[j * img_rows:(j + 1) * img_rows, i * img_cols:(i + 1) * img_cols, :] = \ np.reshape(img, (img_rows, img_cols, img_channels)) plt.imshow(final_image.squeeze(), cmap='gray', extent=(-3, 3, -3, 3)) plt.show() # Visualize results vae = VAE() vae.set_weights('./vae/weights.h5') render_grid(vae.latent.sample, vae.reconstruction)