class BatchSampler(object): def __init__(self, env_name, batch_size, num_workers=mp.cpu_count() - 1): self.env_name = env_name self.batch_size = batch_size self.num_workers = num_workers self.queue = mp.Queue() self.envs = SubprocVecEnv( [make_env(env_name) for _ in range(num_workers)], queue=self.queue) self._env = gym.make(env_name) def sample(self, policy, params=None, gamma=0.95, device='cpu'): episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, batch_ids = self.envs.reset() dones = [False] # count = -1 while (not all(dones)) or (not self.queue.empty()): # count = count + 1 with torch.no_grad(): observations_tensor = torch.from_numpy(observations).to( device=device) actions_tensor = policy(observations_tensor, params=params).sample() actions = actions_tensor.cpu().numpy() new_observations, rewards, dones, new_batch_ids, _ = self.envs.step( actions) # if count <2: # print("\ndones: ", dones) # print("info: ", new_batch_ids) # # print(new_observations.shape) # print("robot position: ", new_observations[:,:2]) # print("goal: ", new_observations[:, 4:6]) new_hid_observations = self.envs.get_peds() # new_hid_observations = np.array([[-1,-1], [1,-1], [1,1], [-1,1]]) episodes.append(observations, new_hid_observations, actions, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids return episodes def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) def sample_tasks(self, num_tasks): tasks = self._env.unwrapped.sample_tasks(num_tasks) return tasks
class BatchSampler(object): def __init__(self, env_name, batch_size, num_workers=mp.cpu_count() - 2): self.env_name = env_name self.batch_size = batch_size self.num_workers = num_workers self.queue = mp.Queue() self.envs = SubprocVecEnv( [make_env(env_name) for _ in range(num_workers)], queue=self.queue) self._env = gym.make(env_name) def sample(self, policy, params=None, gamma=0.95, device='cpu'): episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, batch_ids = self.envs.reset() dones = [False] while (not all(dones)) or (not self.queue.empty()): with torch.no_grad(): observations_tensor = torch.from_numpy(observations).to( device=device) actions_tensor = policy(observations_tensor, params=params).sample() # actions_tensor = policy(observations_tensor, params=params) actions = actions_tensor.cpu().numpy() new_observations, rewards, dones, new_batch_ids, _ = self.envs.step( actions) episodes.append(observations, actions, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids return episodes def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) def sample_tasks(self, num_tasks): tasks = self._env.unwrapped.sample_tasks(num_tasks) return tasks def sample_target_task(self, N): tasks = self._env.unwrapped.sample_target_task(N) return tasks
class BatchSampler(object): def __init__(self, env_name, batch_size, num_workers=None, test_env=False): self.env_name = env_name self.batch_size = batch_size self.num_workers = num_workers or mp.cpu_count() - 1 self.test_env = test_env self.queue = mp.Queue() self.envs = SubprocVecEnv([make_env(env_name, test_env=test_env) for _ in range(num_workers)], queue=self.queue) self._env = make_env(env_name, test_env=test_env)() def sample(self, policy, params=None, gamma=0.95, device='cpu'): episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, batch_ids = self.envs.reset() dones = [False] while (not all(dones)) or (not self.queue.empty()): with torch.no_grad(): observations_tensor = torch.from_numpy(observations).to(device=device, dtype=torch.float32) actions_tensor = policy(observations_tensor, params=params).sample() actions = actions_tensor.cpu().numpy() new_observations, rewards, dones, new_batch_ids, infos = self.envs.step(actions) # info keys: reachDist, pickRew, epRew, goalDist, success, goal, task_name # NOTE: last infos will be absent if batch_size % num_workers != 0 episodes.append(observations, actions, rewards, batch_ids, infos) observations, batch_ids = new_observations, new_batch_ids return episodes def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) def sample_tasks(self, num_tasks, task2prob=None): tasks = self._env.unwrapped.sample_tasks(num_tasks, task2prob) return tasks
class BatchSampler(object): def __init__(self, env_name, batch_size, num_workers=mp.cpu_count() - 1): self.env_name = env_name self.batch_size = batch_size self.num_workers = num_workers self.queue = mp.Queue() self.envs = SubprocVecEnv([make_env(env_name) for _ in range(num_workers)], queue=self.queue) self._env = gym.make(env_name) def sample(self, policy, params=None, gamma=0.9): episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, batch_ids = self.envs.reset() dones = [False] while (not all(dones)) or (not self.queue.empty()): observations_tensor = observations # 气死 observations和action要一样的维度 垃圾 # observations_tensor = observations.reshape(observations.shape[0], -1) actions_tensor = policy(observations_tensor, params=params).sample() # /CPU:0 with tf.device('/CPU:0'): actions = actions_tensor.numpy() new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions) episodes.append(observations, actions, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids return episodes def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) def sample_tasks(self, num_tasks): tasks = self._env.unwrapped.sample_tasks(num_tasks) return tasks
class BanditLearner(object): """ LSTM Learner using A2C/PPO """ def __init__(self, k, n, batch_size, num_workers, num_batches=1000, gamma=0.95, lr=0.01, tau=1.0, ent_coef=.01, vf_coef=0.5, lstm_size=50, clip_frac=0.2, device='cpu', surr_epochs=3, surr_batches=4, max_grad_norm=0.5, D=1): self.k = k self.n = n self.vf_coef = vf_coef self.ent_coef = ent_coef self.gamma = gamma self.D = D # Sampler variables self.num_batches = num_batches self.batch_size = batch_size self.num_workers = num_workers self.queue = mp.Queue() self.env_name = 'Bandit-K{0}-N{1}-v0'.format(self.k, self.n) self.envs = SubprocVecEnv([make_env(self.env_name) for _ in range(num_workers)], queue=self.queue) self.obs_shape = self.envs.observation_space.shape self.num_actions = self.envs.action_space.n self.lstm_size = lstm_size self.policy = GRUPolicy(input_size=self.obs_shape[0], output_size=self.num_actions, lstm_size=self.lstm_size, D=self.D) # Optimization Variables self.lr = lr self.tau = tau self.clip_frac = clip_frac self.optimizer = optim.Adam(self.policy.parameters(), lr=self.lr, eps=1e-5) # PPO variables self.surrogate_epochs = surr_epochs self.surrogate_batches = surr_batches self.surrogate_batch_size = self.batch_size // self.surrogate_batches self.to(device) self.max_grad_norm = max_grad_norm def _forward_policy(self, episodes, ratio=False): T = episodes.observations.size(0) values, log_probs, entropy = [], [], [] hx = torch.zeros(self.D, self.batch_size, self.lstm_size).to(device=self.device) for t in range(T): pi, v, hx = self.policy(episodes.observations[t], hx, episodes.embeds[t]) values.append(v) entropy.append(pi.entropy()) if ratio: log_probs.append(pi.log_prob(episodes.actions[t]) - episodes.logprobs[t]) else: log_probs.append(pi.log_prob(episodes.actions[t])) log_probs = torch.stack(log_probs); values = torch.stack(values); entropy = torch.stack(entropy) advantages = episodes.gae(values, tau=self.tau) advantages = weighted_normalize(advantages, weights=episodes.mask) if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2) return log_probs, advantages, values, entropy def loss(self, episodes): """ REINFORCE gradient with baseline [2], computed on advantages estimated with Generalized Advantage Estimation (GAE, [3]). """ log_probs, advantages, values, entropy = self._forward_policy(episodes) pg_loss = -weighted_mean(log_probs * advantages, dim=0, weights=episodes.mask) vf_loss = 0.5 * weighted_mean((values.squeeze() - episodes.returns) ** 2, dim=0, weights=episodes.mask) entropy_loss = weighted_mean(entropy, dim=0, weights=episodes.mask) return pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy_loss def surrogate_loss(self, episodes, inds=None): """ PPO Surrogate Loss """ log_ratios, advantages, values, entropy = self._forward_policy(episodes, ratio=True) # clipped pg loss ratio = torch.exp(log_ratios) pg_loss1 = -advantages * ratio pg_loss2 = -advantages * torch.clamp(ratio, min=1.0 - self.clip_frac, max=1.0 + self.clip_frac) # clipped value loss values_clipped = episodes.old_values + torch.clamp(values.squeeze() - episodes.old_values, min=-self.clip_frac, max=self.clip_frac) vf_loss1 = (values.squeeze() - episodes.returns) ** 2 vf_loss2 = (values_clipped - episodes.returns) ** 2 if inds is None: inds = np.arange(self.batch_size) masks = episodes.mask[:, inds] pg_loss = weighted_mean(torch.max(pg_loss1, pg_loss2)[:, inds], dim=0, weights=masks) vf_loss = 0.5 * weighted_mean(torch.max(vf_loss1, vf_loss2)[:, inds], dim=0, weights=masks) entropy_loss = weighted_mean(entropy[:, inds], dim=0, weights=masks) return pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy_loss def step(self, episodes): """ Adapt the parameters of the policy network to a new set of examples """ self.optimizer.zero_grad() loss = self.loss(episodes) loss.backward() torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.optimizer.step() def surrogate_step(self, episodes): for _ in range(self.surrogate_epochs): for k in range(self.surrogate_batches): sample_inds = np.random.choice(self.batch_size, self.surrogate_batch_size, replace=False) self.optimizer.zero_grad() loss = self.surrogate_loss(episodes, inds=sample_inds) loss.backward() torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.optimizer.step() def sample(self): """ Sample trajectories """ episodes = LSTMBatchEpisodes(batch_size=self.batch_size, gamma=self.gamma, device=self.device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, batch_ids = self.envs.reset() dones = [False]; timers = np.zeros(self.num_workers) embed_tensor = torch.zeros(self.num_workers, self.num_actions + 3).to(device=self.device) embed_tensor[:, 0] = 1. hx = torch.zeros(self.D, self.num_workers, self.lstm_size).to(device=self.device) while (not all(dones)) or (not self.queue.empty()): with torch.no_grad(): obs_tensor = torch.from_numpy(observations).to(device=self.device) act_dist, values_tensor, hx = self.policy(obs_tensor, hx, embed_tensor) act_tensor = act_dist.sample() # cpu variables for logging log_probs = act_dist.log_prob(act_tensor).cpu().numpy() actions = act_tensor.cpu().numpy() old_values = values_tensor.squeeze().cpu().numpy() embed = embed_tensor.cpu().numpy() new_observations, rewards, dones, new_batch_ids, infos = self.envs.step(actions) timers += 1.0 # Update embeddings when episode is done embed_temp = np.hstack((one_hot(actions, self.num_actions), rewards[:, None], dones[:, None], timers[:, None])) embed_tensor = torch.from_numpy(embed_temp).float().to(device=self.device) # Update hidden states dones_tensor = torch.from_numpy(dones.astype(np.float32)).to(device=self.device) timers[dones] = 0. hx[:, dones_tensor == 1, :] = 0. embed_tensor[dones_tensor == 1] = 0. embed_tensor[dones_tensor == 1, 0] = 1. episodes.append(observations, actions, rewards, batch_ids, log_probs, old_values, embed) observations, batch_ids = new_observations, new_batch_ids return episodes def to(self, device, **kwargs): self.policy.to(device, **kwargs) self.device = device
class LSTMLearner(object): """ LSTM Learner using PPO """ def __init__(self, env_name, num_workers, num_batches=1000, n_step=5, gamma=0.95, lr=0.01, tau=1.0, ent_coef=.01, vf_coef=0.5, lstm_size=256, clip_frac=0.2, device='cpu', surr_epochs=3, clstm=False, surr_batches=4, max_grad_norm=0.5, cnn_type='nature'): self.vf_coef = vf_coef self.ent_coef = ent_coef self.gamma = gamma self.use_clstm = clstm self.n_step = n_step self.reward_log = deque(maxlen=100) self.lstm_size = lstm_size # Sampler variables self.env_name = env_name self.num_batches = num_batches self.num_workers = num_workers self.env_name = env_name self.envs = SubprocVecEnv( [make_env(env_name) for _ in range(self.num_workers)]) self.obs_shape = self.envs.observation_space.shape self.num_actions = self.envs.action_space.n self.obs = np.zeros((self.num_workers, ) + self.obs_shape) self.obs[:] = self.envs.reset() self.dones = [False for _ in range(self.num_workers)] self.embed = torch.zeros(self.num_workers, self.num_actions + 2).to(device=device) self.embed[:, 0] = 1. if not self.use_clstm: self.hx = torch.zeros(self.num_workers, self.lstm_size).to(device=device) self.policy = ConvGRUPolicy(input_size=self.obs_shape, output_size=self.num_actions, use_bn=False, cnn_type=cnn_type, lstm_size=self.lstm_size) #self.policy = ConvPolicy(input_size=self.obs_shape, output_size=self.num_actions, # use_bn=False, cnn_type=cnn_type) else: self.hx = torch.zeros(self.num_workers, self.lstm_size, 7, 7).to(device=device) self.policy = ConvCGRUPolicy(input_size=self.obs_shape, output_size=self.num_actions, use_bn=False, cnn_type=cnn_type, lstm_size=self.lstm_size) # Optimization Variables self.lr = lr self.tau = tau self.clip_frac = clip_frac self.optimizer = optim.Adam(self.policy.parameters(), lr=self.lr, eps=1e-5) # PPO variables self.surrogate_epochs = surr_epochs self.surrogate_batches = surr_batches self.surrogate_batch_size = self.num_workers // self.surrogate_batches self.to(device) self.max_grad_norm = max_grad_norm def _forward_policy(self, episodes, ratio=False): T = episodes.observations.size(0) values, log_probs, entropy = [], [], [] if not self.use_clstm: hx = torch.zeros(self.num_workers, self.lstm_size).to(device=self.device) else: hx = torch.zeros(self.num_workers, self.lstm_size, 7, 7).to(device=self.device) for t in range(T): pi, v, hx = self.policy(episodes.observations[t], hx, episodes.embeds[t]) #pi, v = self.policy(episodes.observations[t]) values.append(v) entropy.append(pi.entropy()) if ratio: log_probs.append( pi.log_prob(episodes.actions[t]) - episodes.logprobs[t]) else: log_probs.append(pi.log_prob(episodes.actions[t])) log_probs = torch.stack(log_probs) values = torch.stack(values) entropy = torch.stack(entropy) advantages = episodes.gae(values, tau=self.tau) advantages = weighted_normalize(advantages, weights=episodes.mask) if log_probs.dim() > 2: log_probs = torch.sum(log_probs, dim=2) return log_probs, advantages, values, entropy def loss(self, episodes, inds=None): """ PPO Surrogate Loss """ log_ratios, advantages, values, entropy = self._forward_policy( episodes, ratio=True) # clipped pg loss ratio = torch.exp(log_ratios) pg_loss1 = -advantages * ratio pg_loss2 = -advantages * torch.clamp( ratio, min=1.0 - self.clip_frac, max=1.0 + self.clip_frac) # clipped value loss values_clipped = episodes.old_values + torch.clamp( values.squeeze() - episodes.old_values, min=-self.clip_frac, max=self.clip_frac) vf_loss1 = (values.squeeze() - episodes.returns)**2 vf_loss2 = (values_clipped - episodes.returns)**2 if inds is None: inds = np.arange(self.num_workers) masks = episodes.mask[:, inds] pg_loss = weighted_mean(torch.max(pg_loss1, pg_loss2)[:, inds], dim=0, weights=masks) vf_loss = 0.5 * weighted_mean( torch.max(vf_loss1, vf_loss2)[:, inds], dim=0, weights=masks) entropy_loss = weighted_mean(entropy[:, inds], dim=0, weights=masks) return pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy_loss def step(self, episodes): """ Adapt the parameters of the policy network to a new set of examples """ for i in range(self.surrogate_epochs): for j in range(self.surrogate_batches): sample_inds = np.random.choice(self.num_workers, self.surrogate_batch_size, replace=False) self.optimizer.zero_grad() loss = self.loss(episodes, inds=sample_inds) loss.backward() torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.optimizer.step() def _get_term_flags(self, infos): if 'v0' in self.env_name: return np.array(self.dones) else: return np.array([np.sign(v['done']) for k, v in enumerate(infos)]) def sample(self): """ Sample trajectories """ episodes = LSTMBatchEpisodes(batch_size=self.num_workers, gamma=self.gamma, device=self.device) for ns in range(self.n_step): with torch.no_grad(): obs_tensor = torch.from_numpy( self.obs).float().to(device=self.device) act_dist, values_tensor, self.hx = self.policy( obs_tensor, self.hx, self.embed) #act_dist, values_tensor = self.policy(obs_tensor) act_tensor = act_dist.sample() # cpu variables for logging log_probs = act_dist.log_prob(act_tensor).cpu().numpy() actions = act_tensor.cpu().numpy() old_values = values_tensor.squeeze().cpu().numpy() embed = self.embed.cpu().numpy() new_observations, rewards, self.dones, infos = self.envs.step( actions) # Update embeddings when episode is done term_flags = self._get_term_flags(infos) embed_temp = np.hstack((one_hot(actions, self.num_actions), rewards[:, None], term_flags[:, None])) self.embed = torch.from_numpy(embed_temp).float().to( device=self.device) # Logging episode rew for dr in rewards[self.dones == 1]: self.reward_log.append(dr) # Update hidden states dones_tensor = torch.from_numpy(self.dones.astype( np.float32)).to(device=self.device) if not self.use_clstm: self.hx[dones_tensor == 1, :] = 0. else: self.hx[dones_tensor == 1, :, :, :] = 0. self.embed[dones_tensor == 1] = 0. self.embed[dones_tensor == 1, 0] = 1. episodes.append(self.obs, actions, rewards, log_probs, old_values, embed) self.obs[:] = new_observations return episodes def to(self, device, **kwargs): self.policy.to(device, **kwargs) self.device = device
class BatchSampler(object): def __init__(self, env_name, batch_size, num_workers=mp.cpu_count() - 1): self.env_name = env_name self.batch_size = batch_size self.num_workers = num_workers self.queue = mp.Queue() self.envs = SubprocVecEnv( [make_env(env_name) for _ in range(num_workers)], queue=self.queue) self._env = gym.make(env_name) def sample(self, policy, task, tree=None, params=None, gamma=0.95, device='cpu'): episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, batch_ids = self.envs.reset() dones = [False] while (not all(dones)) or (not self.queue.empty()): with torch.no_grad(): input = torch.from_numpy(observations).float().to( device=device) if self.env_name == 'AntPos-v0': _, embedding = tree.forward( torch.from_numpy( task["position"]).float().to(device=device)) if self.env_name == 'AntVel-v1': _, embedding = tree.forward( torch.from_numpy(np.array( [task["velocity"]])).float().to(device=device)) # print(input.shape) # print(embedding.shape) observations_tensor = torch.t( torch.stack([ torch.cat([ torch.from_numpy(np.array(teo)).to(device=device), embedding[0] ], 0) for teo in input ], 1)) actions_tensor = policy(observations_tensor, task=task, params=params, enhanced=False).sample() actions = actions_tensor.cpu().numpy() new_observations, rewards, dones, new_batch_ids, _ = self.envs.step( actions) episodes.append(observations_tensor.cpu().numpy(), actions, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids return episodes def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) def sample_tasks(self, num_tasks): tasks = self._env.unwrapped.sample_tasks(num_tasks) return tasks
class BatchSampler: def __init__(self, env_name, batch_size, num_workers=mp.cpu_count()): """ :param env_name: :param batch_size: fast batch size :param num_workers: """ self.env_name = env_name self.batch_size = batch_size self.num_workers = num_workers self.queue = mp.Queue() # [lambda function] env_factorys = [make_env(env_name) for _ in range(num_workers)] # this is the main process manager, and it will be in charge of num_workers sub-processes interacting with # environment. self.envs = SubprocVecEnv(env_factorys, queue_=self.queue) self._env = gym.make(env_name) def sample(self, policy, params=None, gamma=0.95, device='cpu'): """ :param policy: :param params: :param gamma: :param device: :return: """ episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, batch_ids = self.envs.reset() dones = [False] while (not all(dones)) or ( not self.queue.empty()): # if all done and queue is empty # for reinforcement learning, the forward process requires no-gradient with torch.no_grad(): # convert observation to cuda # compute policy on cuda # convert action to cpu observations_tensor = torch.from_numpy(observations).to( device=device) # forward via policy network # policy network will return Categorical(logits=logits) actions_tensor = policy(observations_tensor, params=params).sample() actions = actions_tensor.cpu().numpy() new_observations, rewards, dones, new_batch_ids, _ = self.envs.step( actions) # here is observations NOT new_observations, batch_ids NOT new_batch_ids episodes.append(observations, actions, rewards, batch_ids) observations, batch_ids = new_observations, new_batch_ids return episodes def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) def sample_tasks(self, num_tasks): tasks = self._env.unwrapped.sample_tasks(num_tasks) return tasks
class BatchSampler(object): def __init__(self, env_name, batch_size, num_workers=mp.cpu_count() - 1, args=None): self.env_name = env_name self.batch_size = batch_size # NOTE # of trajectories in each env self.num_workers = num_workers self.args = args self.queue = mp.Queue() self.envs = SubprocVecEnv( [make_env(args, i_worker) for i_worker in range(num_workers)], queue=self.queue) self._env = make_env(args, i_worker=99)() def sample(self, policy, params=None, prey=None, gamma=0.95, device='cpu'): """Sample # of trajectories defined by "self.batch_size". The size of each trajectory is defined by the Gym env registration defined at: ./maml_rl/envs/__init__.py """ assert prey is not None episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) for i in range(self.batch_size): self.queue.put(i) for _ in range(self.num_workers): self.queue.put(None) observations, worker_ids = self.envs.reset( ) # TODO reset needs to be fixed dones = [False] while (not all(dones)) or (not self.queue.empty()): with torch.no_grad(): # Get observations predator_observations, prey_observations = self.split_observations( observations) predator_observations_torch = torch.from_numpy( predator_observations).to(device=device) prey_observations_torch = torch.from_numpy( prey_observations).to(device=device) # Get actions predator_actions = policy(predator_observations_torch, params=params).sample() predator_actions = predator_actions.cpu().numpy() prey_actions = prey.select_deterministic_action( prey_observations_torch) prey_actions = prey_actions.cpu().numpy() actions = np.concatenate([predator_actions, prey_actions], axis=1) new_observations, rewards, dones, new_worker_ids, _ = self.envs.step( copy.deepcopy(actions)) assert np.sum(dones[:, 0]) == np.sum(dones[:, 1]) dones = dones[:, 0] # Get new observations new_predator_observations, _ = self.split_observations( new_observations) # Get rewards predator_rewards = rewards[:, 0] episodes.append(predator_observations, predator_actions, predator_rewards, worker_ids) observations, worker_ids = new_observations, new_worker_ids return episodes def reset_task(self, task): tasks = [task for _ in range(self.num_workers)] reset = self.envs.reset_task(tasks) return all(reset) def sample_tasks(self, num_tasks, test=False): if test is False: i_agents = np.random.randint(low=0, high=16, size=(num_tasks, )) else: i_agents = np.random.randint(low=16, high=21, size=(num_tasks, )) tasks = [{"i_agent": i_agent} for i_agent in i_agents] return tasks def split_observations(self, observations): predator_observations = [] prey_observations = [] for obs in observations: assert len(obs) == 2 predator_observations.append(obs[0]) prey_observations.append(obs[1]) return \ np.asarray(predator_observations, dtype=np.float32), \ np.asarray(prey_observations, dtype=np.float32)