コード例 #1
0
    def __init__(self,
                 env_fn,
                 Actor=core.DiscreteMLPActor,
                 Critic=core.DiscreteMLPQFunction,
                 ac_kwargs=dict(),
                 seed=0,
                 steps_per_epoch=4000,
                 epochs=100,
                 replay_size=int(5e5),
                 gamma=0.99,
                 polyak=0.995,
                 lr=1e-5,
                 alpha=0.2,
                 batch_size=100,
                 start_steps=10000,
                 update_after=1000,
                 update_times_every_step=50,
                 num_test_episodes=10,
                 max_ep_len=2000,
                 logger_kwargs=dict(),
                 save_freq=1,
                 automatic_entropy_tuning=True,
                 use_gpu=False,
                 gpu_parallel=False,
                 show_test_render=False,
                 last_save_path=None,
                 state_of_art_model=False,
                 **kwargs):
        """
        Soft Actor-Critic (SAC)


        Args:
            env_fn : A function which creates a copy of the environment.
                The environment must satisfy the OpenAI Gym API.

            actor_critic: The constructor method for a PyTorch Module with an ``act``
                method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
                The ``act`` method and ``pi`` module should accept batches of
                observations as inputs, and ``q1`` and ``q2`` should accept a batch
                of observations and a batch of actions as inputs. When called,
                ``act``, ``q1``, and ``q2`` should return:

                ===========  ================  ======================================
                Call         Output Shape      Description
                ===========  ================  ======================================
                ``act``      (batch, act_dim)  | Numpy array of actions for each
                                               | observation.
                ``q1``       (batch,)          | Tensor containing one current estimate
                                               | of Q* for the provided observations
                                               | and actions. (Critical: make sure to
                                               | flatten this!)
                ``q2``       (batch,)          | Tensor containing the other current
                                               | estimate of Q* for the provided observations
                                               | and actions. (Critical: make sure to
                                               | flatten this!)
                ===========  ================  ======================================

                Calling ``pi`` should return:

                ===========  ================  ======================================
                Symbol       Shape             Description
                ===========  ================  ======================================
                ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                               | given observations.
                ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                               | actions in ``a``. Importantly: gradients
                                               | should be able to flow back into ``a``.
                ===========  ================  ======================================

            ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
                you provided to SAC.

            seed (int): Seed for random number generators.

            steps_per_epoch (int): Number of steps of interaction (state-action pairs)
                for the agent and the environment in each epoch.

            epochs (int): Number of epochs to run and train agent.

            replay_size (int): Maximum length of replay buffer.

            gamma (float): Discount factor. (Always between 0 and 1.)

            polyak (float): Interpolation factor in polyak averaging for target
                networks. Target networks are updated towards main networks
                according to:

                .. math:: \\theta_{\\text{targ}} \\leftarrow
                    \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

                where :math:`\\rho` is polyak. (Always between 0 and 1, usually
                close to 1.)

            lr (float): Learning rate (used for both policy and value learning).

            alpha (float): Entropy regularization coefficient. (Equivalent to
                inverse of reward scale in the original SAC paper.)

            batch_size (int): Minibatch size for SGD.

            start_steps (int): Number of steps for uniform-random action selection,
                before running real policy. Helps exploration.

            update_after (int): Number of env interactions to collect before
                starting to do gradient descent updates. Ensures replay buffer
                is full enough for useful updates.

            update_times_every_step (int): Number of env interactions that should elapse
                between gradient descent updates. Note: Regardless of how long
                you wait between updates, the ratio of env steps to gradient steps
                is locked to 1.

            num_test_episodes (int): Number of episodes to test the deterministic
                policy at the end of each epoch.

            max_ep_len (int): Maximum length of trajectory / episode / rollout.

            logger_kwargs (dict): Keyword args for EpochLogger.

            save_freq (int): How often (in terms of gap between epochs) to save
                the current policy and value function.

        """
        self.ac_kwargs = ac_kwargs
        self.seed = seed
        self.steps_per_epoch = steps_per_epoch
        self.epochs = epochs
        self.replay_size = replay_size
        self.gamma = gamma
        self.polyak = polyak
        self.lr = lr
        self.alpha = alpha
        self.batch_size = batch_size
        self.start_steps = start_steps
        self.update_after = update_after
        self.update_times_every_step = update_times_every_step
        self.num_test_episodes = num_test_episodes
        self.max_ep_len = max_ep_len
        self.logger_kwargs = logger_kwargs
        self.save_freq = save_freq
        self.automatic_entropy_tuning = automatic_entropy_tuning
        self.use_gpu = use_gpu
        self.gpu_parallel = gpu_parallel
        self.show_test_render = show_test_render
        self.last_save_path = last_save_path
        self.kwargs = kwargs

        self.logger = EpochLogger(**logger_kwargs)
        self.logger.save_config(locals())

        torch.manual_seed(seed)
        np.random.seed(seed)

        self.env = env_fn()
        self.test_env = env_fn()

        self.env.seed(seed)
        # env.seed(seed)
        # test_env.seed(seed)
        self.obs_dim = self.env.observation_space.shape
        self.act_dim = self.env.action_space.n

        # Create actor-critic module and target networks
        self.state_of_art_model = state_of_art_model
        if self.state_of_art_model:
            self.actor = Actor(**ac_kwargs)
            self.critic1 = Critic(**ac_kwargs)
            self.critic2 = Critic(**ac_kwargs)

            self.critic1_targ = deepcopy(self.critic1)
            self.critic2_targ = deepcopy(self.critic2)
        else:
            self.actor = Actor(self.obs_dim, self.act_dim, **ac_kwargs)
            self.critic1 = Critic(self.obs_dim, self.act_dim, **ac_kwargs)
            self.critic2 = Critic(self.obs_dim, self.act_dim, **ac_kwargs)

            self.critic1_targ = deepcopy(self.critic1)
            self.critic2_targ = deepcopy(self.critic2)
        # gpu是否使用
        if torch.cuda.is_available():
            self.device = torch.device("cuda" if self.use_gpu else "cpu")
            if gpu_parallel:
                self.actor = torch.nn.DataParallel(self.actor)
                self.critic1 = torch.nn.DataParallel(self.critic1)
                self.critic2 = torch.nn.DataParallel(self.critic2)
                self.critic1_targ = torch.nn.DataParallel(self.critic1_targ)
                self.critic2_targ = torch.nn.DataParallel(self.critic2_targ)
        else:
            self.use_gpu = False
            self.gpu_parallel = False
            self.device = torch.device("cpu")
        # Freeze target networks with respect to optimizers (only update via polyak averaging)
        for p in self.critic1_targ.parameters():
            p.requires_grad = False
        for p in self.critic2_targ.parameters():
            p.requires_grad = False
        self.actor.to(self.device)
        self.critic1.to(self.device)
        self.critic2.to(self.device)
        self.critic1_targ.to(self.device)
        self.critic2_targ.to(self.device)

        # Experience buffer
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim,
                                          act_dim=1,
                                          size=replay_size,
                                          device=self.device)

        # # List of parameters for both Q-networks (save this for convenience)
        # q_params = itertools.chain(critic1.parameters(), critic2.parameters())

        if self.automatic_entropy_tuning:
            # we set the max possible entropy as the target entropy
            self.target_entropy = -np.log((1.0 / self.act_dim)) * 0.98
            self.log_alpha = torch.zeros(1,
                                         requires_grad=True,
                                         device=self.device)
            self.alpha = self.log_alpha.exp()
            self.alpha_optim = Adam([self.log_alpha], lr=lr, eps=1e-4)

        # Count variables (protip: try to get a feel for how different size networks behave!)
        var_counts = tuple(
            core.count_vars(module)
            for module in [self.actor, self.critic1, self.critic2])
        self.logger.log(
            '\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' %
            var_counts)

        # Set up optimizers for policy and q-function
        self.pi_optimizer = Adam(self.actor.parameters(), lr=lr)
        self.q1_optimizer = Adam(self.critic1.parameters(), lr=lr)
        self.q2_optimizer = Adam(self.critic2.parameters(), lr=lr)

        if last_save_path is not None:
            checkpoints = torch.load(last_save_path)
            self.epoch = checkpoints['epoch']
            self.actor.load_state_dict(checkpoints['actor'])
            self.critic1.load_state_dict(checkpoints['critic1'])
            self.critic2.load_state_dict(checkpoints['critic2'])
            self.pi_optimizer.load_state_dict(checkpoints['pi_optimizer'])
            self.q1_optimizer.load_state_dict(checkpoints['q1_optimizer'])
            self.q2_optimizer.load_state_dict(checkpoints['q2_optimizer'])
            self.critic1_targ.load_state_dict(checkpoints['critic1_targ'])
            self.critic2_targ.load_state_dict(checkpoints['critic2_targ'])

            # last_best_Return_per_local = checkpoints['last_best_Return_per_local']
            print("succesfully load last prameters")
        else:
            self.epoch = 0

            print("Dont load last prameters.")
