Exemple #1
0
def create_disc(sim, save_path):
    import os
    from rts.discriminator import Discriminator

    if save_path is None:
        raise NameError(
            "A save_path should always be given to a discriminator")

    disc = Discriminator(sim.obs_dim * 2 + sim.act_dim)
    disc.save_path = save_path

    if nodes.mpi_role == 'main':
        os.makedirs(disc.save_path)
        disc.save(disc.save_path)

        data_out = {
            nodes.pnid + ":actor_weight":
            warehouse.Entry(action="set", value=disc.get_weights())
        }
        data = warehouse.send(data_out)

    data_out = {
        nodes.pnid + ":actor_weight": warehouse.Entry(action="get", value=None)
    }
    data = warehouse.send(data_out)
    disc.set_weights(data[nodes.pnid + ":actor_weight"].value)

    return disc
Exemple #2
0
def load_actor (actor, path):
	if mpi_role == 'main':
		actor.load(path)
		
		data_out = {pnid+":actor_weight":warehouse.Entry(action="set", value=actor.get_weights())}
		data = warehouse.send(data_out)
	
	data_out = {pnid+":actor_weight":warehouse.Entry(action="get", value=None)}
	data = warehouse.send(data_out)
	actor.set_weights(data[pnid+":actor_weight"].value)
Exemple #3
0
	def node_wrapper(*args, **kwargs):
		global proc_num
		global pnid
		proc_num += 1
		pnid = str(proc_num)+":"
		warehouse.send({"proc_num": warehouse.Entry(action="set_max", value=proc_num)})
		return func(*args, **kwargs)
Exemple #4
0
def generate_trans_batch (env, actor, rollout_nb, rollout_len, log_std, save_path):
	
	mpi_role = nodes.mpi_role
	proc_num = nodes.proc_num
	pnid = nodes.pnid
	
	if mpi_role == 'main':
		os.makedirs(save_path)
		
		msg = {pnid+"trans" : warehouse.Entry(action="get_l", value=rollout_nb)}
		data = warehouse.send(msg)
		
		all_trans = np.stack(data[pnid+"trans"].value)
		
		np.save(os.path.join(save_path, "all_trans.npy"), all_trans)
		
	elif mpi_role == 'worker':
		
		msg = {"proc_num" : warehouse.Entry(action="get", value=None)}
		data = warehouse.send(msg)
		
		while proc_num >= data["proc_num"].value and not warehouse.is_work_done:
		
			all_trans = []
			
			
			obs = env.reset ()
			obs = np.asarray(obs).reshape((1,1,-1))
			
			for i in range(rollout_len):
				
				trans = [obs.flatten()]
				
				act = actor.model(obs).numpy()
				act = act + np.random.normal(size=act.flatten().shape[0]).reshape(act.shape) * np.exp(log_std)
				obs, rew, done = env.step(act)
				obs = np.asarray(obs).reshape((1,1,-1))
				
				trans.append((obs.flatten()-trans[0])*10)
				trans.append(act.flatten())
				all_trans.append(np.concatenate(trans))
				
			all_trans = np.asarray(all_trans).reshape((rollout_len,-1))
			
			msg = {	pnid+"trans" : warehouse.Entry(action="add", value=all_trans),
					"proc_num" : warehouse.Entry(action="get", value=None)}
			data = warehouse.send(msg)
Exemple #5
0
def simple_actor (obs_dim, act_dim, obs_mean=None, obs_std=None, blindfold=None, inp_dim=None, save_path=None):
	import os
	from models.actor import SimpleActor
	
	if save_path is None:
		raise NameError("A save_path should always be given to an actor")
		
	actor = SimpleActor (obs_dim, act_dim, obs_mean=obs_mean, obs_std=obs_std, blindfold=blindfold, inp_dim=inp_dim)
	actor.save_path = save_path
	
	if mpi_role == 'main':
		os.makedirs(actor.save_path)
		actor.save(actor.save_path)
		
		data_out = {pnid+":actor_weight":warehouse.Entry(action="set", value=actor.get_weights())}
		data = warehouse.send(data_out)
	
	data_out = {pnid+":actor_weight":warehouse.Entry(action="get", value=None)}
	data = warehouse.send(data_out)
	actor.set_weights(data[pnid+":actor_weight"].value)
	
	return actor
