예제 #1
0
파일: sac_v.py 프로젝트: afansi/rlpyt
class SAC_V(RlAlgorithm):
    """TO BE DEPRECATED."""

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
            self,
            discount=0.99,
            batch_size=256,
            min_steps_learn=int(1e4),
            replay_size=int(1e6),
            replay_ratio=256,  # data_consumption / data_generation
            target_update_tau=0.005,  # tau=1 for hard update.
            target_update_interval=1,  # 1000 for hard update, 1 for soft.
            learning_rate=3e-4,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_optim_state_dict=None,  # for all of them.
            action_prior="uniform",  # or "gaussian"
            reward_scale=1,
            reparameterize=True,
            clip_grad_norm=1e9,
            policy_output_regularization=0.001,
            n_step_return=1,
            updates_per_sync=1,  # For async mode only.
            bootstrap_timelimit=True,
            ReplayBufferCls=None,  #  Leave None to select by above options.
    ):
        if optim_kwargs is None:
            optim_kwargs = dict()
        assert action_prior in ["uniform", "gaussian"]
        self._batch_size = batch_size
        del batch_size  # Property.
        save__init__args(locals())

    def initialize(self,
                   agent,
                   n_itr,
                   batch_spec,
                   mid_batch_reset,
                   examples,
                   world_size=1,
                   rank=0):
        """Used in basic or synchronous multi-GPU runners, not async."""
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = int(self.replay_ratio * sampler_bs /
                                        self.batch_size)
        logger.log(
            f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = self.min_steps_learn // sampler_bs
        agent.give_min_itr_learn(self.min_itr_learn)
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)

    def async_initialize(self,
                         agent,
                         sampler_n_itr,
                         batch_spec,
                         mid_batch_reset,
                         examples,
                         world_size=1):
        """Used in async runner only."""
        self.agent = agent
        self.n_itr = sampler_n_itr
        self.initialize_replay_buffer(examples, batch_spec, async_=True)
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = self.updates_per_sync
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        agent.give_min_itr_learn(self.min_itr_learn)
        return self.replay_buffer

    def optim_initialize(self, rank=0):
        """Called by async runner."""
        self.rank = rank
        self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.v_optimizer = self.OptimCls(self.agent.v_parameters(),
                                         lr=self.learning_rate,
                                         **self.optim_kwargs)
        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(
                dim=self.agent.env_spaces.action.size, std=1.)

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        example_to_buffer = self.examples_to_buffer(examples)
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        if not self.bootstrap_timelimit:
            ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer
        else:
            ReplayCls = AsyncTlUniformReplayBuffer if async_ else TlUniformReplayBuffer
        if self.ReplayBufferCls is not None:
            ReplayCls = self.ReplayBufferCls
            logger.log(
                f"WARNING: ignoring internal selection logic and using"
                f" input replay buffer class: {ReplayCls} -- compatibility not"
                " guaranteed.")
        self.replay_buffer = ReplayCls(**replay_kwargs)

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            samples_from_replay = self.replay_buffer.sample_batch(
                self.batch_size)
            losses, values = self.loss(samples_from_replay)
            q1_loss, q2_loss, v_loss, pi_loss = losses

            self.v_optimizer.zero_grad()
            v_loss.backward()
            v_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.v_parameters(), self.clip_grad_norm)
            self.v_optimizer.step()

            self.pi_optimizer.zero_grad()
            pi_loss.backward()
            pi_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.pi_parameters(), self.clip_grad_norm)
            self.pi_optimizer.step()

            # Step Q's last because pi_loss.backward() uses them?
            self.q1_optimizer.zero_grad()
            q1_loss.backward()
            q1_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q1_parameters(), self.clip_grad_norm)
            self.q1_optimizer.step()

            self.q2_optimizer.zero_grad()
            q2_loss.backward()
            q2_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q2_parameters(), self.clip_grad_norm)
            self.q2_optimizer.step()

            grad_norms = (q1_grad_norm, q2_grad_norm, v_grad_norm,
                          pi_grad_norm)

            self.append_opt_info_(opt_info, losses, grad_norms, values)
            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)
        return opt_info

    def samples_to_buffer(self, samples):
        return SamplesToBuffer(
            observation=samples.env.observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
            timeout=getattr(samples.env.env_info, "timeout", None),
        )

    def examples_to_buffer(self, examples):
        """Defines how to initialize the replay buffer from examples. Called
        in initialize_replay_buffer().
        """
        return SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
            timeout=getattr(examples["env_info"], "timeout", None),
        )

    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))
        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs)
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= (1 - samples.timeout_n.float())

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        v = self.agent.v(*agent_inputs)
        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())
        v_target = (min_log_target - log_pi +
                    prior_log_pi).detach()  # No grad.

        v_loss = 0.5 * valid_mean((v - v_target)**2, valid)

        if self.reparameterize:
            pi_losses = log_pi - min_log_target
        else:
            pi_factor = (v - v_target).detach()
            pi_losses = log_pi * pi_factor
        if self.policy_output_regularization > 0:
            pi_losses += self.policy_output_regularization * torch.mean(
                0.5 * pi_mean**2 + 0.5 * pi_log_std**2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        losses = (q1_loss, q2_loss, v_loss, pi_loss)
        values = tuple(val.detach()
                       for val in (q1, q2, v, pi_mean, pi_log_std))
        return losses, values

    # def q_loss(self, samples):
    #     """Samples have leading batch dimension [B,..] (but not time)."""
    #     agent_inputs, target_inputs, action = buffer_to(
    #         (samples.agent_inputs, samples.target_inputs, samples.action),
    #         device=self.agent.device)  # Move to device once, re-use.
    #     q1, q2 = self.agent.q(*agent_inputs, action)
    #     with torch.no_grad():
    #         target_v = self.agent.target_v(*target_inputs)
    #     disc = self.discount ** self.n_step_return
    #     y = (self.reward_scale * samples.return_ +
    #         (1 - samples.done_n.float()) * disc * target_v)
    #     if self.mid_batch_reset and not self.agent.recurrent:
    #         valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
    #     else:
    #         valid = valid_from_done(samples.done)

    #     q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid)
    #     q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid)

    #     losses = (q1_loss, q2_loss)
    #     values = tuple(val.detach() for val in (q1, q2))
    #     return losses, values, agent_inputs, valid

    # def pi_v_loss(self, agent_inputs, valid):
    #     v = self.agent.v(*agent_inputs)
    #     new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs)
    #     if not self.reparameterize:
    #         new_action = new_action.detach()  # No grad.
    #     log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
    #     min_log_target = torch.min(log_target1, log_target2)
    #     prior_log_pi = self.get_action_prior(new_action.cpu())
    #     v_target = (min_log_target - log_pi + prior_log_pi).detach()  # No grad.
    #     v_loss = 0.5 * valid_mean((v - v_target) ** 2, valid)

    #     if self.reparameterize:
    #         pi_losses = log_pi - min_log_target  # log_target1  # min_log_target
    #     else:
    #         pi_factor = (v - v_target).detach()  # No grad.
    #         pi_losses = log_pi * pi_factor
    #     if self.policy_output_regularization > 0:
    #         pi_losses += self.policy_output_regularization * torch.sum(
    #             0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
    #     pi_loss = valid_mean(pi_losses, valid)

    #     losses = (v_loss, pi_loss)
    #     values = tuple(val.detach() for val in (v, pi_mean, pi_log_std))
    #     return losses, values

    # def loss(self, samples):
    #     """Samples have leading batch dimension [B,..] (but not time)."""
    #     agent_inputs, target_inputs, action = buffer_to(
    #         (samples.agent_inputs, samples.target_inputs, samples.action),
    #         device=self.agent.device)  # Move to device once, re-use.
    #     q1, q2 = self.agent.q(*agent_inputs, action)
    #     with torch.no_grad():
    #         target_v = self.agent.target_v(*target_inputs)
    #     disc = self.discount ** self.n_step_return
    #     y = (self.reward_scale * samples.return_ +
    #         (1 - samples.done_n.float()) * disc * target_v)
    #     if self.mid_batch_reset and not self.agent.recurrent:
    #         valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
    #     else:
    #         valid = valid_from_done(samples.done)

    #     q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid)
    #     q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid)

    #     v = self.agent.v(*agent_inputs)
    #     new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs)
    #     if not self.reparameterize:
    #         new_action = new_action.detach()  # No grad.
    #     log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
    #     min_log_target = torch.min(log_target1, log_target2)
    #     prior_log_pi = self.get_action_prior(new_action.cpu())
    #     v_target = (min_log_target - log_pi + prior_log_pi).detach()  # No grad.
    #     v_loss = 0.5 * valid_mean((v - v_target) ** 2, valid)

    #     if self.reparameterize:
    #         pi_losses = log_pi - min_log_target  # log_target1
    #     else:
    #         pi_factor = (v - v_target).detach()  # No grad.
    #         pi_losses = log_pi * pi_factor
    #     if self.policy_output_regularization > 0:
    #         pi_losses += torch.sum(self.policy_output_regularization * 0.5 *
    #             pi_mean ** 2 + pi_log_std ** 2, dim=-1)
    #     pi_loss = valid_mean(pi_losses, valid)

    #     losses = (q1_loss, q2_loss, v_loss, pi_loss)
    #     values = tuple(val.detach() for val in (q1, q2, v, pi_mean, pi_log_std))
    #     return losses, values

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def append_opt_info_(self, opt_info, losses, grad_norms, values):
        """In-place."""
        q1_loss, q2_loss, v_loss, pi_loss = losses
        q1_grad_norm, q2_grad_norm, v_grad_norm, pi_grad_norm = grad_norms
        q1, q2, v, pi_mean, pi_log_std = values
        opt_info.q1Loss.append(q1_loss.item())
        opt_info.q2Loss.append(q2_loss.item())
        opt_info.vLoss.append(v_loss.item())
        opt_info.piLoss.append(pi_loss.item())
        opt_info.q1GradNorm.append(
            torch.tensor(q1_grad_norm).item())  # backwards compatible
        opt_info.q2GradNorm.append(
            torch.tensor(q2_grad_norm).item())  # backwards compatible
        opt_info.vGradNorm.append(
            torch.tensor(v_grad_norm).item())  # backwards compatible
        opt_info.piGradNorm.append(
            torch.tensor(pi_grad_norm).item())  # backwards compatible
        opt_info.q1.extend(q1[::10].numpy())  # Downsample for stats.
        opt_info.q2.extend(q2[::10].numpy())
        opt_info.v.extend(v[::10].numpy())
        opt_info.piMu.extend(pi_mean[::10].numpy())
        opt_info.piLogStd.extend(pi_log_std[::10].numpy())
        opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())

    def optim_state_dict(self):
        return dict(
            pi_optimizer=self.pi_optimizer.state_dict(),
            q1_optimizer=self.q1_optimizer.state_dict(),
            q2_optimizer=self.q2_optimizer.state_dict(),
            v_optimizer=self.v_optimizer.state_dict(),
        )

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"])
        self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"])
        self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"])
        self.v_optimizer.load_state_dict(state_dict["v_optimizer"])
