def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ device = device or global_device() for net in self.networks: net.to(device) if self._use_automatic_entropy_tuning: if self._single_alpha: self._log_alpha = torch.cuda.FloatTensor( [self._initial_log_entropy], device=global_device()).requires_grad_() else: self._log_alpha = torch.cuda.FloatTensor( [self._initial_log_entropy] * self._num_train_tasks, device=global_device()).requires_grad_() self._alpha_optimizer = self._optimizer_class([self._log_alpha], lr=self._policy_lr) else: if self._single_alpha: self._log_alpha = torch.cuda.FloatTensor( [self._fixed_alpha], device=global_device()).log() else: self._log_alpha = torch.cuda.FloatTensor( [self._fixed_alpha] * self._num_train_tasks, device=global_device()).log()
def update_context(self, timestep): """Append single transition to the current context. Args: timestep (garage._dtypes.TimeStep): Timestep containing transition information to be added to context. """ o = torch.as_tensor(timestep.observation[None, None, ...], device=global_device()).float() a = torch.as_tensor(timestep.action[None, None, ...], device=global_device()).float() r = torch.as_tensor(np.array([timestep.reward])[None, None, ...], device=global_device()).float() no = torch.as_tensor(timestep.next_observation[None, None, ...], device=global_device()).float() if self._use_next_obs: data = torch.cat([o, a, r, no], dim=2) else: data = torch.cat([o, a, r], dim=2) if self._context is None: self._context = data else: self._context = torch.cat([self._context, data], dim=1)
def test_utils_set_gpu_mode(): """Test setting gpu mode to False to force CPU.""" if torch.cuda.is_available(): set_gpu_mode(mode=True) assert global_device() == torch.device('cuda:0') assert tu._USE_GPU else: set_gpu_mode(mode=False) assert global_device() == torch.device('cpu') assert not tu._USE_GPU assert not tu._GPU_ID
def get_actions(self, observations): r"""Get actions given observations. Args: observations (np.ndarray): Observations from the environment. Shape is :math:`batch_dim \bullet env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted actions. :math:`batch_dim \bullet env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution. * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ with torch.no_grad(): if not isinstance(observations, torch.Tensor): observations = torch.as_tensor(observations).float().to( global_device()) if isinstance(self._env_spec.observation_space, akro.Image): observations /= 255.0 # scale image dist, info = self.forward(observations) return dist.sample().cpu().numpy(), { k: v.detach().cpu().numpy() for (k, v) in info.items() }
def get_action(self, observation): r"""Get a single action given an observation. Args: observation (np.ndarray): Observation from the environment. Shape is :math:`env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted action. Shape is :math:`env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ with torch.no_grad(): if not isinstance(observation, torch.Tensor): observation = torch.as_tensor(observation).float().to( global_device()) if isinstance(self._env_spec.observation_space, akro.Image): observation /= 255.0 # scale image observation = observation.unsqueeze(0) dist, info = self.forward(observation) return dist.sample().squeeze(0).cpu().numpy(), { k: v.squeeze(0).detach().cpu().numpy() for (k, v) in info.items() }
def train_sac(ctxt=None): trainer = Trainer(ctxt) env = MyGymEnv(gym_env, max_episode_length=100) policy = CategoricalGRUPolicy(name='policy', env_spec=env.spec, state_include_action=False).to( global_device()) qf1 = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) qf2 = DiscreteMLPQFunction(env_spec=env.spec, hidden_sizes=(8, 5)) replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) sampler = LocalSampler( agents=policy, envs=env, max_episode_length=env.spec.max_episode_length, worker_class=FragmentWorker) self.algo = LoggedSAC(env=env, env_spec=env.spec, policy=policy, qf1=qf1, qf2=qf2, sampler=sampler, gradient_steps_per_itr=1000, max_episode_length_eval=100, replay_buffer=replay_buffer, min_buffer_size=1e4, target_update_tau=5e-3, discount=0.99, buffer_batch_size=256, reward_scale=1., steps_per_epoch=1) trainer.setup(self.algo, env) trainer.train(n_epochs=n_eps, batch_size=4000) return self.algo.rew_chkpts
def get_actions(self, observations): """Get actions given observations. Args: observations (np.ndarray): Observations from the environment. Returns: tuple: * np.ndarray: Predicted actions. * dict: * np.ndarray[float]: Mean of the distribution * np.ndarray[float]: Log of standard deviation of the distribution """ if not isinstance(observations[0], np.ndarray) and not isinstance( observations[0], torch.Tensor): observations = self._env_spec.observation_space.flatten_n( observations) # frequently users like to pass lists of torch tensors or lists of # numpy arrays. This handles those conversions. if isinstance(observations, list): if isinstance(observations[0], np.ndarray): observations = np.stack(observations) elif isinstance(observations[0], torch.Tensor): observations = torch.stack(observations) if isinstance(self._env_spec.observation_space, akro.Image) and \ len(observations.shape) < \ len(self._env_spec.observation_space.shape): observations = self._env_spec.observation_space.unflatten_n( observations) with torch.no_grad(): x = self(torch.Tensor(observations).to(global_device())) return x.cpu().numpy(), dict()
def _get_log_alpha(self, indices): """Return the value of log_alpha. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size) """ log_alpha = self._log_alpha one_hots = np.zeros( (len(indices) * self._batch_size, self._num_train_tasks), dtype=np.float32) for i in range(len(indices)): one_hots[self._batch_size * i:self._batch_size * (i + 1), indices[i]] = 1 one_hots = torch.as_tensor(one_hots, device=global_device()) ret = torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze() return ret
def get_action(self, observation): r"""Get a single action given an observation. Args: observation (np.ndarray): Observation from the environment. Shape is :math:`env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted action. Shape is :math:`env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ if not isinstance(observation, np.ndarray) and not isinstance( observation, torch.Tensor): observation = self._env_spec.observation_space.flatten(observation) with torch.no_grad(): if not isinstance(observation, torch.Tensor): observation = torch.as_tensor(observation).float().to( global_device()) observation = observation.unsqueeze(0) action, agent_infos = self.get_actions(observation) return action[0], {k: v[0] for k, v in agent_infos.items()}
def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ if device is None: device = global_device() for net in self.networks: net.to(device) if not self._use_automatic_entropy_tuning: self._log_alpha = list_to_tensor([self._fixed_alpha ]).log().to(device) else: self._log_alpha = self._log_alpha.detach().to( device).requires_grad_() self._alpha_optimizer = self._optimizer([self._log_alpha], lr=self._policy_lr) self._alpha_optimizer.load_state_dict( state_dict_to(self._alpha_optimizer.state_dict(), device)) self._qf1_optimizer.load_state_dict( state_dict_to(self._qf1_optimizer.state_dict(), device)) self._qf2_optimizer.load_state_dict( state_dict_to(self._qf2_optimizer.state_dict(), device)) self._policy_optimizer.load_state_dict( state_dict_to(self._policy_optimizer.state_dict(), device))
def adapt_policy(self, exploration_policy, exploration_episodes): """Produce a policy adapted for a task. Args: exploration_policy (Policy): A policy which was returned from get_exploration_policy(), and which generated exploration_episodes by interacting with an environment. The caller may not use this object after passing it into this method. exploration_episodes (EpisodeBatch): Episodes to which to adapt, generated by exploration_policy exploring the environment. Returns: Policy: A policy adapted to the task represented by the exploration_episodes. """ total_steps = sum(exploration_episodes.lengths) o = exploration_episodes.observations a = exploration_episodes.actions r = exploration_episodes.rewards.reshape(total_steps, 1) ctxt = np.hstack((o, a, r)).reshape(1, total_steps, -1) context = torch.as_tensor(ctxt, device=global_device()).float() self._policy.infer_posterior(context) return self._policy
def train(self, trainer): """Obtain samplers and start actual training for each epoch. Args: trainer (Trainer): Gives the algorithm the access to :method:`~Trainer.step_epochs()`, which provides services such as snapshotting and sampler control. Returns: float: The average return in last epoch cycle. """ last_return = None for i, _ in enumerate(trainer.step_epochs()): if not self._multitask: trainer.step_path = trainer.obtain_episodes(trainer.step_itr) else: env_updates = None assert self._train_task_sampler is not None if (not i % self._task_update_frequency) or ( self._task_update_frequency == 1): env_updates = self._train_task_sampler.sample( self._num_tasks) trainer.step_path = self.obtain_exact_trajectories( trainer, env_update=env_updates) # do training on GPU if self._gpu_training: prefer_gpu() self.to(device=global_device()) log_dict, last_return = self._train_once(trainer.step_itr, trainer.step_path) # move back to CPU for collection set_gpu_mode(False) self.to(device=global_device()) if self._wandb_logging: # log dict should be a dict, not None log_dict['total_env_steps'] = trainer.total_env_steps wandb.log(log_dict) trainer.step_itr += 1 return last_return
def test_to(): """Test the torch function that moves modules to GPU. Test that the policy and qfunctions are moved to gpu if gpu is available. """ env_names = ['CartPole-v0', 'CartPole-v1'] task_envs = [GarageEnv(env_name=name) for name in env_names] env = MultiEnvWrapper(task_envs, sample_strategy=round_robin_strategy) deterministic.set_seed(0) policy = TanhGaussianMLPPolicy( env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=torch.nn.ReLU, output_nonlinearity=None, min_std=np.exp(-20.), max_std=np.exp(2.), ) qf1 = ContinuousMLPQFunction(env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=F.relu) qf2 = ContinuousMLPQFunction(env_spec=env.spec, hidden_sizes=[1, 1], hidden_nonlinearity=F.relu) replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), ) num_tasks = 2 buffer_batch_size = 2 mtsac = MTSAC(policy=policy, qf1=qf1, qf2=qf2, gradient_steps_per_itr=150, max_path_length=150, eval_env=env, env_spec=env.spec, num_tasks=num_tasks, steps_per_epoch=5, replay_buffer=replay_buffer, min_buffer_size=1e3, target_update_tau=5e-3, discount=0.99, buffer_batch_size=buffer_batch_size) set_gpu_mode(torch.cuda.is_available()) mtsac.to() device = global_device() for param in mtsac._qf1.parameters(): assert param.device == device for param in mtsac._qf2.parameters(): assert param.device == device for param in mtsac._qf2.parameters(): assert param.device == device for param in mtsac.policy.parameters(): assert param.device == device assert mtsac._log_alpha.device == device
def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ device = device or global_device() for net in self.networks: net.to(device)
def compute_kl_div(self): r"""Compute :math:`KL(q(z|c) \| p(z))`. Returns: float: :math:`KL(q(z|c) \| p(z))`. """ prior = torch.distributions.Normal( torch.zeros(self._latent_dim).to(global_device()), torch.ones(self._latent_dim).to(global_device())) posteriors = [ torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip( torch.unbind(self.z_means), torch.unbind(self.z_vars)) ] kl_divs = [ torch.distributions.kl.kl_divergence(post, prior) for post in posteriors ] kl_div_sum = torch.sum(torch.stack(kl_divs)) return kl_div_sum
def _sample_data(self, indices): """Sample batch of training data from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Obervations, with shape :math:`(X, N, O^*)` where X is the number of tasks. N is batch size. torch.Tensor: Actions, with shape :math:`(X, N, A^*)`. torch.Tensor: Rewards, with shape :math:`(X, N, 1)`. torch.Tensor: Next obervations, with shape :math:`(X, N, O^*)`. torch.Tensor: Dones, with shape :math:`(X, N, 1)`. """ # transitions sampled randomly from replay buffer initialized = False for idx in indices: batch = self._replay_buffers[idx].sample_transitions( self._batch_size) if not initialized: o = batch['observations'][np.newaxis] a = batch['actions'][np.newaxis] r = batch['rewards'][np.newaxis] no = batch['next_observations'][np.newaxis] d = batch['dones'][np.newaxis] initialized = True else: o = np.vstack((o, batch['observations'][np.newaxis])) a = np.vstack((a, batch['actions'][np.newaxis])) r = np.vstack((r, batch['rewards'][np.newaxis])) no = np.vstack((no, batch['next_observations'][np.newaxis])) d = np.vstack((d, batch['dones'][np.newaxis])) o = torch.as_tensor(o, device=global_device()).float() a = torch.as_tensor(a, device=global_device()).float() r = torch.as_tensor(r, device=global_device()).float() no = torch.as_tensor(no, device=global_device()).float() d = torch.as_tensor(d, device=global_device()).float() return o, a, r, no, d
def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ if device is None: device = global_device() logger.log('Using device: ' + str(device)) self._qf = self._qf.to(device) self._target_qf = self._target_qf.to(device)
def reset_belief(self, num_tasks=1): r"""Reset :math:`q(z \| c)` to the prior and sample a new z from the prior. Args: num_tasks (int): Number of tasks. """ # reset distribution over z to the prior mu = torch.zeros(num_tasks, self._latent_dim).to(global_device()) if self._use_information_bottleneck: var = torch.ones(num_tasks, self._latent_dim).to(global_device()) else: var = torch.zeros(num_tasks, self._latent_dim).to(global_device()) self.z_means = mu self.z_vars = var # sample a new z from the prior self.sample_from_belief() # reset the context collected so far self._context = None # reset any hidden state in the encoder network (relevant for RNN) self._context_encoder.reset()
def get_actions(self, observations): r"""Get actions given observations. Args: observations (np.ndarray): Observations from the environment. Shape is :math:`batch_dim \bullet env_spec.observation_space`. Returns: tuple: * np.ndarray: Predicted actions. :math:`batch_dim \bullet env_spec.action_space`. * dict: * np.ndarray[float]: Mean of the distribution. * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ if not isinstance(observations[0], np.ndarray) and not isinstance( observations[0], torch.Tensor): observations = self._env_spec.observation_space.flatten_n( observations) # frequently users like to pass lists of torch tensors or lists of # numpy arrays. This handles those conversions. if isinstance(observations, list): if isinstance(observations[0], np.ndarray): observations = np.stack(observations) elif isinstance(observations[0], torch.Tensor): observations = torch.stack(observations) if isinstance(observations[0], np.ndarray) and len(observations[0].shape) > 1: observations = self._env_spec.observation_space.flatten_n( observations) elif isinstance(observations[0], torch.Tensor) and len(observations[0].shape) > 1: observations = torch.flatten(observations, start_dim=1) if isinstance(self._env_spec.observation_space, akro.Image) and \ len(observations.shape) < \ len(self._env_spec.observation_space.shape): observations = self._env_spec.observation_space.unflatten_n( observations) with torch.no_grad(): if not isinstance(observations, torch.Tensor): observations = torch.as_tensor(observations).float().to( global_device()) if isinstance(self._env_spec.observation_space, akro.Image): observations /= 255.0 # scale image dist, info = self.forward(observations) return dist.sample().cpu().numpy(), { k: v.detach().cpu().numpy() for (k, v) in info.items() }
def _sample_context(self, indices): """Sample batch of context from a list of tasks. Args: indices (list): List of task indices to sample from. Returns: torch.Tensor: Context data, with shape :math:`(X, N, C)`. X is the number of tasks. N is batch size. C is the combined size of observation, action, reward, and next observation if next observation is used in context. Otherwise, C is the combined size of observation, action, and reward. """ # make method work given a single task index if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: path = self._context_replay_buffers[idx].sample_path() batch = self.augment_path( path, self._embedding_batch_size, in_sequence=self._embedding_batch_in_sequence) o = batch['observations'] a = batch['actions'] r = batch['rewards'] context = np.hstack((np.hstack((o, a)), r)) if self._use_next_obs_in_context: context = np.hstack((context, batch['next_observations'])) if not initialized: final_context = context[np.newaxis] initialized = True else: new_context = context[np.newaxis] if final_context.shape[1] != new_context.shape[1]: min_length = min(final_context.shape[1], new_context.shape[1]) new_context = new_context[:, :min_length, :] final_context = np.vstack( (final_context[:, :min_length, :], new_context)) final_context = np.vstack((final_context, new_context)) final_context = torch.as_tensor(final_context, device=global_device()).float() if len(indices) == 1: final_context = final_context.unsqueeze(0) return final_context
def _compute_contrastive_loss_new(self, indices): # Optimize CURL encoder context_augs = self._sample_contrastive_pairs(indices, num_aug=2) aug1 = torch.as_tensor(context_augs[0], device=global_device()) aug2 = torch.as_tensor(context_augs[1], device=global_device()) loss_fun = torch.nn.CrossEntropyLoss() # similar_contrastive query = self._context_encoder(aug1, query=True) key = self._context_encoder(aug2, query=False) t, b, d = query.size() query = query.view(t * b, d) key = key.view(t * b, d) if self._use_wasserstein_distance: query_mean = query[:, :self._latent_dim] query_mean_norm = torch.sum((query_mean**2), dim=1).view(-1, 1) key_mean = key[:, :self._latent_dim] key_mean_norm = torch.sum((key_mean**2), dim=1).view(1, -1) mean_dist = query_mean_norm + key_mean_norm - 2.0 * torch.mm( query_mean, key_mean.T) query_var = query[:, self._latent_dim:] query_var_norm = torch.sum((query_var**2), dim=1).view(-1, 1) key_var = key[:, self._latent_dim:] key_var_norm = torch.sum((key_var**2), dim=1).view(1, -1) var_dist = query_var_norm + key_var_norm - 2.0 * torch.mm( query_var, key_var.T) wasserstein_distance = mean_dist + var_dist wasserstein_distance = wasserstein_distance - torch.max( wasserstein_distance, axis=1)[0] labels = torch.as_tensor(np.repeat(indices[None], self._embedding_batch_size, axis=1).flatten(), device=global_device()) # labels = torch.arange(wasserstein_distance.shape[0]).to(global_device()) # Using negative wasserstein distance for lower distance means more similar loss = loss_fun(-wasserstein_distance, labels) else: left_product = torch.matmul( query, self._contrastive_weight.to(global_device())) logits = torch.matmul(left_product, key.T) logits = logits - torch.max(logits, axis=1)[0] if self._use_task_index_label: labels = torch.as_tensor(indices, device=global_device()).view( t, 1).repeat(1, b).view(t * b) else: labels = torch.arange(logits.shape[0]).to(global_device()) loss = loss_fun(logits, labels) return loss
def to(self, device=None): """Put all the networks within the model on device. Args: device (str): ID of GPU or CPU. """ if device is None: device = global_device() for net in self.networks: net.to(device) if not self._use_automatic_entropy_tuning: self._log_alpha = torch.Tensor([self._fixed_alpha ]).log().to(device) else: self._log_alpha = torch.Tensor([self._initial_log_entropy ]).to(device).requires_grad_() self._alpha_optimizer = self._optimizer([self._log_alpha], lr=self._policy_lr)
def get_action(self, obs): """Sample action from the policy, conditioned on the task embedding. Args: obs (torch.Tensor): Observation values, with shape :math:`(1, O)`. O is the size of the flattened observation space. Returns: torch.Tensor: Output action value, with shape :math:`(1, A)`. A is the size of the flattened action space. dict: * np.ndarray[float]: Mean of the distribution. * np.ndarray[float]: Standard deviation of logarithmic values of the distribution. """ z = self.z obs = torch.as_tensor(obs[None], device=global_device()).float() obs_in = torch.cat([obs, z], dim=1) action, info = self._policy.get_action(obs_in) return action, info
def _sample_contrastive_pairs(self, indices, num_aug=2): # make method work given a single task index if not hasattr(indices, '__iter__'): indices = [indices] path_augs = [] for j in range(num_aug): initialized = False for idx in indices: path = self._context_replay_buffers[idx].sample_path() batch_aug = self.augment_path( path, self._embedding_batch_size, in_sequence=self._embedding_batch_in_sequence ) # conduct random path augmentations o = batch_aug['observations'] a = batch_aug['actions'] r = batch_aug['rewards'] context = np.hstack((np.hstack((o, a)), r)) if self._use_next_obs_in_context: context = np.hstack( (context, batch_aug['next_observations'])) if not initialized: final_context = context[np.newaxis] initialized = True else: final_context = np.vstack( (final_context, context[np.newaxis])) final_context = torch.as_tensor(final_context, device=global_device()).float() if len(indices) == 1: final_context = final_context.unsqueeze(0) path_augs.append(final_context) return path_augs
def compute_advantages(discount, gae_lambda, max_episode_length, baselines, rewards): """Calculate advantages. Advantages are a discounted cumulative sum. Calculate advantages using a baseline according to Generalized Advantage Estimation (GAE) The discounted cumulative sum can be computed using conv2d with filter. filter: [1, (discount * gae_lambda), (discount * gae_lambda) ^ 2, ...] where the length is same with max_episode_length. baselines and rewards are also has same shape. baselines: [ [b_11, b_12, b_13, ... b_1n], [b_21, b_22, b_23, ... b_2n], ... [b_m1, b_m2, b_m3, ... b_mn] ] rewards: [ [r_11, r_12, r_13, ... r_1n], [r_21, r_22, r_23, ... r_2n], ... [r_m1, r_m2, r_m3, ... r_mn] ] Args: discount (float): RL discount factor (i.e. gamma). gae_lambda (float): Lambda, as used for Generalized Advantage Estimation (GAE). max_episode_length (int): Maximum length of a single episode. baselines (torch.Tensor): A 2D vector of value function estimates with shape (N, T), where N is the batch dimension (number of episodes) and T is the maximum episode length experienced by the agent. If an episode terminates in fewer than T time steps, the remaining elements in that episode should be set to 0. rewards (torch.Tensor): A 2D vector of per-step rewards with shape (N, T), where N is the batch dimension (number of episodes) and T is the maximum episode length experienced by the agent. If an episode terminates in fewer than T time steps, the remaining elements in that episode should be set to 0. Returns: torch.Tensor: A 2D vector of calculated advantage values with shape (N, T), where N is the batch dimension (number of episodes) and T is the maximum episode length experienced by the agent. If an episode terminates in fewer than T time steps, the remaining values in that episode should be set to 0. """ adv_filter = torch.full((1, 1, 1, max_episode_length - 1), discount * gae_lambda, dtype=torch.float, device=global_device()) adv_filter = torch.cumprod(F.pad(adv_filter, (1, 0), value=1), dim=-1) deltas = (rewards + discount * F.pad(baselines, (0, 1))[:, 1:] - baselines) deltas = F.pad(deltas, (0, max_episode_length - 1)).unsqueeze(0).unsqueeze(0) advantages = F.conv2d(deltas, adv_filter, stride=1).reshape(rewards.shape) return advantages
def __init__(self, env, inner_policy, qf1, qf2, sampler, num_train_tasks, num_test_tasks, latent_dim, encoder_hidden_sizes, test_env_sampler, policy_class=CurlPolicy, encoder_class=MLPEncoder, policy_lr=3E-4, qf_lr=3E-4, context_lr=3E-4, policy_mean_reg_coeff=1E-3, policy_std_reg_coeff=1E-3, policy_pre_activation_coeff=0., soft_target_tau=0.005, kl_lambda=.1, fixed_alpha=None, target_entropy=None, initial_log_entropy=0., optimizer_class=torch.optim.Adam, use_information_bottleneck=True, use_next_obs_in_context=False, use_kl_loss=False, use_q_loss=True, meta_batch_size=64, num_steps_per_epoch=1000, num_initial_steps=100, num_tasks_sample=100, num_steps_prior=100, num_steps_posterior=0, num_extra_rl_steps_posterior=100, batch_size=1024, embedding_batch_size=1024, embedding_mini_batch_size=1024, max_path_length=1000, discount=0.99, replay_buffer_size=1000000, reward_scale=1, embedding_batch_in_sequence=False, encoder_common_net=True, single_alpha=False, use_task_index_label=False, use_wasserstein_distance=True, update_post_train=1): self._env = env self._qf1 = qf1 self._qf2 = qf2 # use 2 target q networks self._target_qf1 = copy.deepcopy(self._qf1) self._target_qf2 = copy.deepcopy(self._qf2) # Contrastive Encoder setting self._embedding_batch_in_sequence = embedding_batch_in_sequence self._num_train_tasks = num_train_tasks self._num_test_tasks = num_test_tasks self._latent_dim = latent_dim self._policy_lr = policy_lr self._qf_lr = qf_lr self._context_lr = context_lr self._policy_mean_reg_coeff = policy_mean_reg_coeff self._policy_std_reg_coeff = policy_std_reg_coeff self._policy_pre_activation_coeff = policy_pre_activation_coeff self._soft_target_tau = soft_target_tau self._kl_lambda = kl_lambda self._use_information_bottleneck = use_information_bottleneck self._use_next_obs_in_context = use_next_obs_in_context self._use_kl_loss = use_kl_loss self._use_q_loss = use_q_loss self._meta_batch_size = meta_batch_size self._num_steps_per_epoch = num_steps_per_epoch self._num_initial_steps = num_initial_steps self._num_tasks_sample = num_tasks_sample self._num_steps_prior = num_steps_prior self._num_steps_posterior = num_steps_posterior self._num_extra_rl_steps_posterior = num_extra_rl_steps_posterior self._batch_size = batch_size self._embedding_batch_size = embedding_batch_size self._embedding_mini_batch_size = embedding_mini_batch_size self.max_path_length = max_path_length self._discount = discount self._replay_buffer_size = replay_buffer_size self._reward_scale = reward_scale self._update_post_train = update_post_train self._task_idx = None self._is_resuming = False self._sampler = sampler self._optimizer_class = optimizer_class # Architecture choice self._encoder_common_net = encoder_common_net self._single_alpha = single_alpha self._use_task_index_label = use_task_index_label self._use_wasserstein_distance = use_wasserstein_distance worker_args = dict(deterministic=True, accum_context=True) self._evaluator = MetaEvaluator(test_task_sampler=test_env_sampler, worker_class=TCLPEARLWorker, worker_args=worker_args, n_test_tasks=num_test_tasks) env_spec = env[0]() encoder_spec = self.get_env_spec(env_spec, latent_dim, 'encoder', use_information_bottleneck) encoder_in_dim = int(np.prod(encoder_spec.input_space.shape)) if self._use_next_obs_in_context: encoder_in_dim += int(np.prod(env[0]().observation_space.shape)) encoder_out_dim = int(np.prod(encoder_spec.output_space.shape)) self._context_encoder = encoder_class( input_dim=encoder_in_dim, output_dim=encoder_out_dim, common_network=self._encoder_common_net, hidden_sizes=encoder_hidden_sizes) if not self._use_wasserstein_distance: self._contrastive_weight = torch.rand(encoder_out_dim, encoder_out_dim, device=global_device(), requires_grad=True) # Automatic entropy coefficient tuning self._use_automatic_entropy_tuning = fixed_alpha is None self._initial_log_entropy = initial_log_entropy self._fixed_alpha = fixed_alpha if self._use_automatic_entropy_tuning: if target_entropy: self._target_entropy = target_entropy else: self._target_entropy = -np.prod( env_spec.action_space.shape).item() if self._single_alpha: self._log_alpha = torch.Tensor([self._initial_log_entropy ]).requires_grad_() else: self._log_alpha = torch.Tensor( [self._initial_log_entropy] * self._num_train_tasks).requires_grad_() self._alpha_optimizer = self._optimizer_class([self._log_alpha], lr=self._policy_lr) else: if self._single_alpha: self._log_alpha = torch.Tensor([self._fixed_alpha]).log() else: self._log_alpha = torch.Tensor([self._fixed_alpha] * self._num_train_tasks).log() self._context_lr = context_lr self._policy = policy_class( latent_dim=latent_dim, context_encoder=self._context_encoder, policy=inner_policy, use_information_bottleneck=use_information_bottleneck, use_next_obs=use_next_obs_in_context) # buffer for training RL update self._replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self._context_replay_buffers = { i: PathBuffer(replay_buffer_size) for i in range(num_train_tasks) } self._policy_optimizer = self._optimizer_class( self._policy.networks[1].parameters(), lr=self._policy_lr, ) self.qf1_optimizer = self._optimizer_class( self._qf1.parameters(), lr=self._qf_lr, ) self.qf2_optimizer = self._optimizer_class( self._qf2.parameters(), lr=self._qf_lr, ) if self._encoder_common_net: self.context_optimizer = self._optimizer_class( self._context_encoder.networks[0].parameters(), lr=self._context_lr, ) if not self._use_wasserstein_distance: self.contrastive_weight_optimizer = self._optimizer_class( [self._contrastive_weight], lr=self._context_lr, ) self.query_optimizer = self._optimizer_class( self._context_encoder.networks[1].parameters(), lr=self._context_lr, )
def to(self, device=None): if device is None: device = global_device() for net in self.networks: net.to(device)
def __init__(self, input_dim, output_dim, hidden_sizes=(32, 32), hidden_nonlinearity=torch.tanh, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, output_nonlinearity=None, output_w_init=nn.init.xavier_uniform_, output_b_init=nn.init.zeros_, learn_std=True, init_std=1.0, min_std=1e-6, max_std=None, std_hidden_sizes=(32, 32), std_hidden_nonlinearity=torch.tanh, std_hidden_w_init=nn.init.xavier_uniform_, std_hidden_b_init=nn.init.zeros_, std_output_nonlinearity=None, std_output_w_init=nn.init.xavier_uniform_, std_parameterization='exp', layer_normalization=False, normal_distribution_cls=Normal): super().__init__() self._input_dim = input_dim self._hidden_sizes = hidden_sizes self._action_dim = output_dim self._learn_std = learn_std self._std_hidden_sizes = std_hidden_sizes self._min_std = min_std self._max_std = max_std self._std_hidden_nonlinearity = std_hidden_nonlinearity self._std_hidden_w_init = std_hidden_w_init self._std_hidden_b_init = std_hidden_b_init self._std_output_nonlinearity = std_output_nonlinearity self._std_output_w_init = std_output_w_init self._std_parameterization = std_parameterization self._hidden_nonlinearity = hidden_nonlinearity self._hidden_w_init = hidden_w_init self._hidden_b_init = hidden_b_init self._output_nonlinearity = output_nonlinearity self._output_w_init = output_w_init self._output_b_init = output_b_init self._layer_normalization = layer_normalization self._norm_dist_class = normal_distribution_cls if self._std_parameterization not in ('exp', 'softplus'): raise NotImplementedError init_std_param = torch.Tensor([init_std]).log() if self._learn_std: self._init_std = torch.nn.Parameter(init_std_param).to( global_device()) else: self._init_std = init_std_param.to(global_device()) self.register_buffer('init_std', self._init_std) self._min_std_param = self._max_std_param = None if min_std is not None: self._min_std_param = torch.Tensor([min_std]).log() self.register_buffer('min_std_param', self._min_std_param) if max_std is not None: self._max_std_param = torch.Tensor([max_std]).log() self.register_buffer('max_std_param', self._max_std_param)
def _optimize_policy(self, samples_data, grad_step_timer): """Perform algorithm optimization. Args: samples_data (dict): Processed batch data. grad_step_timer (int): Iteration number of the gradient time taken in the env. Returns: float: Loss predicted by the q networks (critic networks). float: Q value (min) predicted by one of the target q networks. float: Q value (min) predicted by one of the current q networks. float: Loss predicted by the policy (action network). """ rewards = samples_data['rewards'].to(global_device()).reshape(-1, 1) terminals = samples_data['terminals'].to(global_device()).reshape( -1, 1) actions = samples_data['actions'].to(global_device()) observations = samples_data['observations'].to(global_device()) next_observations = samples_data['next_observations'].to( global_device()) next_inputs = next_observations inputs = observations with torch.no_grad(): # Select action according to policy and add clipped noise noise = (torch.randn_like(actions) * self._policy_noise).clamp( -self._policy_noise_clip, self._policy_noise_clip) next_actions = (self._target_policy(next_inputs) + noise).clamp( -self._max_action, self._max_action) # Compute the target Q value target_Q1 = self._target_qf_1(next_inputs, next_actions) target_Q2 = self._target_qf_2(next_inputs, next_actions) target_q = torch.min(target_Q1, target_Q2) target_Q = rewards * self._reward_scaling + ( 1. - terminals) * self._discount * target_q # Get current Q values current_Q1 = self._qf_1(inputs, actions) current_Q2 = self._qf_2(inputs, actions) current_Q = torch.min(current_Q1, current_Q2) # Compute critic loss critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) # Optimize critic self._qf_optimizer_1.zero_grad() self._qf_optimizer_2.zero_grad() critic_loss.backward() self._qf_optimizer_1.step() self._qf_optimizer_2.step() # Deplay policy updates if grad_step_timer % self._update_actor_interval == 0: # Compute actor loss actions = self.policy(inputs) self._actor_loss = -self._qf_1(inputs, actions).mean() # Optimize actor self._policy_optimizer.zero_grad() self._actor_loss.backward() self._policy_optimizer.step() # update target networks self._update_network_parameters() return (critic_loss.detach(), target_Q, current_Q.detach(), self._actor_loss.detach())