コード例 #2
0
ファイル: sac.py プロジェクト: eladsar/spinningup
def sac(env_fn,
        actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=100,
        replay_size=int(1e6),
        gamma=0.99,
        polyak=0.995,
        lr=1e-3,
        alpha=0.2,
        batch_size=256,
        start_steps=10000,
        update_after=1000,
        update_every=50,
        num_test_episodes=10,
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=1,
        device='cuda',
        override=True):
    """
    Soft Actor-Critic (SAC)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with an ``act``
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of
            observations as inputs, and ``q1`` and ``q2`` should accept a batch
            of observations and a batch of actions as inputs. When called,
            ``act``, ``q1``, and ``q2`` should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each
                                           | observation.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

            Calling ``pi`` should return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                           | actions in ``a``. Importantly: gradients
                                           | should be able to flow back into ``a``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
            you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long
            you wait between updates, the ratio of env steps to gradient steps
            is locked to 1.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    device = torch.device(device)
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Create actor-critic module and target networks
    ac = actor_critic(env.observation_space, env.action_space,
                      **ac_kwargs).to(device)
    ac_targ = deepcopy(ac)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size,
                                 device=device)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(
        core.count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' %
               var_counts)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data[
            'obs2'], data['done']

        q1 = ac.q1(o, a)
        q2 = ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = ac.pi(o2)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.detach().cpu().numpy(),
                      Q2Vals=q2.detach().cpu().numpy())

        return loss_q, q_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (alpha * logp_pi - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=logp_pi.detach().cpu().numpy())

        return loss_pi, pi_info

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q_optimizer = Adam(q_params, lr=lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update(data):
        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        q_optimizer.step()

        # Record things
        logger.store(LossQ=loss_q.item(), **q_info)

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Record things
        logger.store(LossPi=loss_pi.item(), **pi_info)

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32, device=device),
                      deterministic)

    def test_agent():
        for j in range(num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for t in tqdm(range(total_steps)):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, ep_ret, ep_len = env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch)

        # End of epoch handling
        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs):
                # logger.save_state({'env': env, 'rb': replay_buffer.get_state()}, None)
                logger.save_state({'env': env}, None if override else epoch)

            # Test the performance of the deterministic version of the agent.
            test_agent()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()
コード例 #3
0
ファイル: sac.py プロジェクト: schrammlb2/spinningup-clone
def sac(env_fn,
        actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=250,
        replay_size=int(1e6),
        gamma=0.99,
        polyak=0.995,
        lr=1e-3,
        alpha=0.2,
        batch_size=100,
        start_steps=10000,
        update_after=1000,
        update_every=50,
        num_test_episodes=3,
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=1,
        use_grad_penalty=True,
        penalty_scale=.025):
    """
    Soft Actor-Critic (SAC)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with an ``act`` 
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of 
            observations as inputs, and ``q1`` and ``q2`` should accept a batch 
            of observations and a batch of actions as inputs. When called, 
            ``act``, ``q1``, and ``q2`` should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each 
                                           | observation.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current 
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

            Calling ``pi`` should return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                           | actions in ``a``. Importantly: gradients
                                           | should be able to flow back into ``a``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object 
            you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to 
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long 
            you wait between updates, the ratio of env steps to gradient steps 
            is locked to 1.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Create actor-critic module and target networks
    ac = actor_critic(env.observation_space, env.action_space,
                      **ac_kwargs).to(device=DEVICE)
    ac_targ = deepcopy(ac).to(device=DEVICE)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    clip_val = 10
    for p in ac.parameters():
        p.register_hook(lambda grad: torch.clamp(grad, -clip_val, clip_val))
        p.register_hook(lambda grad: torch.where(
            grad != grad, torch.tensor(0., device=DEVICE), grad))

    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(
        core.count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' %
               var_counts)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data[
            'obs2'], data['done']

        q1 = ac.q1(o, a)
        q2 = ac.q2(o, a)

        # Bellman backup for Q functions
        if use_grad_penalty:
            a2, logp_a2 = ac.pi(o2)

            q1_pi_targ = gradient_penalty(ac_targ.q1,
                                          o2,
                                          a2,
                                          epsilon=penalty_scale)
            q2_pi_targ = gradient_penalty(ac_targ.q2,
                                          o2,
                                          a2,
                                          epsilon=penalty_scale)

            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)
        else:
            with torch.no_grad():
                # Target actions come from *current* policy
                a2, logp_a2 = ac.pi(o2)

                # Target Q-values
                q1_pi_targ = ac_targ.q1(o2, a2)
                q2_pi_targ = ac_targ.q2(o2, a2)

                q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
                backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.cpu().detach().numpy(),
                      Q2Vals=q2.cpu().detach().numpy())

        state_grad_norm, action_grad_norm = gradient_norm(ac.q1, o, a)
        logger.store(StateGradNorm=state_grad_norm.cpu().detach().numpy())
        logger.store(ActionGradNorm=action_grad_norm.cpu().detach().numpy())

        return loss_q, q_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        log_min = -87.3 + 10
        #Approximation of the lowest number log(x) where 1/x is representable in float32
        # Clip to prevent NaN errors.
        # Approximate largest value is -87.3, leave some room just in case
        clamped_logp_pi = torch.clamp(logp_pi, log_min, -log_min)

        # Entropy-regularized policy loss
        loss_pi = (alpha * clamped_logp_pi - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=logp_pi.cpu().detach().numpy())

        return loss_pi, pi_info

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q_optimizer = Adam(q_params, lr=lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update(data):
        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        torch.nn.utils.clip_grad_value_(q_params, clip_val)
        q_optimizer.step()

        # Record things
        logger.store(LossQ=loss_q.item(), **q_info)

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        torch.nn.utils.clip_grad_value_(ac.pi.parameters(), clip_val)
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Record things
        logger.store(LossPi=loss_pi.item(), **pi_info)

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32, device=DEVICE),
                      deterministic)

    def test_agent():
        for j in range(num_test_episodes):
            try:
                test_env = env_fn()
                o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            except:
                pdb.set_trace()

            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

    def test_agent_transfer(
        scale=.1,
        log_lambda=lambda a, b: logger.store(TransferEpRet=a, TransferEpLen=b)
    ):
        import time
        for j in range(num_test_episodes):
            succeeded = False
            i = 0
            while not succeeded:
                try:
                    test_env = env_fn(transfer=True, scale=scale)
                    o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
                    clear_xml(test_env)
                    succeeded = True
                    #When parallelizing, sometimes different
                except:
                    i += 1
                    if i > 10:
                        pdb.set_trace()
                    time.sleep(1)

            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            log_lambda(ep_ret, ep_len)
            # logger.store(TransferEpRet=ep_ret, TransferEpLen=ep_len)
        # logger.store(WorstTransferEpRet=worst_case)

    def test_agent_random(
        scale=.1,
        log_lambda=lambda a, b: logger.store(RandomEpRet=a, RandomEpLen=b)):
        for j in range(num_test_episodes):
            try:
                o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            except:
                pdb.set_trace()
            o += np.random.normal(0, scale, o.shape)
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                o += np.random.normal(0, scale, o.shape)
                ep_ret += r
                ep_len += 1

            log_lambda(ep_ret, ep_len)

    # def test_agent_adversarial_noise():
    #     def adv_step(o):
    #         tens_o = torch.as_tensor(o, device=DEVICE)
    #         def v(obs):
    #             action = ac.pi(obs, deterministic=True, with_logprob=False)[0]
    #             return ac.q1(tens_o, action)
    #         # v = lambda obs: ac.q1(tens_o, ac.pi(obs, deterministic=True, with_logprob=False))
    #             #Value of policy given perturbed observation
    #         adv_obs = state_gradient(v, tens_o, epsilon=2e-2)
    #             #Bounded adversarial perturbation to observation
    #         return adv_obs.cpu().numpy()
    #     for j in range(num_test_episodes):
    #         o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
    #         o = adv_step(o)
    #         while not(d or (ep_len == max_ep_len)):
    #             # Take deterministic actions at test time
    #             o, r, d, _ = test_env.step(get_action(o, True))
    #             o = adv_step(o)
    #             ep_ret += r
    #             ep_len += 1
    #         logger.store(AdvEpRet=ep_ret, AdvEpLen=ep_len)

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, _ = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            try:
                o, ep_ret, ep_len = env.reset(), 0, 0
            except:
                pdb.set_trace()

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch)

        # End of epoch handling
        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs):
                logger.save_state({'env': env}, None)

            # Test the performance of the deterministic version of the agent.
            # scales = [(i+1) for i in range(5)]
            test_agent()

            test_agent_transfer(scale=1 * .1,
                                log_lambda=lambda a, b: logger.store(
                                    Transfer1EpRet=a, Transfer1EpLen=b))
            test_agent_transfer(scale=3 * .1,
                                log_lambda=lambda a, b: logger.store(
                                    Transfer3EpRet=a, Transfer3EpLen=b))
            test_agent_transfer(scale=5 * .1,
                                log_lambda=lambda a, b: logger.store(
                                    Transfer5EpRet=a, Transfer5EpLen=b))
            test_agent_transfer(scale=7 * .1,
                                log_lambda=lambda a, b: logger.store(
                                    Transfer7EpRet=a, Transfer7EpLen=b))

            test_agent_random(scale=1 * .03,
                              log_lambda=lambda a, b: logger.store(
                                  Random1EpRet=a, Random1EpLen=b))
            test_agent_random(scale=3 * .03,
                              log_lambda=lambda a, b: logger.store(
                                  Random3EpRet=a, Random3EpLen=b))
            test_agent_random(scale=5 * .03,
                              log_lambda=lambda a, b: logger.store(
                                  Random5EpRet=a, Random5EpLen=b))
            test_agent_random(scale=7 * .03,
                              log_lambda=lambda a, b: logger.store(
                                  Random7EpRet=a, Random7EpLen=b))

            # transfer_ep_ret_names = ['Transfer' + str(scale) + 'EpRet' for scale in scales]
            # transfer_ep_len_names = ['Transfer' + str(scale) + 'EpLen' for scale in scales]
            # # transfer_logger_lambda = [lambda ep_ret, ep_len: \
            # #     logger.store(**{'Transfer' + str(scale) + 'EpRet': ep_ret, 'Transfer' + str(scale) + 'EpLen': ep_len}) \
            # #     for scale in scales]
            # transfer_logger_lambda = [lambda ep_ret, ep_len: \
            #     logger.store(**{trans_ep_ret: ep_ret, trans_ep_len: ep_len}) \
            #     for trans_ep_ret, trans_ep_len in zip(transfer_ep_ret_names, transfer_ep_len_names)]

            # for scale, lam in zip(scales, transfer_logger_lambda):
            #     test_agent_transfer(scale=scale*.1, log_lambda=lam)

            # random_logger_lambda = [lambda ep_ret, ep_len: \
            #     logger.store(**{'Random' + str(scale) + 'EpRet': ep_ret, 'Random' + str(scale) + 'EpLen': ep_len}) \
            #     for scale in scales]

            # for scale, lam in zip(scales, random_logger_lambda):
            #     test_agent_random(scale=scale*.03, log_lambda=lam)

            scales = [1, 3, 5, 7]
            ep_ret_names = [
                task + str(scale) + 'EpRet' for task in ['Transfer', 'Random']
                for scale in scales
            ]
            ep_len_names = [
                task + str(scale) + 'EpLen' for task in ['Transfer', 'Random']
                for scale in scales
            ]

            for ret in ep_ret_names:
                logger.log_tabular(ret, with_min_and_max=True)
            for length in ep_len_names:
                logger.log_tabular(length, average_only=True)
            # test_agent_adversarial_noise()

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            # logger.log_tabular('TestEpRet', with_min_and_max=True)
            # logger.log_tabular('TransferEpRet', with_min_and_max=True)
            # logger.log_tabular('RandomEpRet', with_min_and_max=True)
            # logger.log_tabular('AdvEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            # logger.log_tabular('TestEpLen', average_only=True)
            # logger.log_tabular('TransferEpLen', average_only=True)
            # logger.log_tabular('RandomEpLen', average_only=True)
            # logger.log_tabular('AdvEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)

            logger.log_tabular('StateGradNorm', with_min_and_max=True)
            logger.log_tabular('ActionGradNorm', with_min_and_max=True)

            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()