예제 #2
0
class SACDiscrete(RlAlgorithm):
    """Soft actor critic algorithm, training from a replay buffer."""

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
            self,
            discount=0.99,
            batch_size=256,
            min_steps_learn=int(1e4),
            replay_size=int(1e6),
            replay_ratio=256,  # data_consumption / data_generation
            target_update_tau=0.005,  # tau=1 for hard update.
            target_update_interval=1,  # 1000 for hard update, 1 for soft.
            learning_rate=3e-4,
            fixed_alpha=None,  # None for adaptive alpha, float for any fixed value
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_optim_state_dict=None,  # for all of them.
            action_prior="uniform",  # or "gaussian"
            reward_scale=1,
            target_entropy="auto",  # "auto", float, or None
            reparameterize=True,
            clip_grad_norm=1e9,
            # policy_output_regularization=0.001,
            n_step_return=1,
            updates_per_sync=1,  # For async mode only.
            bootstrap_timelimit=False,
            ReplayBufferCls=None,  # Leave None to select by above options.
    ):
        """Save input arguments."""
        if optim_kwargs is None:
            optim_kwargs = dict()
        assert action_prior in ["uniform", "gaussian"]
        self._batch_size = batch_size
        del batch_size  # Property.
        save__init__args(locals())

    def initialize(self,
                   agent,
                   n_itr,
                   batch_spec,
                   mid_batch_reset,
                   examples,
                   world_size=1,
                   rank=0):
        """Stores input arguments and initializes replay buffer and optimizer.
        Use in non-async runners.  Computes number of gradient updates per
        optimization iteration as `(replay_ratio * sampler-batch-size /
        training-batch_size)`."""
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = int(self.replay_ratio * sampler_bs /
                                        self.batch_size)
        logger.log(
            f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = self.min_steps_learn // sampler_bs
        agent.give_min_itr_learn(self.min_itr_learn)
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)

    def async_initialize(self,
                         agent,
                         sampler_n_itr,
                         batch_spec,
                         mid_batch_reset,
                         examples,
                         world_size=1):
        """Used in async runner only; returns replay buffer allocated in shared
        memory, does not instantiate optimizer. """
        self.agent = agent
        self.n_itr = sampler_n_itr
        self.initialize_replay_buffer(examples, batch_spec, async_=True)
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = self.updates_per_sync
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        agent.give_min_itr_learn(self.min_itr_learn)
        return self.replay_buffer

    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank
        self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        if self.fixed_alpha is None:
            self.target_entropy = -np.log(
                (1.0 / self.agent.env_spaces.action.n)) * 0.98
            self._log_alpha = torch.zeros(1, requires_grad=True)
            self._alpha = self._log_alpha.exp()
            self.alpha_optimizer = self.OptimCls((self._log_alpha, ),
                                                 lr=self.learning_rate,
                                                 **self.optim_kwargs)
        else:
            self._log_alpha = torch.tensor([np.log(self.fixed_alpha)])
            self._alpha = torch.tensor([self.fixed_alpha])
            self.alpha_optimizer = None
        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.n)

        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(dim=np.prod(
                self.agent.env_spaces.action.shape),
                                                      std=1.)

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        """
        Allocates replay buffer using examples and with the fields in `SamplesToBuffer`
        namedarraytuple.
        """
        example_to_buffer = SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
        )
        if not self.bootstrap_timelimit:
            ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer
        else:
            example_to_buffer = SamplesToBufferTl(
                *example_to_buffer, timeout=examples["env_info"].timeout)
            ReplayCls = AsyncTlUniformReplayBuffer if async_ else TlUniformReplayBuffer
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        if self.ReplayBufferCls is not None:
            ReplayCls = self.ReplayBufferCls
            logger.log(
                f"WARNING: ignoring internal selection logic and using"
                f" input replay buffer class: {ReplayCls} -- compatibility not"
                " guaranteed.")
        self.replay_buffer = ReplayCls(**replay_kwargs)

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        """
        Extracts the needed fields from input samples and stores them in the 
        replay buffer.  Then samples from the replay buffer to train the agent
        by gradient updates (with the number of updates determined by replay
        ratio, sampler batch size, and training batch size).
        """
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            samples_from_replay = self.replay_buffer.sample_batch(
                self.batch_size)
            losses, values = self.loss(samples_from_replay)
            q1_loss, q2_loss, pi_loss, alpha_loss = losses

            if alpha_loss is not None:
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self._alpha = torch.exp(self._log_alpha.detach())

            self.pi_optimizer.zero_grad()
            pi_loss.backward()
            pi_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.pi_parameters(), self.clip_grad_norm)
            self.pi_optimizer.step()

            # Step Q's last because pi_loss.backward() uses them?
            self.q1_optimizer.zero_grad()
            q1_loss.backward()
            q1_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q1_parameters(), self.clip_grad_norm)
            self.q1_optimizer.step()

            self.q2_optimizer.zero_grad()
            q2_loss.backward()
            q2_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q2_parameters(), self.clip_grad_norm)
            self.q2_optimizer.step()

            grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm)

            self.append_opt_info_(opt_info, losses, grad_norms, values)
            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)

        return opt_info

    def samples_to_buffer(self, samples):
        """Defines how to add data from sampler into the replay buffer. Called
        in optimize_agent() if samples are provided to that method."""
        samples_to_buffer = SamplesToBuffer(
            observation=samples.env.observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
        )
        if self.bootstrap_timelimit:
            samples_to_buffer = SamplesToBufferTl(
                *samples_to_buffer, timeout=samples.env.env_info.timeout)
        return samples_to_buffer

    def loss(self, samples):
        """
        Computes losses for twin Q-values against the min of twin target Q-values
        and an entropy term.  Computes reparameterized policy loss, and loss for
        tuning entropy weighting, alpha.  
        
        Input samples have leading batch dimension [B,..] (but not time).
        """
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))

        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= (1 - samples.timeout_n.float())

        with torch.no_grad():
            target_action, target_action_probs, target_log_pi, _ = self.agent.pi(
                *target_inputs)
            target_q1, target_q2 = self.agent.target_q(*target_inputs,
                                                       target_action)
            min_target_q = torch.min(target_q1, target_q2)
            target_value = target_action_probs * (min_target_q -
                                                  self._alpha * target_log_pi)
            target_value = target_value.sum(dim=1).unsqueeze(-1)
            disc = self.discount**self.n_step_return
            y = self.reward_scale * samples.return_ + (
                1 - samples.done_n.float()) * disc * target_value

        q1, q2 = self.agent.q(*agent_inputs, action)
        q1 = torch.gather(q1, 1, action.unsqueeze(1).long())
        q2 = torch.gather(q2, 1, action.unsqueeze(1).long())

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        action, action_probs, log_pi, _ = self.agent.pi(*agent_inputs)
        q1_pi, q2_pi = self.agent.q(*agent_inputs, action)
        min_pi_target = torch.min(q1_pi, q2_pi)
        inside_term = self._alpha * log_pi - min_pi_target
        policy_loss = (action_probs * inside_term).sum(dim=1).mean()
        log_pi = torch.sum(log_pi * action_probs, dim=1)

        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
        pi_loss = valid_mean(policy_loss, valid)

        if self.target_entropy is not None and self.fixed_alpha is None:
            alpha_losses = -self._log_alpha * (log_pi.detach() +
                                               self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, action_probs))
        return losses, values

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def append_opt_info_(self, opt_info, losses, grad_norms, values):
        """In-place."""
        q1_loss, q2_loss, pi_loss, alpha_loss = losses
        q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms
        q1, q2, action_probs = values
        opt_info.q1Loss.append(q1_loss.item())
        opt_info.q2Loss.append(q2_loss.item())
        opt_info.piLoss.append(pi_loss.item())
        opt_info.q1GradNorm.append(
            torch.tensor(q1_grad_norm).item())  # backwards compatible
        opt_info.q2GradNorm.append(
            torch.tensor(q2_grad_norm).item())  # backwards compatible
        opt_info.piGradNorm.append(
            torch.tensor(pi_grad_norm).item())  # backwards compatible
        opt_info.q1.extend(q1[::10].numpy())  # Downsample for stats.
        opt_info.q2.extend(q2[::10].numpy())
        opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())
        opt_info.alpha.append(self._alpha.item())

    def optim_state_dict(self):
        return dict(
            pi_optimizer=self.pi_optimizer.state_dict(),
            q1_optimizer=self.q1_optimizer.state_dict(),
            q2_optimizer=self.q2_optimizer.state_dict(),
            alpha_optimizer=self.alpha_optimizer.state_dict()
            if self.alpha_optimizer else None,
            log_alpha=self._log_alpha.detach().item(),
        )

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"])
        self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"])
        self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"])
        if self.alpha_optimizer is not None and state_dict[
                "alpha_optimizer"] is not None:
            self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
        with torch.no_grad():
            self._log_alpha[:] = state_dict["log_alpha"]
            self._alpha = torch.exp(self._log_alpha.detach())
