class SAC_V(RlAlgorithm): """TO BE DEPRECATED.""" opt_info_fields = tuple(f for f in OptInfo._fields) # copy def __init__( self, discount=0.99, batch_size=256, min_steps_learn=int(1e4), replay_size=int(1e6), replay_ratio=256, # data_consumption / data_generation target_update_tau=0.005, # tau=1 for hard update. target_update_interval=1, # 1000 for hard update, 1 for soft. learning_rate=3e-4, OptimCls=torch.optim.Adam, optim_kwargs=None, initial_optim_state_dict=None, # for all of them. action_prior="uniform", # or "gaussian" reward_scale=1, reparameterize=True, clip_grad_norm=1e9, policy_output_regularization=0.001, n_step_return=1, updates_per_sync=1, # For async mode only. bootstrap_timelimit=True, ReplayBufferCls=None, # Leave None to select by above options. ): if optim_kwargs is None: optim_kwargs = dict() assert action_prior in ["uniform", "gaussian"] self._batch_size = batch_size del batch_size # Property. save__init__args(locals()) def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0): """Used in basic or synchronous multi-GPU runners, not async.""" self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = int(self.replay_ratio * sampler_bs / self.batch_size) logger.log( f"From sampler batch size {sampler_bs}, training " f"batch size {self.batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_learn = self.min_steps_learn // sampler_bs agent.give_min_itr_learn(self.min_itr_learn) self.initialize_replay_buffer(examples, batch_spec) self.optim_initialize(rank) def async_initialize(self, agent, sampler_n_itr, batch_spec, mid_batch_reset, examples, world_size=1): """Used in async runner only.""" self.agent = agent self.n_itr = sampler_n_itr self.initialize_replay_buffer(examples, batch_spec, async_=True) self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = self.updates_per_sync self.min_itr_learn = int(self.min_steps_learn // sampler_bs) agent.give_min_itr_learn(self.min_itr_learn) return self.replay_buffer def optim_initialize(self, rank=0): """Called by async runner.""" self.rank = rank self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.v_optimizer = self.OptimCls(self.agent.v_parameters(), lr=self.learning_rate, **self.optim_kwargs) if self.initial_optim_state_dict is not None: self.load_optim_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian( dim=self.agent.env_spaces.action.size, std=1.) def initialize_replay_buffer(self, examples, batch_spec, async_=False): example_to_buffer = self.examples_to_buffer(examples) replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) if not self.bootstrap_timelimit: ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer else: ReplayCls = AsyncTlUniformReplayBuffer if async_ else TlUniformReplayBuffer if self.ReplayBufferCls is not None: ReplayCls = self.ReplayBufferCls logger.log( f"WARNING: ignoring internal selection logic and using" f" input replay buffer class: {ReplayCls} -- compatibility not" " guaranteed.") self.replay_buffer = ReplayCls(**replay_kwargs) def optimize_agent(self, itr, samples=None, sampler_itr=None): itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr. if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): samples_from_replay = self.replay_buffer.sample_batch( self.batch_size) losses, values = self.loss(samples_from_replay) q1_loss, q2_loss, v_loss, pi_loss = losses self.v_optimizer.zero_grad() v_loss.backward() v_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.v_parameters(), self.clip_grad_norm) self.v_optimizer.step() self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.pi_parameters(), self.clip_grad_norm) self.pi_optimizer.step() # Step Q's last because pi_loss.backward() uses them? self.q1_optimizer.zero_grad() q1_loss.backward() q1_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q1_parameters(), self.clip_grad_norm) self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() q2_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q2_parameters(), self.clip_grad_norm) self.q2_optimizer.step() grad_norms = (q1_grad_norm, q2_grad_norm, v_grad_norm, pi_grad_norm) self.append_opt_info_(opt_info, losses, grad_norms, values) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_target(self.target_update_tau) return opt_info def samples_to_buffer(self, samples): return SamplesToBuffer( observation=samples.env.observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, timeout=getattr(samples.env.env_info, "timeout", None), ) def examples_to_buffer(self, examples): """Defines how to initialize the replay buffer from examples. Called in initialize_replay_buffer(). """ return SamplesToBuffer( observation=examples["observation"], action=examples["action"], reward=examples["reward"], done=examples["done"], timeout=getattr(examples["env_info"], "timeout", None), ) def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs) disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= (1 - samples.timeout_n.float()) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) v = self.agent.v(*agent_inputs) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) v_target = (min_log_target - log_pi + prior_log_pi).detach() # No grad. v_loss = 0.5 * valid_mean((v - v_target)**2, valid) if self.reparameterize: pi_losses = log_pi - min_log_target else: pi_factor = (v - v_target).detach() pi_losses = log_pi * pi_factor if self.policy_output_regularization > 0: pi_losses += self.policy_output_regularization * torch.mean( 0.5 * pi_mean**2 + 0.5 * pi_log_std**2, dim=-1) pi_loss = valid_mean(pi_losses, valid) losses = (q1_loss, q2_loss, v_loss, pi_loss) values = tuple(val.detach() for val in (q1, q2, v, pi_mean, pi_log_std)) return losses, values # def q_loss(self, samples): # """Samples have leading batch dimension [B,..] (but not time).""" # agent_inputs, target_inputs, action = buffer_to( # (samples.agent_inputs, samples.target_inputs, samples.action), # device=self.agent.device) # Move to device once, re-use. # q1, q2 = self.agent.q(*agent_inputs, action) # with torch.no_grad(): # target_v = self.agent.target_v(*target_inputs) # disc = self.discount ** self.n_step_return # y = (self.reward_scale * samples.return_ + # (1 - samples.done_n.float()) * disc * target_v) # if self.mid_batch_reset and not self.agent.recurrent: # valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) # else: # valid = valid_from_done(samples.done) # q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid) # q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid) # losses = (q1_loss, q2_loss) # values = tuple(val.detach() for val in (q1, q2)) # return losses, values, agent_inputs, valid # def pi_v_loss(self, agent_inputs, valid): # v = self.agent.v(*agent_inputs) # new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) # if not self.reparameterize: # new_action = new_action.detach() # No grad. # log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) # min_log_target = torch.min(log_target1, log_target2) # prior_log_pi = self.get_action_prior(new_action.cpu()) # v_target = (min_log_target - log_pi + prior_log_pi).detach() # No grad. # v_loss = 0.5 * valid_mean((v - v_target) ** 2, valid) # if self.reparameterize: # pi_losses = log_pi - min_log_target # log_target1 # min_log_target # else: # pi_factor = (v - v_target).detach() # No grad. # pi_losses = log_pi * pi_factor # if self.policy_output_regularization > 0: # pi_losses += self.policy_output_regularization * torch.sum( # 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1) # pi_loss = valid_mean(pi_losses, valid) # losses = (v_loss, pi_loss) # values = tuple(val.detach() for val in (v, pi_mean, pi_log_std)) # return losses, values # def loss(self, samples): # """Samples have leading batch dimension [B,..] (but not time).""" # agent_inputs, target_inputs, action = buffer_to( # (samples.agent_inputs, samples.target_inputs, samples.action), # device=self.agent.device) # Move to device once, re-use. # q1, q2 = self.agent.q(*agent_inputs, action) # with torch.no_grad(): # target_v = self.agent.target_v(*target_inputs) # disc = self.discount ** self.n_step_return # y = (self.reward_scale * samples.return_ + # (1 - samples.done_n.float()) * disc * target_v) # if self.mid_batch_reset and not self.agent.recurrent: # valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) # else: # valid = valid_from_done(samples.done) # q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid) # q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid) # v = self.agent.v(*agent_inputs) # new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) # if not self.reparameterize: # new_action = new_action.detach() # No grad. # log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) # min_log_target = torch.min(log_target1, log_target2) # prior_log_pi = self.get_action_prior(new_action.cpu()) # v_target = (min_log_target - log_pi + prior_log_pi).detach() # No grad. # v_loss = 0.5 * valid_mean((v - v_target) ** 2, valid) # if self.reparameterize: # pi_losses = log_pi - min_log_target # log_target1 # else: # pi_factor = (v - v_target).detach() # No grad. # pi_losses = log_pi * pi_factor # if self.policy_output_regularization > 0: # pi_losses += torch.sum(self.policy_output_regularization * 0.5 * # pi_mean ** 2 + pi_log_std ** 2, dim=-1) # pi_loss = valid_mean(pi_losses, valid) # losses = (q1_loss, q2_loss, v_loss, pi_loss) # values = tuple(val.detach() for val in (q1, q2, v, pi_mean, pi_log_std)) # return losses, values def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def append_opt_info_(self, opt_info, losses, grad_norms, values): """In-place.""" q1_loss, q2_loss, v_loss, pi_loss = losses q1_grad_norm, q2_grad_norm, v_grad_norm, pi_grad_norm = grad_norms q1, q2, v, pi_mean, pi_log_std = values opt_info.q1Loss.append(q1_loss.item()) opt_info.q2Loss.append(q2_loss.item()) opt_info.vLoss.append(v_loss.item()) opt_info.piLoss.append(pi_loss.item()) opt_info.q1GradNorm.append( torch.tensor(q1_grad_norm).item()) # backwards compatible opt_info.q2GradNorm.append( torch.tensor(q2_grad_norm).item()) # backwards compatible opt_info.vGradNorm.append( torch.tensor(v_grad_norm).item()) # backwards compatible opt_info.piGradNorm.append( torch.tensor(pi_grad_norm).item()) # backwards compatible opt_info.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info.q2.extend(q2[::10].numpy()) opt_info.v.extend(v[::10].numpy()) opt_info.piMu.extend(pi_mean[::10].numpy()) opt_info.piLogStd.extend(pi_log_std[::10].numpy()) opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item()) def optim_state_dict(self): return dict( pi_optimizer=self.pi_optimizer.state_dict(), q1_optimizer=self.q1_optimizer.state_dict(), q2_optimizer=self.q2_optimizer.state_dict(), v_optimizer=self.v_optimizer.state_dict(), ) def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"]) self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"]) self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"]) self.v_optimizer.load_state_dict(state_dict["v_optimizer"])
class SACDiscrete(RlAlgorithm): """Soft actor critic algorithm, training from a replay buffer.""" opt_info_fields = tuple(f for f in OptInfo._fields) # copy def __init__( self, discount=0.99, batch_size=256, min_steps_learn=int(1e4), replay_size=int(1e6), replay_ratio=256, # data_consumption / data_generation target_update_tau=0.005, # tau=1 for hard update. target_update_interval=1, # 1000 for hard update, 1 for soft. learning_rate=3e-4, fixed_alpha=None, # None for adaptive alpha, float for any fixed value OptimCls=torch.optim.Adam, optim_kwargs=None, initial_optim_state_dict=None, # for all of them. action_prior="uniform", # or "gaussian" reward_scale=1, target_entropy="auto", # "auto", float, or None reparameterize=True, clip_grad_norm=1e9, # policy_output_regularization=0.001, n_step_return=1, updates_per_sync=1, # For async mode only. bootstrap_timelimit=False, ReplayBufferCls=None, # Leave None to select by above options. ): """Save input arguments.""" if optim_kwargs is None: optim_kwargs = dict() assert action_prior in ["uniform", "gaussian"] self._batch_size = batch_size del batch_size # Property. save__init__args(locals()) def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0): """Stores input arguments and initializes replay buffer and optimizer. Use in non-async runners. Computes number of gradient updates per optimization iteration as `(replay_ratio * sampler-batch-size / training-batch_size)`.""" self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = int(self.replay_ratio * sampler_bs / self.batch_size) logger.log( f"From sampler batch size {sampler_bs}, training " f"batch size {self.batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_learn = self.min_steps_learn // sampler_bs agent.give_min_itr_learn(self.min_itr_learn) self.initialize_replay_buffer(examples, batch_spec) self.optim_initialize(rank) def async_initialize(self, agent, sampler_n_itr, batch_spec, mid_batch_reset, examples, world_size=1): """Used in async runner only; returns replay buffer allocated in shared memory, does not instantiate optimizer. """ self.agent = agent self.n_itr = sampler_n_itr self.initialize_replay_buffer(examples, batch_spec, async_=True) self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = self.updates_per_sync self.min_itr_learn = int(self.min_steps_learn // sampler_bs) agent.give_min_itr_learn(self.min_itr_learn) return self.replay_buffer def optim_initialize(self, rank=0): """Called in initilize or by async runner after forking sampler.""" self.rank = rank self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(), lr=self.learning_rate, **self.optim_kwargs) if self.fixed_alpha is None: self.target_entropy = -np.log( (1.0 / self.agent.env_spaces.action.n)) * 0.98 self._log_alpha = torch.zeros(1, requires_grad=True) self._alpha = self._log_alpha.exp() self.alpha_optimizer = self.OptimCls((self._log_alpha, ), lr=self.learning_rate, **self.optim_kwargs) else: self._log_alpha = torch.tensor([np.log(self.fixed_alpha)]) self._alpha = torch.tensor([self.fixed_alpha]) self.alpha_optimizer = None if self.target_entropy == "auto": self.target_entropy = -np.prod(self.agent.env_spaces.action.n) if self.initial_optim_state_dict is not None: self.load_optim_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian(dim=np.prod( self.agent.env_spaces.action.shape), std=1.) def initialize_replay_buffer(self, examples, batch_spec, async_=False): """ Allocates replay buffer using examples and with the fields in `SamplesToBuffer` namedarraytuple. """ example_to_buffer = SamplesToBuffer( observation=examples["observation"], action=examples["action"], reward=examples["reward"], done=examples["done"], ) if not self.bootstrap_timelimit: ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer else: example_to_buffer = SamplesToBufferTl( *example_to_buffer, timeout=examples["env_info"].timeout) ReplayCls = AsyncTlUniformReplayBuffer if async_ else TlUniformReplayBuffer replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) if self.ReplayBufferCls is not None: ReplayCls = self.ReplayBufferCls logger.log( f"WARNING: ignoring internal selection logic and using" f" input replay buffer class: {ReplayCls} -- compatibility not" " guaranteed.") self.replay_buffer = ReplayCls(**replay_kwargs) def optimize_agent(self, itr, samples=None, sampler_itr=None): """ Extracts the needed fields from input samples and stores them in the replay buffer. Then samples from the replay buffer to train the agent by gradient updates (with the number of updates determined by replay ratio, sampler batch size, and training batch size). """ itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr. if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): samples_from_replay = self.replay_buffer.sample_batch( self.batch_size) losses, values = self.loss(samples_from_replay) q1_loss, q2_loss, pi_loss, alpha_loss = losses if alpha_loss is not None: self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self._alpha = torch.exp(self._log_alpha.detach()) self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.pi_parameters(), self.clip_grad_norm) self.pi_optimizer.step() # Step Q's last because pi_loss.backward() uses them? self.q1_optimizer.zero_grad() q1_loss.backward() q1_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q1_parameters(), self.clip_grad_norm) self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() q2_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q2_parameters(), self.clip_grad_norm) self.q2_optimizer.step() grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm) self.append_opt_info_(opt_info, losses, grad_norms, values) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_target(self.target_update_tau) return opt_info def samples_to_buffer(self, samples): """Defines how to add data from sampler into the replay buffer. Called in optimize_agent() if samples are provided to that method.""" samples_to_buffer = SamplesToBuffer( observation=samples.env.observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, ) if self.bootstrap_timelimit: samples_to_buffer = SamplesToBufferTl( *samples_to_buffer, timeout=samples.env.env_info.timeout) return samples_to_buffer def loss(self, samples): """ Computes losses for twin Q-values against the min of twin target Q-values and an entropy term. Computes reparameterized policy loss, and loss for tuning entropy weighting, alpha. Input samples have leading batch dimension [B,..] (but not time). """ agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) # or None else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= (1 - samples.timeout_n.float()) with torch.no_grad(): target_action, target_action_probs, target_log_pi, _ = self.agent.pi( *target_inputs) target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action) min_target_q = torch.min(target_q1, target_q2) target_value = target_action_probs * (min_target_q - self._alpha * target_log_pi) target_value = target_value.sum(dim=1).unsqueeze(-1) disc = self.discount**self.n_step_return y = self.reward_scale * samples.return_ + ( 1 - samples.done_n.float()) * disc * target_value q1, q2 = self.agent.q(*agent_inputs, action) q1 = torch.gather(q1, 1, action.unsqueeze(1).long()) q2 = torch.gather(q2, 1, action.unsqueeze(1).long()) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) action, action_probs, log_pi, _ = self.agent.pi(*agent_inputs) q1_pi, q2_pi = self.agent.q(*agent_inputs, action) min_pi_target = torch.min(q1_pi, q2_pi) inside_term = self._alpha * log_pi - min_pi_target policy_loss = (action_probs * inside_term).sum(dim=1).mean() log_pi = torch.sum(log_pi * action_probs, dim=1) # if self.policy_output_regularization > 0: # pi_losses += self.policy_output_regularization * torch.mean( # 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1) pi_loss = valid_mean(policy_loss, valid) if self.target_entropy is not None and self.fixed_alpha is None: alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None losses = (q1_loss, q2_loss, pi_loss, alpha_loss) values = tuple(val.detach() for val in (q1, q2, action_probs)) return losses, values def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def append_opt_info_(self, opt_info, losses, grad_norms, values): """In-place.""" q1_loss, q2_loss, pi_loss, alpha_loss = losses q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms q1, q2, action_probs = values opt_info.q1Loss.append(q1_loss.item()) opt_info.q2Loss.append(q2_loss.item()) opt_info.piLoss.append(pi_loss.item()) opt_info.q1GradNorm.append( torch.tensor(q1_grad_norm).item()) # backwards compatible opt_info.q2GradNorm.append( torch.tensor(q2_grad_norm).item()) # backwards compatible opt_info.piGradNorm.append( torch.tensor(pi_grad_norm).item()) # backwards compatible opt_info.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info.q2.extend(q2[::10].numpy()) opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item()) opt_info.alpha.append(self._alpha.item()) def optim_state_dict(self): return dict( pi_optimizer=self.pi_optimizer.state_dict(), q1_optimizer=self.q1_optimizer.state_dict(), q2_optimizer=self.q2_optimizer.state_dict(), alpha_optimizer=self.alpha_optimizer.state_dict() if self.alpha_optimizer else None, log_alpha=self._log_alpha.detach().item(), ) def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"]) self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"]) self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"]) if self.alpha_optimizer is not None and state_dict[ "alpha_optimizer"] is not None: self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"]) with torch.no_grad(): self._log_alpha[:] = state_dict["log_alpha"] self._alpha = torch.exp(self._log_alpha.detach())
class SAC(RlAlgorithm): opt_info_fields = None def __init__( self, discount=0.99, batch_size=256, min_steps_learn=int(1e4), replay_size=int(6e5), replay_ratio=256, # data_consumption / data_generation target_update_tau=0.005, # tau=1 for hard update. target_update_interval=1, # interval=1000 for hard update. learning_rate=3e-4, OptimCls=torch.optim.Adam, optim_kwargs=None, initial_optim_state_dict=None, # for pi only. action_prior="uniform", # or "gaussian" reward_scale=1, reparameterize=True, clip_grad_norm=1e9, policy_output_regularization=0.001, n_step_return=1, updates_per_sync=1, # For async mode only. target_entropy='auto', ): if optim_kwargs is None: optim_kwargs = dict() assert action_prior in ["uniform", "gaussian"] self._batch_size = batch_size del batch_size # Property. save__init__args(locals()) def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0): """Used in basic or synchronous multi-GPU runners, not async.""" self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = int(self.replay_ratio * sampler_bs / self.batch_size) logger.log(f"From sampler batch size {sampler_bs}, training " f"batch size {self.batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_learn = self.min_steps_learn // sampler_bs agent.give_min_itr_learn(self.min_itr_learn) print('batch_spec:', batch_spec, '\n\n') self.initialize_replay_buffer(examples, batch_spec) self.optim_initialize(rank) if self.target_entropy == 'auto': self.target_entropy = -np.prod(self.agent.env_spaces.action.shape) keys = ["piLoss", "alphaLoss", "piMu", "piLogStd", "alpha", "piGradNorm"] keys += [f'q{i}GradNorm' for i in range(self.agent.n_qs)] keys += [f'q{i}' for i in range(self.agent.n_qs)] keys += [f'q{i}Loss' for i in range(self.agent.n_qs)] global OptInfo OptInfo = namedtuple('OptInfo', keys) SAC.opt_info_fields = tuple(f for f in OptInfo._fields) # copy def async_initialize(self, agent, sampler_n_itr, batch_spec, mid_batch_reset, examples, world_size=1): """Used in async runner only.""" self.agent = agent self.n_itr = sampler_n_itr self.initialize_replay_buffer(examples, batch_spec, async_=True) self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = self.updates_per_sync self.min_itr_learn = int(self.min_steps_learn // sampler_bs) agent.give_min_itr_learn(self.min_itr_learn) return self.replay_buffer def optim_initialize(self, rank=0): """Called by async runner.""" self.rank = rank self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q_optimizers = [self.OptimCls(q_param) for q_param in self.agent.q_parameters()] self.alpha_optimizer = self.OptimCls([self.agent.log_alpha], lr=self.learning_rate, **self.optim_kwargs) if self.initial_optim_state_dict is not None: self.pi_optimizer.load_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian( dim=self.agent.env_spaces.action.size, std=1.) def initialize_replay_buffer(self, examples, batch_spec, async_=False): example_to_buffer = SamplesToBuffer( observation=examples["observation"], action=examples["action"], reward=examples["reward"], done=examples["done"], ) replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer self.replay_buffer = ReplayCls(**replay_kwargs) def optimize_agent(self, itr, samples=None, sampler_itr=None): itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr. if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): samples_from_replay = self.replay_buffer.sample_batch(self.batch_size) q_losses, losses, values, q_values = self.loss(samples_from_replay) pi_loss, alpha_loss = losses self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.pi_parameters(), self.clip_grad_norm) self.pi_optimizer.step() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() q_grad_norms = [] for q_opt, q_loss, q_param in zip(self.q_optimizers, q_losses, self.agent.q_parameters()): q_opt.zero_grad() q_loss.backward() q_grad_norm = torch.nn.utils.clip_grad_norm_(q_param, self.clip_grad_norm) q_opt.step() q_grad_norms.append(q_grad_norm) self.append_opt_info_(opt_info, q_losses, losses, q_grad_norms, pi_grad_norm, q_values, values) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_target(self.target_update_tau) return opt_info def samples_to_buffer(self, samples): return SamplesToBuffer( observation=samples.env.observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, ) def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) qs = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs).detach() disc = self.discount ** self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) q_losses = [0.5 * valid_mean((y - q) ** 2, valid) for q in qs] new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_targets = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(torch.stack(log_targets, dim=0), dim=0)[0] prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: alpha = self.agent.log_alpha.exp().detach() pi_losses = alpha * log_pi - min_log_target - prior_log_pi if self.policy_output_regularization > 0: pi_losses += torch.sum(self.policy_output_regularization * 0.5 * pi_mean ** 2 + pi_log_std ** 2, dim=-1) pi_loss = valid_mean(pi_losses, valid) # Calculate log_alpha loss alpha_loss = -valid_mean(self.agent.log_alpha * (log_pi + self.target_entropy).detach()) losses = (pi_loss, alpha_loss) values = tuple(val.detach() for val in (pi_mean, pi_log_std, alpha)) q_values = tuple(q.detach() for q in qs) return q_losses, losses, values, q_values def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def append_opt_info_(self, opt_info, q_losses, losses, q_grad_norms, pi_grad_norm, q_values, values): """In-place.""" pi_loss, alpha_loss = losses pi_mean, pi_log_std, alpha = values for i in range(self.agent.n_qs): getattr(opt_info, f'q{i}Loss').append(q_losses[i].item()) getattr(opt_info, f'q{i}').extend(q_values[i][::10].numpy()) getattr(opt_info, f'q{i}GradNorm').append(q_grad_norms[i]) opt_info.piLoss.append(pi_loss.item()) opt_info.alphaLoss.append(alpha_loss.item()) opt_info.piGradNorm.append(pi_grad_norm) opt_info.piMu.extend(pi_mean[::10].numpy()) opt_info.piLogStd.extend(pi_log_std[::10].numpy()) opt_info.alpha.append(alpha.numpy()) def optim_state_dict(self): rtn = dict( pi_optimizer=self.pi_optimizer.state_dict(), alpha_optimizer=self.alpha_optimizer.state_dict(), ) rtn.update({f'q{i}_optimizer': q_opt.state_dict() for i, q_opt in enumerate(self.q_optimizers)}) return rtn def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"]) self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"]) [q_opt.load_state_dict(state_dict[f'q{i}_optimizer']) for i, q_opt in enumerate(self.q_optimizers)]
class SAC_LSTM(RlAlgorithm): """Soft actor critic algorithm, training from a replay buffer.""" opt_info_fields = tuple(f for f in OptInfo._fields) # copy def __init__( self, discount=0.99, batch_T=80, batch_B=16, warmup_T=40, min_steps_learn=int(1e5), replay_size=int(1e6), replay_ratio=4, # data_consumption / data_generation store_rnn_state_interval=40, target_update_tau=0.005, # tau=1 for hard update. target_update_interval=1, # 1000 for hard update, 1 for soft. learning_rate=3e-4, fixed_alpha=None, # None for adaptive alpha, float for any fixed value OptimCls=torch.optim.Adam, optim_kwargs=None, initial_optim_state_dict=None, # for all of them. initial_replay_buffer_dict=None, action_prior="uniform", # or "gaussian" reward_scale=1, target_entropy="auto", # "auto", float, or None reparameterize=True, clip_grad_norm=1e3, n_step_return=5, ReplayBufferCls=None, # Leave None to select by above options. ): """ Save input arguments. Args: store_rnn_state_interval (int): store RNN state only once this many steps, to reduce memory usage; replay sequences will only begin at the steps with stored recurrent state. Note: Typically ran with ``store_rnn_state_interval`` equal to the sampler's ``batch_T``, 40. Then every 40 steps can be the beginning of a replay sequence, and will be guaranteed to start with a valid RNN state. Only reset the RNN state (and env) at the end of the sampler batch, so that the beginnings of episodes are trained on. """ if optim_kwargs is None: optim_kwargs = dict() assert action_prior in ["uniform", "gaussian"] save__init__args(locals()) self._batch_size = (self.batch_T + self.warmup_T) * self.batch_B def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0): """Stores input arguments and initializes replay buffer and optimizer. Use in non-async runners. Computes number of gradient updates per optimization iteration as `(replay_ratio * sampler-batch-size / training-batch_size)`.""" self.agent = agent self.n_itr = n_itr # num_itr self.sampler_bs = sampler_bs = batch_spec.size # num_step_per_batch self.mid_batch_reset = mid_batch_reset # True self.updates_per_optimize = max( 1, round(self.replay_ratio * sampler_bs / self._batch_size)) logger.log( f"From sampler batch size {batch_spec.size}, training " f"batch size {self._batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_learn = int(self.min_steps_learn // sampler_bs) agent.give_min_itr_learn( self.min_itr_learn) # filling up replay_buffer self.initialize_replay_buffer(examples, batch_spec) self.optim_initialize(rank) def optim_initialize(self, rank=0): """Called in initilize or by async runner after forking sampler.""" self.rank = rank self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(), lr=self.learning_rate, **self.optim_kwargs) if self.fixed_alpha is None: self._log_alpha = torch.zeros(1, requires_grad=True) self._alpha = torch.exp(self._log_alpha.detach()) self.alpha_optimizer = self.OptimCls((self._log_alpha, ), lr=self.learning_rate, **self.optim_kwargs) else: self._log_alpha = torch.tensor([np.log(self.fixed_alpha)]) self._alpha = torch.tensor([self.fixed_alpha]) self.alpha_optimizer = None if self.target_entropy == "auto": self.target_entropy = -np.prod(self.agent.env_spaces.action.shape) if self.initial_optim_state_dict is not None: self.load_optim_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian(dim=np.prod( self.agent.env_spaces.action.shape), std=1.) def initialize_replay_buffer(self, examples, batch_spec): """ Allocates replay buffer using examples and with the fields in `SamplesToBuffer` namedarraytuple. """ # hidden_in, hidden_out, state, action, last_action, reward, next_state, done = self.replay_buffer.sample(batch_size) # print(examples["agent_info"]) example_to_buffer = SamplesToBufferLSTM( observation=examples["observation"], action=examples["action"], reward=examples["reward"], done=examples["done"], prev_rnn_state=examples["agent_info"].prev_rnn_state, ) ReplayCls = UniformSequenceReplayFrameBuffer replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, discount=self.discount, n_step_return=self.n_step_return, rnn_state_interval=self.store_rnn_state_interval, initial_replay_buffer_dict=self.initial_replay_buffer_dict, batch_T=self.batch_T + self.warmup_T, ) self.replay_buffer = ReplayCls(**replay_kwargs) def optimize_agent(self, itr, samples=None): """ Extracts the needed fields from input samples and stores them in the replay buffer. Then samples from the replay buffer to train the agent by gradient updates (with the number of updates determined by replay ratio, sampler batch size, and training batch size). """ # Update replay buffer if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): samples_from_replay = self.replay_buffer.sample_batch(self.batch_B) losses, values = self.loss(samples_from_replay) q1_loss, q2_loss, pi_loss, alpha_loss = losses if alpha_loss is not None: self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self._alpha = torch.exp(self._log_alpha.detach()) self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.pi_parameters(), self.clip_grad_norm) self.pi_optimizer.step() # Step Q's last because pi_loss.backward() uses them? self.q1_optimizer.zero_grad() q1_loss.backward() q1_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q1_parameters(), self.clip_grad_norm) self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() q2_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q2_parameters(), self.clip_grad_norm) self.q2_optimizer.step() grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm) self.append_opt_info_(opt_info, losses, grad_norms, values) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_target(self.target_update_tau) return opt_info def samples_to_buffer(self, samples): """Defines how to add data from sampler into the replay buffer. Called in optimize_agent() if samples are provided to that method.""" samples_to_buffer = SamplesToBufferLSTM( observation=samples.env.observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, prev_rnn_state=samples.agent.agent_info.prev_rnn_state, ) return samples_to_buffer def loss(self, samples): """ Computes losses for twin Q-values against the min of twin target Q-values and an entropy term. Computes reparameterized policy loss, and loss for tuning entropy weighting, alpha. Input samples have leading batch dimension [B,..] (but not time). """ # SamplesFromReplay = namedarraytuple("SamplesFromReplay", # ["all_observation", "all_action", "all_reward", "return_", "done", "done_n", "init_rnn_state"]) all_observation, all_action, all_reward = buffer_to( (samples.all_observation, samples.all_action, samples.all_reward), device=self.agent.device) # all have (wT + bT + nsr) x bB wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return if wT > 0: warmup_slice = slice(None, wT) # Same for agent and target. warmup_inputs = AgentInputs( observation=all_observation[warmup_slice], prev_action=all_action[warmup_slice], prev_reward=all_reward[warmup_slice], ) agent_slice = slice(wT, wT + bT) agent_inputs = AgentInputs( observation=all_observation[agent_slice], prev_action=all_action[agent_slice], prev_reward=all_reward[agent_slice], ) target_slice = slice(wT, None) # Same start t as agent. (wT + bT + nsr) target_inputs = AgentInputs( observation=all_observation[target_slice], prev_action=all_action[target_slice], prev_reward=all_reward[target_slice], ) warmup_action = samples.all_action[1:wT + 1] action = samples.all_action[ wT + 1:wT + 1 + bT] # 'current' action by shifting index by 1 from prev_action return_ = samples.return_[wT:wT + bT] done_n = samples.done_n[wT:wT + bT] if self.store_rnn_state_interval == 0: init_rnn_state = None else: # [B,N,H]-->[N,B,H] cudnn. init_rnn_state = buffer_method(samples.init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") if wT > 0: # Do warmup. with torch.no_grad(): _, target_q1_rnn_state, _, target_q2_rnn_state = self.agent.target_q( *warmup_inputs, warmup_action, init_rnn_state, init_rnn_state) _, _, _, init_rnn_state = self.agent.pi( *warmup_inputs, init_rnn_state) # Recommend aligning sampling batch_T and store_rnn_interval with # warmup_T (and no mid_batch_reset), so that end of trajectory # during warmup leads to new trajectory beginning at start of # training segment of replay. warmup_invalid_mask = valid_from_done( samples.done[:wT])[-1] == 0 # [B] init_rnn_state[:, warmup_invalid_mask] = 0 # [N,B,H] (cudnn) target_q1_rnn_state[:, warmup_invalid_mask] = 0 target_q2_rnn_state[:, warmup_invalid_mask] = 0 else: target_q1_rnn_state = init_rnn_state target_q2_rnn_state = init_rnn_state valid = valid_from_done(samples.done)[-bT:] q1, _, q2, _ = self.agent.q(*agent_inputs, action, init_rnn_state, init_rnn_state) with torch.no_grad(): target_action, target_log_pi, _, _ = self.agent.pi( *target_inputs, init_rnn_state) target_q1, _, target_q2, _ = self.agent.target_q( *target_inputs, target_action, target_q1_rnn_state, target_q2_rnn_state) target_q1 = target_q1[-bT:] # Same length as q. target_q2 = target_q2[-bT:] target_log_pi = target_log_pi[-bT:] min_target_q = torch.min(target_q1, target_q2) target_value = min_target_q - self._alpha * target_log_pi disc = self.discount**self.n_step_return y = (self.reward_scale * return_ + (1 - done_n.float()) * disc * target_value) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) new_action, log_pi, (pi_mean, pi_log_std), _ = self.agent.pi( *agent_inputs, init_rnn_state) log_target1, _, log_target2, _ = self.agent.q(*agent_inputs, new_action, init_rnn_state, init_rnn_state) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi pi_loss = valid_mean(pi_losses, valid) if self.target_entropy is not None and self.fixed_alpha is None: alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None losses = (q1_loss, q2_loss, pi_loss, alpha_loss) values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std)) return losses, values def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def append_opt_info_(self, opt_info, losses, grad_norms, values): """In-place.""" q1_loss, q2_loss, pi_loss, alpha_loss = losses q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms q1, q2, pi_mean, pi_log_std = values opt_info.q1Loss.append(q1_loss.item()) opt_info.q2Loss.append(q2_loss.item()) opt_info.piLoss.append(pi_loss.item()) opt_info.q1GradNorm.append( q1_grad_norm.clone().detach().item()) # backwards compatible opt_info.q2GradNorm.append( q2_grad_norm.clone().detach().item()) # backwards compatible opt_info.piGradNorm.append( pi_grad_norm.clone().detach().item()) # backwards compatible opt_info.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info.q2.extend(q2[::10].numpy()) opt_info.piMu.extend(pi_mean[::10].numpy()) opt_info.piLogStd.extend(pi_log_std[::10].numpy()) opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item()) opt_info.alpha.append(self._alpha.item()) def optim_state_dict(self): return dict( pi_optimizer=self.pi_optimizer.state_dict(), q1_optimizer=self.q1_optimizer.state_dict(), q2_optimizer=self.q2_optimizer.state_dict(), alpha_optimizer=self.alpha_optimizer.state_dict() if self.alpha_optimizer else None, log_alpha=self._log_alpha.detach().item(), ) def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"]) self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"]) self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"]) if self.alpha_optimizer is not None and state_dict[ "alpha_optimizer"] is not None: self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"]) with torch.no_grad(): self._log_alpha[:] = state_dict["log_alpha"] self._alpha = torch.exp(self._log_alpha.detach()) def replay_buffer_dict(self): return dict(buffer=self.replay_buffer.samples)
class SAC(RlAlgorithm): opt_info_fields = tuple(f for f in OptInfo._fields) # copy def __init__( self, discount=0.99, batch_size=256, min_steps_learn=int(1e4), replay_size=int(1e6), replay_ratio=256, # data_consumption / data_generation target_update_tau=0.005, # tau=1 for hard update. target_update_interval=1, # interval=1000 for hard update. learning_rate=3e-4, OptimCls=torch.optim.Adam, optim_kwargs=None, initial_optim_state_dict=None, # for pi only. action_prior="uniform", # or "gaussian" policy_output_regularization=0.001, reward_scale=1, reparameterize=True, clip_grad_norm=1e9, n_step_return=1, updates_per_sync=1, # For async mode only. target_entropy='auto', ): if optim_kwargs is None: optim_kwargs = dict() assert action_prior in ["uniform", "gaussian"] self._batch_size = batch_size del batch_size # Property. save__init__args(locals()) def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0): """Used in basic or synchronous multi-GPU runners, not async.""" self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = int(self.replay_ratio * sampler_bs / self.batch_size) logger.log( f"From sampler batch size {sampler_bs}, training " f"batch size {self.batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_learn = self.min_steps_learn // sampler_bs agent.give_min_itr_learn(self.min_itr_learn) self.initialize_replay_buffer(examples, batch_spec) self.optim_initialize(rank) if self.target_entropy == 'auto': self.target_entropy = -np.prod(self.agent.env_spaces.action.shape) def async_initialize(self, agent, sampler_n_itr, batch_spec, mid_batch_reset, examples, world_size=1): """Used in async runner only.""" self.agent = agent self.n_itr = sampler_n_itr self.initialize_replay_buffer(examples, batch_spec, async_=True) self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = self.updates_per_sync self.min_itr_learn = int(self.min_steps_learn // sampler_bs) agent.give_min_itr_learn(self.min_itr_learn) return self.replay_buffer def optim_initialize(self, rank=0): """Called by async runner.""" self.rank = rank self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.alpha_optimizer = self.OptimCls([self.agent.log_alpha], lr=self.learning_rate, **self.optim_kwargs) if self.initial_optim_state_dict is not None: self.pi_optimizer.load_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian( dim=self.agent.env_spaces.action.size, std=1.) def initialize_replay_buffer(self, examples, batch_spec, async_=False): example_to_buffer = SamplesToBuffer( observation=examples["observation"], action=examples["action"], reward=examples["reward"], done=examples["done"], ) replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer self.replay_buffer = ReplayCls(**replay_kwargs) def optimize_agent(self, itr, samples=None, sampler_itr=None): itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr. if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): samples_from_replay = self.replay_buffer.sample_batch( self.batch_size) losses, values = self.loss(samples_from_replay) q1_loss, q2_loss, pi_loss, alpha_loss = losses self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.pi_parameters(), self.clip_grad_norm) self.pi_optimizer.step() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.q1_optimizer.zero_grad() q1_loss.backward() q1_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q1_parameters(), self.clip_grad_norm) self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() q2_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q2_parameters(), self.clip_grad_norm) self.q2_optimizer.step() grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm) self.append_opt_info_(opt_info, losses, grad_norms, values) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_target(self.target_update_tau) return opt_info def samples_to_buffer(self, samples): return SamplesToBuffer( observation=samples.env.observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, ) def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs).detach() disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) new_action, log_pi, _ = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: alpha = self.agent.log_alpha.exp().detach() pi_losses = alpha * log_pi - min_log_target - prior_log_pi pi_loss = valid_mean(pi_losses, valid) # Calculate log_alpha loss alpha_loss = -valid_mean(self.agent.log_alpha * (log_pi + self.target_entropy).detach()) losses = (q1_loss, q2_loss, pi_loss, alpha_loss) values = tuple(val.detach() for val in (q1, q2, alpha)) return losses, values def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def append_opt_info_(self, opt_info, losses, grad_norms, values): """In-place.""" q1_loss, q2_loss, pi_loss, alpha_loss = losses q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms q1, q2, alpha = values opt_info.q1Loss.append(q1_loss.item()) opt_info.q2Loss.append(q2_loss.item()) opt_info.piLoss.append(pi_loss.item()) opt_info.alphaLoss.append(alpha_loss.item()) opt_info.q1GradNorm.append(q1_grad_norm) opt_info.q2GradNorm.append(q2_grad_norm) opt_info.piGradNorm.append(pi_grad_norm) opt_info.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info.q2.extend(q2[::10].numpy()) opt_info.alpha.append(alpha.numpy()) opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item()) def optim_state_dict(self): return dict( pi_optimizer=self.pi_optimizer.state_dict(), q1_optimizer=self.q1_optimizer.state_dict(), q2_optimizer=self.q2_optimizer.state_dict(), alpha_optimizer=self.alpha_optimizer.state_dict(), ) def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"]) self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"]) self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"]) self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
class SAC(RlAlgorithm): opt_info_fields = tuple(f for f in OptInfo._fields) # copy def __init__( self, discount=0.99, batch_size=256, min_steps_learn=int( 1e4 ), # the min timesteps to collect before actually start learning. replay_size=int(1e6), replay_ratio=256, # data_consumption (one timestep with one optim.step() called) / data_generation (batch.size) target_update_tau=0.005, # tau=1 for hard update. target_update_interval=1, # 1000 for hard update, 1 for soft. learning_rate=3e-4, fixed_alpha=None, # None for adaptive alpha, float for any fixed value OptimCls=torch.optim.Adam, optim_kwargs=None, initial_optim_state_dict=None, # for all of them. action_prior="uniform", # or "gaussian" reward_scale=1, target_entropy="auto", # "auto", float, or None reparameterize=True, clip_grad_norm=1e9, # policy_output_regularization=0.001, n_step_return=1, updates_per_sync=1, # For async mode only. bootstrap_timelimit=True, ReplayBufferCls=None, # Leave None to select by above options. ): if optim_kwargs is None: optim_kwargs = dict() assert action_prior in ["uniform", "gaussian"] self._batch_size = batch_size del batch_size # Property. save__init__args(locals()) def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0): """Used in basic or synchronous multi-GPU runners, not async. Parameters ---------- agent: SacAgent """ self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = int(self.replay_ratio * sampler_bs / self.batch_size) logger.log( f"From sampler batch size {sampler_bs}, training " f"batch size {self.batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_learn = self.min_steps_learn // sampler_bs agent.give_min_itr_learn(self.min_itr_learn) self.initialize_replay_buffer(examples, batch_spec) self.optim_initialize(rank) def async_initialize(self, agent, sampler_n_itr, batch_spec, mid_batch_reset, examples, world_size=1): """Used in async runner only.""" self.agent = agent self.n_itr = sampler_n_itr self.initialize_replay_buffer(examples, batch_spec, async_=True) self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = self.updates_per_sync self.min_itr_learn = int(self.min_steps_learn // sampler_bs) agent.give_min_itr_learn(self.min_itr_learn) return self.replay_buffer def optim_initialize(self, rank=0): """Called by async runner.""" self.rank = rank self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(), lr=self.learning_rate, **self.optim_kwargs) self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(), lr=self.learning_rate, **self.optim_kwargs) if self.fixed_alpha is None: self._log_alpha = torch.zeros(1, requires_grad=True) self._alpha = torch.exp(self._log_alpha.detach()) self.alpha_optimizer = self.OptimCls((self._log_alpha, ), lr=self.learning_rate, **self.optim_kwargs) else: self._log_alpha = torch.tensor([np.log(self.fixed_alpha)]) self._alpha = torch.tensor([self.fixed_alpha]) self.alpha_optimizer = None if self.target_entropy == "auto": self.target_entropy = -np.prod(self.agent.env_spaces.action.shape) if self.initial_optim_state_dict is not None: self.load_optim_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian(dim=np.prod( self.agent.env_spaces.action.shape), std=1.) def initialize_replay_buffer(self, examples, batch_spec, async_=False): example_to_buffer = SamplesToBuffer( observation=examples["observation"], action=examples["action"], reward=examples["reward"], done=examples["done"], next_observation=examples["next_observation"], ) if not self.bootstrap_timelimit: ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer else: example_to_buffer = SamplesToBufferTl( *example_to_buffer, timeout=examples["env_info"].timeout) ReplayCls = AsyncTlUniformReplayBuffer if async_ else TlUniformReplayBuffer replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) if self.ReplayBufferCls is not None: ReplayCls = self.ReplayBufferCls logger.log( f"WARNING: ignoring internal selection logic and using" f" input replay buffer class: {ReplayCls} -- compatibility not" " guaranteed.") self.replay_buffer = ReplayCls(**replay_kwargs) def optimize_agent(self, itr, samples=None, sampler_itr=None): itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr. if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): samples_from_replay = self.replay_buffer.sample_batch( self.batch_size) losses, values = self.loss(samples_from_replay) q1_loss, q2_loss, pi_loss, alpha_loss = losses if alpha_loss is not None: self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self._alpha = torch.exp(self._log_alpha.detach()) self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.pi_parameters(), self.clip_grad_norm) self.pi_optimizer.step() # Step Q's last because pi_loss.backward() uses them? self.q1_optimizer.zero_grad() q1_loss.backward() q1_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q1_parameters(), self.clip_grad_norm) self.q1_optimizer.step() self.q2_optimizer.zero_grad() q2_loss.backward() q2_grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.q2_parameters(), self.clip_grad_norm) self.q2_optimizer.step() grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm) self.append_opt_info_(opt_info, losses, grad_norms, values) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_target(self.target_update_tau) return opt_info def samples_to_buffer(self, samples): samples_to_buffer = SamplesToBuffer( observation=samples.env.observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, next_observation=samples.env.next_observation, ) if self.bootstrap_timelimit: samples_to_buffer = SamplesToBufferTl( *samples_to_buffer, timeout=samples.env.env_info.timeout) return samples_to_buffer def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) # or None else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= (1 - samples.timeout_n.float()) q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_action, target_log_pi, _ = self.agent.pi(*target_inputs) target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action) min_target_q = torch.min(target_q1, target_q2) target_value = min_target_q - self._alpha * target_log_pi disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_value) # y: target for Q functions, target_value q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi else: raise NotImplementedError # if self.policy_output_regularization > 0: # pi_losses += self.policy_output_regularization * torch.mean( # 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1) pi_loss = valid_mean(pi_losses, valid) if self.target_entropy is not None and self.fixed_alpha is None: alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None losses = (q1_loss, q2_loss, pi_loss, alpha_loss) values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std)) return losses, values def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def append_opt_info_(self, opt_info, losses, grad_norms, values): """ append all the `losses` and `grad_norms` and `values` into each attribute of `opt_info` """ q1_loss, q2_loss, pi_loss, alpha_loss = losses q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms q1, q2, pi_mean, pi_log_std = values opt_info.q1Loss.append(q1_loss.item()) opt_info.q2Loss.append(q2_loss.item()) opt_info.piLoss.append(pi_loss.item()) opt_info.q1GradNorm.append( torch.tensor(q1_grad_norm).item()) # backwards compatible opt_info.q2GradNorm.append( torch.tensor(q2_grad_norm).item()) # backwards compatible opt_info.piGradNorm.append( torch.tensor(pi_grad_norm).item()) # backwards compatible opt_info.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info.q2.extend(q2[::10].numpy()) opt_info.piMu.extend(pi_mean[::10].numpy()) opt_info.piLogStd.extend(pi_log_std[::10].numpy()) opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item()) opt_info.alpha.append(self._alpha.item()) def optim_state_dict(self): return dict( pi_optimizer=self.pi_optimizer.state_dict(), q1_optimizer=self.q1_optimizer.state_dict(), q2_optimizer=self.q2_optimizer.state_dict(), alpha_optimizer=self.alpha_optimizer.state_dict() if self.alpha_optimizer else None, log_alpha=self._log_alpha.detach().item(), ) def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"]) self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"]) self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"]) if self.alpha_optimizer is not None and state_dict[ "alpha_optimizer"] is not None: self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"]) with torch.no_grad(): self._log_alpha[:] = state_dict["log_alpha"]
class SacWithUl(RlAlgorithm): opt_info_fields = tuple(f for f in OptInfo._fields) def __init__( self, discount=0.99, batch_size=512, # replay_ratio=512, # data_consumption / data_generation # min_steps_learn=int(1e4), replay_size=int(1e5), target_update_tau=0.01, # tau=1 for hard update. target_update_interval=2, actor_update_interval=2, OptimCls=torch.optim.Adam, initial_optim_state_dict=None, # for all of them. action_prior="uniform", # or "gaussian" reward_scale=1, target_entropy="auto", # "auto", float, or None reparameterize=True, clip_grad_norm=1e6, n_step_return=1, bootstrap_timelimit=True, q_lr=1e-3, pi_lr=1e-3, alpha_lr=1e-4, q_beta=0.9, pi_beta=0.9, alpha_beta=0.5, alpha_init=0.1, encoder_update_tau=0.05, random_shift_prob=1.0, random_shift_pad=4, # how much to pad on each direction (like DrQ style) stop_rl_conv_grad=False, min_steps_rl=int(1e4), min_steps_ul=int(1e4), max_steps_ul=None, ul_learning_rate=7e-4, ul_optim_kwargs=None, # ul_replay_size=1e5, ul_update_schedule=None, ul_lr_schedule=None, ul_lr_warmup=0, # ul_delta_T=1, # Always 1 # ul_batch_B=512, # ul_batch_T=1, # Always 1 ul_batch_size=512, ul_random_shift_prob=1.0, ul_random_shift_pad=4, ul_target_update_interval=1, ul_target_update_tau=0.01, ul_latent_size=128, ul_anchor_hidden_sizes=512, ul_clip_grad_norm=10.0, ul_pri_alpha=0.0, ul_pri_beta=1.0, ul_pri_n_step_return=1, ul_use_rl_samples=False, UlEncoderCls=UlEncoderModel, UlContrastCls=ContrastModel, ): # assert replay_ratio == batch_size # Unless I want to change it. self._batch_size = batch_size del batch_size if ul_optim_kwargs is None: ul_optim_kwargs = dict() save__init__args(locals()) self.replay_ratio = self.batch_size # standard 1 update per itr. # assert ul_delta_T == n_step_return # Just use the same replay buffer # assert ul_batch_T == 1 # This was fine in DMControl in RlFromUl def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0): """Stores input arguments and initializes replay buffer and optimizer. Use in non-async runners. Computes number of gradient updates per optimization iteration as `(replay_ratio * sampler-batch-size / training-batch_size)`.""" self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = int(self.replay_ratio * sampler_bs / self.batch_size) logger.log( f"From sampler batch size {sampler_bs}, training " f"batch size {self.batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_rl = self.min_steps_rl // sampler_bs self.min_itr_ul = self.min_steps_ul // sampler_bs self.max_itr_ul = (self.n_itr + 1 if self.max_steps_ul is None else self.max_steps_ul // sampler_bs) if self.min_itr_rl == self.min_itr_ul: self.min_itr_rl += 1 # Wait until the next agent.give_min_itr_learn(self.min_itr_rl) self.initialize_replay_buffer(examples, batch_spec) self.ul_encoder = self.UlEncoderCls( conv=self.agent.conv, latent_size=self.ul_latent_size, conv_out_size=self.agent.conv.output_size, ) self.ul_target_encoder = copy.deepcopy(self.ul_encoder) self.ul_contrast = self.UlContrastCls( latent_size=self.ul_latent_size, anchor_hidden_sizes=self.ul_anchor_hidden_sizes, ) self.ul_encoder.to(self.agent.device) self.ul_target_encoder.to(self.agent.device) self.ul_contrast.to(self.agent.device) self.optim_initialize(rank) def async_initialize(*args, **kwargs): raise NotImplementedError def optim_initialize(self, rank=0): """Called in initilize or by async runner after forking sampler.""" self.rank = rank # Be very explicit about which parameters are optimized where. self.pi_optimizer = self.OptimCls( chain( self.agent.pi_fc1.parameters(), # No conv. self.agent.pi_mlp.parameters(), ), lr=self.pi_lr, betas=(self.pi_beta, 0.999), ) self.q_optimizer = self.OptimCls( chain( () if self.stop_rl_conv_grad else self.agent.conv.parameters(), self.agent.q_fc1.parameters(), self.agent.q_mlps.parameters(), ), lr=self.q_lr, betas=(self.q_beta, 0.999), ) self._log_alpha = torch.tensor(np.log(self.alpha_init), requires_grad=True) self._alpha = torch.exp(self._log_alpha.detach()) self.alpha_optimizer = self.OptimCls((self._log_alpha, ), lr=self.alpha_lr, betas=(self.alpha_beta, 0.999)) if self.target_entropy == "auto": self.target_entropy = -np.prod(self.agent.env_spaces.action.shape) if self.initial_optim_state_dict is not None: self.load_optim_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian(dim=np.prod( self.agent.env_spaces.action.shape), std=1.0) self.ul_optimizer = self.OptimCls(self.ul_parameters(), lr=self.ul_learning_rate, **self.ul_optim_kwargs) self.total_ul_updates = sum([ self.compute_ul_update_schedule(itr) for itr in range(self.n_itr) ]) logger.log( f"Total number of UL updates to do: {self.total_ul_updates}.") self.ul_update_counter = 0 self.ul_lr_scheduler = None if self.total_ul_updates > 0: if self.ul_lr_schedule == "linear": self.ul_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer=self.ul_optimizer, lr_lambda=lambda upd: (self.total_ul_updates - upd) / self.total_ul_updates, ) elif self.ul_lr_schedule == "cosine": self.ul_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer=self.ul_optimizer, T_max=self.total_ul_updates - self.ul_lr_warmup, ) elif self.ul_lr_schedule is not None: raise NotImplementedError if self.ul_lr_warmup > 0: self.ul_lr_scheduler = GradualWarmupScheduler( self.ul_optimizer, multiplier=1, total_epoch=self.ul_lr_warmup, # actually n_updates after_scheduler=self.ul_lr_scheduler, ) if self.ul_lr_scheduler is not None: self.ul_optimizer.zero_grad() self.ul_optimizer.step() self.c_e_loss = torch.nn.CrossEntropyLoss( ignore_index=IGNORE_INDEX) def initialize_replay_buffer(self, examples, batch_spec, async_=False): """ Allocates replay buffer using examples and with the fields in `SamplesToBuffer` namedarraytuple. POSSIBLY CHANGE TO FRAME-BASED BUFFER (only if need memory, speed is fine). """ if async_: raise NotImplementedError example_to_buffer = self.examples_to_buffer(examples) ReplayCls = (TlUniformReplayBuffer if self.bootstrap_timelimit else UniformReplayBuffer) replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) self.replay_buffer = ReplayCls(**replay_kwargs) if self.ul_pri_alpha > 0.0: self.replay_buffer = RlWithUlPrioritizedReplayWrapper( replay_buffer=self.replay_buffer, n_step_return=self.ul_pri_n_step_return, alpha=self.ul_pri_alpha, beta=self.ul_pri_beta, ) def optimize_agent(self, itr, samples): """ Extracts the needed fields from input samples and stores them in the replay buffer. Then samples from the replay buffer to train the agent by gradient updates (with the number of updates determined by replay ratio, sampler batch size, and training batch size). DIFFERENCES FROM SAC: -Organizes optimizers a little differently, clarifies which parameters. """ samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) rl_samples = None if itr >= self.min_itr_rl: opt_info_rl, rl_samples = self.rl_optimize(itr) opt_info = opt_info._replace(**opt_info_rl._asdict()) if itr >= self.min_itr_ul: opt_info_ul = self.ul_optimize(itr, rl_samples) opt_info = opt_info._replace(**opt_info_ul._asdict()) else: opt_info.ulUpdates.append(0) return opt_info def rl_optimize(self, itr): opt_info_rl = OptInfoRl(*([] for _ in range(len(OptInfoRl._fields)))) for _ in range(self.updates_per_optimize): # Sample from the replay buffer, center crop, and move to GPU. samples_from_replay = self.replay_buffer.sample_batch( self.batch_size) rl_samples = self.random_shift_rl_samples(samples_from_replay) rl_samples = self.samples_to_device(rl_samples) # Q-loss includes computing some values used in pi-loss. q1_loss, q2_loss, valid, conv_out, q1, q2 = self.q_loss(rl_samples) if self.update_counter % self.actor_update_interval == 0: pi_loss, alpha_loss, pi_mean, pi_log_std = self.pi_alpha_loss( rl_samples, valid, conv_out) if alpha_loss is not None: self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self._alpha = torch.exp(self._log_alpha.detach()) opt_info_rl.alpha.append(self._alpha.item()) self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_( chain( self.agent.pi_fc1.parameters(), self.agent.pi_mlp.parameters(), ), self.clip_grad_norm, ) self.pi_optimizer.step() opt_info_rl.piLoss.append(pi_loss.item()) opt_info_rl.piGradNorm.append(pi_grad_norm.item()) opt_info_rl.piMu.extend(pi_mean[::10].numpy()) opt_info_rl.piLogStd.extend(pi_log_std[::10].numpy()) # Step Q's last because pi_loss.backward() uses them. self.q_optimizer.zero_grad() q_loss = q1_loss + q2_loss q_loss.backward() q_grad_norm = torch.nn.utils.clip_grad_norm_( chain( () if self.stop_rl_conv_grad else self.agent.conv.parameters(), self.agent.q_fc1.parameters(), self.agent.q_mlps.parameters(), ), self.clip_grad_norm, ) self.q_optimizer.step() opt_info_rl.q1Loss.append(q1_loss.item()) opt_info_rl.q2Loss.append(q2_loss.item()) opt_info_rl.qGradNorm.append(q_grad_norm.item()) opt_info_rl.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info_rl.q2.extend(q2[::10].numpy()) opt_info_rl.qMeanDiff.append(torch.mean(abs(q1 - q2)).item()) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_targets( q_tau=self.target_update_tau, encoder_tau=self.encoder_update_tau, ) return opt_info_rl, rl_samples def ul_optimize(self, itr, rl_samples=None): opt_info_ul = OptInfoUl(*([] for _ in range(len(OptInfoUl._fields)))) n_ul_updates = self.compute_ul_update_schedule(itr) ul_bs = self.ul_batch_size n_rl_samples = (0 if rl_samples is None else len( rl_samples.agent_inputs.observation)) for i in range(n_ul_updates): self.ul_update_counter += 1 if self.ul_lr_scheduler is not None: self.ul_lr_scheduler.step(self.ul_update_counter) if n_rl_samples >= self.ul_batch_size * (i + 1): ul_samples = rl_samples[i * ul_bs:(i + 1) * ul_bs] else: ul_samples = None ul_loss, ul_accuracy, grad_norm = self.ul_optimize_one_step( ul_samples) opt_info_ul.ulLoss.append(ul_loss.item()) opt_info_ul.ulAccuracy.append(ul_accuracy.item()) opt_info_ul.ulGradNorm.append(grad_norm.item()) if self.ul_update_counter % self.ul_target_update_interval == 0: update_state_dict( self.ul_target_encoder, self.ul_encoder.state_dict(), self.ul_target_update_tau, ) opt_info_ul.ulUpdates.append(self.ul_update_counter) return opt_info_ul def ul_optimize_one_step(self, samples=None): self.ul_optimizer.zero_grad() if samples is None: if self.ul_pri_alpha > 0: samples = self.replay_buffer.sample_batch(self.ul_batch_size, mode="UL") else: samples = self.replay_buffer.sample_batch(self.ul_batch_size) # This is why need ul_delta_T == n_step_return, usually == 1; anchor = samples.agent_inputs.observation positive = samples.target_inputs.observation if self.ul_random_shift_prob > 0.0: anchor = random_shift( imgs=anchor, pad=self.ul_random_shift_pad, prob=self.ul_random_shift_prob, ) positive = random_shift( imgs=positive, pad=self.ul_random_shift_pad, prob=self.ul_random_shift_prob, ) anchor, positive = buffer_to((anchor, positive), device=self.agent.device) else: # Assume samples were already augmented in the RL loss. anchor = samples.agent_inputs.observation positive = samples.target_inputs.observation with torch.no_grad(): c_positive, _pos_conv = self.ul_target_encoder(positive) c_anchor, _anc_conv = self.ul_encoder(anchor) logits = self.ul_contrast(c_anchor, c_positive) # anchor mlp in here. labels = torch.arange(c_anchor.shape[0], dtype=torch.long, device=self.agent.device) invalid = samples.done # shape: [B], if done, following state invalid labels[invalid] = IGNORE_INDEX ul_loss = self.c_e_loss(logits, labels) ul_loss.backward() if self.ul_clip_grad_norm is None: grad_norm = 0.0 else: grad_norm = torch.nn.utils.clip_grad_norm_(self.ul_parameters(), self.ul_clip_grad_norm) self.ul_optimizer.step() correct = torch.argmax(logits.detach(), dim=1) == labels accuracy = torch.mean(correct[~invalid].float()) return ul_loss, accuracy, grad_norm def samples_to_buffer(self, samples): """Defines how to add data from sampler into the replay buffer. Called in optimize_agent() if samples are provided to that method.""" observation = samples.env.observation samples_to_buffer = SamplesToBuffer( observation=observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, ) if self.bootstrap_timelimit: samples_to_buffer = SamplesToBufferTl( *samples_to_buffer, timeout=samples.env.env_info.timeout) return samples_to_buffer def examples_to_buffer(self, examples): observation = examples["observation"] example_to_buffer = SamplesToBuffer( observation=observation, action=examples["action"], reward=examples["reward"], done=examples["done"], ) if self.bootstrap_timelimit: example_to_buffer = SamplesToBufferTl( *example_to_buffer, timeout=examples["env_info"].timeout) return example_to_buffer def samples_to_device(self, samples): """Only move the parts of samples which need to go to GPU.""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action), device=self.agent.device, ) device_samples = samples._replace( agent_inputs=agent_inputs, target_inputs=target_inputs, action=action, ) return device_samples def random_shift_rl_samples(self, samples): if self.random_shift_prob == 0.0: return samples obs = samples.agent_inputs.observation target_obs = samples.target_inputs.observation aug_obs = random_shift( imgs=obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) aug_target_obs = random_shift( imgs=target_obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) aug_samples = samples._replace( agent_inputs=samples.agent_inputs._replace(observation=aug_obs), target_inputs=samples.target_inputs._replace( observation=aug_target_obs), ) return aug_samples def q_loss(self, samples): if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) # or None else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= 1 - samples.timeout_n.float() # Run the convolution only once, return so pi_loss can use it. conv_out = self.agent.conv(samples.agent_inputs.observation) if self.stop_rl_conv_grad: conv_out = conv_out.detach() q_inputs = samples.agent_inputs._replace(observation=conv_out) # Q LOSS. q1, q2 = self.agent.q(*q_inputs, samples.action) with torch.no_grad(): # Run the target convolution only once. target_conv_out = self.agent.target_conv( samples.target_inputs.observation) target_inputs = samples.target_inputs._replace( observation=target_conv_out) target_action, target_log_pi, _ = self.agent.pi(*target_inputs) target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action) min_target_q = torch.min(target_q1, target_q2) target_value = min_target_q - self._alpha * target_log_pi disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_value) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) return q1_loss, q2_loss, valid, conv_out, q1.detach(), q2.detach() def pi_alpha_loss(self, samples, valid, conv_out): # PI LOSS. # Uses detached conv out; avoid re-computing. conv_detach = conv_out.detach() agent_inputs = samples.agent_inputs._replace(observation=conv_detach) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: # new_action = new_action.detach() # No grad. raise NotImplementedError # Re-use the detached latent. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi else: raise NotImplementedError # if self.policy_output_regularization > 0: # pi_losses += self.policy_output_regularization * torch.mean( # 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1) pi_loss = valid_mean(pi_losses, valid) # ALPHA LOSS. if self.target_entropy is not None: alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None return pi_loss, alpha_loss, pi_mean.detach(), pi_log_std.detach() def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def optim_state_dict(self): return dict( pi=self.pi_optimizer.state_dict(), q=self.q_optimizer.state_dict(), alpha=self.alpha_optimizer.state_dict(), log_alpha_value=self._log_alpha.detach().item(), ul=self.ul_optimizer.state_dict(), ) def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi"]) self.q_optimizer.load_state_dict(state_dict["q"]) self.alpha_optimizer.load_state_dict(state_dict["alpha"]) self.ul_optimizer.load_state_dict(state_dict["ul"]) with torch.no_grad(): self._log_alpha[:] = state_dict["log_alpha_value"] self._alpha = torch.exp(self._log_alpha.detach()) def ul_parameters(self): yield from self.ul_encoder.parameters() yield from self.ul_contrast.parameters() def ul_named_parameters(self): yield from self.ul_encoder.named_parameters() yield from self.ul_contrast.named_parameters() def compute_ul_update_schedule(self, itr): if itr < self.min_itr_ul or itr > self.max_itr_ul: return 0 remaining = (self.max_itr_ul - itr) / ( self.max_itr_ul - self.min_itr_ul) # from 1 to 0 if "constant" in self.ul_update_schedule: # Format: "constant_X", for X num updates per RL itr. n_ul_updates = int(self.ul_update_schedule.split("_")[1]) elif "front" in self.ul_update_schedule: # Format: "front_X_Y", for X updates first itr, Y updates rest. entries = self.ul_update_schedule.split("_") if itr == self.min_itr_ul: n_ul_updates = int(entries[1]) else: n_ul_updates = int(entries[2]) elif "linear" in self.ul_update_schedule: first = int(self.ul_update_schedule.split("_")[1]) n_ul_updates = int(np.round(first * remaining)) elif "quadratic" in self.ul_update_schedule: first = int(self.ul_update_schedule.split("_")[1]) n_ul_updates = int(np.round(first * remaining**2)) elif "cosine" in self.ul_update_schedule: first = int(self.ul_update_schedule.split("_")[1]) n_ul_updates = int( np.round(first * math.sin(math.pi / 2 * remaining))) return n_ul_updates
class RadSacFromUl(RlAlgorithm): opt_info_fields = tuple(f for f in OptInfo._fields) def __init__( self, discount=0.99, batch_size=512, # replay_ratio=512, # data_consumption / data_generation min_steps_learn=int(1e4), replay_size=int(1e5), target_update_tau=0.01, # tau=1 for hard update. target_update_interval=2, actor_update_interval=2, OptimCls=torch.optim.Adam, initial_optim_state_dict=None, # for all of them. action_prior="uniform", # or "gaussian" reward_scale=1, target_entropy="auto", # "auto", float, or None reparameterize=True, clip_grad_norm=1e6, n_step_return=1, bootstrap_timelimit=True, q_lr=1e-3, pi_lr=1e-3, alpha_lr=1e-4, q_beta=0.9, pi_beta=0.9, alpha_beta=0.5, alpha_init=0.1, encoder_update_tau=0.05, augmentation="random_shift", # [None, "random_shift", "subpixel_shift"] random_shift_pad=4, # how much to pad on each direction (like DrQ style) random_shift_prob=1.0, stop_conv_grad=False, max_pixel_shift=1.0, ): self.replay_ratio = batch_size # Unless you want to change it. self._batch_size = batch_size del batch_size assert augmentation in [None, "random_shift", "subpixel_shift"] save__init__args(locals()) def initialize( self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0 ): """Stores input arguments and initializes replay buffer and optimizer. Use in non-async runners. Computes number of gradient updates per optimization iteration as `(replay_ratio * sampler-batch-size / training-batch_size)`.""" self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.sampler_bs = sampler_bs = batch_spec.size self.updates_per_optimize = int( self.replay_ratio * sampler_bs / self.batch_size ) logger.log( f"From sampler batch size {sampler_bs}, training " f"batch size {self.batch_size}, and replay ratio " f"{self.replay_ratio}, computed {self.updates_per_optimize} " f"updates per iteration." ) self.min_itr_learn = self.min_steps_learn // sampler_bs agent.give_min_itr_learn(self.min_itr_learn) self.store_latent = agent.store_latent if self.store_latent: assert self.stop_conv_grad self.initialize_replay_buffer(examples, batch_spec) self.optim_initialize(rank) def async_initialize(*args, **kwargs): raise NotImplementedError def optim_initialize(self, rank=0): """Called in initilize or by async runner after forking sampler.""" self.rank = rank # Be very explicit about which parameters are optimized where. self.pi_optimizer = self.OptimCls( chain( self.agent.pi_fc1.parameters(), # No conv. self.agent.pi_mlp.parameters(), ), lr=self.pi_lr, betas=(self.pi_beta, 0.999), ) self.q_optimizer = self.OptimCls( chain( () if self.stop_conv_grad else self.agent.conv.parameters(), self.agent.q_fc1.parameters(), self.agent.q_mlps.parameters(), ), lr=self.q_lr, betas=(self.q_beta, 0.999), ) self._log_alpha = torch.tensor(np.log(self.alpha_init), requires_grad=True) self._alpha = torch.exp(self._log_alpha.detach()) self.alpha_optimizer = self.OptimCls( (self._log_alpha,), lr=self.alpha_lr, betas=(self.alpha_beta, 0.999) ) if self.target_entropy == "auto": self.target_entropy = -np.prod(self.agent.env_spaces.action.shape) if self.initial_optim_state_dict is not None: self.load_optim_state_dict(self.initial_optim_state_dict) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian( dim=np.prod(self.agent.env_spaces.action.shape), std=1.0 ) def initialize_replay_buffer(self, examples, batch_spec, async_=False): """ Allocates replay buffer using examples and with the fields in `SamplesToBuffer` namedarraytuple. POSSIBLY CHANGE TO FRAME-BASED BUFFER (only if need memory, speed is fine). """ if async_: raise NotImplementedError example_to_buffer = self.examples_to_buffer(examples) ReplayCls = ( TlUniformReplayBuffer if self.bootstrap_timelimit else UniformReplayBuffer ) replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) self.replay_buffer = ReplayCls(**replay_kwargs) def optimize_agent(self, itr, samples=None, sampler_itr=None): """ Extracts the needed fields from input samples and stores them in the replay buffer. Then samples from the replay buffer to train the agent by gradient updates (with the number of updates determined by replay ratio, sampler batch size, and training batch size). DIFFERENCES FROM SAC: -Organizes optimizers a little differently, clarifies which parameters. """ itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr. if samples is not None: samples_to_buffer = self.samples_to_buffer(samples) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): # Sample from the replay buffer, center crop, and move to GPU. samples_from_replay = self.replay_buffer.sample_batch(self.batch_size) loss_samples = self.data_aug_loss_samples(samples_from_replay) loss_samples = self.samples_to_device(loss_samples) # Q-loss includes computing some values used in pi-loss. q1_loss, q2_loss, valid, conv_out, q1, q2 = self.q_loss(loss_samples) if self.update_counter % self.actor_update_interval == 0: pi_loss, alpha_loss, pi_mean, pi_log_std = self.pi_alpha_loss( loss_samples, valid, conv_out ) if alpha_loss is not None: self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self._alpha = torch.exp(self._log_alpha.detach()) opt_info.alpha.append(self._alpha.item()) self.pi_optimizer.zero_grad() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_( chain( self.agent.pi_fc1.parameters(), self.agent.pi_mlp.parameters(), ), self.clip_grad_norm, ) self.pi_optimizer.step() opt_info.piLoss.append(pi_loss.item()) opt_info.piGradNorm.append(pi_grad_norm.item()) opt_info.piMu.extend(pi_mean[::10].numpy()) opt_info.piLogStd.extend(pi_log_std[::10].numpy()) # Step Q's last because pi_loss.backward() uses them. self.q_optimizer.zero_grad() q_loss = q1_loss + q2_loss q_loss.backward() q_grad_norm = torch.nn.utils.clip_grad_norm_( chain( () if self.stop_conv_grad else self.agent.conv.parameters(), self.agent.q_fc1.parameters(), self.agent.q_mlps.parameters(), ), self.clip_grad_norm, ) self.q_optimizer.step() opt_info.q1Loss.append(q1_loss.item()) opt_info.q2Loss.append(q2_loss.item()) opt_info.qGradNorm.append(q_grad_norm.item()) opt_info.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info.q2.extend(q2[::10].numpy()) opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item()) self.update_counter += 1 if self.update_counter % self.target_update_interval == 0: self.agent.update_targets( q_tau=self.target_update_tau, encoder_tau=self.encoder_update_tau, ) return opt_info def samples_to_buffer(self, samples): """Defines how to add data from sampler into the replay buffer. Called in optimize_agent() if samples are provided to that method.""" if self.store_latent: observation = samples.agent.agent_info.conv else: observation = samples.env.observation samples_to_buffer = SamplesToBuffer( observation=observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, ) if self.bootstrap_timelimit: samples_to_buffer = SamplesToBufferTl( *samples_to_buffer, timeout=samples.env.env_info.timeout ) return samples_to_buffer def examples_to_buffer(self, examples): if self.store_latent: observation = examples["agent_info"].conv else: observation = examples["observation"] example_to_buffer = SamplesToBuffer( observation=observation, action=examples["action"], reward=examples["reward"], done=examples["done"], ) if self.bootstrap_timelimit: example_to_buffer = SamplesToBufferTl( *example_to_buffer, timeout=examples["env_info"].timeout ) return example_to_buffer def samples_to_device(self, samples): """Only move the parts of samples which need to go to GPU.""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action), device=self.agent.device, ) device_samples = samples._replace( agent_inputs=agent_inputs, target_inputs=target_inputs, action=action, ) return device_samples def data_aug_loss_samples(self, samples): """Perform data augmentation (on CPU).""" if self.augmentation is None: return samples obs = samples.agent_inputs.observation target_obs = samples.target_inputs.observation if self.augmentation == "random_shift": aug_obs = random_shift( imgs=obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) aug_target_obs = random_shift( imgs=target_obs, pad=self.random_shift_pad, prob=self.random_shift_prob, ) elif self.augmentation == "subpixel_shift": aug_obs = subpixel_shift( imgs=obs, max_shift=self.max_pixel_shift, ) aug_target_obs = subpixel_shift( imgs=target_obs, max_shift=self.max_pixel_shift, ) else: raise NotImplementedError aug_samples = samples._replace( agent_inputs=samples.agent_inputs._replace(observation=aug_obs), target_inputs=samples.target_inputs._replace(observation=aug_target_obs), ) return aug_samples def q_loss(self, samples): if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) # or None else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= 1 - samples.timeout_n.float() # Run the convolution only once, return so pi_loss can use it. if self.store_latent: conv_out = None q_inputs = samples.agent_inputs else: conv_out = self.agent.conv(samples.agent_inputs.observation) if self.stop_conv_grad: conv_out = conv_out.detach() q_inputs = samples.agent_inputs._replace(observation=conv_out) # Q LOSS. q1, q2 = self.agent.q(*q_inputs, samples.action) with torch.no_grad(): # Run the target convolution only once. if self.store_latent: target_inputs = samples.target_inputs else: target_conv_out = self.agent.target_conv( samples.target_inputs.observation ) target_inputs = samples.target_inputs._replace( observation=target_conv_out ) target_action, target_log_pi, _ = self.agent.pi(*target_inputs) target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action) min_target_q = torch.min(target_q1, target_q2) target_value = min_target_q - self._alpha * target_log_pi disc = self.discount ** self.n_step_return y = ( self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_value ) q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid) q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid) return q1_loss, q2_loss, valid, conv_out, q1.detach(), q2.detach() def pi_alpha_loss(self, samples, valid, conv_out): # PI LOSS. # Uses detached conv; avoid re-computing. if self.store_latent: agent_inputs = samples.agent_inputs else: conv_detach = conv_out.detach() # Always detached in actor. agent_inputs = samples.agent_inputs._replace(observation=conv_detach) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: # new_action = new_action.detach() # No grad. raise NotImplementedError # Re-use the detached latent. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi else: raise NotImplementedError # if self.policy_output_regularization > 0: # pi_losses += self.policy_output_regularization * torch.mean( # 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1) pi_loss = valid_mean(pi_losses, valid) # ALPHA LOSS. if self.target_entropy is not None: alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None return pi_loss, alpha_loss, pi_mean.detach(), pi_log_std.detach() def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action)) ) return prior_log_pi def optim_state_dict(self): return dict( pi=self.pi_optimizer.state_dict(), q=self.q_optimizer.state_dict(), alpha=self.alpha_optimizer.state_dict(), log_alpha_value=self._log_alpha.detach().item(), ) def load_optim_state_dict(self, state_dict): self.pi_optimizer.load_state_dict(state_dict["pi"]) self.q_optimizer.load_state_dict(state_dict["q"]) self.alpha_optimizer.load_state_dict(state_dict["alpha"]) with torch.no_grad(): self._log_alpha[:] = state_dict["log_alpha_value"] self._alpha = torch.exp(self._log_alpha.detach())
class SAC(RlAlgorithm): opt_info_fields = tuple(f for f in OptInfo._fields) # copy def __init__( self, discount=0.99, batch_size=256, min_steps_learn=int(1e4), replay_size=int(1e6), training_ratio=256, # data_consumption / data_generation target_update_tau=0.005, # tau=1 for hard update. target_update_interval=1, # interval=1000 for hard update. learning_rate=3e-4, OptimCls=torch.optim.Adam, optim_kwargs=None, initial_optim_state_dict=None, action_prior="uniform", # or "gaussian" reward_scale=1, reparameterize=True, clip_grad_norm=1e6, policy_output_regularization=0.001, n_step_return=1, ): if optim_kwargs is None: optim_kwargs = dict() assert action_prior in ["uniform", "gaussian"] save__init__args(locals()) self.update_counter = 0 def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples): if agent.recurrent: raise NotImplementedError self.agent = agent self.n_itr = n_itr self.mid_batch_reset = mid_batch_reset self.optimizer = self.OptimCls(agent.parameters(), lr=self.learning_rate, **self.optim_kwargs) if self.initial_optim_state_dict is not None: self.optimizer.load_state_dict(self.initial_optim_state_dict) sample_bs = batch_spec.size train_bs = self.batch_size assert (self.training_ratio * sample_bs) % train_bs == 0 self.updates_per_optimize = int( (self.training_ratio * sample_bs) // train_bs) logger.log( f"From sampler batch size {sample_bs}, training " f"batch size {train_bs}, and training ratio " f"{self.training_ratio}, computed {self.updates_per_optimize} " f"updates per iteration.") self.min_itr_learn = self.min_steps_learn // sample_bs self.agent.give_min_itr_learn(self.min_itr_learn) example_to_buffer = SamplesToBuffer( observation=examples["observation"], action=examples["action"], reward=examples["reward"], done=examples["done"], ) replay_kwargs = dict( example=example_to_buffer, size=self.replay_size, B=batch_spec.B, n_step_return=self.n_step_return, ) self.replay_buffer = UniformReplayBuffer(**replay_kwargs) if self.action_prior == "gaussian": self.action_prior_distribution = Gaussian( dim=agent.env_spaces.action.size, std=1.) def optimize_agent(self, itr, samples=None): if samples is not None: samples_to_buffer = SamplesToBuffer( observation=samples.env.observation, action=samples.agent.action, reward=samples.env.reward, done=samples.env.done, ) self.replay_buffer.append_samples(samples_to_buffer) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) if itr < self.min_itr_learn: return opt_info for _ in range(self.updates_per_optimize): self.update_counter += 1 samples_from_replay = self.replay_buffer.sample_batch( self.batch_size) self.optimizer.zero_grad() losses, values = self.loss(samples_from_replay) for loss in losses: loss.backward() grad_norms = [ torch.nn.utils.clip_grad_norm_(ps, self.clip_grad_norm) for ps in self.agent.parameters_by_model() ] self.optimizer.step() self.append_opt_info_(opt_info, losses, grad_norms, values) if self.update_counter % self.target_update_interval == 0: self.agent.update_target(self.target_update_tau) return opt_info def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action), device=self.agent.device) # Move to device once, re-use. q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs) disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) v = self.agent.v(*agent_inputs) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) v_target = (min_log_target - log_pi + prior_log_pi).detach() # No grad. v_loss = 0.5 * valid_mean((v - v_target)**2, valid) if self.reparameterize: pi_losses = log_pi - min_log_target else: pi_factor = (v - v_target).detach() # No grad. pi_losses = log_pi * pi_factor if self.policy_output_regularization > 0: pi_losses += torch.sum( self.policy_output_regularization * 0.5 * pi_mean**2 + pi_log_std**2, dim=-1) pi_loss = valid_mean(pi_losses, valid) losses = (q1_loss, q2_loss, v_loss, pi_loss) values = tuple(val.detach() for val in (q1, q2, v, pi_mean, pi_log_std)) return losses, values def get_action_prior(self, action): if self.action_prior == "uniform": prior_log_pi = 0.0 elif self.action_prior == "gaussian": prior_log_pi = self.action_prior_distribution.log_likelihood( action, GaussianDistInfo(mean=torch.zeros_like(action))) return prior_log_pi def append_opt_info_(self, opt_info, losses, grad_norms, values): """In-place.""" q1_loss, q2_loss, v_loss, pi_loss = losses q1_grad_norm, q2_grad_norm, v_grad_norm, pi_grad_norm = grad_norms q1, q2, v, pi_mean, pi_log_std = values opt_info.q1Loss.append(q1_loss.item()) opt_info.q2Loss.append(q2_loss.item()) opt_info.vLoss.append(v_loss.item()) opt_info.piLoss.append(pi_loss.item()) opt_info.q1GradNorm.append(q1_grad_norm) opt_info.q2GradNorm.append(q2_grad_norm) opt_info.vGradNorm.append(v_grad_norm) opt_info.piGradNorm.append(pi_grad_norm) opt_info.q1.extend(q1[::10].numpy()) # Downsample for stats. opt_info.q2.extend(q2[::10].numpy()) opt_info.v.extend(v[::10].numpy()) opt_info.piMu.extend(pi_mean[::10].numpy()) opt_info.piLogStd.extend(pi_log_std[::10].numpy()) opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())