コード例 #4
0
def sac_discrete(env_fn,
                 Actor=core.DiscreteMLPActor,
                 Critic=core.DiscreteMLPQFunction,
                 ac_kwargs=dict(),
                 seed=0,
                 steps_per_epoch=4000,
                 epochs=100,
                 replay_size=int(1e6),
                 gamma=0.99,
                 polyak=0.995,
                 lr=0.0003,
                 alpha=0.2,
                 batch_size=100,
                 start_steps=1000,
                 update_after=1000,
                 update_times_every_step=50,
                 num_test_episodes=10,
                 max_ep_len=1000,
                 logger_kwargs=dict(),
                 save_freq=1,
                 automatic_entropy_tuning=True,
                 use_gpu=False,
                 gpu_parallel=False,
                 total_eps=100):
    """
    Soft Actor-Critic (SAC)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with an ``act``
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of
            observations as inputs, and ``q1`` and ``q2`` should accept a batch
            of observations and a batch of actions as inputs. When called,
            ``act``, ``q1``, and ``q2`` should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each
                                           | observation.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

            Calling ``pi`` should return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                           | actions in ``a``. Importantly: gradients
                                           | should be able to flow back into ``a``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
            you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy to collect data. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_times_every_step (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long
            you wait between updates, the ratio of env steps to gradient steps
            is locked to 1.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    env.seed(seed)
    # env.seed(seed)
    # test_env.seed(seed)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.n

    # Create actor-critic module and target networks
    actor = Actor(obs_dim[0], act_dim, **ac_kwargs)
    critic1 = Critic(obs_dim[0], act_dim, **ac_kwargs)
    critic2 = Critic(obs_dim[0], act_dim, **ac_kwargs)
    critic1_targ = deepcopy(critic1)
    critic2_targ = deepcopy(critic2)
    # gpu是否使用
    if torch.cuda.is_available():
        device = torch.device("cuda" if use_gpu else "cpu")
        if gpu_parallel:
            actor = torch.nn.DataParallel(actor)
            critic1 = torch.nn.DataParallel(critic1)
            critic2 = torch.nn.DataParallel(critic2)
            critic1_targ = torch.nn.DataParallel(critic1_targ)
            critic2_targ = torch.nn.DataParallel(critic2_targ)
    else:
        use_gpu = False
        gpu_parallel = False
        device = torch.device("cpu")
    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in critic1_targ.parameters():
        p.requires_grad = False
    for p in critic2_targ.parameters():
        p.requires_grad = False
    actor.to(device)
    critic1.to(device)
    critic2.to(device)
    critic1_targ.to(device)
    critic2_targ.to(device)

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=1,
                                 size=replay_size,
                                 device=device)

    if automatic_entropy_tuning:
        # we set the max possible entropy as the target entropy
        target_entropy = -np.log((1.0 / act_dim)) * 0.98
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp()
        alpha_optim = Adam([log_alpha], lr=lr, eps=1e-4)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(
        core.count_vars(module) for module in [actor, critic1, critic2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' %
               var_counts)

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(actor.parameters(), lr=lr)
    q1_optimizer = Adam(critic1.parameters(), lr=lr)
    q2_optimizer = Adam(critic2.parameters(), lr=lr)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):

        # Bellman backup for Q functions
        with torch.no_grad():
            o, a, r, o2, d = data['obs'], data['act'], data['rew'], data[
                'obs2'], data['done']
            if r.ndim == 1:
                r = r.unsqueeze(-1)
            if d.ndim == 1:
                d = d.unsqueeze(-1)
            # Target actions come from *current* policy
            a2, (a2_p, logp_a2), _ = get_action(o2)

            # Target Q-values
            q1_pi_targ = critic1_targ(o2)
            q2_pi_targ = critic2_targ(o2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            min_qf_next_target = a2_p * (q_pi_targ - alpha * logp_a2)
            min_qf_next_target = min_qf_next_target.mean(dim=1).unsqueeze(-1)
            backup = r + gamma * (1 - d) * min_qf_next_target

        q1 = critic1(o).gather(1, a.long())
        q2 = critic2(o).gather(1, a.long())
        # MSE loss against Bellman backup
        loss_q1 = F.mse_loss(q1, backup)
        loss_q2 = F.mse_loss(q2, backup)

        # Useful info for logging
        q_info = dict(Q1Vals=q1.detach().cpu().numpy(),
                      Q2Vals=q2.detach().cpu().numpy())

        return loss_q1, loss_q2, q_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        state_batch = data['obs']
        action, (action_probabilities,
                 log_action_probabilities), _ = get_action(state_batch)
        qf1_pi = critic1(state_batch)
        qf2_pi = critic2(state_batch)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        inside_term = alpha * log_action_probabilities - min_qf_pi
        policy_loss = action_probabilities * inside_term
        policy_loss = policy_loss.mean()
        log_action_probabilities = torch.sum(log_action_probabilities *
                                             action_probabilities,
                                             dim=1)
        # Useful info for logging
        pi_info = dict(LogPi=log_action_probabilities.detach().cpu().numpy())

        return policy_loss, log_action_probabilities, pi_info

    def take_optimisation_step(optimizer,
                               network,
                               loss,
                               clipping_norm=None,
                               retain_graph=False):
        if not isinstance(network, list):
            network = [network]
        optimizer.zero_grad()  # reset gradients to 0
        loss.backward(
            retain_graph=retain_graph)  # this calculates the gradients
        if clipping_norm is not None:
            for net in network:
                torch.nn.utils.clip_grad_norm_(
                    net.parameters(),
                    clipping_norm)  # clip gradients to help stabilise training
        optimizer.step()  # this applies the gradients

    def soft_update_of_target_network(local_model, target_model, tau):
        """Updates the target network in the direction of the local network but by taking a step size
        less than one so the target network's parameter values trail the local networks. This helps stabilise training"""
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def update(data):
        # First run one gradient descent step for Q1 and Q2

        loss_q1, loss_q2, q_info = compute_loss_q(data)

        # Record things
        # logger.store(LossQ=(loss_q1.item()+loss_q2.item())/2., **q_info)

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        # for p in q_params:
        #     p.requires_grad = False

        # Next run one gradient descent step for pi.

        loss_pi, log_pi, pi_info = compute_loss_pi(data)

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        # for p in q_params:
        #     p.requires_grad = True

        # Record things
        # logger.store(LossPi=loss_pi.item(), **pi_info)

        if automatic_entropy_tuning:

            alpha_loss = -(log_alpha *
                           (log_pi + target_entropy).detach()).mean()
            # logger.store(alpha_loss=alpha_loss.item())

        take_optimisation_step(
            q1_optimizer,
            critic1,
            loss_q1,
            5,
        )
        take_optimisation_step(
            q2_optimizer,
            critic2,
            loss_q2,
            5,
        )
        take_optimisation_step(
            pi_optimizer,
            actor,
            loss_pi,
            5,
        )

        with torch.no_grad():
            for p, p_targ in zip(critic1.parameters(),
                                 critic1_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)
            for p, p_targ in zip(critic2.parameters(),
                                 critic2_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

        if automatic_entropy_tuning:
            take_optimisation_step(alpha_optim, None, alpha_loss, None)
            alpha = log_alpha.exp()

    def get_action(state):
        """Given the state, produces an action, the probability of the action, the log probability of the action, and
        the argmax action"""
        action_probabilities = actor(state)
        max_probability_action = torch.argmax(action_probabilities).unsqueeze(
            0)
        action_distribution = Categorical(action_probabilities)
        action = action_distribution.sample().cpu()
        # Have to deal with situation of 0.0 probabilities because we can't do log 0
        z = action_probabilities == 0.0
        z = z.float() * 1e-8
        log_action_probabilities = torch.log(action_probabilities + z)
        return action, (action_probabilities,
                        log_action_probabilities), max_probability_action

    def test_agent():
        for j in range(num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                test_env.render()
                # Take deterministic actions at test time
                o = torch.FloatTensor([o]).to(device)
                _, z, a = get_action(o)
                o, r, d, _ = test_env.step(a.item())
                ep_ret += r
                ep_len += 1
            logger.store(EpRet=ep_ret, EpLen=ep_len)

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()

    # Main loop: collect experience in env and update/log each epoch

    o, ep_ret, ep_len = env.reset(), 0, 0
    eps = 0
    t = 0
    while eps < total_eps:
        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        # print(t)
        if t >= start_steps:
            with torch.no_grad():
                if o.ndim == 1:
                    a, _, _ = get_action(torch.FloatTensor([o]).to(device))
                else:
                    a, _, _ = get_action(o)
                a = a.cpu().item()
        else:
            a = np.random.randint(0, act_dim)

        # Step the env
        o2, r, d, _ = env.step(a)
        ep_ret += r
        ep_len += 10

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len else d

        # Store experience to replay buffer

        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # Update handling
        if t >= update_after and t % update_times_every_step == 0 and t > batch_size:
            for j in range(update_times_every_step):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch)

        # End of trajectory handling
        if d or (ep_len == max_ep_len):  # ep_len == max_ep_len是游戏成功时最少ep长度
            # logger.store(EpRet=ep_ret, EpLen=ep_len)
            o, ep_ret, ep_len = env.reset(), 0, 0
            if eps % 10 == 0 and eps != 0:
                actor.eval()
                test_agent()
                actor.train()
                logger.store(step=t)
                logger.log_tabular('Epoch', eps)
                logger.log_tabular('EpRet', with_min_and_max=True)
                logger.log_tabular('Time', time.time() - start_time)
                logger.dump_tabular()
            eps += 1

        t += 1
コード例 #5
0
def sac(env_fn,
        env_name,
        test_env_fns=[],
        actor_critic=core.MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=100,
        replay_size=int(1e6),
        gamma=0.99,
        polyak=0.995,
        lr=1e-3,
        alpha=0.2,
        batch_size=100,
        start_steps=10000,
        update_after=1000,
        update_every=50,
        num_test_episodes=10,
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=1,
        load_dir=None,
        num_procs=1,
        clean_every=200):
    """
    Soft Actor-Critic (SAC)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with an ``act``
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of
            observations as inputs, and ``q1`` and ``q2`` should accept a batch
            of observations and a batch of actions as inputs. When called,
            ``act``, ``q1``, and ``q2`` should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each
                                           | observation.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

            Calling ``pi`` should return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                           | actions in ``a``. Importantly: gradients
                                           | should be able to flow back into ``a``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
            you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long
            you wait between updates, the ratio of env steps to gradient steps
            is locked to 1.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """
    from spinup.examples.pytorch.eval_sac import load_pytorch_policy

    print(f"SAC proc_id {proc_id()}")
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    if proc_id() == 0:
        writer = SummaryWriter(log_dir=os.path.join(
            logger.output_dir, str(datetime.datetime.now())),
                               comment=logger_kwargs["exp_name"])

    torch.manual_seed(seed)
    np.random.seed(seed)

    env = SubprocVecEnv([partial(env_fn, rank=i) for i in range(num_procs)],
                        "spawn")
    test_env = SubprocVecEnv(test_env_fns, "spawn")
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Create actor-critic module and target networks
    if load_dir is not None:
        _, ac = load_pytorch_policy(load_dir, itr="", deterministic=False)
    else:
        ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
    ac_targ = deepcopy(ac)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim,
                                 act_dim=act_dim,
                                 size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(
        core.count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' %
               var_counts)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data[
            'obs2'], data['done']

        q1 = ac.q1(o, a)
        q2 = ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = ac.pi(o2)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.detach().numpy(), Q2Vals=q2.detach().numpy())

        return loss_q, q_info

    # Set up function for computing TD feats-losses
    def compute_loss_feats(data):
        o, a, r, o2, d, feats = data['obs'], data['act'], data['rew'], data[
            'obs2'], data['done'], data["feats"]

        feats = torch.stack(list(feats.values())).T  # (nbatch, nfeats)
        feats1 = ac.q1.predict_feats(o, a)
        feats2 = ac.q2.predict_feats(o, a)

        feats_keys = replay_buffer.feats_keys

        # Bellman backup for feature functions
        with torch.no_grad():
            a2, _ = ac.pi(o2)

            # Target feature values
            feats1_targ = ac_targ.q1.predict_feats(o2, a2)
            feats2_targ = ac_targ.q2.predict_feats(o2, a2)
            feats_targ = torch.min(feats1_targ, feats2_targ)
            backup = feats + gamma * (1 - d[:, None]) * feats_targ

        # MSE loss against Bellman backup
        loss_feats1 = ((feats1 - backup)**2).mean(axis=0)
        loss_feats2 = ((feats2 - backup)**2).mean(axis=0)
        loss_feats = loss_feats1 + loss_feats2

        # Useful info for logging
        feats_info = dict(Feats1Vals=feats1.detach().numpy(),
                          Feats2Vals=feats2.detach().numpy())

        return loss_feats, feats_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (alpha * logp_pi - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=logp_pi.detach().numpy())

        return loss_pi, pi_info

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q_optimizer = Adam(q_params, lr=lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    def update(data, feats_keys):
        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        loss_feats, feats_info = compute_loss_feats(data)
        q_optimizer.step()

        # Record things
        logger.store(LossQ=loss_q.item(), **q_info)

        # Feature loss
        keys = [f"LossFeats_{key}" for key in feats_keys]
        for key, val in zip(keys, loss_feats):
            logger.store(**dict(key, val.item()))

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Record things
        logger.store(LossPi=loss_pi.item(), **pi_info)

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32), deterministic)

    def test_agent(feats_keys):
        num_envs = len(test_env_fns)
        env_ep_rets = np.zeros(num_envs)
        for j in range(num_test_episodes):
            o, d = test_env.reset(), np.zeros(num_envs, dtype=bool)
            ep_len = np.zeros(num_envs)
            while not (np.all(d) or np.all(ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, info = test_env.step(get_action(o, True))
                env_ep_rets += r
                ep_len += 1
            # logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
        for ti in range(num_envs):
            logger.store(
                **{f"TestEpRet_{ti}": env_ep_rets[ti] / num_test_episodes})

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), np.zeros(num_procs), np.zeros(num_procs)

    # Main loop: collect experience in env and update/log each epoch
    epoch = 0
    update_times, clean_times = 0, 0
    t = 0
    while t <= total_steps:
        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if t > start_steps:
            a = get_action(o)
        else:
            a = np.stack([env.action_space.sample() for _ in range(num_procs)])

        # Step the env
        o2, r, d, info = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        if np.all(ep_len == max_ep_len):
            d.fill(False)

        # Store experience to replay buffer
        replay_buffer.store_vec(o, a, r, o2, d,
                                [inf["features"] for inf in info])

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling, assumes all subenvs end at the same time
        if np.all(d) or np.all(ep_len == max_ep_len):
            logger.store(EpRet=ep_ret, EpLen=ep_len)

            if clean_every > 0 and epoch // clean_every >= clean_times:
                env.close()
                test_env.close()
                env = SubprocVecEnv(
                    [partial(env_fn, rank=i) for i in range(num_procs)],
                    "spawn")
                test_env = SubprocVecEnv(test_env_fns, "spawn")
                clean_times += 1

            o, ep_ret, ep_len = env.reset(), np.zeros(num_procs), np.zeros(
                num_procs)

        # Update handling
        if t >= update_after and t / update_every > update_times:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch, feats_keys=replay_buffer.feats_keys)
            update_times += 1

        # End of epoch handling
        if t // steps_per_epoch > epoch:
            epoch = t // steps_per_epoch

            # Save model
            if (epoch % save_freq == 0) or (epoch == epochs):
                # try:
                logger.save_state({'env_name': env_name}, None)
                # logger.save_state({'env': env}, None)
                #except:
                #logger.save_state({'env_name': env_name}, None)

            # Test the performance of the deterministic version of the agent.
            test_agent(replay_buffer.feats_keys)

            # Update tensorboard
            if proc_id() == 0:
                log_perf_board = ['EpRet', 'EpLen', 'Q1Vals', 'Q2Vals'] + [
                    f"TestEpRet_{ti}" for ti in range(len(test_env_fns))
                ]
                log_loss_board = ['LogPi', 'LossPi', 'LossQ'] + [
                    key
                    for key in logger.epoch_dict.keys() if "LossFeats" in key
                ]
                log_board = {
                    'Performance': log_perf_board,
                    'Loss': log_loss_board
                }
                for key, value in log_board.items():
                    for val in value:
                        mean, std = logger.get_stats(val)
                        if key == 'Performance':
                            writer.add_scalar(key + '/Average' + val, mean,
                                              epoch)
                            writer.add_scalar(key + '/Std' + val, std, epoch)
                        else:
                            writer.add_scalar(key + '/' + val, mean, epoch)

            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LogPi', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time() - start_time)
            logger.dump_tabular()

            if proc_id() == 0:
                writer.flush()

                import psutil
                # gives a single float value
                cpu_percent = psutil.cpu_percent()
                # gives an object with many fields
                mem_percent = psutil.virtual_memory().percent
                print(f"Used cpu avg {cpu_percent}% memory {mem_percent}%")
                cpu_separate = psutil.cpu_percent(percpu=True)
                for ci, cval in enumerate(cpu_separate):
                    print(f"\t cpu {ci}: {cval}%")
                # buf_size = replay_buffer.get_size()
                # print(f"Replay buffer size: {buf_size//1e6}MB {buf_size // 1e3} KB {buf_size % 1e3} B")
        t += num_procs

    if proc_id() == 0:
        writer.close()