예제 #3
0
class SAC(RlAlgorithm):

    opt_info_fields = None

    def __init__(
            self,
            discount=0.99,
            batch_size=256,
            min_steps_learn=int(1e4),
            replay_size=int(6e5),
            replay_ratio=256,  # data_consumption / data_generation
            target_update_tau=0.005,  # tau=1 for hard update.
            target_update_interval=1,  # interval=1000 for hard update.
            learning_rate=3e-4,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_optim_state_dict=None,  # for pi only.
            action_prior="uniform",  # or "gaussian"
            reward_scale=1,
            reparameterize=True,
            clip_grad_norm=1e9,
            policy_output_regularization=0.001,
            n_step_return=1,
            updates_per_sync=1,  # For async mode only.
            target_entropy='auto',
            ):
        if optim_kwargs is None:
            optim_kwargs = dict()
        assert action_prior in ["uniform", "gaussian"]
        self._batch_size = batch_size
        del batch_size  # Property.
        save__init__args(locals())

    def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples,
            world_size=1, rank=0):
        """Used in basic or synchronous multi-GPU runners, not async."""
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = int(self.replay_ratio * sampler_bs /
            self.batch_size)
        logger.log(f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = self.min_steps_learn // sampler_bs
        agent.give_min_itr_learn(self.min_itr_learn)
        print('batch_spec:', batch_spec, '\n\n')
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)

        if self.target_entropy == 'auto':
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)

        keys = ["piLoss", "alphaLoss",
                "piMu", "piLogStd", "alpha", "piGradNorm"]
        keys += [f'q{i}GradNorm' for i in range(self.agent.n_qs)]
        keys += [f'q{i}' for i in range(self.agent.n_qs)]
        keys += [f'q{i}Loss' for i in range(self.agent.n_qs)]
        global OptInfo
        OptInfo = namedtuple('OptInfo', keys)

        SAC.opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def async_initialize(self, agent, sampler_n_itr, batch_spec, mid_batch_reset,
            examples, world_size=1):
        """Used in async runner only."""
        self.agent = agent
        self.n_itr = sampler_n_itr
        self.initialize_replay_buffer(examples, batch_spec, async_=True)
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = self.updates_per_sync
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        agent.give_min_itr_learn(self.min_itr_learn)
        return self.replay_buffer

    def optim_initialize(self, rank=0):
        """Called by async runner."""
        self.rank = rank
        self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
            lr=self.learning_rate, **self.optim_kwargs)
        self.q_optimizers = [self.OptimCls(q_param)
                             for q_param in self.agent.q_parameters()]
        self.alpha_optimizer = self.OptimCls([self.agent.log_alpha],
            lr=self.learning_rate, **self.optim_kwargs)
        if self.initial_optim_state_dict is not None:
            self.pi_optimizer.load_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(
                dim=self.agent.env_spaces.action.size, std=1.)

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        example_to_buffer = SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
        )
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer
        self.replay_buffer = ReplayCls(**replay_kwargs)

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            samples_from_replay = self.replay_buffer.sample_batch(self.batch_size)
            q_losses, losses, values, q_values = self.loss(samples_from_replay)
            pi_loss, alpha_loss = losses

            self.pi_optimizer.zero_grad()
            pi_loss.backward()
            pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.pi_parameters(),
                self.clip_grad_norm)
            self.pi_optimizer.step()

            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

            q_grad_norms = []
            for q_opt, q_loss, q_param in zip(self.q_optimizers, q_losses,
                                               self.agent.q_parameters()):
                q_opt.zero_grad()
                q_loss.backward()
                q_grad_norm = torch.nn.utils.clip_grad_norm_(q_param,
                                                             self.clip_grad_norm)
                q_opt.step()
                q_grad_norms.append(q_grad_norm)

            self.append_opt_info_(opt_info, q_losses, losses,
                                  q_grad_norms, pi_grad_norm, q_values, values)
            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)
        return opt_info

    def samples_to_buffer(self, samples):
        return SamplesToBuffer(
            observation=samples.env.observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
        )

    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))
        qs = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs).detach()
        disc = self.discount ** self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        q_losses = [0.5 * valid_mean((y - q) ** 2, valid) for q in qs]

        new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_targets = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(torch.stack(log_targets, dim=0), dim=0)[0]
        prior_log_pi = self.get_action_prior(new_action.cpu())

        if self.reparameterize:
            alpha = self.agent.log_alpha.exp().detach()
            pi_losses = alpha * log_pi - min_log_target - prior_log_pi

        if self.policy_output_regularization > 0:
            pi_losses += torch.sum(self.policy_output_regularization * 0.5 *
                                   pi_mean ** 2 + pi_log_std ** 2, dim=-1)

        pi_loss = valid_mean(pi_losses, valid)

        # Calculate log_alpha loss
        alpha_loss = -valid_mean(self.agent.log_alpha * (log_pi + self.target_entropy).detach())

        losses = (pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (pi_mean, pi_log_std, alpha))
        q_values = tuple(q.detach() for q in qs)
        return q_losses, losses, values, q_values


    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def append_opt_info_(self, opt_info, q_losses, losses, q_grad_norms,
                         pi_grad_norm, q_values, values):
        """In-place."""
        pi_loss, alpha_loss = losses
        pi_mean, pi_log_std, alpha = values

        for i in range(self.agent.n_qs):
            getattr(opt_info, f'q{i}Loss').append(q_losses[i].item())
            getattr(opt_info, f'q{i}').extend(q_values[i][::10].numpy())
            getattr(opt_info, f'q{i}GradNorm').append(q_grad_norms[i])

        opt_info.piLoss.append(pi_loss.item())
        opt_info.alphaLoss.append(alpha_loss.item())
        opt_info.piGradNorm.append(pi_grad_norm)
        opt_info.piMu.extend(pi_mean[::10].numpy())
        opt_info.piLogStd.extend(pi_log_std[::10].numpy())
        opt_info.alpha.append(alpha.numpy())

    def optim_state_dict(self):
        rtn = dict(
            pi_optimizer=self.pi_optimizer.state_dict(),
            alpha_optimizer=self.alpha_optimizer.state_dict(),
        )
        rtn.update({f'q{i}_optimizer': q_opt.state_dict()
                    for i, q_opt in enumerate(self.q_optimizers)})
        return rtn

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"])
        self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
        [q_opt.load_state_dict(state_dict[f'q{i}_optimizer'])
         for i, q_opt in enumerate(self.q_optimizers)]
예제 #4
0
파일: sac_lstm.py 프로젝트: zren96/rlpyt
class SAC_LSTM(RlAlgorithm):
    """Soft actor critic algorithm, training from a replay buffer."""

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
            self,
            discount=0.99,
            batch_T=80,
            batch_B=16,
            warmup_T=40,
            min_steps_learn=int(1e5),
            replay_size=int(1e6),
            replay_ratio=4,  # data_consumption / data_generation
            store_rnn_state_interval=40,
            target_update_tau=0.005,  # tau=1 for hard update.
            target_update_interval=1,  # 1000 for hard update, 1 for soft.
            learning_rate=3e-4,
            fixed_alpha=None,  # None for adaptive alpha, float for any fixed value
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_optim_state_dict=None,  # for all of them.
            initial_replay_buffer_dict=None,
            action_prior="uniform",  # or "gaussian"
            reward_scale=1,
            target_entropy="auto",  # "auto", float, or None
            reparameterize=True,
            clip_grad_norm=1e3,
            n_step_return=5,
            ReplayBufferCls=None,  # Leave None to select by above options.
    ):
        """ Save input arguments.
        Args:
            store_rnn_state_interval (int): store RNN state only once this many steps, to reduce memory usage; replay sequences will only begin at the steps with stored recurrent state.
        Note:
            Typically ran with ``store_rnn_state_interval`` equal to the sampler's ``batch_T``, 40.  Then every 40 steps
            can be the beginning of a replay sequence, and will be guaranteed to start with a valid RNN state.  Only reset
            the RNN state (and env) at the end of the sampler batch, so that the beginnings of episodes are trained on.
        """
        if optim_kwargs is None:
            optim_kwargs = dict()
        assert action_prior in ["uniform", "gaussian"]
        save__init__args(locals())
        self._batch_size = (self.batch_T + self.warmup_T) * self.batch_B

    def initialize(self,
                   agent,
                   n_itr,
                   batch_spec,
                   mid_batch_reset,
                   examples,
                   world_size=1,
                   rank=0):
        """Stores input arguments and initializes replay buffer and optimizer.
        Use in non-async runners.  Computes number of gradient updates per
        optimization iteration as `(replay_ratio * sampler-batch-size /
        training-batch_size)`."""
        self.agent = agent
        self.n_itr = n_itr  # num_itr
        self.sampler_bs = sampler_bs = batch_spec.size  # num_step_per_batch
        self.mid_batch_reset = mid_batch_reset  # True
        self.updates_per_optimize = max(
            1, round(self.replay_ratio * sampler_bs / self._batch_size))
        logger.log(
            f"From sampler batch size {batch_spec.size}, training "
            f"batch size {self._batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        agent.give_min_itr_learn(
            self.min_itr_learn)  # filling up replay_buffer
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)

    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank
        self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        if self.fixed_alpha is None:
            self._log_alpha = torch.zeros(1, requires_grad=True)
            self._alpha = torch.exp(self._log_alpha.detach())
            self.alpha_optimizer = self.OptimCls((self._log_alpha, ),
                                                 lr=self.learning_rate,
                                                 **self.optim_kwargs)
        else:
            self._log_alpha = torch.tensor([np.log(self.fixed_alpha)])
            self._alpha = torch.tensor([self.fixed_alpha])
            self.alpha_optimizer = None
        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(dim=np.prod(
                self.agent.env_spaces.action.shape),
                                                      std=1.)

    def initialize_replay_buffer(self, examples, batch_spec):
        """
        Allocates replay buffer using examples and with the fields in `SamplesToBuffer` namedarraytuple.
        """
        # hidden_in, hidden_out, state, action, last_action, reward, next_state, done = self.replay_buffer.sample(batch_size)
        # print(examples["agent_info"])
        example_to_buffer = SamplesToBufferLSTM(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
            prev_rnn_state=examples["agent_info"].prev_rnn_state,
        )
        ReplayCls = UniformSequenceReplayFrameBuffer
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            discount=self.discount,
            n_step_return=self.n_step_return,
            rnn_state_interval=self.store_rnn_state_interval,
            initial_replay_buffer_dict=self.initial_replay_buffer_dict,
            batch_T=self.batch_T + self.warmup_T,
        )
        self.replay_buffer = ReplayCls(**replay_kwargs)

    def optimize_agent(self, itr, samples=None):
        """
        Extracts the needed fields from input samples and stores them in the 
        replay buffer.  Then samples from the replay buffer to train the agent
        by gradient updates (with the number of updates determined by replay
        ratio, sampler batch size, and training batch size).
        """

        # Update replay buffer
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)

        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info

        for _ in range(self.updates_per_optimize):

            samples_from_replay = self.replay_buffer.sample_batch(self.batch_B)
            losses, values = self.loss(samples_from_replay)
            q1_loss, q2_loss, pi_loss, alpha_loss = losses

            if alpha_loss is not None:
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self._alpha = torch.exp(self._log_alpha.detach())

            self.pi_optimizer.zero_grad()
            pi_loss.backward()
            pi_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.pi_parameters(), self.clip_grad_norm)
            self.pi_optimizer.step()

            # Step Q's last because pi_loss.backward() uses them?
            self.q1_optimizer.zero_grad()
            q1_loss.backward()
            q1_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q1_parameters(), self.clip_grad_norm)
            self.q1_optimizer.step()

            self.q2_optimizer.zero_grad()
            q2_loss.backward()
            q2_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q2_parameters(), self.clip_grad_norm)
            self.q2_optimizer.step()

            grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm)

            self.append_opt_info_(opt_info, losses, grad_norms, values)
            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)

        return opt_info

    def samples_to_buffer(self, samples):
        """Defines how to add data from sampler into the replay buffer. Called
        in optimize_agent() if samples are provided to that method."""
        samples_to_buffer = SamplesToBufferLSTM(
            observation=samples.env.observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
            prev_rnn_state=samples.agent.agent_info.prev_rnn_state,
        )
        return samples_to_buffer

    def loss(self, samples):
        """
        Computes losses for twin Q-values against the min of twin target Q-values and an entropy term.  Computes reparameterized policy loss, and loss for tuning entropy weighting, alpha.  
        
        Input samples have leading batch dimension [B,..] (but not time).
        """
        # SamplesFromReplay = namedarraytuple("SamplesFromReplay",
        # ["all_observation", "all_action", "all_reward", "return_", "done", "done_n", "init_rnn_state"])
        all_observation, all_action, all_reward = buffer_to(
            (samples.all_observation, samples.all_action, samples.all_reward),
            device=self.agent.device)  # all have (wT + bT + nsr) x bB
        wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return
        if wT > 0:
            warmup_slice = slice(None, wT)  # Same for agent and target.
            warmup_inputs = AgentInputs(
                observation=all_observation[warmup_slice],
                prev_action=all_action[warmup_slice],
                prev_reward=all_reward[warmup_slice],
            )
        agent_slice = slice(wT, wT + bT)
        agent_inputs = AgentInputs(
            observation=all_observation[agent_slice],
            prev_action=all_action[agent_slice],
            prev_reward=all_reward[agent_slice],
        )
        target_slice = slice(wT,
                             None)  # Same start t as agent. (wT + bT + nsr)
        target_inputs = AgentInputs(
            observation=all_observation[target_slice],
            prev_action=all_action[target_slice],
            prev_reward=all_reward[target_slice],
        )
        warmup_action = samples.all_action[1:wT + 1]
        action = samples.all_action[
            wT + 1:wT + 1 +
            bT]  # 'current' action by shifting index by 1 from prev_action
        return_ = samples.return_[wT:wT + bT]
        done_n = samples.done_n[wT:wT + bT]
        if self.store_rnn_state_interval == 0:
            init_rnn_state = None
        else:
            # [B,N,H]-->[N,B,H] cudnn.
            init_rnn_state = buffer_method(samples.init_rnn_state, "transpose",
                                           0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
        if wT > 0:  # Do warmup.
            with torch.no_grad():
                _, target_q1_rnn_state, _, target_q2_rnn_state = self.agent.target_q(
                    *warmup_inputs, warmup_action, init_rnn_state,
                    init_rnn_state)
                _, _, _, init_rnn_state = self.agent.pi(
                    *warmup_inputs, init_rnn_state)
            # Recommend aligning sampling batch_T and store_rnn_interval with
            # warmup_T (and no mid_batch_reset), so that end of trajectory
            # during warmup leads to new trajectory beginning at start of
            # training segment of replay.
            warmup_invalid_mask = valid_from_done(
                samples.done[:wT])[-1] == 0  # [B]
            init_rnn_state[:, warmup_invalid_mask] = 0  # [N,B,H] (cudnn)
            target_q1_rnn_state[:, warmup_invalid_mask] = 0
            target_q2_rnn_state[:, warmup_invalid_mask] = 0
        else:
            target_q1_rnn_state = init_rnn_state
            target_q2_rnn_state = init_rnn_state

        valid = valid_from_done(samples.done)[-bT:]

        q1, _, q2, _ = self.agent.q(*agent_inputs, action, init_rnn_state,
                                    init_rnn_state)
        with torch.no_grad():
            target_action, target_log_pi, _, _ = self.agent.pi(
                *target_inputs, init_rnn_state)
            target_q1, _, target_q2, _ = self.agent.target_q(
                *target_inputs, target_action, target_q1_rnn_state,
                target_q2_rnn_state)
            target_q1 = target_q1[-bT:]  # Same length as q.
            target_q2 = target_q2[-bT:]
            target_log_pi = target_log_pi[-bT:]

        min_target_q = torch.min(target_q1, target_q2)
        target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * return_ +
             (1 - done_n.float()) * disc * target_value)
        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        new_action, log_pi, (pi_mean, pi_log_std), _ = self.agent.pi(
            *agent_inputs, init_rnn_state)
        log_target1, _, log_target2, _ = self.agent.q(*agent_inputs,
                                                      new_action,
                                                      init_rnn_state,
                                                      init_rnn_state)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())

        pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        pi_loss = valid_mean(pi_losses, valid)

        if self.target_entropy is not None and self.fixed_alpha is None:
            alpha_losses = -self._log_alpha * (log_pi.detach() +
                                               self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std))
        return losses, values

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def append_opt_info_(self, opt_info, losses, grad_norms, values):
        """In-place."""
        q1_loss, q2_loss, pi_loss, alpha_loss = losses
        q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms
        q1, q2, pi_mean, pi_log_std = values
        opt_info.q1Loss.append(q1_loss.item())
        opt_info.q2Loss.append(q2_loss.item())
        opt_info.piLoss.append(pi_loss.item())
        opt_info.q1GradNorm.append(
            q1_grad_norm.clone().detach().item())  # backwards compatible
        opt_info.q2GradNorm.append(
            q2_grad_norm.clone().detach().item())  # backwards compatible
        opt_info.piGradNorm.append(
            pi_grad_norm.clone().detach().item())  # backwards compatible
        opt_info.q1.extend(q1[::10].numpy())  # Downsample for stats.
        opt_info.q2.extend(q2[::10].numpy())
        opt_info.piMu.extend(pi_mean[::10].numpy())
        opt_info.piLogStd.extend(pi_log_std[::10].numpy())
        opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())
        opt_info.alpha.append(self._alpha.item())

    def optim_state_dict(self):
        return dict(
            pi_optimizer=self.pi_optimizer.state_dict(),
            q1_optimizer=self.q1_optimizer.state_dict(),
            q2_optimizer=self.q2_optimizer.state_dict(),
            alpha_optimizer=self.alpha_optimizer.state_dict()
            if self.alpha_optimizer else None,
            log_alpha=self._log_alpha.detach().item(),
        )

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"])
        self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"])
        self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"])
        if self.alpha_optimizer is not None and state_dict[
                "alpha_optimizer"] is not None:
            self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
        with torch.no_grad():
            self._log_alpha[:] = state_dict["log_alpha"]
            self._alpha = torch.exp(self._log_alpha.detach())

    def replay_buffer_dict(self):
        return dict(buffer=self.replay_buffer.samples)
