def create_disc(sim, save_path): import os from rts.discriminator import Discriminator if save_path is None: raise NameError( "A save_path should always be given to a discriminator") disc = Discriminator(sim.obs_dim * 2 + sim.act_dim) disc.save_path = save_path if nodes.mpi_role == 'main': os.makedirs(disc.save_path) disc.save(disc.save_path) data_out = { nodes.pnid + ":actor_weight": warehouse.Entry(action="set", value=disc.get_weights()) } data = warehouse.send(data_out) data_out = { nodes.pnid + ":actor_weight": warehouse.Entry(action="get", value=None) } data = warehouse.send(data_out) disc.set_weights(data[nodes.pnid + ":actor_weight"].value) return disc
def load_actor (actor, path): if mpi_role == 'main': actor.load(path) data_out = {pnid+":actor_weight":warehouse.Entry(action="set", value=actor.get_weights())} data = warehouse.send(data_out) data_out = {pnid+":actor_weight":warehouse.Entry(action="get", value=None)} data = warehouse.send(data_out) actor.set_weights(data[pnid+":actor_weight"].value)
def node_wrapper(*args, **kwargs): global proc_num global pnid proc_num += 1 pnid = str(proc_num)+":" warehouse.send({"proc_num": warehouse.Entry(action="set_max", value=proc_num)}) return func(*args, **kwargs)
def generate_trans_batch (env, actor, rollout_nb, rollout_len, log_std, save_path): mpi_role = nodes.mpi_role proc_num = nodes.proc_num pnid = nodes.pnid if mpi_role == 'main': os.makedirs(save_path) msg = {pnid+"trans" : warehouse.Entry(action="get_l", value=rollout_nb)} data = warehouse.send(msg) all_trans = np.stack(data[pnid+"trans"].value) np.save(os.path.join(save_path, "all_trans.npy"), all_trans) elif mpi_role == 'worker': msg = {"proc_num" : warehouse.Entry(action="get", value=None)} data = warehouse.send(msg) while proc_num >= data["proc_num"].value and not warehouse.is_work_done: all_trans = [] obs = env.reset () obs = np.asarray(obs).reshape((1,1,-1)) for i in range(rollout_len): trans = [obs.flatten()] act = actor.model(obs).numpy() act = act + np.random.normal(size=act.flatten().shape[0]).reshape(act.shape) * np.exp(log_std) obs, rew, done = env.step(act) obs = np.asarray(obs).reshape((1,1,-1)) trans.append((obs.flatten()-trans[0])*10) trans.append(act.flatten()) all_trans.append(np.concatenate(trans)) all_trans = np.asarray(all_trans).reshape((rollout_len,-1)) msg = { pnid+"trans" : warehouse.Entry(action="add", value=all_trans), "proc_num" : warehouse.Entry(action="get", value=None)} data = warehouse.send(msg)
def simple_actor (obs_dim, act_dim, obs_mean=None, obs_std=None, blindfold=None, inp_dim=None, save_path=None): import os from models.actor import SimpleActor if save_path is None: raise NameError("A save_path should always be given to an actor") actor = SimpleActor (obs_dim, act_dim, obs_mean=obs_mean, obs_std=obs_std, blindfold=blindfold, inp_dim=inp_dim) actor.save_path = save_path if mpi_role == 'main': os.makedirs(actor.save_path) actor.save(actor.save_path) data_out = {pnid+":actor_weight":warehouse.Entry(action="set", value=actor.get_weights())} data = warehouse.send(data_out) data_out = {pnid+":actor_weight":warehouse.Entry(action="get", value=None)} data = warehouse.send(data_out) actor.set_weights(data[pnid+":actor_weight"].value) return actor
def train_discrim (disc, real_trans_path, all_env, actor, epoch_nb, train_step_per_epoch, rollout_per_epoch, rollout_len, log_std, model_save_interval, tensorboard_path): mpi_role = nodes.mpi_role proc_num = nodes.proc_num pnid = nodes.pnid if mpi_role == 'main': os.makedirs(tensorboard_path) from rts.discriminator import Trainer trainer = Trainer(disc, tensorboard_path) trainer.model_save_interval = model_save_interval real_lab = np.asarray([1, 0]).reshape((1,2)) synth_lab = np.asarray([0, 1]).reshape((1,2)) all_real_trans = np.load(os.path.join(real_trans_path, "all_trans.npy")) all_real_trans, all_real_labs = format_trans(all_real_trans, real_lab) start_time = time.time() for n in range(epoch_nb): # get the latest rollouts msg = { pnid+"trans" : warehouse.Entry(action="get_l", value=rollout_per_epoch), "dumped" : warehouse.Entry(action="get", value=None) } data = warehouse.send(msg) dumped_rollout_nb = data["dumped"].value all_synth_trans_raw = data[pnid+"trans"].value all_synth_trans, all_synth_labs = format_trans(np.concatenate(all_synth_trans_raw, axis=0), synth_lab) # put the training data together all_trans = np.concatenate([all_real_trans, all_synth_trans], axis=0) all_labs = np.concatenate([all_real_labs, all_synth_labs], axis=0) # random offset for regularisation all_trans += np.random.normal(size=all_trans.shape) * 0.003 # update the network weights accuracy = trainer.train_network(n, all_trans, all_labs, train_step_per_epoch) #debug n_rollouts = len(all_synth_trans_raw) print("Epoch {} :".format(n), flush=True) print("Loaded {} synthetic rollouts for training while dumping {} for a total of {} transitions.".format(n_rollouts, dumped_rollout_nb, all_synth_trans.shape[0]), flush=True) dt = time.time() - start_time start_time = time.time() if dt > 0: print("fps : {}".format(all_synth_trans.shape[0]/dt), flush=True) print("accuracy : {}/{}".format(accuracy, all_trans.shape[0]), flush=True) msg = {pnid+":disc_weight":warehouse.Entry(action="set", value=disc.get_weights())} data = warehouse.send(msg) elif mpi_role == "worker": msg = {"proc_num" : warehouse.Entry(action="get", value=None)} data = warehouse.send(msg) while proc_num >= data["proc_num"].value and not warehouse.is_work_done: all_trans = [] env = all_env[np.random.randint(len(all_env))] obs = env.reset () obs = np.asarray(obs).reshape((1,1,-1)) done = [False] i = 0 while i < rollout_len and not done[0]: i += 1 trans = [obs.flatten()] act = actor.model(obs).numpy() act = act + np.random.normal(size=act.flatten().shape[0]).reshape(act.shape) * np.exp(log_std) obs, rew, done = env.step(act) obs = np.asarray(obs).reshape((1,1,-1)) trans.append((obs.flatten()-trans[0])*10) trans.append(act.flatten()) all_trans.append(np.concatenate(trans)) all_trans = np.asarray(all_trans).reshape((i,-1)) msg = { pnid+"trans" : warehouse.Entry(action="add", value=all_trans), "proc_num" : warehouse.Entry(action="get", value=None)} data = warehouse.send(msg) msg = {pnid+":disc_weight":warehouse.Entry(action="get", value=None)} data = warehouse.send(msg) disc.set_weights(data[pnid+":disc_weight"].value)
def train_ppo (actor, env, epoch_nb, rollout_per_epoch, rollout_len, train_step_per_epoch, init_log_std, model_save_interval, adr_test_prob, tensorboard_path): mpi_role = nodes.mpi_role proc_num = nodes.proc_num pnid = nodes.pnid import os import time from ppo import PPO from models.critic import Critic USE_ADR = hasattr(env, 'adr') and adr_test_prob > 1e-7 if mpi_role == 'main': os.makedirs(tensorboard_path) critic = Critic(env) trainer = PPO(env, actor, critic, tensorboard_path, init_log_std=init_log_std) trainer.model_save_interval = model_save_interval start_time = time.time() for n in range(epoch_nb): # send the network weights # and get the latest rollouts msg = { pnid+"weights" : warehouse.Entry(action="set", value=trainer.get_weights()), pnid+"adr" : warehouse.Entry(action="get", value=None), pnid+"s" : warehouse.Entry(action="get_l", value=rollout_per_epoch), pnid+"a" : warehouse.Entry(action="get_l", value=rollout_per_epoch), pnid+"r" : warehouse.Entry(action="get_l", value=rollout_per_epoch), pnid+"neglog" : warehouse.Entry(action="get_l", value=rollout_per_epoch), pnid+"mask" : warehouse.Entry(action="get_l", value=rollout_per_epoch), "dumped" : warehouse.Entry(action="get", value=None) } data = warehouse.send(msg) all_s = np.concatenate(data[pnid+"s"].value, axis=0) all_a = np.concatenate(data[pnid+"a"].value, axis=0) all_r = np.concatenate(data[pnid+"r"].value, axis=0) all_neglog = np.concatenate(data[pnid+"neglog"].value, axis=0) all_masks = np.concatenate(data[pnid+"mask"].value, axis=0) dumped_rollout_nb = data["dumped"].value if USE_ADR: env.adr.update(data[pnid+"adr"].value) env.adr.log() # update the network weights all_last_values, all_gae, all_new_value = trainer.calc_gae(all_s, all_r, all_masks) trainer.train_networks(n, all_s, all_a, all_r, all_neglog, all_masks, train_step_per_epoch, all_last_values, all_gae, all_new_value) #debug n_rollouts = all_s.shape[0] cur_rollout_len = all_s.shape[1] print("Epoch {} :".format(n), flush=True) #dumped_rollout_nb = "?" print("Loaded {} rollouts for training while dumping {}.".format(n_rollouts, dumped_rollout_nb), flush=True) dt = time.time() - start_time start_time = time.time() if dt > 0: print("fps : {}".format(n_rollouts*cur_rollout_len/dt), flush=True) print("mean_rew : {}".format(np.sum(all_r * all_masks)/np.sum(all_masks)), flush=True) if USE_ADR: env.adr.save() elif mpi_role == 'worker': trainer = PPO(env, actor, Critic(env), init_log_std=init_log_std) msg = { pnid+"weights" : warehouse.Entry(action="get", value=None), pnid+"adr" : warehouse.Entry(action="set", value={}), "proc_num" : warehouse.Entry(action="get", value=None)} data = warehouse.send(msg) while proc_num >= data["proc_num"].value and not warehouse.is_work_done: test_adr = USE_ADR and np.random.random() < adr_test_prob env.test_adr = test_adr trainer.set_weights (data[pnid+"weights"].value) if test_adr: # simulate rollout all_s, all_a, all_r, all_neglog, all_mask = trainer.get_rollout(env.adr_rollout_len) msg = { pnid+"adr" : warehouse.Entry(action="update", value=env.adr.get_msg()), pnid+"weights" : warehouse.Entry(action="get", value=None), "proc_num" : warehouse.Entry(action="get", value=None)} else: # simulate rollout all_s, all_a, all_r, all_neglog, all_mask = trainer.get_rollout(rollout_len) # send rollout back to warehouse # and get network weights to update actor msg = { pnid+"s" : warehouse.Entry(action="add", value=all_s), pnid+"a" : warehouse.Entry(action="add", value=all_a), pnid+"r" : warehouse.Entry(action="add", value=all_r), pnid+"neglog" : warehouse.Entry(action="add", value=all_neglog), pnid+"mask" : warehouse.Entry(action="add", value=all_mask), pnid+"weights" : warehouse.Entry(action="get", value=None), pnid+"adr" : warehouse.Entry(action="get", value=None), "proc_num" : warehouse.Entry(action="get", value=None)} data = warehouse.send(msg) if USE_ADR: env.adr.update(data[pnid+"adr"].value)