Exemple #6
0
def train_discrim (disc, real_trans_path, all_env, actor, epoch_nb, train_step_per_epoch, rollout_per_epoch, rollout_len, log_std, model_save_interval, tensorboard_path):
	
	
	mpi_role = nodes.mpi_role
	proc_num = nodes.proc_num
	pnid = nodes.pnid
	
	if mpi_role == 'main':
		os.makedirs(tensorboard_path)
		
		from rts.discriminator import Trainer
		trainer = Trainer(disc, tensorboard_path)
		trainer.model_save_interval = model_save_interval
		
		real_lab = np.asarray([1, 0]).reshape((1,2))
		synth_lab = np.asarray([0, 1]).reshape((1,2))
		
		all_real_trans = np.load(os.path.join(real_trans_path, "all_trans.npy"))
		all_real_trans, all_real_labs = format_trans(all_real_trans, real_lab)
		
		start_time = time.time()
		for n in range(epoch_nb):
		
			# get the latest rollouts
			msg = {	pnid+"trans" : warehouse.Entry(action="get_l", value=rollout_per_epoch), 
					"dumped" : warehouse.Entry(action="get", value=None) }
			data = warehouse.send(msg)
			dumped_rollout_nb = data["dumped"].value

			all_synth_trans_raw = data[pnid+"trans"].value
			all_synth_trans, all_synth_labs = format_trans(np.concatenate(all_synth_trans_raw, axis=0), synth_lab)
			
			# put the training data together
			all_trans = np.concatenate([all_real_trans, all_synth_trans], axis=0)
			all_labs = np.concatenate([all_real_labs, all_synth_labs], axis=0)
			
			# random offset for regularisation
			all_trans += np.random.normal(size=all_trans.shape) * 0.003
			
			# update the network weights
			accuracy = trainer.train_network(n, all_trans, all_labs, train_step_per_epoch)
			
			#debug
			n_rollouts = len(all_synth_trans_raw)
			print("Epoch {} :".format(n), flush=True)
			print("Loaded {} synthetic rollouts for training while dumping {} for a total of {} transitions.".format(n_rollouts, dumped_rollout_nb, all_synth_trans.shape[0]), flush=True)
			dt = time.time() - start_time
			start_time = time.time()
			if dt > 0:
				print("fps : {}".format(all_synth_trans.shape[0]/dt), flush=True)
			print("accuracy : {}/{}".format(accuracy, all_trans.shape[0]), flush=True)
		
		
		msg = {pnid+":disc_weight":warehouse.Entry(action="set", value=disc.get_weights())}
		data = warehouse.send(msg)
	
	elif mpi_role == "worker":
		
		msg = {"proc_num" : warehouse.Entry(action="get", value=None)}
		data = warehouse.send(msg)
		
		while proc_num >= data["proc_num"].value and not warehouse.is_work_done:
		
			all_trans = []
			
			env = all_env[np.random.randint(len(all_env))]
			obs = env.reset ()
			obs = np.asarray(obs).reshape((1,1,-1))
			done = [False]
			
			i = 0
			while i < rollout_len and not done[0]:
				i += 1
				trans = [obs.flatten()]
				
				act = actor.model(obs).numpy()
				act = act + np.random.normal(size=act.flatten().shape[0]).reshape(act.shape) * np.exp(log_std)
				obs, rew, done = env.step(act)
				obs = np.asarray(obs).reshape((1,1,-1))
				
				trans.append((obs.flatten()-trans[0])*10)
				trans.append(act.flatten())
				all_trans.append(np.concatenate(trans))
				
			all_trans = np.asarray(all_trans).reshape((i,-1))
			
			msg = {	pnid+"trans" : warehouse.Entry(action="add", value=all_trans),
					"proc_num" : warehouse.Entry(action="get", value=None)}
			data = warehouse.send(msg)
		
	
	msg = {pnid+":disc_weight":warehouse.Entry(action="get", value=None)}
	data = warehouse.send(msg)
	disc.set_weights(data[pnid+":disc_weight"].value)
	
	
	
	
	
	
	