예제 #5
0
class SAC(RlAlgorithm):

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
        self,
        discount=0.99,
        batch_size=256,
        min_steps_learn=int(1e4),
        replay_size=int(1e6),
        replay_ratio=256,  # data_consumption / data_generation
        target_update_tau=0.005,  # tau=1 for hard update.
        target_update_interval=1,  # interval=1000 for hard update.
        learning_rate=3e-4,
        OptimCls=torch.optim.Adam,
        optim_kwargs=None,
        initial_optim_state_dict=None,  # for pi only.
        action_prior="uniform",  # or "gaussian"
        policy_output_regularization=0.001,
        reward_scale=1,
        reparameterize=True,
        clip_grad_norm=1e9,
        n_step_return=1,
        updates_per_sync=1,  # For async mode only.
        target_entropy='auto',
    ):
        if optim_kwargs is None:
            optim_kwargs = dict()
        assert action_prior in ["uniform", "gaussian"]
        self._batch_size = batch_size
        del batch_size  # Property.
        save__init__args(locals())

    def initialize(self,
                   agent,
                   n_itr,
                   batch_spec,
                   mid_batch_reset,
                   examples,
                   world_size=1,
                   rank=0):
        """Used in basic or synchronous multi-GPU runners, not async."""
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = int(self.replay_ratio * sampler_bs /
                                        self.batch_size)
        logger.log(
            f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = self.min_steps_learn // sampler_bs
        agent.give_min_itr_learn(self.min_itr_learn)
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)

        if self.target_entropy == 'auto':
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)

    def async_initialize(self,
                         agent,
                         sampler_n_itr,
                         batch_spec,
                         mid_batch_reset,
                         examples,
                         world_size=1):
        """Used in async runner only."""
        self.agent = agent
        self.n_itr = sampler_n_itr
        self.initialize_replay_buffer(examples, batch_spec, async_=True)
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = self.updates_per_sync
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        agent.give_min_itr_learn(self.min_itr_learn)
        return self.replay_buffer

    def optim_initialize(self, rank=0):
        """Called by async runner."""
        self.rank = rank
        self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.alpha_optimizer = self.OptimCls([self.agent.log_alpha],
                                             lr=self.learning_rate,
                                             **self.optim_kwargs)
        if self.initial_optim_state_dict is not None:
            self.pi_optimizer.load_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(
                dim=self.agent.env_spaces.action.size, std=1.)

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        example_to_buffer = SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
        )
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer
        self.replay_buffer = ReplayCls(**replay_kwargs)

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            samples_from_replay = self.replay_buffer.sample_batch(
                self.batch_size)
            losses, values = self.loss(samples_from_replay)
            q1_loss, q2_loss, pi_loss, alpha_loss = losses

            self.pi_optimizer.zero_grad()
            pi_loss.backward()
            pi_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.pi_parameters(), self.clip_grad_norm)
            self.pi_optimizer.step()

            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

            self.q1_optimizer.zero_grad()
            q1_loss.backward()
            q1_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q1_parameters(), self.clip_grad_norm)
            self.q1_optimizer.step()

            self.q2_optimizer.zero_grad()
            q2_loss.backward()
            q2_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q2_parameters(), self.clip_grad_norm)
            self.q2_optimizer.step()

            grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm)

            self.append_opt_info_(opt_info, losses, grad_norms, values)
            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)
        return opt_info

    def samples_to_buffer(self, samples):
        return SamplesToBuffer(
            observation=samples.env.observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
        )

    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))
        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs).detach()
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        new_action, log_pi, _ = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())

        if self.reparameterize:
            alpha = self.agent.log_alpha.exp().detach()
            pi_losses = alpha * log_pi - min_log_target - prior_log_pi

        pi_loss = valid_mean(pi_losses, valid)

        # Calculate log_alpha loss
        alpha_loss = -valid_mean(self.agent.log_alpha *
                                 (log_pi + self.target_entropy).detach())

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, alpha))
        return losses, values

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def append_opt_info_(self, opt_info, losses, grad_norms, values):
        """In-place."""
        q1_loss, q2_loss, pi_loss, alpha_loss = losses
        q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms
        q1, q2, alpha = values
        opt_info.q1Loss.append(q1_loss.item())
        opt_info.q2Loss.append(q2_loss.item())
        opt_info.piLoss.append(pi_loss.item())
        opt_info.alphaLoss.append(alpha_loss.item())
        opt_info.q1GradNorm.append(q1_grad_norm)
        opt_info.q2GradNorm.append(q2_grad_norm)
        opt_info.piGradNorm.append(pi_grad_norm)
        opt_info.q1.extend(q1[::10].numpy())  # Downsample for stats.
        opt_info.q2.extend(q2[::10].numpy())
        opt_info.alpha.append(alpha.numpy())
        opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())

    def optim_state_dict(self):
        return dict(
            pi_optimizer=self.pi_optimizer.state_dict(),
            q1_optimizer=self.q1_optimizer.state_dict(),
            q2_optimizer=self.q2_optimizer.state_dict(),
            alpha_optimizer=self.alpha_optimizer.state_dict(),
        )

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"])
        self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"])
        self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"])
        self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
