def test_meta_evaluator_with_tf(): set_seed(100) tasks = SetTaskSampler(PointEnv, wrapper=set_length) max_episode_length = 200 env = PointEnv() n_eps = 3 with tempfile.TemporaryDirectory() as log_dir_name: ctxt = SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='none', snapshot_gap=1) with TFTrainer(ctxt) as trainer: meta_eval = MetaEvaluator(test_task_sampler=tasks, n_test_tasks=10, n_exploration_eps=n_eps) policy = GaussianMLPPolicy(env.spec) algo = MockAlgo(env, policy, max_episode_length, n_eps, meta_eval) trainer.setup(algo, env) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) algo_pickle = cloudpickle.dumps(algo) tf.compat.v1.reset_default_graph() with TFTrainer(ctxt) as trainer: algo2 = cloudpickle.loads(algo_pickle) trainer.setup(algo2, env) trainer.train(10, 0)
def test_meta_evaluator(): set_seed(100) tasks = SetTaskSampler(PointEnv, wrapper=set_length) max_episode_length = 200 with tempfile.TemporaryDirectory() as log_dir_name: trainer = Trainer( SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='last', snapshot_gap=1)) env = PointEnv(max_episode_length=max_episode_length) algo = OptimalActionInference(env=env, max_episode_length=max_episode_length) trainer.setup(algo, env) meta_eval = MetaEvaluator(test_task_sampler=tasks, n_test_tasks=10) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) logger.log(tabular) meta_eval.evaluate(algo) logger.log(tabular) logger.dump_output_type(CsvOutput) logger.remove_output_type(CsvOutput) with open(log_file.name, 'r') as file: rows = list(csv.DictReader(file)) assert len(rows) == 2 assert float( rows[0]['MetaTest/__unnamed_task__/TerminationRate']) < 1.0 assert float(rows[0]['MetaTest/__unnamed_task__/Iteration']) == 0 assert (float(rows[0]['MetaTest/__unnamed_task__/MaxReturn']) >= float( rows[0]['MetaTest/__unnamed_task__/AverageReturn'])) assert (float(rows[0]['MetaTest/__unnamed_task__/AverageReturn']) >= float(rows[0]['MetaTest/__unnamed_task__/MinReturn'])) assert float(rows[1]['MetaTest/__unnamed_task__/Iteration']) == 1
def test_pickle_meta_evaluator(): set_seed(100) tasks = SetTaskSampler(lambda: GarageEnv(PointEnv())) max_path_length = 200 env = GarageEnv(PointEnv()) n_traj = 3 with tempfile.TemporaryDirectory() as log_dir_name: runner = LocalRunner( SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='last', snapshot_gap=1)) meta_eval = MetaEvaluator(test_task_sampler=tasks, max_path_length=max_path_length, n_test_tasks=10, n_exploration_traj=n_traj) policy = RandomPolicy(env.spec.action_space) algo = MockAlgo(env, policy, max_path_length, n_traj, meta_eval) runner.setup(algo, env) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) meta_eval_pickle = cloudpickle.dumps(meta_eval) meta_eval2 = cloudpickle.loads(meta_eval_pickle) meta_eval2.evaluate(algo)
def test_meta_evaluator_with_tf(): set_seed(100) tasks = SetTaskSampler(lambda: GarageEnv(PointEnv())) max_path_length = 200 env = GarageEnv(PointEnv()) n_traj = 3 with tempfile.TemporaryDirectory() as log_dir_name: ctxt = SnapshotConfig(snapshot_dir=log_dir_name, snapshot_mode='none', snapshot_gap=1) with LocalTFRunner(ctxt) as runner: meta_eval = MetaEvaluator(test_task_sampler=tasks, max_path_length=max_path_length, n_test_tasks=10, n_exploration_traj=n_traj) policy = GaussianMLPPolicy(env.spec) algo = MockAlgo(env, policy, max_path_length, n_traj, meta_eval) runner.setup(algo, env) log_file = tempfile.NamedTemporaryFile() csv_output = CsvOutput(log_file.name) logger.add_output(csv_output) meta_eval.evaluate(algo) algo_pickle = cloudpickle.dumps(algo) with tf.Graph().as_default(): with LocalTFRunner(ctxt) as runner: algo2 = cloudpickle.loads(algo_pickle) runner.setup(algo2, env) runner.train(10, 0)
class PEARL(MetaRLAlgorithm): r"""A PEARL model based on https://arxiv.org/abs/1903.08254. PEARL, which stands for Probablistic Embeddings for Actor-Critic Reinforcement Learning, is an off-policy meta-RL algorithm. It is built on top of SAC using two Q-functions and a value function with an addition of an inference network that estimates the posterior :math:`q(z \| c)`. The policy is conditioned on the latent variable Z in order to adpat its behavior to specific tasks. Args: env (list[GarageEnv]): Batch of sampled environment updates(EnvUpdate), which, when invoked on environments, will configure them with new tasks. policy_class (garage.torch.policies.Policy): Context-conditioned policy class. encoder_class (garage.torch.embeddings.ContextEncoder): Encoder class for the encoder in context-conditioned policy. inner_policy (garage.torch.policies.Policy): Policy. qf (torch.nn.Module): Q-function. vf (torch.nn.Module): Value function. num_train_tasks (int): Number of tasks for training. num_test_tasks (int): Number of tasks for testing. latent_dim (int): Size of latent context vector. encoder_hidden_sizes (list[int]): Output dimension of dense layer(s) of the context encoder. test_env_sampler (garage.experiment.SetTaskSampler): Sampler for test tasks. policy_lr (float): Policy learning rate. qf_lr (float): Q-function learning rate. vf_lr (float): Value function learning rate. context_lr (float): Inference network learning rate. policy_mean_reg_coeff (float): Policy mean regulation weight. policy_std_reg_coeff (float): Policy std regulation weight. policy_pre_activation_coeff (float): Policy pre-activation weight. soft_target_tau (float): Interpolation parameter for doing the soft target update. kl_lambda (float): KL lambda value. optimizer_class (callable): Type of optimizer for training networks. use_information_bottleneck (bool): False means latent context is deterministic. use_next_obs_in_context (bool): Whether or not to use next observation in distinguishing between tasks. meta_batch_size (int): Meta batch size. num_steps_per_epoch (int): Number of iterations per epoch. num_initial_steps (int): Number of transitions obtained per task before training. num_tasks_sample (int): Number of random tasks to obtain data for each iteration. num_steps_prior (int): Number of transitions to obtain per task with z ~ prior. num_steps_posterior (int): Number of transitions to obtain per task with z ~ posterior. num_extra_rl_steps_posterior (int): Number of additional transitions to obtain per task with z ~ posterior that are only used to train the policy and NOT the encoder. batch_size (int): Number of transitions in RL batch. embedding_batch_size (int): Number of transitions in context batch. embedding_mini_batch_size (int): Number of transitions in mini context batch; should be same as embedding_batch_size for non-recurrent encoder. max_path_length (int): Maximum path length. discount (float): RL discount factor. replay_buffer_size (int): Maximum samples in replay buffer. reward_scale (int): Reward scale. update_post_train (int): How often to resample context when obtaining data during training (in trajectories). """ # pylint: disable=too-many-statements def __init__(self, env, inner_policy, qf, vf, num_train_tasks, num_test_tasks, latent_dim, encoder_hidden_sizes, test_env_sampler, policy_class=ContextConditionedPolicy, encoder_class=MLPEncoder, policy_lr=3E-4, qf_lr=3E-4, vf_lr=3E-4, context_lr=3E-4, policy_mean_reg_coeff=1E-3, policy_std_reg_coeff=1E-3, policy_pre_activation_coeff=0., soft_target_tau=0.005, kl_lambda=.1, optimizer_class=torch.optim.Adam, use_information_bottleneck=True, use_next_obs_in_context=False, meta_batch_size=64, num_steps_per_epoch=1000, num_initial_steps=100, num_tasks_sample=100, num_steps_prior=100, num_steps_posterior=0, num_extra_rl_steps_posterior=100, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, max_path_length=1000, discount=0.99, replay_buffer_size=1000000, reward_scale=1, update_post_train=1): self._env = env self._qf1 = qf self._qf2 = copy.deepcopy(qf) self._vf = vf self._num_train_tasks = num_train_tasks self._num_test_tasks = num_test_tasks self._latent_dim = latent_dim self._policy_mean_reg_coeff = policy_mean_reg_coeff self._policy_std_reg_coeff = policy_std_reg_coeff self._policy_pre_activation_coeff = policy_pre_activation_coeff self._soft_target_tau = soft_target_tau self._kl_lambda = kl_lambda self._use_information_bottleneck = use_information_bottleneck self._use_next_obs_in_context = use_next_obs_in_context self._meta_batch_size = meta_batch_size self._num_steps_per_epoch = num_steps_per_epoch self._num_initial_steps = num_initial_steps self._num_tasks_sample = num_tasks_sample self._num_steps_prior = num_steps_prior self._num_steps_posterior = num_steps_posterior self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior self._batch_size = batch_size self._embedding_batch_size = embedding_batch_size self._embedding_mini_batch_size = embedding_mini_batch_size self.max_path_length = max_path_length self._discount = discount self._replay_buffer_size = replay_buffer_size self._reward_scale = reward_scale self._update_post_train = update_post_train self._task_idx = None self._is_resuming = False worker_args = dict(deterministic=True, accum_context=True) self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler, max_path_length=max_path_length, worker_class=PEARLWorker, worker_args=worker_args, n_test_tasks=num_test_tasks) encoder_spec = self.get_env_spec(env[0](), latent_dim, 'encoder') encoder_in_dim = int(np.prod(encoder_spec.input_space.shape)) encoder_out_dim = int(np.prod(encoder_spec.output_space.shape)) context_encoder = encoder_class(input_dim=encoder_in_dim, output_dim=encoder_out_dim, hidden_sizes=encoder_hidden_sizes) self._policy = policy_class( latent_dim=latent_dim, context_encoder=context_encoder, policy=inner_policy, use_information_bottleneck=use_information_bottleneck, use_next_obs=use_next_obs_in_context) # buffer for training RL update self._replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self.target_vf = copy.deepcopy(self._vf) self.vf_criterion = torch.nn.MSELoss() self._policy_optimizer = optimizer_class( self._policy.networks[1].parameters(), lr=policy_lr, ) self.qf1_optimizer = optimizer_class( self._qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self._qf2.parameters(), lr=qf_lr, ) self.vf_optimizer = optimizer_class( self._vf.parameters(), lr=vf_lr, ) self.context_optimizer = optimizer_class( self._policy.networks[0].parameters(), lr=context_lr, ) def __getstate__(self): """Object.__getstate__. Returns: dict: the state to be pickled for the instance. """ data = self.__dict__.copy() del data['_replay_buffers'] del data['_context_replay_buffers'] return data def __setstate__(self, state): """Object.__setstate__. Args: state (dict): unpickled state. """ self.__dict__.update(state) self._replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._is_resuming = True def train(self, runner): """Obtain samples, train, and evaluate for each epoch. Args: runner (LocalRunner): LocalRunner is passed to give algorithm the access to runner.step_epochs(), which provides services such as snapshotting and sampler control. """ for _ in runner.step_epochs(): epoch = runner.step_itr / self._num_steps_per_epoch # obtain initial set of samples from all train tasks if epoch == 0 or self._is_resuming: for idx in range(self._num_train_tasks): self._task_idx = idx self._obtain_samples(runner, epoch, self._num_initial_steps, np.inf) self._is_resuming = False # obtain samples from random tasks for _ in range(self._num_tasks_sample): idx = np.random.randint(self._num_train_tasks) self._task_idx = idx self._context_replay_buffers[idx].clear() # obtain samples with z ~ prior if self._num_steps_prior > 0: self._obtain_samples(runner, epoch, self._num_steps_prior, np.inf) # obtain samples with z ~ posterior if self._num_steps_posterior > 0: self._obtain_samples(runner, epoch, self._num_steps_posterior, self._update_post_train) # obtain extras samples for RL training but not encoder if self._num_extra_rl_steps_posterior > 0: self._obtain_samples(runner, epoch, self._num_extra_rl_steps_posterior, self._update_post_train, add_to_enc_buffer=False) logger.log('Training...') # sample train tasks and optimize networks self._train_once() runner.step_itr += 1 logger.log('Evaluating...') # evaluate self._policy.reset_belief() self._evaluator.evaluate(self) def _train_once(self): """Perform one iteration of training.""" for _ in range(self._num_steps_per_epoch): indices = np.random.choice(range(self._num_train_tasks), self._meta_batch_size) self._optimize_policy(indices) def _optimize_policy(self, indices): """Perform algorithm optimizing. Args: indices (list): Tasks used for training. """ num_tasks = len(indices) context = self._sample_context(indices) # clear context and reset belief of policy self._policy.reset_belief(num_tasks=num_tasks) # data shape is (task, batch, feat) obs, actions, rewards, next_obs, terms = self._sample_data(indices) policy_outputs, task_z = self._policy(obs, context) new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4] # flatten out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) actions = actions.view(t * b, -1) next_obs = next_obs.view(t * b, -1) # optimize qf and encoder networks q1_pred = self._qf1(torch.cat([obs, actions], dim=1), task_z) q2_pred = self._qf2(torch.cat([obs, actions], dim=1), task_z) v_pred = self._vf(obs, task_z.detach()) with torch.no_grad(): target_v_values = self.target_vf(next_obs, task_z) # KL constraint on z if probabilistic self.context_optimizer.zero_grad() if self._use_information_bottleneck: kl_div = self._policy.compute_kl_div() kl_loss = self._kl_lambda * kl_div kl_loss.backward(retain_graph=True) self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() rewards_flat = rewards.view(self._batch_size * num_tasks, -1) rewards_flat = rewards_flat * self._reward_scale terms_flat = terms.view(self._batch_size * num_tasks, -1) q_target = rewards_flat + ( 1. - terms_flat) * self._discount * target_v_values qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean( (q2_pred - q_target)**2) qf_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.step() self.context_optimizer.step() # compute min Q on the new actions q1 = self._qf1(torch.cat([obs, new_actions], dim=1), task_z.detach()) q2 = self._qf2(torch.cat([obs, new_actions], dim=1), task_z.detach()) min_q = torch.min(q1, q2) # optimize vf v_target = min_q - log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() self._update_target_network() # optimize policy log_policy_target = min_q policy_loss = (log_pi - log_policy_target).mean() mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean() std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self._policy_pre_activation_coeff * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = (mean_reg_loss + std_reg_loss + pre_activation_reg_loss) policy_loss = policy_loss + policy_reg_loss self._policy_optimizer.zero_grad() policy_loss.backward() self._policy_optimizer.step() def _obtain_samples(self, runner, itr, num_samples, update_posterior_rate, add_to_enc_buffer=True): """Obtain samples. Args: runner (LocalRunner): LocalRunner. itr (int): Index of iteration (epoch). num_samples (int): Number of samples to obtain. update_posterior_rate (int): How often (in trajectories) to infer posterior of policy. add_to_enc_buffer (bool): Whether or not to add samples to encoder buffer. """ self._policy.reset_belief() total_samples = 0 if update_posterior_rate != np.inf: num_samples_per_batch = (update_posterior_rate * self.max_path_length) else: num_samples_per_batch = num_samples while total_samples < num_samples: paths = runner.obtain_samples(itr, num_samples_per_batch, self._policy, self._env[self._task_idx]) total_samples += sum([len(path['rewards']) for path in paths]) for path in paths: p = { 'observations': path['observations'], 'actions': path['actions'], 'rewards': path['rewards'].reshape(-1, 1), 'next_observations': path['next_observations'], 'dones': path['dones'].reshape(-1, 1) } self._replay_buffers[self._task_idx].add_path(p) if add_to_enc_buffer: self._context_replay_buffers[self._task_idx].add_path(p) if update_posterior_rate != np.inf: context = self._sample_context(self._task_idx) self._policy.infer_posterior(context) def _sample_data(self, indices): """Sample batch of training data from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Obervations, with shape :math:`(X, N, O^*)` where X is the number of tasks. N is batch size. torch.Tensor: Actions, with shape :math:`(X, N, A^*)`. torch.Tensor: Rewards, with shape :math:`(X, N, 1)`. torch.Tensor: Next obervations, with shape :math:`(X, N, O^*)`. torch.Tensor: Dones, with shape :math:`(X, N, 1)`. """ # transitions sampled randomly from replay buffer initialized = False for idx in indices: batch = self._replay_buffers[idx].sample_transitions( self._batch_size) if not initialized: o = batch['observations'][np.newaxis] a = batch['actions'][np.newaxis] r = batch['rewards'][np.newaxis] no = batch['next_observations'][np.newaxis] d = batch['dones'][np.newaxis] initialized = True else: o = np.vstack((o, batch['observations'][np.newaxis])) a = np.vstack((a, batch['actions'][np.newaxis])) r = np.vstack((r, batch['rewards'][np.newaxis])) no = np.vstack((no, batch['next_observations'][np.newaxis])) d = np.vstack((d, batch['dones'][np.newaxis])) o = torch.as_tensor(o, device=tu.global_device()).float() a = torch.as_tensor(a, device=tu.global_device()).float() r = torch.as_tensor(r, device=tu.global_device()).float() no = torch.as_tensor(no, device=tu.global_device()).float() d = torch.as_tensor(d, device=tu.global_device()).float() return o, a, r, no, d def _sample_context(self, indices): """Sample batch of context from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Context data, with shape :math:`(X, N, C)`. X is the number of tasks. N is batch size. C is the combined size of observation, action, reward, and next observation if next observation is used in context. Otherwise, C is the combined size of observation, action, and reward. """ # make method work given a single task index if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: batch = self._context_replay_buffers[idx].sample_transitions( self._embedding_batch_size) o = batch['observations'] a = batch['actions'] r = batch['rewards'] context = np.hstack((np.hstack((o, a)), r)) if self._use_next_obs_in_context: context = np.hstack((context, batch['next_observations'])) if not initialized: final_context = context[np.newaxis] initialized = True else: final_context = np.vstack((final_context, context[np.newaxis])) final_context = torch.as_tensor(final_context, device=tu.global_device()).float() if len(indices) == 1: final_context = final_context.unsqueeze(0) return final_context def _update_target_network(self): """Update parameters in the target vf network.""" for target_param, param in zip(self.target_vf.parameters(), self._vf.parameters()): target_param.data.copy_(target_param.data * (1.0 - self._soft_target_tau) + param.data * self._soft_target_tau) @property def policy(self): """Return all the policy within the model. Returns: garage.torch.policies.Policy: Policy within the model. """ return self._policy @property def networks(self): """Return all the networks within the model. Returns: list: A list of networks. """ return self._policy.networks + [self._policy] + [ self._qf1, self._qf2, self._vf, self.target_vf ] def get_exploration_policy(self): """Return a policy used before adaptation to a specific task. Each time it is retrieved, this policy should only be evaluated in one task. Returns: garage.Policy: The policy used to obtain samples that are later used for meta-RL adaptation. """ return self._policy def adapt_policy(self, exploration_policy, exploration_trajectories): """Produce a policy adapted for a task. Args: exploration_policy (garage.Policy): A policy which was returned from get_exploration_policy(), and which generated exploration_trajectories by interacting with an environment. The caller may not use this object after passing it into this method. exploration_trajectories (garage.TrajectoryBatch): Trajectories to adapt to, generated by exploration_policy exploring the environment. Returns: garage.Policy: A policy adapted to the task represented by the exploration_trajectories. """ total_steps = sum(exploration_trajectories.lengths) o = exploration_trajectories.observations a = exploration_trajectories.actions r = exploration_trajectories.rewards.reshape(total_steps, 1) ctxt = np.hstack((o, a, r)).reshape(1, total_steps, -1) context = torch.as_tensor(ctxt, device=tu.global_device()).float() self._policy.infer_posterior(context) return self._policy def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ device = device or tu.global_device() for net in self.networks: net.to(device) @classmethod def augment_env_spec(cls, env_spec, latent_dim): """Augment environment by a size of latent dimension. Args: env_spec (garage.envs.EnvSpec): Environment specs to be augmented. latent_dim (int): Latent dimension. Returns: garage.envs.EnvSpec: Augmented environment specs. """ obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) aug_obs = akro.Box(low=-1, high=1, shape=(obs_dim + latent_dim, ), dtype=np.float32) aug_act = akro.Box(low=-1, high=1, shape=(action_dim, ), dtype=np.float32) return EnvSpec(aug_obs, aug_act) @classmethod def get_env_spec(cls, env_spec, latent_dim, module): """Get environment specs of encoder with latent dimension. Args: env_spec (garage.envs.EnvSpec): Environment specs. latent_dim (int): Latent dimension. module (str): Module to get environment specs for. Returns: garage.envs.InOutSpec: Module environment specs with latent dimension. """ obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) if module == 'encoder': in_dim = obs_dim + action_dim + 1 out_dim = latent_dim * 2 elif module == 'vf': in_dim = obs_dim out_dim = latent_dim in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32) out_space = akro.Box(low=-1, high=1, shape=(out_dim, ), dtype=np.float32) if module == 'encoder': spec = InOutSpec(in_space, out_space) elif module == 'vf': spec = EnvSpec(in_space, out_space) return spec
class PEARLSAC2(MetaRLAlgorithm): """A PEARL model based on https://arxiv.org/abs/1903.08254. PEARL, which stands for Probablistic Embeddings for Actor-Critic Reinforcement Learning, is an off-policy meta-RL algorithm. It is built on top of SAC using two Q-functions and a value function with an addition of an inference network that estimates the posterior :math:`q(z \| c)`. The policy is conditioned on the latent variable Z in order to adpat its behavior to specific tasks. Args: env (list[GarageEnv]): Batch of sampled environment updates(EnvUpdate), which, when invoked on environments, will configure them with new tasks. policy_class (garage.torch.policies.Policy): Context-conditioned policy class. encoder_class (garage.torch.embeddings.ContextEncoder): Encoder class for the encoder in context-conditioned policy. inner_policy (garage.torch.policies.Policy): Policy. qf (torch.nn.Module): Q-function. vf (torch.nn.Module): Value function. num_train_tasks (int): Number of tasks for training. num_test_tasks (int): Number of tasks for testing. latent_dim (int): Size of latent context vector. encoder_hidden_sizes (list[int]): Output dimension of dense layer(s) of the context encoder. test_env_sampler (garage.experiment.SetTaskSampler): Sampler for test tasks. policy_lr (float): Policy learning rate. qf_lr (float): Q-function learning rate. vf_lr (float): Value function learning rate. context_lr (float): Inference network learning rate. policy_mean_reg_coeff (float): Policy mean regulation weight. policy_std_reg_coeff (float): Policy std regulation weight. policy_pre_activation_coeff (float): Policy pre-activation weight. soft_target_tau (float): Interpolation parameter for doing the soft target update. kl_lambda (float): KL lambda value. optimizer_class (callable): Type of optimizer for training networks. use_information_bottleneck (bool): False means latent context is deterministic. use_next_obs_in_context (bool): Whether or not to use next observation in distinguishing between tasks. meta_batch_size (int): Meta batch size. num_steps_per_epoch (int): Number of iterations per epoch. num_initial_steps (int): Number of transitions obtained per task before training. num_tasks_sample (int): Number of random tasks to obtain data for each iteration. num_steps_prior (int): Number of transitions to obtain per task with z ~ prior. num_steps_posterior (int): Number of transitions to obtain per task with z ~ posterior. num_extra_rl_steps_posterior (int): Number of additional transitions to obtain per task with z ~ posterior that are only used to train the policy and NOT the encoder. batch_size (int): Number of transitions in RL batch. embedding_batch_size (int): Number of transitions in context batch. embedding_mini_batch_size (int): Number of transitions in mini context batch; should be same as embedding_batch_size for non-recurrent encoder. max_path_length (int): Maximum path length. discount (float): RL discount factor. replay_buffer_size (int): Maximum samples in replay buffer. reward_scale (int): Reward scale. update_post_train (int): How often to resample context when obtaining data during training (in trajectories). """ # pylint: disable=too-many-statements def __init__(self, env, inner_policy, qf1, qf2, sampler, num_train_tasks, num_test_tasks, latent_dim, encoder_hidden_sizes, test_env_sampler, policy_class=ContextConditionedPolicy, encoder_class=MLPEncoder, policy_lr=3E-4, qf_lr=3E-4, context_lr=3E-4, policy_mean_reg_coeff=1E-3, policy_std_reg_coeff=1E-3, policy_pre_activation_coeff=0., soft_target_tau=0.005, kl_lambda=.1, fixed_alpha=None, target_entropy=None, initial_log_entropy=0., optimizer_class=torch.optim.Adam, use_information_bottleneck=True, use_next_obs_in_context=False, meta_batch_size=64, num_steps_per_epoch=1000, num_initial_steps=100, num_tasks_sample=100, num_steps_prior=100, num_steps_posterior=0, num_extra_rl_steps_posterior=100, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, max_path_length=500, discount=0.99, replay_buffer_size=1000000, single_alpha=False, reward_scale=1, update_post_train=1): self._env = env self._qf1 = qf1 self._qf2 = qf2 # use 2 target q networks self._target_qf1 = copy.deepcopy(self._qf1) self._target_qf2 = copy.deepcopy(self._qf2) self._num_train_tasks = num_train_tasks self._num_test_tasks = num_test_tasks self._latent_dim = latent_dim self._policy_lr = policy_lr self._qf_lr = qf_lr self._context_lr = context_lr self._policy_mean_reg_coeff = policy_mean_reg_coeff self._policy_std_reg_coeff = policy_std_reg_coeff self._policy_pre_activation_coeff = policy_pre_activation_coeff self._soft_target_tau = soft_target_tau self._kl_lambda = kl_lambda self._use_information_bottleneck = use_information_bottleneck self._use_next_obs_in_context = use_next_obs_in_context self._meta_batch_size = meta_batch_size self._num_steps_per_epoch = num_steps_per_epoch self._num_initial_steps = num_initial_steps self._num_tasks_sample = num_tasks_sample self._num_steps_prior = num_steps_prior self._num_steps_posterior = num_steps_posterior self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior self._batch_size = batch_size self._embedding_batch_size = embedding_batch_size self._embedding_mini_batch_size = embedding_mini_batch_size self.max_path_length = max_path_length self._discount = discount self._replay_buffer_size = replay_buffer_size self._reward_scale = reward_scale self._update_post_train = update_post_train self._task_idx = None self._is_resuming = False self._sampler = sampler self._optimizer_class = optimizer_class self._single_alpha = single_alpha worker_args = dict(deterministic=True, accum_context=True) self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler, worker_class=PEARLSAC2Worker, worker_args=worker_args, n_test_tasks=num_test_tasks) env_spec = env[0]() encoder_spec = self.get_env_spec(env_spec, latent_dim, 'encoder') encoder_in_dim = int(np.prod(encoder_spec.input_space.shape)) encoder_out_dim = int(np.prod(encoder_spec.output_space.shape)) context_encoder = encoder_class(input_dim=encoder_in_dim, output_dim=encoder_out_dim, hidden_sizes=encoder_hidden_sizes) # Automatic entropy coefficient tuning self._use_automatic_entropy_tuning = fixed_alpha is None self._initial_log_entropy = initial_log_entropy self._fixed_alpha = fixed_alpha if self._use_automatic_entropy_tuning: if target_entropy: self._target_entropy = target_entropy else: self._target_entropy = -np.prod( env_spec.action_space.shape).item() if self._single_alpha: self._log_alpha = torch.Tensor([self._initial_log_entropy ]).requires_grad_() else: self._log_alpha = torch.Tensor( [self._initial_log_entropy] * self._num_train_tasks).requires_grad_() self._alpha_optimizer = self._optimizer_class([self._log_alpha], lr=self._policy_lr) else: if self._single_alpha: self._log_alpha = torch.Tensor([self._fixed_alpha]).log() else: self._log_alpha = torch.Tensor([self._fixed_alpha] * self._num_train_tasks).log() self._policy = policy_class( latent_dim=latent_dim, context_encoder=context_encoder, policy=inner_policy, use_information_bottleneck=use_information_bottleneck, use_next_obs=use_next_obs_in_context) # buffer for training RL update self._replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(num_train_tasks) } self._policy_optimizer = optimizer_class( self._policy.networks[1].parameters(), lr=self._policy_lr, ) self.qf1_optimizer = optimizer_class( self._qf1.parameters(), lr=self._qf_lr, ) self.qf2_optimizer = optimizer_class( self._qf2.parameters(), lr=self._qf_lr, ) self.context_optimizer = optimizer_class( self._policy.networks[0].parameters(), lr=self._context_lr, ) def __getstate__(self): """Object.__getstate__. Returns: dict: the state to be pickled for the instance. """ data = self.__dict__.copy() del data['_replay_buffers'] del data['_context_replay_buffers'] return data def __setstate__(self, state): """Object.__setstate__. Args: state (dict): unpickled state. """ self.__dict__.update(state) self._replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._is_resuming = True def update_env(self, env, evaluator, num_train_tasks, num_test_tasks): print("Updating environments") self._env = env self._evaluator = evaluator self._num_train_tasks = num_train_tasks self._num_test_tasks = num_test_tasks # buffer for training RL update self._replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(num_train_tasks) } self._task_idx = 0 print("Updated with new envipickleronment setup") self._policy_optimizer = torch.optim.Adam( self._policy.networks[1].parameters(), lr=3E-4, ) self.qf1_optimizer = torch.optim.Adam( self._qf1.parameters(), lr=3E-4, ) self.qf2_optimizer = torch.optim.Adam( self._qf2.parameters(), lr=3E-4, ) self.vf_optimizer = torch.optim.Adam( self._vf.parameters(), lr=3E-4, ) self.context_optimizer = torch.optim.Adam( self._policy.networks[0].parameters(), lr=3E-4, ) print('Reset optimizer state') def fill_expert_traj(self, expert_traj_dir): print("Filling Expert trajectory to replay buffer") from os import listdir from os.path import isfile, join expert_traj_paths = [ join(expert_traj_dir, f) for f in listdir(expert_traj_dir) if isfile(join(expert_traj_dir, f)) ] expert_trajs = [] for exp_path in expert_traj_paths: with open(exp_path, 'rb') as handle: data = pickle.load(handle) expert_trajs.append(data) for path in expert_trajs: p = { 'observations': path['observations'], 'actions': path['actions'], 'rewards': path['rewards'].reshape(-1, 1), 'next_observations': path['next_observations'], 'dones': path['dones'].reshape(-1, 1) } self._replay_buffers[self._task_idx].add_path(p) self._context_replay_buffers[self._task_idx].add_path(p) def adapt_expert_traj(self, runner): """Obtain samples, train, and evaluate for each epoch. Args: runner (LocalRunner): LocalRunner is passed to give algorithm the access to runner.step_epochs(), which provides services such as snapshotting and sampler control. """ for _ in runner.step_epochs(): logger.log('Adapting Policy {}...'.format(runner.step_itr)) self._train_once() runner.step_itr += 1 logger.log('Evaluating...') # evaluate self._policy.reset_belief() self._evaluator.evaluate(self) def reset(self): self._policy.reset_belief() def train(self, runner): """Obtain samples, train, and evaluate for each epoch. Args: runner (LocalRunner): LocalRunner is passed to give algorithm the access to runner.step_epochs(), which provides services such as snapshotting and sampler control. """ for _ in runner.step_epochs(): epoch = runner.step_itr / self._num_steps_per_epoch # obtain initial set of samples from all train tasks if epoch == 0 or self._is_resuming: for idx in range(self._num_train_tasks): self._task_idx = idx self._obtain_samples(runner, epoch, self._num_initial_steps, np.inf) self._is_resuming = False # obtain samples from random tasks for _ in range(self._num_tasks_sample): idx = np.random.randint(self._num_train_tasks) self._task_idx = idx self._context_replay_buffers[idx].clear() # obtain samples with z ~ prior if self._num_steps_prior > 0: self._obtain_samples(runner, epoch, self._num_steps_prior, np.inf) # obtain samples with z ~ posterior if self._num_steps_posterior > 0: self._obtain_samples(runner, epoch, self._num_steps_posterior, self._update_post_train) # obtain extras samples for RL training but not encoder if self._num_extra_rl_steps_posterior > 0: self._obtain_samples(runner, epoch, self._num_extra_rl_steps_posterior, self._update_post_train, add_to_enc_buffer=False) logger.log('Training...') # sample train tasks and optimize networks self._train_once() runner.step_itr += 1 logger.log('Evaluating...') # evaluate self._policy.reset_belief() self._evaluator.evaluate(self) def _train_once(self): """Perform one iteration of training.""" policy_loss_list = [] qf_loss_list = [] alpha_loss_list = [] alpha_list = [] for _ in range(self._num_steps_per_epoch): indices = np.random.choice(range(self._num_train_tasks), self._meta_batch_size) policy_loss, qf_loss, alpha_loss, alpha = self._optimize_policy( indices) policy_loss_list.append(policy_loss) qf_loss_list.append(qf_loss) alpha_loss_list.append(alpha_loss) alpha_list.append(alpha) with tabular.prefix('MetaTrain/Average/'): tabular.record('PolicyLoss', np.average(np.array(policy_loss_list))) tabular.record('QfLoss', np.average(np.array(qf_loss_list))) tabular.record('AlphaLoss', np.average(np.array(alpha_loss_list))) tabular.record('AlphaLoss', np.average(np.array(alpha_loss_list))) tabular.record('Alpha', np.average(np.array(alpha_list))) def _optimize_policy(self, indices): """Perform algorithm optimizing. Args: indices (list): Tasks used for training. """ num_tasks = len(indices) context = self._sample_context(indices) # clear context and reset belief of policy self._policy.reset_belief(num_tasks=num_tasks) # data shape is (task, batch, feat) obs, actions, rewards, next_obs, terms = self._sample_data(indices) # flatten out the task dimension t, b, _ = obs.size() batch_obs = obs.view(t * b, -1) batch_action = actions.view(t * b, -1) batch_next_obs = next_obs.view(t * b, -1) policy_outputs, task_z = self._policy(next_obs, context) new_next_actions, policy_mean, policy_log_std, log_pi, pre_tanh = policy_outputs # ===== Critic Objective ===== with torch.no_grad(): alpha = self._get_log_alpha(indices).exp() q1_pred = self._qf1(torch.cat([batch_obs, batch_action], dim=1), task_z) q2_pred = self._qf2(torch.cat([batch_obs, batch_action], dim=1), task_z) target_q_values = torch.min( self._target_qf1( torch.cat([batch_next_obs, new_next_actions], dim=1), task_z), self._target_qf2( torch.cat([batch_next_obs, new_next_actions], dim=1), task_z)).flatten() - (alpha * log_pi.flatten()) rewards_flat = rewards.view(self._batch_size * num_tasks, -1).flatten() rewards_flat = rewards_flat * self._reward_scale terms_flat = terms.view(self._batch_size * num_tasks, -1).flatten() with torch.no_grad(): q_target = rewards_flat + ( (1. - terms_flat) * self._discount) * target_q_values qf1_loss = F.mse_loss(q1_pred.flatten(), q_target) qf2_loss = F.mse_loss(q2_pred.flatten(), q_target) qf_loss = qf1_loss + qf2_loss # KL constraint on z if probabilistic self.context_optimizer.zero_grad() if self._use_information_bottleneck: kl_div = self._policy.compute_kl_div() kl_loss = self._kl_lambda * kl_div kl_loss.backward(retain_graph=True) # Optimize Q network and context encoder self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() qf_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.step() self.context_optimizer.step() # ===== Actor Objective ===== policy_outputs, task_z = self._policy(obs, context) new_actions, policy_mean, policy_log_std, log_pi, pre_tanh = policy_outputs # compute min Q on the new actions q1 = self._qf1(torch.cat([batch_obs, new_actions], dim=1), task_z.detach()) q2 = self._qf2(torch.cat([batch_obs, new_actions], dim=1), task_z.detach()) min_q = torch.min(q1, q2) # optimize policy policy_loss = ((alpha * log_pi) - min_q.flatten()).mean() mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean() std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean() pre_tanh_value = policy_outputs[-1] pre_activation_reg_loss = self._policy_pre_activation_coeff * ( (pre_tanh_value**2).sum(dim=1).mean()) policy_reg_loss = (mean_reg_loss + std_reg_loss + pre_activation_reg_loss) policy_loss += policy_reg_loss self._policy_optimizer.zero_grad() policy_loss.backward() self._policy_optimizer.step() # ===== Temperature Objective ===== if self._use_automatic_entropy_tuning: alpha = (self._get_log_alpha(indices)).exp() alpha_loss = (-alpha * (log_pi.detach() + self._target_entropy)).mean() self._alpha_optimizer.zero_grad() alpha_loss.backward() self._alpha_optimizer.step() alpha_avg_cpu = np.average(alpha.detach().cpu().numpy()) alpha_loss_cpu = alpha_loss.detach().cpu().numpy() # ===== Update Target Network ===== target_qfs = [self._target_qf1, self._target_qf2] qfs = [self._qf1, self._qf2] for target_qf, qf in zip(target_qfs, qfs): for t_param, param in zip(target_qf.parameters(), qf.parameters()): t_param.data.copy_(t_param.data * (1.0 - self._soft_target_tau) + param.data * self._soft_target_tau) qf_loss_cpu = qf_loss.detach().cpu().numpy() policy_loss_cpu = policy_loss.detach().cpu().numpy() return policy_loss_cpu, qf_loss_cpu, alpha_loss_cpu, alpha_avg_cpu def _obtain_samples(self, runner, itr, num_samples, update_posterior_rate, add_to_enc_buffer=True): """Obtain samples. Args: runner (LocalRunner): LocalRunner. itr (int): Index of iteration (epoch). num_samples (int): Number of samples to obtain. update_posterior_rate (int): How often (in trajectories) to infer posterior of policy. add_to_enc_buffer (bool): Whether or not to add samples to encoder buffer. """ self._policy.reset_belief() total_samples = 0 if update_posterior_rate != np.inf: num_samples_per_batch = (update_posterior_rate * self.max_path_length) else: num_samples_per_batch = num_samples while total_samples < num_samples: paths = runner.obtain_samples(itr, num_samples_per_batch, self._policy, self._env[self._task_idx]) total_samples += sum([len(path['rewards']) for path in paths]) for path in paths: terminations = np.array([ step_type == StepType.TERMINAL for step_type in path['step_types'] ]).reshape(-1, 1) p = { 'observations': path['observations'], 'actions': path['actions'], 'rewards': path['rewards'].reshape(-1, 1), 'next_observations': path['next_observations'], 'dones': terminations } self._replay_buffers[self._task_idx].add_path(p) if add_to_enc_buffer: self._context_replay_buffers[self._task_idx].add_path(p) if update_posterior_rate != np.inf: context = self._sample_context(self._task_idx) self._policy.infer_posterior(context) def _sample_data(self, indices): """Sample batch of training data from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Obervations, with shape :math:`(X, N, O^*)` where X is the number of tasks. N is batch size. torch.Tensor: Actions, with shape :math:`(X, N, A^*)`. torch.Tensor: Rewards, with shape :math:`(X, N, 1)`. torch.Tensor: Next obervations, with shape :math:`(X, N, O^*)`. torch.Tensor: Dones, with shape :math:`(X, N, 1)`. """ # transitions sampled randomly from replay buffer initialized = False for idx in indices: batch = self._replay_buffers[idx].sample_transitions( self._batch_size) if not initialized: o = batch['observations'][np.newaxis] a = batch['actions'][np.newaxis] r = batch['rewards'][np.newaxis] no = batch['next_observations'][np.newaxis] d = batch['dones'][np.newaxis] initialized = True else: o = np.vstack((o, batch['observations'][np.newaxis])) a = np.vstack((a, batch['actions'][np.newaxis])) r = np.vstack((r, batch['rewards'][np.newaxis])) no = np.vstack((no, batch['next_observations'][np.newaxis])) d = np.vstack((d, batch['dones'][np.newaxis])) o = torch.as_tensor(o, device=global_device()).float() a = torch.as_tensor(a, device=global_device()).float() r = torch.as_tensor(r, device=global_device()).float() no = torch.as_tensor(no, device=global_device()).float() d = torch.as_tensor(d, device=global_device()).float() return o, a, r, no, d def _sample_context(self, indices): """Sample batch of context from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Context data, with shape :math:`(X, N, C)`. X is the number of tasks. N is batch size. C is the combined size of observation, action, reward, and next observation if next observation is used in context. Otherwise, C is the combined size of observation, action, and reward. """ # make method work given a single task index if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: batch = self._context_replay_buffers[idx].sample_transitions( self._embedding_batch_size) o = batch['observations'] a = batch['actions'] r = batch['rewards'] context = np.hstack((np.hstack((o, a)), r)) if self._use_next_obs_in_context: context = np.hstack((context, batch['next_observations'])) if not initialized: final_context = context[np.newaxis] initialized = True else: final_context = np.vstack((final_context, context[np.newaxis])) final_context = torch.as_tensor(final_context, device=global_device()).float() if len(indices) == 1: final_context = final_context.unsqueeze(0) return final_context @property def policy(self): """Return all the policy within the model. Returns: garage.torch.policies.Policy: Policy within the model. """ return self._policy @property def networks(self): """Return all the networks within the model. Returns: list: A list of networks. """ return self._policy.networks + [ self._policy, self._qf1, self._qf2, self._target_qf1, self._target_qf2 ] def get_exploration_policy(self): """Return a policy used before adaptation to a specific task. Each time it is retrieved, this policy should only be evaluated in one task. Returns: garage.Policy: The policy used to obtain samples that are later used for meta-RL adaptation. """ return self._policy def adapt_policy(self, exploration_policy, exploration_trajectories): """Produce a policy adapted for a task. Args: exploration_policy (garage.Policy): A policy which was returned from get_exploration_policy(), and which generated exploration_trajectories by interacting with an environment. The caller may not use this object after passing it into this method. exploration_trajectories (garage.TrajectoryBatch): Trajectories to adapt to, generated by exploration_policy exploring the environment. Returns: garage.Policy: A policy adapted to the task represented by the exploration_trajectories. """ total_steps = sum(exploration_trajectories.lengths) o = exploration_trajectories.observations a = exploration_trajectories.actions r = exploration_trajectories.rewards.reshape(total_steps, 1) ctxt = np.hstack((o, a, r)).reshape(1, total_steps, -1) context = torch.as_tensor(ctxt, device=global_device()).float() self._policy.infer_posterior(context) return self._policy def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ device = device or global_device() for net in self.networks: net.to(device) if self._use_automatic_entropy_tuning: if self._single_alpha: self._log_alpha = torch.cuda.FloatTensor( [self._initial_log_entropy], device=global_device()).requires_grad_() else: self._log_alpha = torch.cuda.FloatTensor( [self._initial_log_entropy] * self._num_train_tasks, device=global_device()).requires_grad_() self._alpha_optimizer = self._optimizer_class([self._log_alpha], lr=self._policy_lr) else: if self._single_alpha: self._log_alpha = torch.cuda.FloatTensor( [self._fixed_alpha], device=global_device()).log() else: self._log_alpha = torch.cuda.FloatTensor( [self._fixed_alpha] * self._num_train_tasks, device=global_device()).log() @classmethod def augment_env_spec(cls, env_spec, latent_dim): """Augment environment by a size of latent dimension. Args: env_spec (garage.envs.EnvSpec): Environment specs to be augmented. latent_dim (int): Latent dimension. Returns: garage.envs.EnvSpec: Augmented environment specs. """ obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) aug_obs = akro.Box(low=-1, high=1, shape=(obs_dim + latent_dim, ), dtype=np.float32) aug_act = akro.Box(low=-1, high=1, shape=(action_dim, ), dtype=np.float32) return EnvSpec(aug_obs, aug_act) @classmethod def get_env_spec(cls, env_spec, latent_dim, module): """Get environment specs of encoder with latent dimension. Args: env_spec (garage.envs.EnvSpec): Environment specs. latent_dim (int): Latent dimension. module (str): Module to get environment specs for. Returns: garage.envs.InOutSpec: Module environment specs with latent dimension. """ obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) if module == 'encoder': in_dim = obs_dim + action_dim + 1 out_dim = latent_dim * 2 elif module == 'vf': in_dim = obs_dim out_dim = latent_dim in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32) out_space = akro.Box(low=-1, high=1, shape=(out_dim, ), dtype=np.float32) if module == 'encoder': spec = InOutSpec(in_space, out_space) elif module == 'vf': spec = EnvSpec(in_space, out_space) return spec def _get_log_alpha(self, indices): """Return the value of log_alpha. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size) """ log_alpha = self._log_alpha one_hots = np.zeros( (len(indices) * self._batch_size, self._num_train_tasks), dtype=np.float32) for i in range(len(indices)): one_hots[self._batch_size * i:self._batch_size * (i + 1), indices[i]] = 1 one_hots = torch.as_tensor(one_hots, device=global_device()) ret = torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze() return ret
class MetaBasicHierch(MetaRLAlgorithm): def __init__( self, env, skill_env, controller_policy, skill_actor, qf, vf, num_skills, num_train_tasks, num_test_tasks, test_env_sampler, sampler_class, # to avoid cycling import controller_class=ControllerPolicy, controller_lr=3E-4, qf_lr=3E-4, vf_lr=3E-4, policy_mean_reg_coeff=1E-3, policy_std_reg_coeff=1E-3, policy_pre_activation_coeff=0., soft_target_tau=0.005, optimizer_class=torch.optim.Adam, meta_batch_size=64, num_steps_per_epoch=1000, num_tasks_sample=5, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, max_path_length=1000, discount=0.99, replay_buffer_size=1000000, ): self._env = env self._skill_env = skill_env self._qf1 = qf self._qf2 = copy.deepcopy(qf) self._vf = vf self._skill_actor = skill_actor self._num_skills = num_skills self._num_train_tasks = num_train_tasks self._num_test_tasks = num_test_tasks self._policy_mean_reg_coeff = policy_mean_reg_coeff self._policy_std_reg_coeff = policy_std_reg_coeff self._policy_pre_activation_coeff = policy_pre_activation_coeff self._soft_target_tau = soft_target_tau self._meta_batch_size = meta_batch_size self._num_steps_per_epoch = num_steps_per_epoch self._num_tasks_sample = num_tasks_sample self._batch_size = batch_size self._embedding_batch_size = embedding_batch_size self._embedding_mini_batch_size = embedding_mini_batch_size self.max_path_length = max_path_length self._discount = discount self._replay_buffer_size = replay_buffer_size self._task_idx = None self._skill_idx = None # do we really need it self._is_resuming = False worker_args = dict(num_skills=num_skills, skill_actor_class=type(skill_actor), controller_class=controller_class, deterministic=False, accum_context=False) self._evaluator = MetaEvaluator( test_task_sampler=test_env_sampler, max_path_length=max_path_length, worker_class=BasicHierachWorker, worker_args=worker_args, n_test_tasks=num_test_tasks, sampler_class=sampler_class, trajectory_batch_class=SkillTrajectoryBatch) self._average_rewards = [] self._controller = controller_class( controller_policy=controller_policy, num_skills=num_skills, sub_actor=skill_actor) self._skills_replay_buffer = PathBuffer(replay_buffer_size) self._replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self.target_vf = copy.deepcopy(self._vf) self.vf_criterion = torch.nn.MSELoss() self._controller_optimizer = optimizer_class( self._controller.networks[1].parameters(), lr=controller_lr, ) self.qf1_optimizer = optimizer_class( self._qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self._qf2.parameters(), lr=qf_lr, ) self.vf_optimizer = optimizer_class( self._vf.parameters(), lr=vf_lr, ) def train(self, runner): for _ in runner.step_epochs(): epoch = runner.step_itr / self._num_steps_per_epoch if epoch == 0 or self._is_resuming: for idx in range(self._num_train_tasks): self._task_idx = idx self._obtain_task_samples(runner, epoch, self._num_initial_steps, np.inf) self._is_resuming = False logger.log('Sampling tasks') for _ in range(self._num_tasks_sample): idx = np.random.randint(self._num_train_tasks) self._task_idx = idx self._replay_buffers[idx].clear() self._obtain_task_samples(runner, epoch, self._num_steps_per_epoch) logger.log('Training task adapting...') self._tasks_adapt_train_once() runner.step_itr += 1 logger.log('Evaluating...') # evaluate self._controller.reset_belief() self._average_rewards.append(self._evaluator.evaluate(self)) return self._average_rewards def _tasks_adapt_train_once(self): for _ in range(self._num_steps_per_epoch): indices = np.random.choice(range(self._num_train_tasks), self._meta_batch_size) self._tasks_adapt_optimize_policy(indices) def _tasks_adapt_optimize_policy(self, indices): num_tasks = len(indices) obs, actions, rewards, skills, next_obs, terms = self._sample_task_path( indices) self._controller.reset_belief(num_tasks=num_tasks) # data shape is (task, batch, feat) # new_skills_pred is distribution policy_outputs, new_skills_pred, task_z = self._controller(obs) new_actions, policy_mean, policy_log_std, policy_log_pi = policy_outputs[: 4] # flatten out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) # actions = actions.view(t * b, -1) skills = skills.view(t * b, -1) next_obs = next_obs.view(t * b, -1) # optimize qf and encoder networks # TODO try [obs, skills, actions] or [obs, skills, task_z] # FIXME prob need to reshape or tile task_z obs = obs.to(tu.global_device()) skills = skills.to(tu.global_device()) next_obs = next_obs.to(tu.global_device()) q1_pred = self._qf1(torch.cat([obs, skills], dim=1)) q2_pred = self._qf2(torch.cat([obs, skills], dim=1)) self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() rewards_flat = rewards.view(b * t, -1) terms_flat = terms.view(b * t, -1) q_target = rewards_flat + (1. - terms_flat) * self._discount qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean( (q2_pred - q_target)**2) qf_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.step() new_skills_pred = new_skills_pred.to(tu.global_device()) # compute min Q on the new actions q1 = self._qf1(torch.cat([obs, new_skills_pred], dim=1), task_z.detach()) q2 = self._qf2(torch.cat([obs, new_skills_pred], dim=1), task_z.detach()) min_q = torch.min(q1, q2) # optimize vf policy_log_pi = policy_log_pi.to(tu.global_device()) # optimize policy log_policy_target = min_q policy_loss = (policy_log_pi - log_policy_target).mean() mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean() std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean() # took away pre-activation reg policy_reg_loss = mean_reg_loss + std_reg_loss policy_loss = policy_loss + policy_reg_loss self._controller_optimizer.zero_grad() policy_loss.backward() self._controller_optimizer.step() def _obtain_task_samples(self, runner, itr, num_paths): self._controller.reset_belief() total_paths = 0 num_paths_per_batch = num_paths while total_paths < num_paths: num_samples = num_paths_per_batch * self.max_path_length paths = runner.obtain_samples(itr, num_samples, self._controller, self._env[self._task_idx]) total_paths += len(paths) for path in paths: p = { 'states': path['states'], 'actions': path['actions'], 'env_rewards': path['env_rewards'].reshape(-1, 1), 'skills_onehot': path['skills_onehot'], 'next_states': path['next_states'], 'dones': path['dones'].reshape(-1, 1) } self._replay_buffers[self._task_idx].add_path(p) def _sample_task_path(self, indices): if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: path = self._replay_buffers[idx].sample_path() # TODO: trim or extend batch to the same size if not initialized: o = path['states'][np.newaxis] a = path['actions'][np.newaxis] r = path['env_rewards'][np.newaxis] z = path['skills_onehot'][np.newaxis] no = path['next_states'][np.newaxis] d = path['dones'][np.newaxis] initialized = True else: o = np.vstack((o, path['states'][np.newaxis])) a = np.vstack((a, path['actions'][np.newaxis])) r = np.vstack((r, path['env_rewards'][np.newaxis])) z = np.vstack((z, path['skills_onehot'][np.newaxis])) no = np.vstack((no, path['next_states'][np.newaxis])) d = np.vstack((d, path['dones'][np.newaxis])) o = torch.as_tensor(o, device=tu.global_device()).float() a = torch.as_tensor(a, device=tu.global_device()).float() r = torch.as_tensor(r, device=tu.global_device()).float() z = torch.as_tensor(z, device=tu.global_device()).float() no = torch.as_tensor(no, device=tu.global_device()).float() d = torch.as_tensor(d, device=tu.global_device()).float() return o, a, r, z, no, d def _update_target_network(self): for target_param, param in zip(self.target_vf.parameters(), self._vf.parameters()): target_param.data.copy_(target_param.data * (1.0 - self._soft_target_tau) + param.data * self._soft_target_tau) def __getstate__(self): data = self.__dict__.copy() del data['_replay_buffers'] return data def __setstate__(self, state): self.__dict__.update(state) self._replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._is_resuming = True def to(self, device=None): device = device or tu.global_device() for net in self.networks: # print(net) net.to(device) @classmethod def get_env_spec(cls, env_spec, num_skills, module): obs_dim = int(np.prod(env_spec.observation_space.shape)) action_dim = int(np.prod(env_spec.action_space.shape)) if module == 'controller_policy': in_dim = obs_dim out_dim = num_skills if module == 'qf': in_dim = obs_dim out_dim = num_skills in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32) out_space = akro.Box(low=-1, high=1, shape=(out_dim, ), dtype=np.float32) if module == 'controller_policy': spec = EnvSpec(in_space, out_space) if module == 'qf': spec = EnvSpec(in_space, out_space) return spec @property def policy(self): return self._controller @property def networks(self): return self._controller.networks + [self._controller] + [ self._qf1, self._qf2, self._vf, self.target_vf ] def get_exploration_policy(self): return self._controller def adapt_policy(self, exploration_policy, exploration_trajectories): total_steps = sum(exploration_trajectories.lengths) o = exploration_trajectories.states a = exploration_trajectories.actions r = exploration_trajectories.env_rewards.reshape(total_steps, 1) s = exploration_trajectories.skills_onehot return self._controller
class MetaKant(MetaRLAlgorithm): def __init__( self, env, skill_env, controller_policy, skill_actor, qf, vf, num_skills, num_train_tasks, num_test_tasks, latent_dim, encoder_hidden_sizes, test_env_sampler, sampler_class, # to avoid cycling import controller_class=OpenContextConditionedControllerPolicy, encoder_class=GaussianContextEncoder, encoder_module_class=MLPEncoder, is_encoder_recurrent=False, controller_lr=3E-4, qf_lr=3E-4, vf_lr=3E-4, context_lr=3E-4, policy_mean_reg_coeff=1E-3, policy_std_reg_coeff=1E-3, policy_pre_activation_coeff=0., soft_target_tau=0.005, kl_lambda=.1, optimizer_class=torch.optim.Adam, use_next_obs_in_context=False, meta_batch_size=64, num_steps_per_epoch=1000, num_skills_reason_steps=1000, num_skills_sample=10, num_initial_steps=1500, num_tasks_sample=5, num_steps_prior=400, num_steps_posterior=0, num_extra_rl_steps_posterior=600, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, max_path_length=1000, discount=0.99, replay_buffer_size=1000000, # TODO: the ratio needs to be tuned skills_reason_reward_scale=1, tasks_adapt_reward_scale=1.2, use_information_bottleneck=True, update_post_train=1, ): self._env = env self._skill_env = skill_env self._qf1 = qf self._qf2 = copy.deepcopy(qf) self._vf = vf self._skill_actor = skill_actor self._num_skills = num_skills self._num_train_tasks = num_train_tasks self._num_test_tasks = num_test_tasks self._latent_dim = latent_dim self._policy_mean_reg_coeff = policy_mean_reg_coeff self._policy_std_reg_coeff = policy_std_reg_coeff self._policy_pre_activation_coeff = policy_pre_activation_coeff self._soft_target_tau = soft_target_tau self._kl_lambda = kl_lambda self._use_next_obs_in_context = use_next_obs_in_context self._meta_batch_size = meta_batch_size self._num_skills_reason_steps = num_skills_reason_steps self._num_steps_per_epoch = num_steps_per_epoch self._num_initial_steps = num_initial_steps self._num_tasks_sample = num_tasks_sample self._num_skills_sample = num_skills_sample self._num_steps_prior = num_steps_prior self._num_steps_posterior = num_steps_posterior self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior self._batch_size = batch_size self._embedding_batch_size = embedding_batch_size self._embedding_mini_batch_size = embedding_mini_batch_size self.max_path_length = max_path_length self._discount = discount self._replay_buffer_size = replay_buffer_size self._is_encoder_recurrent = is_encoder_recurrent self._use_information_bottleneck = use_information_bottleneck self._skills_reason_reward_scale = skills_reason_reward_scale self._tasks_adapt_reward_scale = tasks_adapt_reward_scale self._update_post_train = update_post_train self._task_idx = None self._skill_idx = None # do we really need it self._is_resuming = False worker_args = dict(num_skills=num_skills, skill_actor_class=type(skill_actor), controller_class=controller_class, deterministic=False, accum_context=True) self._evaluator = MetaEvaluator( test_task_sampler=test_env_sampler, max_path_length=max_path_length, worker_class=KantWorker, worker_args=worker_args, n_test_tasks=num_test_tasks, sampler_class=sampler_class, trajectory_batch_class=SkillTrajectoryBatch) self._average_rewards = [] encoder_spec = self.get_env_spec(env[0](), latent_dim, num_skills, 'encoder') encoder_in_dim = int(np.prod(encoder_spec.input_space.shape)) encoder_out_dim = int(np.prod(encoder_spec.output_space.shape)) encoder_module = encoder_module_class( input_dim=encoder_in_dim, output_dim=encoder_out_dim, hidden_sizes=encoder_hidden_sizes) context_encoder = encoder_class(encoder_module, use_information_bottleneck, latent_dim) self._controller = controller_class( latent_dim=latent_dim, context_encoder=context_encoder, controller_policy=controller_policy, num_skills=num_skills, sub_actor=skill_actor, use_information_bottleneck=use_information_bottleneck, use_next_obs=use_next_obs_in_context) self._skills_replay_buffer = PathBuffer(replay_buffer_size) self._replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self.target_vf = copy.deepcopy(self._vf) self.vf_criterion = torch.nn.MSELoss() self._controller_optimizer = optimizer_class( self._controller.networks[1].parameters(), lr=controller_lr, ) self.qf1_optimizer = optimizer_class( self._qf1.parameters(), lr=qf_lr, ) self.qf2_optimizer = optimizer_class( self._qf2.parameters(), lr=qf_lr, ) self.vf_optimizer = optimizer_class( self._vf.parameters(), lr=vf_lr, ) self.context_optimizer = optimizer_class( self._controller.networks[0].parameters(), lr=context_lr, ) def train(self, runner): for _ in runner.step_epochs(): epoch = runner.step_itr / self._num_steps_per_epoch if epoch == 0 or self._is_resuming: for idx in range(self._num_skills): self._skill_idx = idx self._obtain_skill_samples(runner, epoch, self._num_initial_steps) for idx in range(self._num_train_tasks): self._task_idx = idx self._obtain_task_samples(runner, epoch, self._num_initial_steps, np.inf) self._is_resuming = False logger.log('Sampling skills') for idx in range(self._num_skills_sample): self._skill_idx = idx # self._skills_replay_buffer.clear() self._obtain_skill_samples(runner, epoch, self._num_skills_reason_steps) logger.log('Training skill reasoning...') self._skills_reason_train_once() logger.log('Sampling tasks') for _ in range(self._num_tasks_sample): idx = np.random.randint(self._num_train_tasks) self._task_idx = idx self._context_replay_buffers[idx].clear() # obtain samples with z ~ prior logger.log("Obtaining samples with z ~ prior") if self._num_steps_prior > 0: self._obtain_task_samples(runner, epoch, self._num_steps_prior, np.inf) # obtain samples with z ~ posterior logger.log("Obtaining samples with z ~ posterior") if self._num_steps_posterior > 0: self._obtain_task_samples(runner, epoch, self._num_steps_posterior, self._update_post_train) # obtain extras samples for RL training but not encoder logger.log( "Obtaining extra samples for RL traing but not encoder") if self._num_extra_rl_steps_posterior > 0: self._obtain_task_samples( runner, epoch, self._num_extra_rl_steps_posterior, self._update_post_train, add_to_enc_buffer=False) logger.log('Training task adapting...') self._tasks_adapt_train_once() runner.step_itr += 1 logger.log('Evaluating...') # evaluate self._controller.reset_belief() self._average_rewards.append(self._evaluator.evaluate(self)) return self._average_rewards def _skills_reason_train_once(self): for _ in range(self._num_steps_per_epoch): self._skills_reason_optimize_policy() def _tasks_adapt_train_once(self): for _ in range(self._num_steps_per_epoch): indices = np.random.choice(range(self._num_train_tasks), self._meta_batch_size) self._tasks_adapt_optimize_policy(indices) def _skills_reason_optimize_policy(self): self._controller.reset_belief() # data shape is (task, batch, feat) obs, actions, rewards, skills, next_obs, terms, context = self.\ _sample_skill_path() # skills_pred is distribution policy_outputs, skills_pred, task_z = self._controller(obs, context) _, policy_mean, policy_log_std, policy_log_pi = policy_outputs[:4] self.context_optimizer.zero_grad() if self._use_information_bottleneck: kl_div = self._controller.compute_kl_div() kl_loss = self._kl_lambda * kl_div kl_loss.backward(retain_graph=True) skills_target = skills.clone().detach().requires_grad_(True)\ .to(tu.global_device()) skills_pred = skills_pred.to(tu.global_device()) policy_loss = F.mse_loss(skills_pred.flatten(), skills_target.flatten())\ * self._skills_reason_reward_scale mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean() std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean() #took away the pre-activation reg term policy_reg_loss = mean_reg_loss + std_reg_loss policy_loss = policy_loss + policy_reg_loss self._controller_optimizer.zero_grad() policy_loss.backward() self._controller_optimizer.step() def _tasks_adapt_optimize_policy(self, indices): num_tasks = len(indices) obs, actions, rewards, skills, next_obs, terms, context = \ self._sample_task_path(indices) self._controller.reset_belief(num_tasks=num_tasks) # data shape is (task, batch, feat) # new_skills_pred is distribution policy_outputs, new_skills_pred, task_z = self._controller( obs, context) new_actions, policy_mean, policy_log_std, policy_log_pi = policy_outputs[: 4] # flatten out the task dimension t, b, _ = obs.size() obs = obs.view(t * b, -1) skills = skills.view(t * b, -1) next_obs = next_obs.view(t * b, -1) # optimize qf and encoder networks # TODO try [obs, skills, actions] or [obs, skills, task_z] # FIXME prob need to reshape or tile task_z obs = obs.to(tu.global_device()) skills = skills.to(tu.global_device()) next_obs = next_obs.to(tu.global_device()) q1_pred = self._qf1(torch.cat([obs, skills], dim=1), task_z) q2_pred = self._qf2(torch.cat([obs, skills], dim=1), task_z) v_pred = self._vf(obs, task_z.detach()) with torch.no_grad(): target_v_values = self.target_vf(next_obs, task_z) # KL constraint on z if probabilistic self.context_optimizer.zero_grad() if self._use_information_bottleneck: kl_div = self._controller.compute_kl_div() kl_loss = self._kl_lambda * kl_div kl_loss.backward(retain_graph=True) self.qf1_optimizer.zero_grad() self.qf2_optimizer.zero_grad() rewards_flat = rewards.view(b * t, -1) rewards_flat = rewards_flat * self._tasks_adapt_reward_scale terms_flat = terms.view(b * t, -1) q_target = rewards_flat + ( 1. - terms_flat) * self._discount * target_v_values qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean( (q2_pred - q_target)**2) qf_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.step() self.context_optimizer.step() new_skills_pred = new_skills_pred.to(tu.global_device()) # compute min Q on the new actions q1 = self._qf1(torch.cat([obs, new_skills_pred], dim=1), task_z.detach()) q2 = self._qf2(torch.cat([obs, new_skills_pred], dim=1), task_z.detach()) min_q = torch.min(q1, q2) # optimize vf policy_log_pi = policy_log_pi.to(tu.global_device()) v_target = min_q - policy_log_pi vf_loss = self.vf_criterion(v_pred, v_target.detach()) self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() self._update_target_network() # optimize policy log_policy_target = min_q policy_loss = (policy_log_pi - log_policy_target).mean() mean_reg_loss = self._policy_mean_reg_coeff * (policy_mean**2).mean() std_reg_loss = self._policy_std_reg_coeff * (policy_log_std**2).mean() # took away pre-activation reg policy_reg_loss = mean_reg_loss + std_reg_loss policy_loss = policy_loss + policy_reg_loss self._controller_optimizer.zero_grad() policy_loss.backward() self._controller_optimizer.step() def _obtain_skill_samples(self, runner, itr, num_paths): self._controller.reset_belief() total_paths = 0 while total_paths < num_paths: num_samples = num_paths * self.max_path_length paths = runner.obtain_samples(itr, num_samples, self._skill_actor, self._skill_env) total_paths += len(paths) for path in paths: p = { 'states': path['states'], 'actions': path['actions'], 'env_rewards': path['env_rewards'].reshape(-1, 1), 'skills_onehot': path['skills_onehot'], 'next_states': path['next_states'], 'dones': path['dones'].reshape(-1, 1) } self._skills_replay_buffer.add_path(p) def _obtain_task_samples(self, runner, itr, num_paths, update_posterior_rate, add_to_enc_buffer=True): self._controller.reset_belief() total_paths = 0 if update_posterior_rate != np.inf: num_paths_per_batch = update_posterior_rate else: num_paths_per_batch = num_paths while total_paths < num_paths: num_samples = num_paths_per_batch * self.max_path_length paths = runner.obtain_samples(itr, num_samples, self._controller, self._env[self._task_idx]) total_paths += len(paths) for path in paths: p = { 'states': path['states'], 'actions': path['actions'], 'env_rewards': path['env_rewards'].reshape(-1, 1), 'skills_onehot': path['skills_onehot'], 'next_states': path['next_states'], 'dones': path['dones'].reshape(-1, 1) } self._replay_buffers[self._task_idx].add_path(p) if add_to_enc_buffer: self._context_replay_buffers[self._task_idx].add_path(p) if update_posterior_rate != np.inf: context = self._sample_path_context(self._task_idx) self._controller.infer_posterior(context) def _sample_task_path(self, indices): if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: path = self._context_replay_buffers[idx].sample_path() # should be replay_buffers[] # TODO: trim or extend batch to the same size context_o = path['states'] context_a = path['actions'] context_r = path['env_rewards'] context_z = path['skills_onehot'] context = np.hstack((np.hstack((np.hstack( (context_o, context_a)), context_r)), context_z)) if self._use_next_obs_in_context: context = np.hstack((context, path['next_states'])) if not initialized: final_context = context[np.newaxis] o = path['states'][np.newaxis] a = path['actions'][np.newaxis] r = path['env_rewards'][np.newaxis] z = path['skills_onehot'][np.newaxis] no = path['next_states'][np.newaxis] d = path['dones'][np.newaxis] initialized = True else: # print(o.shape) # print(path['states'].shape) o = np.vstack((o, path['states'][np.newaxis])) a = np.vstack((a, path['actions'][np.newaxis])) r = np.vstack((r, path['env_rewards'][np.newaxis])) z = np.vstack((z, path['skills_onehot'][np.newaxis])) no = np.vstack((no, path['next_states'][np.newaxis])) d = np.vstack((d, path['dones'][np.newaxis])) final_context = np.vstack((final_context, context[np.newaxis])) o = torch.as_tensor(o, device=tu.global_device()).float() a = torch.as_tensor(a, device=tu.global_device()).float() r = torch.as_tensor(r, device=tu.global_device()).float() z = torch.as_tensor(z, device=tu.global_device()).float() no = torch.as_tensor(no, device=tu.global_device()).float() d = torch.as_tensor(d, device=tu.global_device()).float() final_context = torch.as_tensor(final_context, device=tu.global_device()).float() if len(indices) == 1: final_context = final_context.unsqueeze(0) return o, a, r, z, no, d, final_context def _sample_skill_path(self): path = self._skills_replay_buffer.sample_path() # TODO: trim or extend batch to the same size o = path['states'] a = path['actions'] r = path['env_rewards'] z = path['skills_onehot'] context = np.hstack((np.hstack((np.hstack((o, a)), r)), z)) if self._use_next_obs_in_context: context = np.hstack((context, path['next_states'])) context = context[np.newaxis] o = path['states'][np.newaxis] a = path['actions'][np.newaxis] r = path['env_rewards'][np.newaxis] z = path['skills_onehot'][np.newaxis] no = path['next_states'][np.newaxis] d = path['dones'][np.newaxis] o = torch.as_tensor(o, device=tu.global_device()).float() a = torch.as_tensor(a, device=tu.global_device()).float() r = torch.as_tensor(r, device=tu.global_device()).float() z = torch.as_tensor(z, device=tu.global_device()).float() no = torch.as_tensor(no, device=tu.global_device()).float() d = torch.as_tensor(d, device=tu.global_device()).float() context = torch.as_tensor(context, device=tu.global_device()).float() context = context.unsqueeze(0) return o, a, r, z, no, d, context def _sample_path_context(self, indices): if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: path = self._context_replay_buffers[idx].sample_path() o = path['states'] a = path['actions'] r = path['env_rewards'] z = path['skills_onehot'] context = np.hstack((np.hstack((np.hstack((o, a)), r)), z)) if self._use_next_obs_in_context: context = np.hstack((context, path['states'])) if not initialized: final_context = context[np.newaxis] initialized = True else: final_context = np.vstack((final_context, context[np.newaxis])) final_context = torch.as_tensor(final_context, device=tu.global_device()).float() if len(indices) == 1: final_context = final_context.unsqueeze(0) return final_context def _update_target_network(self): for target_param, param in zip(self.target_vf.parameters(), self._vf.parameters()): target_param.data.copy_(target_param.data * (1.0 - self._soft_target_tau) + param.data * self._soft_target_tau) def __getstate__(self): data = self.__dict__.copy() del data['_skills_replay_buffer'] del data['_replay_buffers'] del data['_context_replay_buffers'] return data def __setstate__(self, state): self.__dict__.update(state) self._skills_replay_buffer = PathBuffer(self._replay_buffer_size) self._replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(self._replay_buffer_size) for i in range(self._num_train_tasks) } self._is_resuming = True def to(self, device=None): device = device or tu.global_device() for net in self.networks: # print(net) net.to(device) @classmethod def get_env_spec(cls, env_spec, latent_dim, num_skills, module): obs_dim = int(np.prod(env_spec.observation_space.shape)) # print("obs_dim is") # print(obs_dim) action_dim = int(np.prod(env_spec.action_space.shape)) if module == 'encoder': in_dim = obs_dim + action_dim + num_skills + 1 out_dim = latent_dim * 2 elif module == 'vf': in_dim = obs_dim out_dim = latent_dim elif module == 'controller_policy': in_dim = obs_dim + latent_dim out_dim = num_skills elif module == 'qf': in_dim = obs_dim + latent_dim out_dim = num_skills in_space = akro.Box(low=-1, high=1, shape=(in_dim, ), dtype=np.float32) out_space = akro.Box(low=-1, high=1, shape=(out_dim, ), dtype=np.float32) if module == 'encoder': spec = InOutSpec(in_space, out_space) elif module == 'vf': spec = EnvSpec(in_space, out_space) elif module == 'controller_policy': spec = EnvSpec(in_space, out_space) elif module == 'qf': spec = EnvSpec(in_space, out_space) return spec @property def policy(self): return self._controller @property def networks(self): return self._controller.networks + [self._controller] + [ self._qf1, self._qf2, self._vf, self.target_vf ] def get_exploration_policy(self): return self._controller def adapt_policy(self, exploration_policy, exploration_trajectories): total_steps = sum(exploration_trajectories.lengths) o = exploration_trajectories.states a = exploration_trajectories.actions r = exploration_trajectories.env_rewards.reshape(total_steps, 1) s = exploration_trajectories.skills_onehot ctxt = np.hstack((o, a, r, s)).reshape(1, total_steps, -1) context = torch.as_tensor(ctxt, device=tu.global_device()).float() self._controller.infer_posterior(context) return self._controller