def main(args): start_batch = args.start_batch max_batch = args.max_batch new_model = args.new_model rnn = RNN() if not new_model: try: rnn.set_weights('./rnn/weights.h5') except: print("Either set --new_model or ensure ./rnn/weights.h5 exists") raise for batch_num in range(start_batch, max_batch + 1): print('Building batch {}...'.format(batch_num)) new_rnn_input = np.load('./data/rnn_input_' + str(batch_num) + '.npy') new_rnn_output = np.load( './data/rnn_output_' + str(batch_num) + '.npy') if batch_num>start_batch: rnn_input = np.concatenate([rnn_input, new_rnn_input]) rnn_output = np.concatenate([rnn_output, new_rnn_output]) else: rnn_input = new_rnn_input rnn_output = new_rnn_output rnn.train(rnn_input, rnn_output)
def make_model(): vae = VAE() vae.set_weights('./vae/weights.h5') rnn = RNN() rnn.set_weights('./rnn/weights.h5') controller = Controller() model = Model(controller, vae, rnn) return model
def make_model(): acai = ACAI() rnn = RNN() rnn.set_weights('./rnn/weights.h5') controller = Controller() model = Model(controller, acai, rnn) return model
def main(args): new_model = args.new_model N = int(args.N) steps = int(args.steps) batch_size = int(args.batch_size) # TODO: Take the learning rate as an argument rnn = RNN(learning_rate=0.0001) if not new_model: try: rnn.set_weights('./rnn/weights.h5') except: print("Either set --new_model or ensure ./rnn/weights.h5 exists") raise filelist, N = get_filelist(N) for step in range(steps): print('STEP ' + str(step)) z, action, rew, done = random_batch(filelist, batch_size) rnn_input = np.concatenate([z[:, :-1, :], action[:, :-1, :], rew[:, :-1, :]], axis=2) rnn_output = np.concatenate([z[:, 1:, :], rew[:, 1:, :]], axis=2) # , done[:, 1:, :] if step == 0: np.savez_compressed(ROOT_DIR_NAME + 'rnn_files.npz', rnn_input=rnn_input, rnn_output=rnn_output) rnn.train(rnn_input, rnn_output) if step % 10 == 0: rnn.model.save_weights('./rnn/weights.h5') rnn.model.save_weights('./rnn/weights.h5')
def main(args): new_model = args.new_model N = int(args.N) steps = int(args.steps) batch_size = int(args.batch_size) rnn = RNN() # learning_rate = LEARNING_RATE if not new_model: try: rnn.set_weights('./rnn/weights.h5') except: print("Either set --new_model or ensure ./rnn/weights.h5 exists") raise filelist, N = get_filelist(N) for step in range(steps): print('STEP ' + str(step) + '/' + str(steps)) zS, zB, action, rew, done = random_batch(filelist, batch_size) print(zS.shape) new_batch_size = zS.shape[0] # rnn_input = np.concatenate([zS[:, :-1, :].reshape(batch_size, -1, Z_DIM), action[:, :-1, :], rew[:, :-1, :]], axis = 2) rnn_input = np.concatenate([ zS[:, :-1, :].reshape(new_batch_size, -1, Z_DIM), action[:, :-1, :], rew[:, :-1, :] ], axis=2) # print(rnn_input.shape) print(rnn_input.shape) # rnn_output = np.concatenate([zB[:, 1:, :].reshape(batch_size, -1, Z_DIM), rew[:, 1:, :]], axis = 2) #, done[:, 1:, :] rnn_output = np.concatenate( [zB[:, 1:, :].reshape(new_batch_size, -1, Z_DIM), rew[:, 1:, :]], axis=2) #, done[:, 1:, :] print(rnn_output.shape) if step == 0: np.savez_compressed(ROOT_DIR_NAME + 'rnn_files.npz', rnn_input=rnn_input, rnn_output=rnn_output) rnn.train(rnn_input, rnn_output) if step % 10 == 0: # print(step) rnn.model.save_weights('./rnn/weights.h5') rnn.model.save_weights('./rnn/weights.h5')
def main(args): new_model = args.new_model N = int(args.N) steps = int(args.steps) mtype = args.model_type batch_size = int(args.batch_size) rnn = RNN(mtype) #learning_rate = LEARNING_RATE if not new_model: try: rnn.set_weights('./rnn/weights.h5') except: print("Either set --new_model or ensure ./rnn/weights.h5 exists") raise filelist, N = get_filelist(N) for step in range(steps): print('STEP ' + str(step)) z, action, rew ,done = random_batch(filelist, batch_size) rnn_input = np.concatenate([z[:, :-1, :], action[:, :-1, :], rew[:, :-1, :]], axis = 2) rnn_output = np.concatenate([z[:, 1:, :], rew[:, 1:, :]], axis = 2) #, done[:, 1:, :] if step == 0: np.savez_compressed(ROOT_DIR_NAME + 'rnn_files.npz', rnn_input = rnn_input, rnn_output = rnn_output) rnn.train(rnn_input, rnn_output) if step % 10 == 0: rnn.model.save_weights('./rnn/weights.h5') rnn.model.save_weights('./rnn/weights.h5') if args.model_type=='LSTM' or args.model_type=='lstm': rnn.model.save('./rnn/model_lstm.h5') else: rnn.model.save('./rnn/model_gru.h5') import pickle with open('./rnn/history.pickle', 'wb') as f: pickle.dump(rnn.model.history.history, f)
def main(args): num_files = args.num_files load_model = args.load_model rnn = RNN() if not load_model == "None": try: print("Loading model " + load_model) rnn.set_weights(load_model) except: print("Either don't set --load_model or ensure " + load_model + " exists") raise rnn.train(num_files)
def main(args): new_model = args.new_model S = int(args.S) N = int(args.N) model_name = str(args.model_name) rnn = RNN() if not new_model: try: rnn.set_weights('./rnn/' + model_name + '/' + model_name + '.h5') except: print("Either set --new_model or ensure ./rnn/weights.h5 exists") raise elif not os.path.isdir('./rnn/' + model_name): os.mkdir('./rnn/' + model_name) os.mkdir('./rnn/' + model_name + '/log/') filelist, N = get_filelist(S, N) for step in range(N): print('STEP ' + str(step)) zS, zB, action, rew = get_batch(filelist, step) rnn_input = np.concatenate([ zS[:, :-1, :].reshape(1, -1, 32), action[:, :-1, :], rew[:, :-1, :] ], axis=2) rnn_output = np.concatenate( [zB[:, 1:, :].reshape(1, -1, 32), rew[:, 1:, :]], axis=2) if step == 0: np.savez_compressed(ROOT_DIR_NAME + 'rnn_files.npz', rnn_input=rnn_input, rnn_output=rnn_output) rnn.train(rnn_input, rnn_output, model_name) if step % 10 == 0: rnn.model.save_weights('./rnn/' + model_name + '/' + model_name + '_weights.h5') rnn.model.save_weights('./rnn/' + model_name + '/' + model_name + '_weights.h5')
def main(args): new_model = args.new_model N = int(args.N) steps = int(args.steps) batch_size = int(args.batch_size) rnn = RNN() #learning_rate = LEARNING_RATE if not new_model: try: rnn.set_weights('./rnn/weights.h5') # rnn.set_weights('./rnn/weights_epoch-1_batch-512_steps-4000_lr=1e-4.h5') except: print("Either set --new_model or ensure ./rnn/weights.h5 exists") raise filelist, N = get_filelist(N) for step in range(steps): print('STEP ' + str(step)) z, action, rew, done = random_batch(filelist, batch_size) rnn_input = np.concatenate( [z[:, :-1, :], action[:, :-1, :], rew[:, :-1, :]], axis=2) rnn_output = np.concatenate([z[:, 1:, :], rew[:, 1:, :]], axis=2) #, done[:, 1:, :] if step == 0: np.savez_compressed(ROOT_DIR_NAME + 'rnn_files.npz', rnn_input=rnn_input, rnn_output=rnn_output) rnn.train(rnn_input, rnn_output) if step % 10 == 0: # pass rnn.model.save_weights('./rnn/weights.h5') rnn.model.save_weights('./rnn/weights.h5')
def main(args): new_model = args.new_model S = int(args.S) N = int(args.N) steps = int(args.steps) batch_size = int(args.batch_size) model_name = str(args.model_name) rnn = RNN() if not new_model: try: rnn.set_weights('./rnn/' + model_name + '_weights.h5') except: print("Either set --new_model or ensure ./rnn/" + model_name + "_weights.h5 exists") raise filelist, N = get_filelist(S, N, model_name) for step in range(steps): print('STEP ' + str(step)) zS, zB, action, rew, done = random_batch(filelist, batch_size, model_name) rnn_input = np.concatenate( [zS[:, :-1, :], action[:, :-1, :], rew[:, :-1, :]], axis=2) rnn_output = np.concatenate([zB[:, 1:, :], rew[:, 1:, :]], axis=2) if step == 0: np.savez_compressed(ROOT_DIR_NAME + 'rnn_files.npz', rnn_input=rnn_input, rnn_output=rnn_output) rnn.train(rnn_input, rnn_output, model_name) rnn.model.save_weights('./rnn/' + model_name + '/' + model_name + '_weights.h5')
# Ignoring time for now, will need to use it to not try to learn across time gaps rnn_input = data[1:, 1:] rnn_output = data[:-1, 3:] print( "RNN Input : {}".format( rnn_input.shape ) ) print( "RNN Output: {}".format( rnn_output.shape ) ) return rnn_input, rnn_output def main(args): new_model = args.new_model N = int(args.N) steps = int(args.steps) batch_size = int(args.batch_size) rnn = RNN( z_dim=128, action_dim=2) if not new_model: try: rnn.set_weights('./rnn/weights.h5') except: print("Either set --new_model or ensure ./rnn/weights.h5 exists") raise print( "Loading data..." ) rnn_input, rnn_output = load_data() for step in range(steps): print('STEP ' + str(step)) z, action, rew ,done = random_batch(filelist, batch_size)