예제 #6
0
class SAC(RlAlgorithm):

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
            self,
            discount=0.99,
            batch_size=256,
            min_steps_learn=int(
                1e4
            ),  # the min timesteps to collect before actually start learning.
            replay_size=int(1e6),
            replay_ratio=256,  # data_consumption (one timestep with one optim.step() called) / data_generation (batch.size)
            target_update_tau=0.005,  # tau=1 for hard update.
            target_update_interval=1,  # 1000 for hard update, 1 for soft.
            learning_rate=3e-4,
            fixed_alpha=None,  # None for adaptive alpha, float for any fixed value
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_optim_state_dict=None,  # for all of them.
            action_prior="uniform",  # or "gaussian"
            reward_scale=1,
            target_entropy="auto",  # "auto", float, or None
            reparameterize=True,
            clip_grad_norm=1e9,
            # policy_output_regularization=0.001,
            n_step_return=1,
            updates_per_sync=1,  # For async mode only.
            bootstrap_timelimit=True,
            ReplayBufferCls=None,  # Leave None to select by above options.
    ):
        if optim_kwargs is None:
            optim_kwargs = dict()
        assert action_prior in ["uniform", "gaussian"]
        self._batch_size = batch_size
        del batch_size  # Property.
        save__init__args(locals())

    def initialize(self,
                   agent,
                   n_itr,
                   batch_spec,
                   mid_batch_reset,
                   examples,
                   world_size=1,
                   rank=0):
        """Used in basic or synchronous multi-GPU runners, not async.
        Parameters
        ----------
            agent: SacAgent
        """
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = int(self.replay_ratio * sampler_bs /
                                        self.batch_size)
        logger.log(
            f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = self.min_steps_learn // sampler_bs
        agent.give_min_itr_learn(self.min_itr_learn)
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)

    def async_initialize(self,
                         agent,
                         sampler_n_itr,
                         batch_spec,
                         mid_batch_reset,
                         examples,
                         world_size=1):
        """Used in async runner only."""
        self.agent = agent
        self.n_itr = sampler_n_itr
        self.initialize_replay_buffer(examples, batch_spec, async_=True)
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = self.updates_per_sync
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        agent.give_min_itr_learn(self.min_itr_learn)
        return self.replay_buffer

    def optim_initialize(self, rank=0):
        """Called by async runner."""
        self.rank = rank
        self.pi_optimizer = self.OptimCls(self.agent.pi_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q1_optimizer = self.OptimCls(self.agent.q1_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
                                          lr=self.learning_rate,
                                          **self.optim_kwargs)
        if self.fixed_alpha is None:
            self._log_alpha = torch.zeros(1, requires_grad=True)
            self._alpha = torch.exp(self._log_alpha.detach())
            self.alpha_optimizer = self.OptimCls((self._log_alpha, ),
                                                 lr=self.learning_rate,
                                                 **self.optim_kwargs)
        else:
            self._log_alpha = torch.tensor([np.log(self.fixed_alpha)])
            self._alpha = torch.tensor([self.fixed_alpha])
            self.alpha_optimizer = None
        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(dim=np.prod(
                self.agent.env_spaces.action.shape),
                                                      std=1.)

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        example_to_buffer = SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
            next_observation=examples["next_observation"],
        )
        if not self.bootstrap_timelimit:
            ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer
        else:
            example_to_buffer = SamplesToBufferTl(
                *example_to_buffer, timeout=examples["env_info"].timeout)
            ReplayCls = AsyncTlUniformReplayBuffer if async_ else TlUniformReplayBuffer
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        if self.ReplayBufferCls is not None:
            ReplayCls = self.ReplayBufferCls
            logger.log(
                f"WARNING: ignoring internal selection logic and using"
                f" input replay buffer class: {ReplayCls} -- compatibility not"
                " guaranteed.")
        self.replay_buffer = ReplayCls(**replay_kwargs)

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            samples_from_replay = self.replay_buffer.sample_batch(
                self.batch_size)
            losses, values = self.loss(samples_from_replay)
            q1_loss, q2_loss, pi_loss, alpha_loss = losses

            if alpha_loss is not None:
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self._alpha = torch.exp(self._log_alpha.detach())

            self.pi_optimizer.zero_grad()
            pi_loss.backward()
            pi_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.pi_parameters(), self.clip_grad_norm)
            self.pi_optimizer.step()

            # Step Q's last because pi_loss.backward() uses them?
            self.q1_optimizer.zero_grad()
            q1_loss.backward()
            q1_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q1_parameters(), self.clip_grad_norm)
            self.q1_optimizer.step()

            self.q2_optimizer.zero_grad()
            q2_loss.backward()
            q2_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q2_parameters(), self.clip_grad_norm)
            self.q2_optimizer.step()

            grad_norms = (q1_grad_norm, q2_grad_norm, pi_grad_norm)

            self.append_opt_info_(opt_info, losses, grad_norms, values)
            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)

        return opt_info

    def samples_to_buffer(self, samples):
        samples_to_buffer = SamplesToBuffer(
            observation=samples.env.observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
            next_observation=samples.env.next_observation,
        )
        if self.bootstrap_timelimit:
            samples_to_buffer = SamplesToBufferTl(
                *samples_to_buffer, timeout=samples.env.env_info.timeout)
        return samples_to_buffer

    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))

        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= (1 - samples.timeout_n.float())

        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_action, target_log_pi, _ = self.agent.pi(*target_inputs)
            target_q1, target_q2 = self.agent.target_q(*target_inputs,
                                                       target_action)
        min_target_q = torch.min(target_q1, target_q2)
        target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_value)
        # y: target for Q functions, target_value

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())

        if self.reparameterize:
            pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        else:
            raise NotImplementedError

        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        if self.target_entropy is not None and self.fixed_alpha is None:
            alpha_losses = -self._log_alpha * (log_pi.detach() +
                                               self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std))
        return losses, values

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def append_opt_info_(self, opt_info, losses, grad_norms, values):
        """ append all the `losses` and `grad_norms` and `values` into each attribute 
            of `opt_info`
        """
        q1_loss, q2_loss, pi_loss, alpha_loss = losses
        q1_grad_norm, q2_grad_norm, pi_grad_norm = grad_norms
        q1, q2, pi_mean, pi_log_std = values
        opt_info.q1Loss.append(q1_loss.item())
        opt_info.q2Loss.append(q2_loss.item())
        opt_info.piLoss.append(pi_loss.item())
        opt_info.q1GradNorm.append(
            torch.tensor(q1_grad_norm).item())  # backwards compatible
        opt_info.q2GradNorm.append(
            torch.tensor(q2_grad_norm).item())  # backwards compatible
        opt_info.piGradNorm.append(
            torch.tensor(pi_grad_norm).item())  # backwards compatible
        opt_info.q1.extend(q1[::10].numpy())  # Downsample for stats.
        opt_info.q2.extend(q2[::10].numpy())
        opt_info.piMu.extend(pi_mean[::10].numpy())
        opt_info.piLogStd.extend(pi_log_std[::10].numpy())
        opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())
        opt_info.alpha.append(self._alpha.item())

    def optim_state_dict(self):
        return dict(
            pi_optimizer=self.pi_optimizer.state_dict(),
            q1_optimizer=self.q1_optimizer.state_dict(),
            q2_optimizer=self.q2_optimizer.state_dict(),
            alpha_optimizer=self.alpha_optimizer.state_dict()
            if self.alpha_optimizer else None,
            log_alpha=self._log_alpha.detach().item(),
        )

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"])
        self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"])
        self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"])
        if self.alpha_optimizer is not None and state_dict[
                "alpha_optimizer"] is not None:
            self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
        with torch.no_grad():
            self._log_alpha[:] = state_dict["log_alpha"]
