def main(): n = 5 env = PermutationSorting(n) st = RBFStateTransformer(env, MaxStateTransformer(n)) agent = TDLambdaAgent(env, st) agent.train(500, Eps1().eps, True) print(agent.solve([0, 4, 3, 1, 2]))
def main(argv): np.random.seed(12345678) n = int(argv[1]) n_neighbors = 0.1 pretrain = './saved_models/ddpg_tf_pretrain_weights_' + str(n) + str( n_neighbors) + '.h5' train = './saved_models/ddpg_tf_final_weights_' + str(n) + str( n_neighbors) + '.h5' env = PermutationSorting(n, transpositions=True) state_transformer = OneHotStateTransformer(n) actor_layer_sizes = [400, 300] critic_layer_sizes = [400, 300] agent = DDPGAgent(env, state_transformer, actor_layer_sizes, critic_layer_sizes, batch_size=32, train_start=60000, maxlen=1e6, neighbors_percent=n_neighbors, render=False, fill_mem=True, pretrain_path=pretrain, train_path=train) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: agent.set_session(sess) sess.run(tf.global_variables_initializer()) agent.train(episodes=1000)
def main(): np.random.seed(12345678) n = 15 env = PermutationSorting(n, transpositions=True) state_transformer = OneHotStateTransformer(n) agent = DDQNAgent(env, state_transformer, batch_size=126) #agent.serial_pretrain() # agent.load_pretrain_weights() # agent.load_final_weights() agent.fill_memory() agent.train()
def main(argv): np.random.seed(12345678) n = int(argv[1]) env = PermutationSorting(n, transpositions=True) state_transformer = OneHotStateTransformer(n) agent = DDPGAgent(env, state_transformer, batch_size=64, neighbors_percent=0.1, render=False) agent.fill_memory() agent.train()
def ddqn_10(): np.random.seed(12345678) n = 10 env = PermutationSorting(n, transpositions=True) state_transformer = OneHotStateTransformer(n) ddqn = DDQNAgent(env, state_transformer, 0, hidden_layer_sizes=(400, 300), batch_size=256, train_start=60000, epsilon=0.09, render=False, train_path='./saved_models/ddqn_tf_final_weights_' + str(n) + '.h5') path = "./data/10_exact/all10urt" path_res = "./data/10_res/ddqn_exact" with tf.Session() as sess: ddqn.set_session(sess) ddqn.load_weights() n_better = 0 n_worse = 0 n_equal = 0 t_dif = 0 start = time.time() with open(path_res, 'w') as fp: for i in range(10000): f = open(path, 'r') line = random_line(f) p = np.fromstring(line, dtype=int, sep=',') p -= 1 res = re.search(r'\s+\d+', line) exact_ans = int(res.group()) f.close() ddqn_ans = ddqn.solve(p) dif = float(ddqn_ans) / float(exact_ans) t_dif = dif if t_dif == 0 else (dif + t_dif) / 2 if ddqn_ans < exact_ans: n_better += 1 elif ddqn_ans > exact_ans: n_worse += 1 else: n_equal += 1 string = str(i) + ' - ' + str(p) + ' - ' + ' DDQN: ' + str(ddqn_ans) + ' Exact: ' +\ str(exact_ans) + ' Difference: ' + str(dif) print(time.time() - start, string) fp.write(string + '\n') string = "Total difference: " + str(t_dif) + " Times better: " + str(n_better) + " Times worse: " +\ str(n_worse) + " Times equal: " + str(n_equal) print(string) fp.write(string + '\n')
def main(argv): np.random.seed(12345678) n = int(argv[1]) pretrain = './saved_models/ddqn_tf_pretrain_weights_' + str(n) + '.h5' train = './saved_models/ddqn_tf_final_weights_' + str(n) + '.h5' env = PermutationSorting(n, transpositions=True) state_transformer = OneHotStateTransformer(n) hidden_layer_sizes = (400, 300) agent = DDQNAgent(env, state_transformer, 0, hidden_layer_sizes=hidden_layer_sizes, batch_size=32, train_start=20000, epsilon=0.7, fill_mem=True, render=False, pretrain_path=pretrain, train_path=train) with tf.Session() as sess: agent.set_session(sess) sess.run(tf.global_variables_initializer()) agent.serial_pretrain() #agent.load_pretrain_weights() agent.train(episodes=10000)
def main(): np.random.seed(12345678) n = 10 env = PermutationSorting(n) state_transformer = OneHotStateTransformer(n) agent = DDQNAgent(env, state_transformer) agent.parallel_pretrain(rows=10000, epochs=30) # agent.load_pretrain_weights() agent.train(episodes=10000, max_steps=250) # agent.load_final_weights() for _ in range(10): p = np.random.permutation(n) rl_ans = agent.solve(p) exact_ans = PermutationExactSolver(n).solve(p) print(p, '-', 'RL:', rl_ans, ' Exact:', exact_ans)
def ddpg_15(): n = 15 neighbors = 0.8 length = 1000 env = PermutationSorting(n, transpositions=True) state_transformer = OneHotStateTransformer(n) rt = np.arange(1, 16) actor_layer_sizes = [400, 300] critic_layer_sizes = [400, 300] ddpg = DDPGAgent(env, state_transformer, actor_layer_sizes, critic_layer_sizes, batch_size=32, actor_learning_rate=0.0001, critic_learning_rate=0.001, train_start=60000, maxlen=1e6, neighbors_percent=neighbors, render=False, train_path='./final_weights/ddpg_tf_final_weights_' + str(n) + '_' + str(neighbors) + '.h5') with tf.Session() as sess: ddpg.set_session(sess) ddpg.load_weights() for revt in rt: path_db = "./data/db15/perms-10k-r" + str(revt) + "-t" + str(revt) path_res = "./data/15_wp/wp-1k-r" + str(revt) + "-t" + str( revt) + ".dist" start = time.time() with open(path_db, 'r') as fp, open(path_res, 'w') as res: line = fp.readline() i = 0 while line and i < length: p = np.fromstring(line, dtype=int, sep=',') p = p - 1 ddpg_ans = ddpg.solve(p) print(time.time() - start, i, ddpg_ans) res.write(str(ddpg_ans) + '\n') i += 1 line = fp.readline()
def ddpg_10(): n = 10 neighbors = 0.8 env = PermutationSorting(n, transpositions=True) state_transformer = OneHotStateTransformer(n) actor_layer_sizes = [400, 300] critic_layer_sizes = [400, 300] ddpg = DDPGAgent(env, state_transformer, actor_layer_sizes, critic_layer_sizes, batch_size=32, actor_learning_rate=0.0001, critic_learning_rate=0.001, train_start=60000, maxlen=1e6, neighbors_percent=neighbors, render=False, train_path='./final_weights/ddpg_tf_final_weights_' + str(n) + '_' + str(neighbors) + '.h5') with tf.Session() as sess: ddpg.set_session(sess) ddpg.load_weights() path_db = "./data/10_res/10_db" path_res = "./data/10_res/wp.dist" start = time.time() i = 0 with open(path_db, 'r') as fp, open(path_res, 'w') as res: line = fp.readline() while line: p = np.fromstring(line, dtype=int, sep=' ') ddpg_ans = ddpg.solve(p) print(time.time() - start, i, ddpg_ans) res.write(str(ddpg_ans) + '\n') i += 1 line = fp.readline()