Exemple #7
0
def train_ppo (actor, env, epoch_nb, rollout_per_epoch, rollout_len, train_step_per_epoch, init_log_std, model_save_interval, adr_test_prob, tensorboard_path):
	
	mpi_role = nodes.mpi_role
	proc_num = nodes.proc_num
	pnid = nodes.pnid
	
	import os
	import time
	
	from ppo import PPO
	from models.critic import Critic
	
	USE_ADR = hasattr(env, 'adr') and adr_test_prob > 1e-7
	
	if mpi_role == 'main':
		os.makedirs(tensorboard_path)
		
		critic = Critic(env)
		
		trainer = PPO(env, actor, critic, tensorboard_path, init_log_std=init_log_std)
		trainer.model_save_interval = model_save_interval
		
		start_time = time.time()
		
		for n in range(epoch_nb):
			# send the network weights
			# and get the latest rollouts
			msg = {	pnid+"weights" : warehouse.Entry(action="set", value=trainer.get_weights()),
					pnid+"adr" : warehouse.Entry(action="get", value=None),
					pnid+"s" : warehouse.Entry(action="get_l", value=rollout_per_epoch),
					pnid+"a" : warehouse.Entry(action="get_l", value=rollout_per_epoch),
					pnid+"r" : warehouse.Entry(action="get_l", value=rollout_per_epoch),
					pnid+"neglog" : warehouse.Entry(action="get_l", value=rollout_per_epoch),
					pnid+"mask" : warehouse.Entry(action="get_l", value=rollout_per_epoch),
					"dumped" : warehouse.Entry(action="get", value=None)
					}
			data = warehouse.send(msg)
			all_s = np.concatenate(data[pnid+"s"].value, axis=0)
			all_a = np.concatenate(data[pnid+"a"].value, axis=0)
			all_r = np.concatenate(data[pnid+"r"].value, axis=0)
			all_neglog = np.concatenate(data[pnid+"neglog"].value, axis=0)
			all_masks = np.concatenate(data[pnid+"mask"].value, axis=0)
			dumped_rollout_nb = data["dumped"].value
			
			if USE_ADR:
				env.adr.update(data[pnid+"adr"].value)
				env.adr.log()
			
			# update the network weights
			all_last_values, all_gae, all_new_value = trainer.calc_gae(all_s, all_r, all_masks)
			trainer.train_networks(n, all_s, all_a, all_r, all_neglog, all_masks, train_step_per_epoch, all_last_values, all_gae, all_new_value)
			
			#debug
			n_rollouts = all_s.shape[0]
			cur_rollout_len = all_s.shape[1]
			print("Epoch {} :".format(n), flush=True)
			#dumped_rollout_nb = "?"
			print("Loaded {} rollouts for training while dumping {}.".format(n_rollouts, dumped_rollout_nb), flush=True)
			dt = time.time() - start_time
			start_time = time.time()
			if dt > 0:
				print("fps : {}".format(n_rollouts*cur_rollout_len/dt), flush=True)
			print("mean_rew : {}".format(np.sum(all_r * all_masks)/np.sum(all_masks)), flush=True)
			
			if USE_ADR:
				env.adr.save()
				
	elif mpi_role == 'worker':
		trainer = PPO(env, actor, Critic(env), init_log_std=init_log_std)
		
		msg = {	pnid+"weights" : warehouse.Entry(action="get", value=None),
				pnid+"adr" : warehouse.Entry(action="set", value={}),
				"proc_num" : warehouse.Entry(action="get", value=None)}
		data = warehouse.send(msg)
		
		while proc_num >= data["proc_num"].value and not warehouse.is_work_done:
			test_adr = USE_ADR and np.random.random() < adr_test_prob
			
			env.test_adr = test_adr
			
			trainer.set_weights (data[pnid+"weights"].value)
			
			if test_adr:
				# simulate rollout
				all_s, all_a, all_r, all_neglog, all_mask = trainer.get_rollout(env.adr_rollout_len)
				
				msg = {	pnid+"adr" : warehouse.Entry(action="update", value=env.adr.get_msg()),
						pnid+"weights" : warehouse.Entry(action="get", value=None),
						"proc_num" : warehouse.Entry(action="get", value=None)}
			else:
				# simulate rollout
				all_s, all_a, all_r, all_neglog, all_mask = trainer.get_rollout(rollout_len)
				
				# send rollout back to warehouse
				# and get network weights to update actor
				msg = {	pnid+"s" : warehouse.Entry(action="add", value=all_s),
						pnid+"a" : warehouse.Entry(action="add", value=all_a),
						pnid+"r" : warehouse.Entry(action="add", value=all_r),
						pnid+"neglog" : warehouse.Entry(action="add", value=all_neglog),
						pnid+"mask" : warehouse.Entry(action="add", value=all_mask),
						pnid+"weights" : warehouse.Entry(action="get", value=None),
						pnid+"adr" : warehouse.Entry(action="get", value=None), 
						"proc_num" : warehouse.Entry(action="get", value=None)}
				
			data = warehouse.send(msg)
			
			if USE_ADR:
				env.adr.update(data[pnid+"adr"].value)