예제 #7
0
class SacWithUl(RlAlgorithm):

    opt_info_fields = tuple(f for f in OptInfo._fields)

    def __init__(
        self,
        discount=0.99,
        batch_size=512,
        # replay_ratio=512,  # data_consumption / data_generation
        # min_steps_learn=int(1e4),
        replay_size=int(1e5),
        target_update_tau=0.01,  # tau=1 for hard update.
        target_update_interval=2,
        actor_update_interval=2,
        OptimCls=torch.optim.Adam,
        initial_optim_state_dict=None,  # for all of them.
        action_prior="uniform",  # or "gaussian"
        reward_scale=1,
        target_entropy="auto",  # "auto", float, or None
        reparameterize=True,
        clip_grad_norm=1e6,
        n_step_return=1,
        bootstrap_timelimit=True,
        q_lr=1e-3,
        pi_lr=1e-3,
        alpha_lr=1e-4,
        q_beta=0.9,
        pi_beta=0.9,
        alpha_beta=0.5,
        alpha_init=0.1,
        encoder_update_tau=0.05,
        random_shift_prob=1.0,
        random_shift_pad=4,  # how much to pad on each direction (like DrQ style)
        stop_rl_conv_grad=False,
        min_steps_rl=int(1e4),
        min_steps_ul=int(1e4),
        max_steps_ul=None,
        ul_learning_rate=7e-4,
        ul_optim_kwargs=None,
        # ul_replay_size=1e5,
        ul_update_schedule=None,
        ul_lr_schedule=None,
        ul_lr_warmup=0,
        # ul_delta_T=1,  # Always 1
        # ul_batch_B=512,
        # ul_batch_T=1,  # Always 1
        ul_batch_size=512,
        ul_random_shift_prob=1.0,
        ul_random_shift_pad=4,
        ul_target_update_interval=1,
        ul_target_update_tau=0.01,
        ul_latent_size=128,
        ul_anchor_hidden_sizes=512,
        ul_clip_grad_norm=10.0,
        ul_pri_alpha=0.0,
        ul_pri_beta=1.0,
        ul_pri_n_step_return=1,
        ul_use_rl_samples=False,
        UlEncoderCls=UlEncoderModel,
        UlContrastCls=ContrastModel,
    ):
        # assert replay_ratio == batch_size  # Unless I want to change it.
        self._batch_size = batch_size
        del batch_size
        if ul_optim_kwargs is None:
            ul_optim_kwargs = dict()
        save__init__args(locals())
        self.replay_ratio = self.batch_size  # standard 1 update per itr.
        # assert ul_delta_T == n_step_return  # Just use the same replay buffer
        # assert ul_batch_T == 1  # This was fine in DMControl in RlFromUl

    def initialize(self,
                   agent,
                   n_itr,
                   batch_spec,
                   mid_batch_reset,
                   examples,
                   world_size=1,
                   rank=0):
        """Stores input arguments and initializes replay buffer and optimizer.
        Use in non-async runners.  Computes number of gradient updates per
        optimization iteration as `(replay_ratio * sampler-batch-size /
        training-batch_size)`."""
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = int(self.replay_ratio * sampler_bs /
                                        self.batch_size)
        logger.log(
            f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_rl = self.min_steps_rl // sampler_bs
        self.min_itr_ul = self.min_steps_ul // sampler_bs
        self.max_itr_ul = (self.n_itr + 1 if self.max_steps_ul is None else
                           self.max_steps_ul // sampler_bs)
        if self.min_itr_rl == self.min_itr_ul:
            self.min_itr_rl += 1  # Wait until the next
        agent.give_min_itr_learn(self.min_itr_rl)
        self.initialize_replay_buffer(examples, batch_spec)

        self.ul_encoder = self.UlEncoderCls(
            conv=self.agent.conv,
            latent_size=self.ul_latent_size,
            conv_out_size=self.agent.conv.output_size,
        )
        self.ul_target_encoder = copy.deepcopy(self.ul_encoder)
        self.ul_contrast = self.UlContrastCls(
            latent_size=self.ul_latent_size,
            anchor_hidden_sizes=self.ul_anchor_hidden_sizes,
        )
        self.ul_encoder.to(self.agent.device)
        self.ul_target_encoder.to(self.agent.device)
        self.ul_contrast.to(self.agent.device)

        self.optim_initialize(rank)

    def async_initialize(*args, **kwargs):
        raise NotImplementedError

    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank

        # Be very explicit about which parameters are optimized where.
        self.pi_optimizer = self.OptimCls(
            chain(
                self.agent.pi_fc1.parameters(),  # No conv.
                self.agent.pi_mlp.parameters(),
            ),
            lr=self.pi_lr,
            betas=(self.pi_beta, 0.999),
        )
        self.q_optimizer = self.OptimCls(
            chain(
                () if self.stop_rl_conv_grad else self.agent.conv.parameters(),
                self.agent.q_fc1.parameters(),
                self.agent.q_mlps.parameters(),
            ),
            lr=self.q_lr,
            betas=(self.q_beta, 0.999),
        )

        self._log_alpha = torch.tensor(np.log(self.alpha_init),
                                       requires_grad=True)
        self._alpha = torch.exp(self._log_alpha.detach())
        self.alpha_optimizer = self.OptimCls((self._log_alpha, ),
                                             lr=self.alpha_lr,
                                             betas=(self.alpha_beta, 0.999))

        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(dim=np.prod(
                self.agent.env_spaces.action.shape),
                                                      std=1.0)

        self.ul_optimizer = self.OptimCls(self.ul_parameters(),
                                          lr=self.ul_learning_rate,
                                          **self.ul_optim_kwargs)

        self.total_ul_updates = sum([
            self.compute_ul_update_schedule(itr) for itr in range(self.n_itr)
        ])
        logger.log(
            f"Total number of UL updates to do: {self.total_ul_updates}.")
        self.ul_update_counter = 0
        self.ul_lr_scheduler = None
        if self.total_ul_updates > 0:
            if self.ul_lr_schedule == "linear":
                self.ul_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
                    optimizer=self.ul_optimizer,
                    lr_lambda=lambda upd:
                    (self.total_ul_updates - upd) / self.total_ul_updates,
                )
            elif self.ul_lr_schedule == "cosine":
                self.ul_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer=self.ul_optimizer,
                    T_max=self.total_ul_updates - self.ul_lr_warmup,
                )
            elif self.ul_lr_schedule is not None:
                raise NotImplementedError

            if self.ul_lr_warmup > 0:
                self.ul_lr_scheduler = GradualWarmupScheduler(
                    self.ul_optimizer,
                    multiplier=1,
                    total_epoch=self.ul_lr_warmup,  # actually n_updates
                    after_scheduler=self.ul_lr_scheduler,
                )

            if self.ul_lr_scheduler is not None:
                self.ul_optimizer.zero_grad()
                self.ul_optimizer.step()

            self.c_e_loss = torch.nn.CrossEntropyLoss(
                ignore_index=IGNORE_INDEX)

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        """
        Allocates replay buffer using examples and with the fields in `SamplesToBuffer`
        namedarraytuple.
        POSSIBLY CHANGE TO FRAME-BASED BUFFER (only if need memory, speed is fine).
        """
        if async_:
            raise NotImplementedError
        example_to_buffer = self.examples_to_buffer(examples)
        ReplayCls = (TlUniformReplayBuffer
                     if self.bootstrap_timelimit else UniformReplayBuffer)
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        self.replay_buffer = ReplayCls(**replay_kwargs)
        if self.ul_pri_alpha > 0.0:
            self.replay_buffer = RlWithUlPrioritizedReplayWrapper(
                replay_buffer=self.replay_buffer,
                n_step_return=self.ul_pri_n_step_return,
                alpha=self.ul_pri_alpha,
                beta=self.ul_pri_beta,
            )

    def optimize_agent(self, itr, samples):
        """
        Extracts the needed fields from input samples and stores them in the
        replay buffer.  Then samples from the replay buffer to train the agent
        by gradient updates (with the number of updates determined by replay
        ratio, sampler batch size, and training batch size).

        DIFFERENCES FROM SAC:
          -Organizes optimizers a little differently, clarifies which parameters.
        """
        samples_to_buffer = self.samples_to_buffer(samples)
        self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        rl_samples = None
        if itr >= self.min_itr_rl:
            opt_info_rl, rl_samples = self.rl_optimize(itr)
            opt_info = opt_info._replace(**opt_info_rl._asdict())
        if itr >= self.min_itr_ul:
            opt_info_ul = self.ul_optimize(itr, rl_samples)
            opt_info = opt_info._replace(**opt_info_ul._asdict())
        else:
            opt_info.ulUpdates.append(0)
        return opt_info

    def rl_optimize(self, itr):
        opt_info_rl = OptInfoRl(*([] for _ in range(len(OptInfoRl._fields))))
        for _ in range(self.updates_per_optimize):
            # Sample from the replay buffer, center crop, and move to GPU.
            samples_from_replay = self.replay_buffer.sample_batch(
                self.batch_size)
            rl_samples = self.random_shift_rl_samples(samples_from_replay)
            rl_samples = self.samples_to_device(rl_samples)

            # Q-loss includes computing some values used in pi-loss.
            q1_loss, q2_loss, valid, conv_out, q1, q2 = self.q_loss(rl_samples)

            if self.update_counter % self.actor_update_interval == 0:
                pi_loss, alpha_loss, pi_mean, pi_log_std = self.pi_alpha_loss(
                    rl_samples, valid, conv_out)
                if alpha_loss is not None:
                    self.alpha_optimizer.zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer.step()
                    self._alpha = torch.exp(self._log_alpha.detach())
                    opt_info_rl.alpha.append(self._alpha.item())

                self.pi_optimizer.zero_grad()
                pi_loss.backward()
                pi_grad_norm = torch.nn.utils.clip_grad_norm_(
                    chain(
                        self.agent.pi_fc1.parameters(),
                        self.agent.pi_mlp.parameters(),
                    ),
                    self.clip_grad_norm,
                )
                self.pi_optimizer.step()
                opt_info_rl.piLoss.append(pi_loss.item())
                opt_info_rl.piGradNorm.append(pi_grad_norm.item())
                opt_info_rl.piMu.extend(pi_mean[::10].numpy())
                opt_info_rl.piLogStd.extend(pi_log_std[::10].numpy())

            # Step Q's last because pi_loss.backward() uses them.
            self.q_optimizer.zero_grad()
            q_loss = q1_loss + q2_loss
            q_loss.backward()
            q_grad_norm = torch.nn.utils.clip_grad_norm_(
                chain(
                    () if self.stop_rl_conv_grad else
                    self.agent.conv.parameters(),
                    self.agent.q_fc1.parameters(),
                    self.agent.q_mlps.parameters(),
                ),
                self.clip_grad_norm,
            )
            self.q_optimizer.step()
            opt_info_rl.q1Loss.append(q1_loss.item())
            opt_info_rl.q2Loss.append(q2_loss.item())
            opt_info_rl.qGradNorm.append(q_grad_norm.item())
            opt_info_rl.q1.extend(q1[::10].numpy())  # Downsample for stats.
            opt_info_rl.q2.extend(q2[::10].numpy())
            opt_info_rl.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())

            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_targets(
                    q_tau=self.target_update_tau,
                    encoder_tau=self.encoder_update_tau,
                )

        return opt_info_rl, rl_samples

    def ul_optimize(self, itr, rl_samples=None):
        opt_info_ul = OptInfoUl(*([] for _ in range(len(OptInfoUl._fields))))
        n_ul_updates = self.compute_ul_update_schedule(itr)
        ul_bs = self.ul_batch_size
        n_rl_samples = (0 if rl_samples is None else len(
            rl_samples.agent_inputs.observation))
        for i in range(n_ul_updates):
            self.ul_update_counter += 1
            if self.ul_lr_scheduler is not None:
                self.ul_lr_scheduler.step(self.ul_update_counter)
            if n_rl_samples >= self.ul_batch_size * (i + 1):
                ul_samples = rl_samples[i * ul_bs:(i + 1) * ul_bs]
            else:
                ul_samples = None
            ul_loss, ul_accuracy, grad_norm = self.ul_optimize_one_step(
                ul_samples)
            opt_info_ul.ulLoss.append(ul_loss.item())
            opt_info_ul.ulAccuracy.append(ul_accuracy.item())
            opt_info_ul.ulGradNorm.append(grad_norm.item())
            if self.ul_update_counter % self.ul_target_update_interval == 0:
                update_state_dict(
                    self.ul_target_encoder,
                    self.ul_encoder.state_dict(),
                    self.ul_target_update_tau,
                )
        opt_info_ul.ulUpdates.append(self.ul_update_counter)
        return opt_info_ul

    def ul_optimize_one_step(self, samples=None):
        self.ul_optimizer.zero_grad()
        if samples is None:
            if self.ul_pri_alpha > 0:
                samples = self.replay_buffer.sample_batch(self.ul_batch_size,
                                                          mode="UL")
            else:
                samples = self.replay_buffer.sample_batch(self.ul_batch_size)

            # This is why need ul_delta_T == n_step_return, usually == 1;
            anchor = samples.agent_inputs.observation
            positive = samples.target_inputs.observation

            if self.ul_random_shift_prob > 0.0:
                anchor = random_shift(
                    imgs=anchor,
                    pad=self.ul_random_shift_pad,
                    prob=self.ul_random_shift_prob,
                )
                positive = random_shift(
                    imgs=positive,
                    pad=self.ul_random_shift_pad,
                    prob=self.ul_random_shift_prob,
                )

            anchor, positive = buffer_to((anchor, positive),
                                         device=self.agent.device)

        else:
            # Assume samples were already augmented in the RL loss.
            anchor = samples.agent_inputs.observation
            positive = samples.target_inputs.observation

        with torch.no_grad():
            c_positive, _pos_conv = self.ul_target_encoder(positive)
        c_anchor, _anc_conv = self.ul_encoder(anchor)
        logits = self.ul_contrast(c_anchor, c_positive)  # anchor mlp in here.

        labels = torch.arange(c_anchor.shape[0],
                              dtype=torch.long,
                              device=self.agent.device)
        invalid = samples.done  # shape: [B], if done, following state invalid
        labels[invalid] = IGNORE_INDEX
        ul_loss = self.c_e_loss(logits, labels)
        ul_loss.backward()
        if self.ul_clip_grad_norm is None:
            grad_norm = 0.0
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.ul_parameters(),
                                                       self.ul_clip_grad_norm)
        self.ul_optimizer.step()

        correct = torch.argmax(logits.detach(), dim=1) == labels
        accuracy = torch.mean(correct[~invalid].float())

        return ul_loss, accuracy, grad_norm

    def samples_to_buffer(self, samples):
        """Defines how to add data from sampler into the replay buffer. Called
        in optimize_agent() if samples are provided to that method."""
        observation = samples.env.observation
        samples_to_buffer = SamplesToBuffer(
            observation=observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
        )
        if self.bootstrap_timelimit:
            samples_to_buffer = SamplesToBufferTl(
                *samples_to_buffer, timeout=samples.env.env_info.timeout)
        return samples_to_buffer

    def examples_to_buffer(self, examples):
        observation = examples["observation"]
        example_to_buffer = SamplesToBuffer(
            observation=observation,
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
        )
        if self.bootstrap_timelimit:
            example_to_buffer = SamplesToBufferTl(
                *example_to_buffer, timeout=examples["env_info"].timeout)
        return example_to_buffer

    def samples_to_device(self, samples):
        """Only move the parts of samples which need to go to GPU."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action),
            device=self.agent.device,
        )
        device_samples = samples._replace(
            agent_inputs=agent_inputs,
            target_inputs=target_inputs,
            action=action,
        )
        return device_samples

    def random_shift_rl_samples(self, samples):
        if self.random_shift_prob == 0.0:
            return samples
        obs = samples.agent_inputs.observation
        target_obs = samples.target_inputs.observation
        aug_obs = random_shift(
            imgs=obs,
            pad=self.random_shift_pad,
            prob=self.random_shift_prob,
        )
        aug_target_obs = random_shift(
            imgs=target_obs,
            pad=self.random_shift_pad,
            prob=self.random_shift_prob,
        )
        aug_samples = samples._replace(
            agent_inputs=samples.agent_inputs._replace(observation=aug_obs),
            target_inputs=samples.target_inputs._replace(
                observation=aug_target_obs),
        )
        return aug_samples

    def q_loss(self, samples):
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= 1 - samples.timeout_n.float()

        # Run the convolution only once, return so pi_loss can use it.
        conv_out = self.agent.conv(samples.agent_inputs.observation)
        if self.stop_rl_conv_grad:
            conv_out = conv_out.detach()
        q_inputs = samples.agent_inputs._replace(observation=conv_out)

        # Q LOSS.
        q1, q2 = self.agent.q(*q_inputs, samples.action)
        with torch.no_grad():
            # Run the target convolution only once.
            target_conv_out = self.agent.target_conv(
                samples.target_inputs.observation)
            target_inputs = samples.target_inputs._replace(
                observation=target_conv_out)
            target_action, target_log_pi, _ = self.agent.pi(*target_inputs)
            target_q1, target_q2 = self.agent.target_q(*target_inputs,
                                                       target_action)
            min_target_q = torch.min(target_q1, target_q2)
            target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_value)
        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        return q1_loss, q2_loss, valid, conv_out, q1.detach(), q2.detach()

    def pi_alpha_loss(self, samples, valid, conv_out):
        # PI LOSS.
        # Uses detached conv out; avoid re-computing.
        conv_detach = conv_out.detach()
        agent_inputs = samples.agent_inputs._replace(observation=conv_detach)
        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            # new_action = new_action.detach()  # No grad.
            raise NotImplementedError
        # Re-use the detached latent.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())
        if self.reparameterize:
            pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        else:
            raise NotImplementedError
        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        # ALPHA LOSS.
        if self.target_entropy is not None:
            alpha_losses = -self._log_alpha * (log_pi.detach() +
                                               self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        return pi_loss, alpha_loss, pi_mean.detach(), pi_log_std.detach()

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def optim_state_dict(self):
        return dict(
            pi=self.pi_optimizer.state_dict(),
            q=self.q_optimizer.state_dict(),
            alpha=self.alpha_optimizer.state_dict(),
            log_alpha_value=self._log_alpha.detach().item(),
            ul=self.ul_optimizer.state_dict(),
        )

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi"])
        self.q_optimizer.load_state_dict(state_dict["q"])
        self.alpha_optimizer.load_state_dict(state_dict["alpha"])
        self.ul_optimizer.load_state_dict(state_dict["ul"])
        with torch.no_grad():
            self._log_alpha[:] = state_dict["log_alpha_value"]
            self._alpha = torch.exp(self._log_alpha.detach())

    def ul_parameters(self):
        yield from self.ul_encoder.parameters()
        yield from self.ul_contrast.parameters()

    def ul_named_parameters(self):
        yield from self.ul_encoder.named_parameters()
        yield from self.ul_contrast.named_parameters()

    def compute_ul_update_schedule(self, itr):
        if itr < self.min_itr_ul or itr > self.max_itr_ul:
            return 0
        remaining = (self.max_itr_ul - itr) / (
            self.max_itr_ul - self.min_itr_ul)  # from 1 to 0
        if "constant" in self.ul_update_schedule:
            # Format: "constant_X", for X num updates per RL itr.
            n_ul_updates = int(self.ul_update_schedule.split("_")[1])
        elif "front" in self.ul_update_schedule:
            # Format: "front_X_Y", for X updates first itr, Y updates rest.
            entries = self.ul_update_schedule.split("_")
            if itr == self.min_itr_ul:
                n_ul_updates = int(entries[1])
            else:
                n_ul_updates = int(entries[2])
        elif "linear" in self.ul_update_schedule:
            first = int(self.ul_update_schedule.split("_")[1])
            n_ul_updates = int(np.round(first * remaining))
        elif "quadratic" in self.ul_update_schedule:
            first = int(self.ul_update_schedule.split("_")[1])
            n_ul_updates = int(np.round(first * remaining**2))
        elif "cosine" in self.ul_update_schedule:
            first = int(self.ul_update_schedule.split("_")[1])
            n_ul_updates = int(
                np.round(first * math.sin(math.pi / 2 * remaining)))
        return n_ul_updates
예제 #8
0
class RadSacFromUl(RlAlgorithm):

    opt_info_fields = tuple(f for f in OptInfo._fields)

    def __init__(
        self,
        discount=0.99,
        batch_size=512,
        # replay_ratio=512,  # data_consumption / data_generation
        min_steps_learn=int(1e4),
        replay_size=int(1e5),
        target_update_tau=0.01,  # tau=1 for hard update.
        target_update_interval=2,
        actor_update_interval=2,
        OptimCls=torch.optim.Adam,
        initial_optim_state_dict=None,  # for all of them.
        action_prior="uniform",  # or "gaussian"
        reward_scale=1,
        target_entropy="auto",  # "auto", float, or None
        reparameterize=True,
        clip_grad_norm=1e6,
        n_step_return=1,
        bootstrap_timelimit=True,
        q_lr=1e-3,
        pi_lr=1e-3,
        alpha_lr=1e-4,
        q_beta=0.9,
        pi_beta=0.9,
        alpha_beta=0.5,
        alpha_init=0.1,
        encoder_update_tau=0.05,
        augmentation="random_shift",  # [None, "random_shift", "subpixel_shift"]
        random_shift_pad=4,  # how much to pad on each direction (like DrQ style)
        random_shift_prob=1.0,
        stop_conv_grad=False,
        max_pixel_shift=1.0,
    ):
        self.replay_ratio = batch_size  # Unless you want to change it.
        self._batch_size = batch_size
        del batch_size
        assert augmentation in [None, "random_shift", "subpixel_shift"]
        save__init__args(locals())

    def initialize(
        self, agent, n_itr, batch_spec, mid_batch_reset, examples, world_size=1, rank=0
    ):
        """Stores input arguments and initializes replay buffer and optimizer.
        Use in non-async runners.  Computes number of gradient updates per
        optimization iteration as `(replay_ratio * sampler-batch-size /
        training-batch_size)`."""
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = int(
            self.replay_ratio * sampler_bs / self.batch_size
        )
        logger.log(
            f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration."
        )
        self.min_itr_learn = self.min_steps_learn // sampler_bs
        agent.give_min_itr_learn(self.min_itr_learn)
        self.store_latent = agent.store_latent
        if self.store_latent:
            assert self.stop_conv_grad
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)

    def async_initialize(*args, **kwargs):
        raise NotImplementedError

    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank

        # Be very explicit about which parameters are optimized where.
        self.pi_optimizer = self.OptimCls(
            chain(
                self.agent.pi_fc1.parameters(),  # No conv.
                self.agent.pi_mlp.parameters(),
            ),
            lr=self.pi_lr,
            betas=(self.pi_beta, 0.999),
        )
        self.q_optimizer = self.OptimCls(
            chain(
                () if self.stop_conv_grad else self.agent.conv.parameters(),
                self.agent.q_fc1.parameters(),
                self.agent.q_mlps.parameters(),
            ),
            lr=self.q_lr,
            betas=(self.q_beta, 0.999),
        )

        self._log_alpha = torch.tensor(np.log(self.alpha_init), requires_grad=True)
        self._alpha = torch.exp(self._log_alpha.detach())
        self.alpha_optimizer = self.OptimCls(
            (self._log_alpha,), lr=self.alpha_lr, betas=(self.alpha_beta, 0.999)
        )

        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
        if self.initial_optim_state_dict is not None:
            self.load_optim_state_dict(self.initial_optim_state_dict)
        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(
                dim=np.prod(self.agent.env_spaces.action.shape), std=1.0
            )

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        """
        Allocates replay buffer using examples and with the fields in `SamplesToBuffer`
        namedarraytuple.
        POSSIBLY CHANGE TO FRAME-BASED BUFFER (only if need memory, speed is fine).
        """
        if async_:
            raise NotImplementedError
        example_to_buffer = self.examples_to_buffer(examples)
        ReplayCls = (
            TlUniformReplayBuffer if self.bootstrap_timelimit else UniformReplayBuffer
        )
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        self.replay_buffer = ReplayCls(**replay_kwargs)

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        """
        Extracts the needed fields from input samples and stores them in the
        replay buffer.  Then samples from the replay buffer to train the agent
        by gradient updates (with the number of updates determined by replay
        ratio, sampler batch size, and training batch size).

        DIFFERENCES FROM SAC:
          -Organizes optimizers a little differently, clarifies which parameters.
        """
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            # Sample from the replay buffer, center crop, and move to GPU.
            samples_from_replay = self.replay_buffer.sample_batch(self.batch_size)
            loss_samples = self.data_aug_loss_samples(samples_from_replay)
            loss_samples = self.samples_to_device(loss_samples)

            # Q-loss includes computing some values used in pi-loss.
            q1_loss, q2_loss, valid, conv_out, q1, q2 = self.q_loss(loss_samples)

            if self.update_counter % self.actor_update_interval == 0:
                pi_loss, alpha_loss, pi_mean, pi_log_std = self.pi_alpha_loss(
                    loss_samples, valid, conv_out
                )
                if alpha_loss is not None:
                    self.alpha_optimizer.zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer.step()
                    self._alpha = torch.exp(self._log_alpha.detach())
                    opt_info.alpha.append(self._alpha.item())

                self.pi_optimizer.zero_grad()
                pi_loss.backward()
                pi_grad_norm = torch.nn.utils.clip_grad_norm_(
                    chain(
                        self.agent.pi_fc1.parameters(),
                        self.agent.pi_mlp.parameters(),
                    ),
                    self.clip_grad_norm,
                )
                self.pi_optimizer.step()
                opt_info.piLoss.append(pi_loss.item())
                opt_info.piGradNorm.append(pi_grad_norm.item())
                opt_info.piMu.extend(pi_mean[::10].numpy())
                opt_info.piLogStd.extend(pi_log_std[::10].numpy())

            # Step Q's last because pi_loss.backward() uses them.
            self.q_optimizer.zero_grad()
            q_loss = q1_loss + q2_loss
            q_loss.backward()
            q_grad_norm = torch.nn.utils.clip_grad_norm_(
                chain(
                    () if self.stop_conv_grad else self.agent.conv.parameters(),
                    self.agent.q_fc1.parameters(),
                    self.agent.q_mlps.parameters(),
                ),
                self.clip_grad_norm,
            )
            self.q_optimizer.step()
            opt_info.q1Loss.append(q1_loss.item())
            opt_info.q2Loss.append(q2_loss.item())
            opt_info.qGradNorm.append(q_grad_norm.item())
            opt_info.q1.extend(q1[::10].numpy())  # Downsample for stats.
            opt_info.q2.extend(q2[::10].numpy())
            opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())

            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_targets(
                    q_tau=self.target_update_tau,
                    encoder_tau=self.encoder_update_tau,
                )

        return opt_info

    def samples_to_buffer(self, samples):
        """Defines how to add data from sampler into the replay buffer. Called
        in optimize_agent() if samples are provided to that method."""
        if self.store_latent:
            observation = samples.agent.agent_info.conv
        else:
            observation = samples.env.observation
        samples_to_buffer = SamplesToBuffer(
            observation=observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
        )
        if self.bootstrap_timelimit:
            samples_to_buffer = SamplesToBufferTl(
                *samples_to_buffer, timeout=samples.env.env_info.timeout
            )
        return samples_to_buffer

    def examples_to_buffer(self, examples):
        if self.store_latent:
            observation = examples["agent_info"].conv
        else:
            observation = examples["observation"]
        example_to_buffer = SamplesToBuffer(
            observation=observation,
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
        )
        if self.bootstrap_timelimit:
            example_to_buffer = SamplesToBufferTl(
                *example_to_buffer, timeout=examples["env_info"].timeout
            )
        return example_to_buffer

    def samples_to_device(self, samples):
        """Only move the parts of samples which need to go to GPU."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action),
            device=self.agent.device,
        )
        device_samples = samples._replace(
            agent_inputs=agent_inputs,
            target_inputs=target_inputs,
            action=action,
        )
        return device_samples

    def data_aug_loss_samples(self, samples):
        """Perform data augmentation (on CPU)."""
        if self.augmentation is None:
            return samples

        obs = samples.agent_inputs.observation
        target_obs = samples.target_inputs.observation

        if self.augmentation == "random_shift":
            aug_obs = random_shift(
                imgs=obs,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
            aug_target_obs = random_shift(
                imgs=target_obs,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
        elif self.augmentation == "subpixel_shift":
            aug_obs = subpixel_shift(
                imgs=obs,
                max_shift=self.max_pixel_shift,
            )
            aug_target_obs = subpixel_shift(
                imgs=target_obs,
                max_shift=self.max_pixel_shift,
            )
        else:
            raise NotImplementedError

        aug_samples = samples._replace(
            agent_inputs=samples.agent_inputs._replace(observation=aug_obs),
            target_inputs=samples.target_inputs._replace(observation=aug_target_obs),
        )

        return aug_samples

    def q_loss(self, samples):
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= 1 - samples.timeout_n.float()

        # Run the convolution only once, return so pi_loss can use it.
        if self.store_latent:
            conv_out = None
            q_inputs = samples.agent_inputs
        else:
            conv_out = self.agent.conv(samples.agent_inputs.observation)
            if self.stop_conv_grad:
                conv_out = conv_out.detach()
            q_inputs = samples.agent_inputs._replace(observation=conv_out)

        # Q LOSS.
        q1, q2 = self.agent.q(*q_inputs, samples.action)
        with torch.no_grad():
            # Run the target convolution only once.
            if self.store_latent:
                target_inputs = samples.target_inputs
            else:
                target_conv_out = self.agent.target_conv(
                    samples.target_inputs.observation
                )
                target_inputs = samples.target_inputs._replace(
                    observation=target_conv_out
                )
            target_action, target_log_pi, _ = self.agent.pi(*target_inputs)
            target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action)
            min_target_q = torch.min(target_q1, target_q2)
            target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount ** self.n_step_return
        y = (
            self.reward_scale * samples.return_
            + (1 - samples.done_n.float()) * disc * target_value
        )
        q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid)
        q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid)

        return q1_loss, q2_loss, valid, conv_out, q1.detach(), q2.detach()

    def pi_alpha_loss(self, samples, valid, conv_out):
        # PI LOSS.
        # Uses detached conv; avoid re-computing.
        if self.store_latent:
            agent_inputs = samples.agent_inputs
        else:
            conv_detach = conv_out.detach()  # Always detached in actor.
            agent_inputs = samples.agent_inputs._replace(observation=conv_detach)

        new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            # new_action = new_action.detach()  # No grad.
            raise NotImplementedError
        # Re-use the detached latent.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())
        if self.reparameterize:
            pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        else:
            raise NotImplementedError
        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        # ALPHA LOSS.
        if self.target_entropy is not None:
            alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        return pi_loss, alpha_loss, pi_mean.detach(), pi_log_std.detach()

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action))
            )
        return prior_log_pi

    def optim_state_dict(self):
        return dict(
            pi=self.pi_optimizer.state_dict(),
            q=self.q_optimizer.state_dict(),
            alpha=self.alpha_optimizer.state_dict(),
            log_alpha_value=self._log_alpha.detach().item(),
        )

    def load_optim_state_dict(self, state_dict):
        self.pi_optimizer.load_state_dict(state_dict["pi"])
        self.q_optimizer.load_state_dict(state_dict["q"])
        self.alpha_optimizer.load_state_dict(state_dict["alpha"])
        with torch.no_grad():
            self._log_alpha[:] = state_dict["log_alpha_value"]
            self._alpha = torch.exp(self._log_alpha.detach())
예제 #9
0
class SAC(RlAlgorithm):

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
        self,
        discount=0.99,
        batch_size=256,
        min_steps_learn=int(1e4),
        replay_size=int(1e6),
        training_ratio=256,  # data_consumption / data_generation
        target_update_tau=0.005,  # tau=1 for hard update.
        target_update_interval=1,  # interval=1000 for hard update.
        learning_rate=3e-4,
        OptimCls=torch.optim.Adam,
        optim_kwargs=None,
        initial_optim_state_dict=None,
        action_prior="uniform",  # or "gaussian"
        reward_scale=1,
        reparameterize=True,
        clip_grad_norm=1e6,
        policy_output_regularization=0.001,
        n_step_return=1,
    ):
        if optim_kwargs is None:
            optim_kwargs = dict()
        assert action_prior in ["uniform", "gaussian"]
        save__init__args(locals())
        self.update_counter = 0

    def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples):
        if agent.recurrent:
            raise NotImplementedError
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        self.optimizer = self.OptimCls(agent.parameters(),
                                       lr=self.learning_rate,
                                       **self.optim_kwargs)
        if self.initial_optim_state_dict is not None:
            self.optimizer.load_state_dict(self.initial_optim_state_dict)

        sample_bs = batch_spec.size
        train_bs = self.batch_size
        assert (self.training_ratio * sample_bs) % train_bs == 0
        self.updates_per_optimize = int(
            (self.training_ratio * sample_bs) // train_bs)
        logger.log(
            f"From sampler batch size {sample_bs}, training "
            f"batch size {train_bs}, and training ratio "
            f"{self.training_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = self.min_steps_learn // sample_bs
        self.agent.give_min_itr_learn(self.min_itr_learn)

        example_to_buffer = SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
        )
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            B=batch_spec.B,
            n_step_return=self.n_step_return,
        )
        self.replay_buffer = UniformReplayBuffer(**replay_kwargs)

        if self.action_prior == "gaussian":
            self.action_prior_distribution = Gaussian(
                dim=agent.env_spaces.action.size, std=1.)

    def optimize_agent(self, itr, samples=None):
        if samples is not None:
            samples_to_buffer = SamplesToBuffer(
                observation=samples.env.observation,
                action=samples.agent.action,
                reward=samples.env.reward,
                done=samples.env.done,
            )
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            self.update_counter += 1
            samples_from_replay = self.replay_buffer.sample_batch(
                self.batch_size)
            self.optimizer.zero_grad()
            losses, values = self.loss(samples_from_replay)
            for loss in losses:
                loss.backward()
            grad_norms = [
                torch.nn.utils.clip_grad_norm_(ps, self.clip_grad_norm)
                for ps in self.agent.parameters_by_model()
            ]
            self.optimizer.step()
            self.append_opt_info_(opt_info, losses, grad_norms, values)
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)
        return opt_info

    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action),
            device=self.agent.device)  # Move to device once, re-use.
        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs)
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        v = self.agent.v(*agent_inputs)
        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())
        v_target = (min_log_target - log_pi +
                    prior_log_pi).detach()  # No grad.
        v_loss = 0.5 * valid_mean((v - v_target)**2, valid)

        if self.reparameterize:
            pi_losses = log_pi - min_log_target
        else:
            pi_factor = (v - v_target).detach()  # No grad.
            pi_losses = log_pi * pi_factor
        if self.policy_output_regularization > 0:
            pi_losses += torch.sum(
                self.policy_output_regularization * 0.5 * pi_mean**2 +
                pi_log_std**2,
                dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        losses = (q1_loss, q2_loss, v_loss, pi_loss)
        values = tuple(val.detach()
                       for val in (q1, q2, v, pi_mean, pi_log_std))
        return losses, values

    def get_action_prior(self, action):
        if self.action_prior == "uniform":
            prior_log_pi = 0.0
        elif self.action_prior == "gaussian":
            prior_log_pi = self.action_prior_distribution.log_likelihood(
                action, GaussianDistInfo(mean=torch.zeros_like(action)))
        return prior_log_pi

    def append_opt_info_(self, opt_info, losses, grad_norms, values):
        """In-place."""
        q1_loss, q2_loss, v_loss, pi_loss = losses
        q1_grad_norm, q2_grad_norm, v_grad_norm, pi_grad_norm = grad_norms
        q1, q2, v, pi_mean, pi_log_std = values
        opt_info.q1Loss.append(q1_loss.item())
        opt_info.q2Loss.append(q2_loss.item())
        opt_info.vLoss.append(v_loss.item())
        opt_info.piLoss.append(pi_loss.item())
        opt_info.q1GradNorm.append(q1_grad_norm)
        opt_info.q2GradNorm.append(q2_grad_norm)
        opt_info.vGradNorm.append(v_grad_norm)
        opt_info.piGradNorm.append(pi_grad_norm)
        opt_info.q1.extend(q1[::10].numpy())  # Downsample for stats.
        opt_info.q2.extend(q2[::10].numpy())
        opt_info.v.extend(v[::10].numpy())
        opt_info.piMu.extend(pi_mean[::10].numpy())
        opt_info.piLogStd.extend(pi_log_std[::10].numpy())
        opt_info.qMeanDiff.append(torch.mean(abs(q1 - q2)).item())