def __init__(self, config, ob_space, ac_space, actor, critic):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        self._target_entropy = -ac_space.size
        self._log_alpha = torch.zeros(1, requires_grad=True, device=config.device)
        self._alpha_optim = optim.Adam([self._log_alpha], lr=config.lr_actor)

        # build up networks
        self._build_actor(actor)
        self._critic1 = critic(config, ob_space, ac_space)
        self._critic2 = critic(config, ob_space, ac_space)

        # build up target networks
        self._critic1_target = critic(config, ob_space, ac_space)
        self._critic2_target = critic(config, ob_space, ac_space)
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())
        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(), lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(), lr=config.lr_critic)

        sampler = RandomSampler()
        buffer_keys = ['ob', 'ac', 'done', 'rew']
        self._buffer = ReplayBuffer(buffer_keys,
                                    config.buffer_size,
                                    sampler.sample_func)

        self._log_creation()
Exemple #2
0
    def __init__(self, config, ob_space, ac_space,
                 actor, critic):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        self._log_alpha = [torch.zeros(1, requires_grad=True, device=config.device)]
        self._alpha_optim = [optim.Adam([self._log_alpha[0]], lr=config.lr_actor)]

        self._actor = actor(self._config, self._ob_space,
                              self._ac_space, self._config.tanh_policy, deterministic=True)
        self._actor_target = actor(self._config, self._ob_space,
                              self._ac_space, self._config.tanh_policy, deterministic=True)
        self._actor_target.load_state_dict(self._actor.state_dict())
        self._critic = critic(config, ob_space, ac_space)
        self._critic_target = critic(config, ob_space, ac_space)
        self._critic_target.load_state_dict(self._critic.state_dict())

        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.lr_actor)
        self._critic_optim = optim.Adam(self._critic.parameters(), lr=config.lr_critic)

        self._buffer = ReplayBuffer(config,
                                    sampler.sample_func,
                                    ob_space,
                                    ac_space)

        self._ounoise = OUNoise(action_size(ac_space))

        self._log_creation()
    def __init__(self, config, ob_space, ac_space, actor, critic):
        super().__init__(config, ob_space)

        self._ac_space = ac_space  # e.g. ActionSpace(shape=OrderedDict([('default', 8)]),minimum=-1.0, maximum=1.0)

        # build up networks
        self._actor = actor(config, ob_space, ac_space,
                            config.tanh_policy)  # actor model (MLP)
        self._old_actor = actor(config, ob_space, ac_space, config.tanh_policy)
        self._critic = critic(config, ob_space)  # critic model (MLP)
        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic_optim = optim.Adam(self._critic.parameters(),
                                        lr=config.lr_critic)

        sampler = RandomSampler()
        buffer_keys = [
            'ob', 'ac', 'done', 'rew', 'ret', 'adv', 'ac_before_activation'
        ]
        self._buffer = ReplayBuffer(buffer_keys, config.buffer_size,
                                    sampler.sample_func)

        if config.is_chef:
            logger.info('Creating a PPO agent')
            logger.info('The actor has %d parameters',
                        count_parameters(self._actor))
            logger.info('The critic has %d parameters',
                        count_parameters(self._critic))
Exemple #4
0
    def __init__(self, config, ob_space, ac_space, dqn):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        # build up networks
        self._dqn = dqn(config, ob_space, ac_space)
        self._network_cuda(config.device)

        self._dqn_optim = optim.Adam(self._dqn.parameters(),
                                     lr=config.lr_actor)
        sampler = RandomSampler()
        self._buffer = ReplayBuffer(config.buffer_size, sampler.sample_func,
                                    ob_space, ac_space)
Exemple #5
0
    def __init__(self, config, ob_space, ac_space, actor, critic):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        self._log_alpha = torch.tensor(np.log(config.alpha),
                                       requires_grad=True,
                                       device=config.device)
        self._alpha_optim = optim.Adam([self._log_alpha], lr=config.lr_actor)

        # build up networks
        self._actor = actor(config, ob_space, ac_space, config.tanh_policy)
        self._critic1 = critic(config, ob_space, ac_space)
        self._critic2 = critic(config, ob_space, ac_space)

        self._target_entropy = -action_size(self._actor._ac_space)

        # build up target networks
        self._critic1_target = critic(config, ob_space, ac_space)
        self._critic2_target = critic(config, ob_space, ac_space)
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())

        if config.policy == 'cnn':
            self._critic2.base.copy_conv_weights_from(self._critic1.base)
            self._actor.base.copy_conv_weights_from(self._critic1.base)

            if config.unsup_algo == 'curl':
                self._curl = CURL(config, ob_space, ac_space, self._critic1,
                                  self._critic1_target)
                self._encoder_optim = optim.Adam(
                    self._critic1.base.parameters(), lr=config.lr_encoder)
                self._cpc_optim = optim.Adam(self._curl.parameters(),
                                             lr=config.lr_encoder)

        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(),
                                         lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(),
                                         lr=config.lr_critic)

        self._buffer = ReplayBuffer(config, ob_space, ac_space)
Exemple #6
0
    def __init__(self, config, ob_space, ac_space,
                 actor, critic):
        super().__init__(config, ob_space)

        self._ac_space = ac_space

        # build up networks
        self._actor = actor(config, ob_space, ac_space, config.tanh_policy)
        self._old_actor = actor(config, ob_space, ac_space, config.tanh_policy)
        self._critic = critic(config, ob_space)
        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.lr_actor)
        self._critic_optim = optim.Adam(self._critic.parameters(), lr=config.lr_critic)

        sampler = RandomSampler()
        self._buffer = ReplayBuffer(['ob', 'ac', 'done', 'rew', 'ret', 'adv', 'ac_before_activation'],
                                    config.buffer_size,
                                    sampler.sample_func)

        if config.is_chef:
            logger.info('Creating a PPO agent')
            logger.info('The actor has %d parameters', count_parameters(self._actor))
            logger.info('The critic has %d parameters', count_parameters(self._critic))
Exemple #7
0
class SACAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, actor, critic):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        self._log_alpha = torch.tensor(np.log(config.alpha),
                                       requires_grad=True,
                                       device=config.device)
        self._alpha_optim = optim.Adam([self._log_alpha], lr=config.lr_actor)

        # build up networks
        self._actor = actor(config, ob_space, ac_space, config.tanh_policy)
        self._critic1 = critic(config, ob_space, ac_space)
        self._critic2 = critic(config, ob_space, ac_space)

        self._target_entropy = -action_size(self._actor._ac_space)

        # build up target networks
        self._critic1_target = critic(config, ob_space, ac_space)
        self._critic2_target = critic(config, ob_space, ac_space)
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())

        if config.policy == 'cnn':
            self._critic2.base.copy_conv_weights_from(self._critic1.base)
            self._actor.base.copy_conv_weights_from(self._critic1.base)

            if config.unsup_algo == 'curl':
                self._curl = CURL(config, ob_space, ac_space, self._critic1,
                                  self._critic1_target)
                self._encoder_optim = optim.Adam(
                    self._critic1.base.parameters(), lr=config.lr_encoder)
                self._cpc_optim = optim.Adam(self._curl.parameters(),
                                             lr=config.lr_encoder)

        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(),
                                         lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(),
                                         lr=config.lr_critic)

        self._buffer = ReplayBuffer(config, ob_space, ac_space)

    def _log_creation(self):
        logger.info("Creating a SAC agent")
        logger.info("The actor has %d parameters".format(
            count_parameters(self._actor)))
        logger.info('The critic1 has %d parameters',
                    count_parameters(self._critic1))
        logger.info('The critic2 has %d parameters',
                    count_parameters(self._critic2))

    def store_sample(self, rollouts):
        self._buffer.store_sample(rollouts)

    def _network_cuda(self, device):
        self._actor.to(device)
        self._critic1.to(device)
        self._critic2.to(device)
        self._critic1_target.to(device)
        self._critic2_target.to(device)
        if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl':
            self._curl.to(device)

    def state_dict(self):
        ret = {
            'log_alpha': self._log_alpha.cpu().detach().numpy(),
            'actor_state_dict': self._actor.state_dict(),
            'critic1_state_dict': self._critic1.state_dict(),
            'critic2_state_dict': self._critic2.state_dict(),
            'alpha_optim_state_dict': self._alpha_optim.state_dict(),
            'actor_optim_state_dict': self._actor_optim.state_dict(),
            'critic1_optim_state_dict': self._critic1_optim.state_dict(),
            'critic2_optim_state_dict': self._critic2_optim.state_dict(),
        }
        if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl':
            ret['curl_state_dict'] = self._curl.state_dict()
            ret['encoder_optim_state_dict'] = self._encoder_optim.state_dict()
            ret['cpc_optim_state_dict'] = self._cpc_optim.state_dict()

    def load_state_dict(self, ckpt):
        self._log_alpha.data = torch.tensor(ckpt['log_alpha'],
                                            requires_grad=True,
                                            device=self._config.device)
        self._actor.load_state_dict(ckpt['actor_state_dict'])
        self._critic1.load_state_dict(ckpt['critic1_state_dict'])
        self._critic2.load_state_dict(ckpt['critic2_state_dict'])
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())

        self._alpha_optim.load_state_dict(ckpt['alpha_optim_state_dict'])
        self._actor_optim.load_state_dict(ckpt['actor_optim_state_dict'])
        self._critic1_optim.load_state_dict(ckpt['critic1_optim_state_dict'])
        self._critic2_optim.load_state_dict(ckpt['critic2_optim_state_dict'])
        optimizer_cuda(self._alpha_optim, self._config.device)
        optimizer_cuda(self._actor_optim, self._config.device)
        optimizer_cuda(self._critic1_optim, self._config.device)
        optimizer_cuda(self._critic2_optim, self._config.device)

        if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl':
            self._curl.load_state_dict(ckpt['curl_state_dict'])
            self._encoder_optim.load_state_dict(
                ckpt['encoder_optim_state_dict'])
            self._cpc_optim.load_state_dict(ckpt['cpc_optim_state_dict'])
            optimizer_cuda(self._encoder_optim, self._config.device)
            optimizer_cuda(self._cpc_optim, self._config.device)

        self._network_cuda(self._config.device)

    def train(self):
        for _ in range(self._config.num_batches):
            if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl':
                transitions = self._buffer.sample_cpc(self._config.batch_size)
            else:
                transitions = self._buffer.sample(self._config.batch_size)
            train_info = self._update_network(transitions)
            self._soft_update_target_network(self._critic1_target,
                                             self._critic1,
                                             self._config.polyak)
            self._soft_update_target_network(self._critic2_target,
                                             self._critic2,
                                             self._config.polyak)

        return train_info

    def act_log(self, o):
        return self._actor.act_log(o)

    def _update_critic(self, o, ac, rew, o_next, done):
        info = {}
        alpha = self._log_alpha.exp()
        with torch.no_grad():
            actions_next, log_pi_next = self.act_log(o_next)
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1,
                                     q_next_value2) - alpha * log_pi_next
            target_q_value = rew * self._config.reward_scale + \
                (1-done)  * self._config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()

        # q loss
        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)
        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info['min_target_q'] = target_q_value.min().cpu().item()
        info['target_q'] = target_q_value.mean().cpu().item()
        info['min_real1_q'] = real_q_value1.min().cpu().item()
        info['min_real2_q'] = real_q_value2.min().cpu().item()
        info['real1_q'] = real_q_value1.mean().cpu().item()
        info['real2_q'] = real_q_value2.mean().cpu().item()
        info['critic1_loss'] = critic1_loss.cpu().item()
        info['critic2_loss'] = critic2_loss.cpu().item()
        return info

    def _update_actor_and_alpha(self, o):
        info = {}
        actions_real, log_pi = self.act_log(o)
        alpha_loss = -(self._log_alpha *
                       (log_pi + self._target_entropy).detach()).mean()
        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()
        alpha = self._log_alpha.exp()

        # actor loss
        entropy_loss = (alpha * log_pi).mean()
        actor_loss = -torch.min(self._critic1(o, actions_real),
                                self._critic2(o, actions_real)).mean()
        info['entropy_alpha'] = alpha.cpu().item()
        info['entropy_loss'] = entropy_loss.cpu().item()
        info['actor_loss'] = actor_loss.cpu().item()
        actor_loss += entropy_loss

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        self._actor_optim.step()
        return info

    def _update_cpc(self, o_anchor, o_pos, cpc_kwargs):
        info = {}
        z_a = self._curl.encode(o_anchor)
        z_pos = self._curl.encode(o_pos, ema=True)
        logits = self._curl.compute_logits(z_a, z_pos)
        labels = torch.arange(logits.shape[0]).long().to(self._config.device)
        cpc_loss = F.cross_entropy(logits, labels)
        info['cpc_loss'] = cpc_loss.cpu().item()

        self._encoder_optim.zero_grad()
        self._cpc_optim.zero_grad()
        cpc_loss.backward()
        self._encoder_optim.step()
        self._cpc_optim.step()
        return info

    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o, o_next = transitions['ob'], transitions['ob_next']

        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])

        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)
        # update alpha
        critic_info = self._update_critic(o, ac, rew, o_next, done)
        info.update(critic_info)
        actor_alpha_info = self._update_actor_and_alpha(o)
        info.update(actor_alpha_info)

        if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl':

            cpc_kwargs = transitions['cpc_kwargs']
            o_anchor = _to_tensor(cpc_kwargs['ob_anchor'])
            o_pos = _to_tensor(cpc_kwargs['ob_pos'])
            cpc_info = self._update_cpc(o_anchor, o_pos, cpc_kwargs)
            info.update(cpc_info)

        return info
Exemple #8
0
class DDPGAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space,
                 actor, critic):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        self._log_alpha = [torch.zeros(1, requires_grad=True, device=config.device)]
        self._alpha_optim = [optim.Adam([self._log_alpha[0]], lr=config.lr_actor)]

        self._actor = actor(self._config, self._ob_space,
                              self._ac_space, self._config.tanh_policy, deterministic=True)
        self._actor_target = actor(self._config, self._ob_space,
                              self._ac_space, self._config.tanh_policy, deterministic=True)
        self._actor_target.load_state_dict(self._actor.state_dict())
        self._critic = critic(config, ob_space, ac_space)
        self._critic_target = critic(config, ob_space, ac_space)
        self._critic_target.load_state_dict(self._critic.state_dict())

        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(), lr=config.lr_actor)
        self._critic_optim = optim.Adam(self._critic.parameters(), lr=config.lr_critic)

        self._buffer = ReplayBuffer(config,
                                    sampler.sample_func,
                                    ob_space,
                                    ac_space)

        self._ounoise = OUNoise(action_size(ac_space))

        self._log_creation()

    def _log_creation(self):
        logger.info('creating a DDPG agent')
        logger.info('the actor has %d parameters', count_parameters(self._actor))
        logger.info('the critic has %d parameters', count_parameters(self._critic))

    def store_episode(self, rollouts):
        self._buffer.store_episode(rollouts)

    def state_dict(self):
        return {
            'actor_state_dict': self._actor.state_dict(),
            'critic_state_dict': self._critic.state_dict(),
            'actor_optim_state_dict': self._actor_optim.state_dict(),
            'critic_optim_state_dict': self._critic_optim.state_dict(),
        }

    def load_state_dict(self, ckpt):
        self._actor.load_state_dict(ckpt['actor_state_dict'])
        self._actor_target.load_state_dict(self._actor.state_dict())
        self._critic.load_state_dict(ckpt['critic_state_dict'])
        self._critic_target.load_state_dict(self._critic.state_dict())
        self._network_cuda(self._config.device)

        self._actor_optim.load_state_dict(ckpt['actor_optim_state_dict'])
        self._critic_optim.load_state_dict(ckpt['critic_optim_state_dict'])
        optimizer_cuda(self._actor_optim, self._config.device)
        optimizer_cuda(self._critic_optim, self._config.device)

    def _network_cuda(self, device):
        self._actor.to(device)
        self._actor_target.to(device)
        self._critic.to(device)
        self._critic_target.to(device)

    def train(self):
        config = self._config
        for i in range(config.num_batches):
            transitions = self._buffer.sample(config.batch_size)
            train_info = self._update_network(transitions, step=i)
            self._soft_update_target_network(self._actor_target, self._actor, self._config.polyak)
            self._soft_update_target_network(self._critic_target, self._critic, self._config.polyak)
        return train_info

    def act_log(self, ob):
        return self._actor.act_log(ob)

    def act(self, ob, is_train=True):
        ob = to_tensor(ob, self._config.device)
        ac, activation = self._actor.act(ob, is_train=is_train)
        if is_train:
            for k, space in self._ac_space.spaces.items():
                if isinstance(space, spaces.Box):
                    ac[k] += self._config.noise_scale*np.random.randn(len(ac[k]))
                    ac[k] = np.clip(ac[k], self._ac_space[k].low, self._ac_space[k].high)
        return ac, activation

    def target_act(self, ob, is_train=True):
        ac, activation = self._actor_target.act(ob, is_train=is_train)
        return ac, activation

    def target_act_log(self, ob):
        return self._actor_target.act_log(ob)

    def _update_network(self, transitions, step=0):
        config = self._config
        info = {}

        o, o_next = transitions['ob'], transitions['ob_next']
        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])

        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)

        ## Actor loss
        actions_real, _ = self.act_log(o)
        actor_loss = -self._critic(o, actions_real).mean()
        info['actor_loss'] = actor_loss.cpu().item()

        ## Critic loss
        with torch.no_grad():
            actions_next, _ = self.target_act_log(o_next)
            q_next_value = self._critic_target(o_next, actions_next)
            target_q_value = rew + (1.-done) * config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()

        real_q_value = self._critic(o, ac)

        critic_loss = 0.5 * (target_q_value - real_q_value).pow(2).mean()

        info['min_target_q'] = target_q_value.min().cpu().item()
        info['target_q'] = target_q_value.mean().cpu().item()
        info['min_real1_q'] = real_q_value.min().cpu().item()
        info['real_q'] = real_q_value.mean().cpu().item()
        info['critic_loss'] = critic_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        self._actor_optim.step()

        # update the critics
        self._critic_optim.zero_grad()
        critic_loss.backward()
        self._critic_optim.step()

        return info
Exemple #9
0
    def __init__(
        self,
        config,
        ob_space,
        ac_space,
        actor,
        critic,
        non_limited_idx=None,
        ref_joint_pos_indexes=None,
        joint_space=None,
        is_jnt_limited=None,
        jnt_indices=None,
    ):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        self._jnt_indices = jnt_indices
        self._ref_joint_pos_indexes = ref_joint_pos_indexes
        self._joint_space = joint_space
        self._is_jnt_limited = is_jnt_limited
        if joint_space is not None:
            self._jnt_minimum = joint_space["default"].low
            self._jnt_maximum = joint_space["default"].high

        self._log_alpha = [
            torch.zeros(1, requires_grad=True, device=config.device)
        ]
        self._alpha_optim = [
            optim.Adam([self._log_alpha[0]], lr=config.lr_actor)
        ]

        self._actor = actor(
            self._config,
            self._ob_space,
            self._ac_space,
            self._config.tanh_policy,
            deterministic=True,
        )
        self._actor_target = actor(
            self._config,
            self._ob_space,
            self._ac_space,
            self._config.tanh_policy,
            deterministic=True,
        )
        self._actor_target.load_state_dict(self._actor.state_dict())
        self._critic1 = critic(config, ob_space, ac_space)
        self._critic2 = critic(config, ob_space, ac_space)
        self._critic1_target = critic(config, ob_space, ac_space)
        self._critic2_target = critic(config, ob_space, ac_space)
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())

        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(),
                                         lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(),
                                         lr=config.lr_critic)

        self._update_steps = 0

        sampler = RandomSampler()
        buffer_keys = ["ob", "ac", "done", "rew"]
        if config.mopa or config.expand_ac_space:
            buffer_keys.append("intra_steps")
        self._buffer = ReplayBuffer(buffer_keys, config.buffer_size,
                                    sampler.sample_func)

        self._log_creation()

        self._planner = None
        self._is_planner_initialized = False
        if config.mopa:
            self._planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                is_simplified=config.is_simplified,
                simplified_duration=config.simplified_duration,
                allow_approximate=config.allow_approximate,
                range_=config.range,
            )
            self._simple_planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.simple_planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                goal_bias=1.0,
                allow_approximate=False,
                is_simplified=config.simple_planner_simplified,
                simplified_duration=config.simple_planner_simplified_duration,
                range_=config.simple_planner_range,
            )
            self._omega = config.omega
Exemple #10
0
class TD3Agent(BaseAgent):
    def __init__(
        self,
        config,
        ob_space,
        ac_space,
        actor,
        critic,
        non_limited_idx=None,
        ref_joint_pos_indexes=None,
        joint_space=None,
        is_jnt_limited=None,
        jnt_indices=None,
    ):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        self._jnt_indices = jnt_indices
        self._ref_joint_pos_indexes = ref_joint_pos_indexes
        self._joint_space = joint_space
        self._is_jnt_limited = is_jnt_limited
        if joint_space is not None:
            self._jnt_minimum = joint_space["default"].low
            self._jnt_maximum = joint_space["default"].high

        self._log_alpha = [
            torch.zeros(1, requires_grad=True, device=config.device)
        ]
        self._alpha_optim = [
            optim.Adam([self._log_alpha[0]], lr=config.lr_actor)
        ]

        self._actor = actor(
            self._config,
            self._ob_space,
            self._ac_space,
            self._config.tanh_policy,
            deterministic=True,
        )
        self._actor_target = actor(
            self._config,
            self._ob_space,
            self._ac_space,
            self._config.tanh_policy,
            deterministic=True,
        )
        self._actor_target.load_state_dict(self._actor.state_dict())
        self._critic1 = critic(config, ob_space, ac_space)
        self._critic2 = critic(config, ob_space, ac_space)
        self._critic1_target = critic(config, ob_space, ac_space)
        self._critic2_target = critic(config, ob_space, ac_space)
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())

        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(),
                                         lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(),
                                         lr=config.lr_critic)

        self._update_steps = 0

        sampler = RandomSampler()
        buffer_keys = ["ob", "ac", "done", "rew"]
        if config.mopa or config.expand_ac_space:
            buffer_keys.append("intra_steps")
        self._buffer = ReplayBuffer(buffer_keys, config.buffer_size,
                                    sampler.sample_func)

        self._log_creation()

        self._planner = None
        self._is_planner_initialized = False
        if config.mopa:
            self._planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                is_simplified=config.is_simplified,
                simplified_duration=config.simplified_duration,
                allow_approximate=config.allow_approximate,
                range_=config.range,
            )
            self._simple_planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.simple_planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                goal_bias=1.0,
                allow_approximate=False,
                is_simplified=config.simple_planner_simplified,
                simplified_duration=config.simple_planner_simplified_duration,
                range_=config.simple_planner_range,
            )
            self._omega = config.omega

    def _log_creation(self):
        logger.info("creating a TD3 agent")
        logger.info("the actor has %d parameters",
                    count_parameters(self._actor))
        logger.info("the critic1 has %d parameters",
                    count_parameters(self._critic1))
        logger.info("the critic2 has %d parameters",
                    count_parameters(self._critic2))

    def store_episode(self, rollouts):
        self._buffer.store_episode(rollouts)

    def valid_action(self, ac):
        return np.all(ac["default"] >= -1.0) and np.all(ac["default"] <= 1.0)

    def is_planner_ac(self, ac):
        if np.any(
                ac["default"][:len(self._ref_joint_pos_indexes)] < -self._omega
        ) or np.any(ac["default"][:len(self._ref_joint_pos_indexes)] >
                    self._omega):
            return True
        return False

    def isValidState(self, state):
        return self._planner.isValidState(state)

    def convert2planner_displacement(self, ac, ac_scale):
        ac_space_type = self._config.ac_space_type
        action_range = self._config.action_range
        if ac_space_type == "normal":
            return ac * action_range
        elif ac_space_type == "piecewise":
            return np.where(
                np.abs(ac) < self._omega,
                ac / (self._omega / ac_scale),
                np.sign(ac) * (ac_scale + (action_range - ac_scale) *
                               ((np.abs(ac) - self._omega) /
                                (1 - self._omega))),
            )
        else:
            raise NotImplementedError

    def invert_displacement(self, displacement, ac_scale):
        ac_space_type = self._config.ac_space_type
        action_range = self._config.action_range
        if ac_space_type == "normal":
            return displacement / action_range
        elif ac_space_type == "piecewise":
            return np.where(
                np.abs(displacement) < ac_scale,
                displacement * (self._omega / ac_scale),
                np.sign(displacement) *
                ((np.abs(displacement) - ac_scale) /
                 ((action_range - ac_scale) / (1.0 - ac_scale)) /
                 ((1.0 - ac_scale) / (1.0 - self._omega)) + self._omega),
            )
        else:
            raise NotImplementedError

    # Calls motion planner to plan a path
    def plan(
        self,
        curr_qpos,
        target_qpos,
        ac_scale=None,
        meta_ac=None,
        ob=None,
        is_train=True,
        random_exploration=False,
        ref_joint_pos_indexes=None,
    ):

        curr_qpos = self.clip_qpos(curr_qpos)
        interpolation = True
        if self._config.interpolate_type == "planner":
            traj, success, valid, exact = self._simple_planner.plan(
                curr_qpos, target_qpos, self._config.simple_planner_timelimit)
        else:
            traj, success, valid, exact = self.simple_interpolate(
                curr_qpos, target_qpos, ac_scale)
        if not success:
            if not exact or self._config.allow_approximate:
                traj, success, valid, exact = self._planner.plan(
                    curr_qpos, target_qpos, self._config.timelimit)
                interpolation = False
                if self._config.use_interpolation and success:
                    new_traj = []
                    start = curr_qpos
                    for i in range(len(traj)):
                        diff = traj[i] - start
                        if np.any(diff[:len(self._ref_joint_pos_indexes)] <
                                  -ac_scale) or np.any(
                                      diff[:len(self._ref_joint_pos_indexes)] >
                                      ac_scale):
                            if self._config.interpolate_type == "planner":
                                (
                                    inner_traj,
                                    inner_success,
                                    inner_valid,
                                    inner_exact,
                                ) = self._simple_planner.plan(
                                    start,
                                    traj[i],
                                    self._config.simple_planner_timelimit,
                                )
                                if inner_success:
                                    new_traj.extend(inner_traj)
                            else:
                                inner_traj, _, _, _ = self.simple_interpolate(
                                    start, traj[i], ac_scale, use_planner=True)
                                new_traj.extend(inner_traj)
                        else:
                            new_traj.append(traj[i])
                        start = traj[i]
                    traj = np.array(new_traj)

        return traj, success, interpolation, valid, exact

    def interpolate(self, curr_qpos, target_qpos):
        traj, success, valid, exact = self._simple_planner.plan(
            curr_qpos, target_qpos, self._config.simple_planner_timelimit)
        return traj, success, interpolation, valid, exact

    def clip_qpos(self, curr_qpos):
        tmp_pos = curr_qpos.copy()
        if np.any(curr_qpos[self._is_jnt_limited[self._jnt_indices]] <
                  self._jnt_minimum[self._jnt_indices][self._is_jnt_limited[
                      self._jnt_indices]]) or np.any(
                          curr_qpos[self._is_jnt_limited[self._jnt_indices]] >
                          self._jnt_maximum[self._jnt_indices][
                              self._is_jnt_limited[self._jnt_indices]]):
            new_curr_qpos = np.clip(
                curr_qpos.copy(),
                self._jnt_minimum[self._jnt_indices] +
                self._config.joint_margin,
                self._jnt_maximum[self._jnt_indices] -
                self._config.joint_margin,
            )
            new_curr_qpos[np.invert(
                self._is_jnt_limited[self._jnt_indices])] = tmp_pos[np.invert(
                    self._is_jnt_limited[self._jnt_indices])]
            curr_qpos = new_curr_qpos
        return curr_qpos

    # interpolation function
    def simple_interpolate(self,
                           curr_qpos,
                           target_qpos,
                           ac_scale,
                           use_planner=False):
        success = True
        exact = True
        curr_qpos = self.clip_qpos(curr_qpos)

        traj = []
        min_action = self._ac_space["default"].low[0] * ac_scale * 0.8
        max_action = self._ac_space["default"].high[0] * ac_scale * 0.8
        assert max_action > min_action, "action space box is ill defined"
        assert (max_action > 0 and min_action < 0
                ), "action space MAY be ill defined. Check this assertion"

        diff = (target_qpos[:len(self._ref_joint_pos_indexes)] -
                curr_qpos[:len(self._ref_joint_pos_indexes)])
        out_of_bounds = np.where((diff > max_action) | (diff < min_action))[0]
        out_diff = diff[out_of_bounds]

        scales = np.where(out_diff > max_action, out_diff / max_action,
                          out_diff / min_action)
        if len(scales) == 0:
            scaling_factor = 1.0
        else:
            scaling_factor = max(max(scales), 1.0)
        scaled_ac = diff[:len(self._ref_joint_pos_indexes)] / scaling_factor

        valid = True
        interp_qpos = curr_qpos.copy()
        for i in range(int(scaling_factor)):
            interp_qpos[:len(self._ref_joint_pos_indexes)] += scaled_ac
            if not self._planner.isValidState(interp_qpos):
                valid = False
                break
            traj.append(interp_qpos.copy())

        if not valid and use_planner:
            traj, success, valid, exact = self._simple_planner.plan(
                curr_qpos, target_qpos, self._config.simple_planner_timelimit)
            if not success:
                traj, success, valid, exact = self._planner.plan(
                    curr_qpos, target_qpos, self._config.timelimit)
                if not success:
                    traj = [target_qpos]
                    success = False
                    exact = False
        else:
            if not valid:
                success = False
                exact = False
            traj.append(target_qpos)

        return np.array(traj), success, valid, exact

    def state_dict(self):
        return {
            "actor_state_dict": self._actor.state_dict(),
            "critic1_state_dict": self._critic1.state_dict(),
            "critic2_state_dict": self._critic2.state_dict(),
            "actor_optim_state_dict": self._actor_optim.state_dict(),
            "critic1_optim_state_dict": self._critic1_optim.state_dict(),
            "critic2_optim_state_dict": self._critic2_optim.state_dict(),
        }

    def load_state_dict(self, ckpt):
        self._actor.load_state_dict(ckpt["actor_state_dict"])
        self._actor_target.load_state_dict(self._actor.state_dict())
        self._critic1.load_state_dict(ckpt["critic1_state_dict"])
        self._critic2.load_state_dict(ckpt["critic2_state_dict"])
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())
        self._network_cuda(self._config.device)

        self._actor_optim.load_state_dict(ckpt["actor_optim_state_dict"])
        self._critic1_optim.load_state_dict(ckpt["critic1_optim_state_dict"])
        self._critic2_optim.load_state_dict(ckpt["critic2_optim_state_dict"])
        optimizer_cuda(self._actor_optim, self._config.device)
        optimizer_cuda(self._critic1_optim, self._config.device)
        optimizer_cuda(self._critic2_optim, self._config.device)

    def _network_cuda(self, device):
        self._actor.to(device)
        self._actor_target.to(device)
        self._critic1.to(device)
        self._critic2.to(device)
        self._critic1_target.to(device)
        self._critic2_target.to(device)

    def sync_networks(self):
        if self._config.is_mpi:
            sync_networks(self._actor)
            sync_networks(self._critic1)
            sync_networks(self._critic2)

    def train(self):
        config = self._config
        for i in range(config.num_batches):
            transitions = self._buffer.sample(config.batch_size)
            train_info = self._update_network(transitions, step=i)
            if self._update_steps % self._config.actor_update_freq:
                self._soft_update_target_network(self._actor_target,
                                                 self._actor,
                                                 self._config.polyak)
                self._soft_update_target_network(self._critic1_target,
                                                 self._critic1,
                                                 self._config.polyak)
                self._soft_update_target_network(self._critic2_target,
                                                 self._critic2,
                                                 self._config.polyak)
        return train_info

    def act_log(self, ob, meta_ac=None):
        return self._actor.act_log(ob)

    def act(self, ob, is_train=True, return_stds=False):
        ob = to_tensor(ob, self._config.device)
        if return_stds:
            ac, activation, stds = self._actor.act(ob,
                                                   is_train=is_train,
                                                   return_stds=return_stds)
        else:
            ac, activation = self._actor.act(ob, is_train=is_train)
        if is_train:
            for k, space in self._ac_space.spaces.items():
                if isinstance(space, spaces.Box):
                    ac[k] += np.random.normal(0,
                                              self._config.action_noise,
                                              size=len(ac[k]))
                    ac[k] = np.clip(ac[k], self._ac_space[k].low,
                                    self._ac_space[k].high)
        if return_stds:
            return ac, activation, stds
        else:
            return ac, activation

    def target_act(self, ob, is_train=True):
        ac, activation = self._actor_target.act(ob, is_train=is_train)
        return ac, activation

    def target_act_log(self, ob):
        return self._actor_target.act_log(ob)

    def _update_network(self, transitions, step=0):
        config = self._config
        info = {}

        o, o_next = transitions["ob"], transitions["ob_next"]
        bs = len(transitions["done"])
        _to_tensor = lambda x: to_tensor(x, config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions["ac"])

        done = _to_tensor(transitions["done"]).reshape(bs, 1)
        rew = _to_tensor(transitions["rew"]).reshape(bs, 1)

        ## Actor loss
        actions_real, _ = self.act_log(o)
        actor_loss = -self._critic1(o, actions_real).mean()
        info["actor_loss"] = actor_loss.cpu().item()

        ## Critic loss
        with torch.no_grad():
            actions_next, _ = self.target_act_log(o_next)
            for k, space in self._ac_space.spaces.items():
                if isinstance(space, spaces.Box):
                    epsilon = (torch.randn_like(actions_next[k]) *
                               self._config.target_noise)
                    epsilon = torch.clamp(epsilon, -config.noise_clip,
                                          config.noise_clip)
                    actions_next[k] += epsilon
                    actions_next[k].clamp(-1.0, 1.0)
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1, q_next_value2)
            target_q_value = (
                rew * self._config.reward_scale +
                (1.0 - done) * config.discount_factor * q_next_value)
            target_q_value = target_q_value.detach()

        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)

        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info["min_target_q"] = target_q_value.min().cpu().item()
        info["target_q"] = target_q_value.mean().cpu().item()
        info["min_real1_q"] = real_q_value1.min().cpu().item()
        info["min_real2_q"] = real_q_value2.min().cpu().item()
        info["real1_q"] = real_q_value1.mean().cpu().item()
        info["rea2_q"] = real_q_value2.mean().cpu().item()
        info["critic1_loss"] = critic1_loss.cpu().item()
        info["critic2_loss"] = critic2_loss.cpu().item()

        if self._update_steps % self._config.actor_update_freq == 0:
            # update the actor
            self._actor_optim.zero_grad()
            actor_loss.backward()
            self._actor_optim.step()

        # update the critics
        self._critic1_optim.zero_grad()
        critic1_loss.backward()
        self._critic1_optim.step()

        self._critic2_optim.zero_grad()
        critic2_loss.backward()
        self._critic2_optim.step()
        self._update_steps += 1

        return info
class PPOAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, actor, critic):
        super().__init__(config, ob_space)

        self._ac_space = ac_space  # e.g. ActionSpace(shape=OrderedDict([('default', 8)]),minimum=-1.0, maximum=1.0)

        # build up networks
        self._actor = actor(config, ob_space, ac_space,
                            config.tanh_policy)  # actor model (MLP)
        self._old_actor = actor(config, ob_space, ac_space, config.tanh_policy)
        self._critic = critic(config, ob_space)  # critic model (MLP)
        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic_optim = optim.Adam(self._critic.parameters(),
                                        lr=config.lr_critic)

        sampler = RandomSampler()
        buffer_keys = [
            'ob', 'ac', 'done', 'rew', 'ret', 'adv', 'ac_before_activation'
        ]
        self._buffer = ReplayBuffer(buffer_keys, config.buffer_size,
                                    sampler.sample_func)

        if config.is_chef:
            logger.info('Creating a PPO agent')
            logger.info('The actor has %d parameters',
                        count_parameters(self._actor))
            logger.info('The critic has %d parameters',
                        count_parameters(self._critic))

    def store_episode(self, rollouts):
        self._compute_gae(rollouts)
        self._buffer.store_episode(rollouts)

    def _compute_gae(self, rollouts):
        T = len(rollouts['done'])
        ob = rollouts['ob']
        ob = self.normalize(ob)
        ob = obs2tensor(ob, self._config.device)
        vpred = self._critic(ob).detach().cpu().numpy()[:, 0]
        assert len(vpred) == T + 1

        done = rollouts['done']
        rew = rollouts['rew']
        adv = np.empty((T, ), 'float32')
        lastgaelam = 0
        for t in reversed(range(T)):
            nonterminal = 1 - done[t]
            delta = rew[t] + self._config.discount_factor * vpred[
                t + 1] * nonterminal - vpred[t]
            adv[t] = lastgaelam = delta + self._config.discount_factor * self._config.gae_lambda * nonterminal * lastgaelam

        ret = adv + vpred[:-1]

        assert np.isfinite(adv).all()
        assert np.isfinite(ret).all()

        # update rollouts
        rollouts['adv'] = ((adv - adv.mean()) / adv.std()).tolist()
        rollouts['ret'] = ret.tolist()

    def state_dict(self):
        return {
            'actor_state_dict': self._actor.state_dict(
            ),  # state_dict contains info about model params (learnable tensors: weights&biases)
            'critic_state_dict': self._critic.state_dict(),
            'actor_optim_state_dict': self._actor_optim.state_dict(
            ),  # state_dict contains info about optim state and hyperparams
            'critic_optim_state_dict': self._critic_optim.state_dict(),
            'ob_norm_state_dict': self._ob_norm.state_dict(),
        }

    def load_state_dict(self, ckpt):
        self._actor.load_state_dict(ckpt['actor_state_dict'])
        self._critic.load_state_dict(ckpt['critic_state_dict'])
        self._ob_norm.load_state_dict(ckpt['ob_norm_state_dict'])
        self._network_cuda(self._config.device)

        self._actor_optim.load_state_dict(ckpt['actor_optim_state_dict'])
        self._critic_optim.load_state_dict(ckpt['critic_optim_state_dict'])
        optimizer_cuda(self._actor_optim, self._config.device
                       )  # required when loading optim state from checkpoint
        optimizer_cuda(self._critic_optim, self._config.device)

    def _network_cuda(self, device):
        self._actor.to(device)
        self._old_actor.to(device)
        self._critic.to(device)

    def sync_networks(self):
        sync_networks(self._actor)
        sync_networks(self._critic)

    def train(self):
        self._soft_update_target_network(self._old_actor, self._actor, 0.0)

        for _ in range(self._config.num_batches):
            transitions = self._buffer.sample(self._config.batch_size)
            train_info = self._update_network(transitions)

        self._buffer.clear()

        train_info.update({
            'actor_grad_norm':
            compute_gradient_norm(self._actor),
            'actor_weight_norm':
            compute_weight_norm(self._actor),
            'critic_grad_norm':
            compute_gradient_norm(self._critic),
            'critic_weight_norm':
            compute_weight_norm(self._critic),
        })
        return train_info

    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o = transitions[
            'ob']  #  dict {'robot-state': array(), 'object-state': array()} coming from Sampler
        o = self.normalize(o)

        bs = len(transitions['done'])  # batch size
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        ac = _to_tensor(transitions['ac'])
        a_z = _to_tensor(
            transitions['ac_before_activation']
        )  # is OrderedDict([('default', tensor()]; tensor shape is (batch_size, 8)
        ret = _to_tensor(transitions['ret']).reshape(bs, 1)
        adv = _to_tensor(transitions['adv']).reshape(bs, 1)

        log_pi, ent = self._actor.act_log(
            o, a_z)  # log_pi is tensor of shape (batch_size, 1)
        old_log_pi, _ = self._old_actor.act_log(o, a_z)

        if old_log_pi.min() < -100:
            import ipdb
            ipdb.set_trace()

        # the actor loss
        entropy_loss = self._config.entropy_loss_coeff * ent.mean()
        ratio = torch.exp(log_pi - old_log_pi)
        surr1 = ratio * adv
        surr2 = torch.clamp(ratio, 1.0 - self._config.clip_param,
                            1.0 + self._config.clip_param) * adv
        actor_loss = -torch.min(surr1, surr2).mean()

        if not np.isfinite(ratio.cpu().detach()).all() or not np.isfinite(
                adv.cpu().detach()).all():
            import ipdb
            ipdb.set_trace()
        info['entropy_loss'] = entropy_loss.cpu().item()
        info['actor_loss'] = actor_loss.cpu().item()
        actor_loss += entropy_loss

        #custom_loss = self._actor.custom_loss()
        #if custom_loss is not None:
        #actor_loss += custom_loss * self._config.custom_loss_weight
        #info['custom_loss'] = custom_loss.cpu().item()

        # the q loss
        value_pred = self._critic(o)
        value_loss = self._config.value_loss_coeff * (
            ret - value_pred).pow(2).mean()

        info['value_target'] = ret.mean().cpu().item()
        info['value_predicted'] = value_pred.mean().cpu().item()
        info['value_loss'] = value_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._actor.parameters(), self._config.max_grad_norm)
        sync_grads(self._actor)
        self._actor_optim.step()

        # update the critic
        self._critic_optim.zero_grad()
        value_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._critic1.parameters(), self._config.max_grad_norm)
        sync_grads(self._critic)
        self._critic_optim.step()

        # include info from policy
        info.update(self._actor.info)

        return mpi_average(info)
Exemple #12
0
class SACAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, actor, critic):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        self._target_entropy = -ac_space.size
        self._log_alpha = torch.zeros(1,
                                      requires_grad=True,
                                      device=config.device)
        self._alpha_optim = optim.Adam([self._log_alpha], lr=config.lr_actor)

        # build up networks
        self._build_actor(actor)
        self._critic1 = critic(config, ob_space, ac_space)
        self._critic2 = critic(config, ob_space, ac_space)

        # build up target networks
        self._critic1_target = critic(config, ob_space, ac_space)
        self._critic2_target = critic(config, ob_space, ac_space)
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())
        self._network_cuda(config.device)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(),
                                         lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(),
                                         lr=config.lr_critic)

        sampler = RandomSampler()
        buffer_keys = ['ob', 'ac', 'done', 'rew']
        self._buffer = ReplayBuffer(buffer_keys, config.buffer_size,
                                    sampler.sample_func)

        self._log_creation()

    def _log_creation(self):
        if self._config.is_chef:
            logger.info('Creating a SAC agent')
            logger.info('The actor has %d parameters',
                        count_parameters(self._actor))
            logger.info('The critic1 has %d parameters',
                        count_parameters(self._critic1))
            logger.info('The critic2 has %d parameters',
                        count_parameters(self._critic2))

    def _build_actor(self, actor):
        self._actor = actor(self._config, self._ob_space, self._ac_space,
                            self._config.tanh_policy)

    def store_episode(self, rollouts):
        self._buffer.store_episode(rollouts)

    def state_dict(self):
        return {
            'log_alpha': self._log_alpha.cpu().detach().numpy(),
            'actor_state_dict': self._actor.state_dict(),
            'critic1_state_dict': self._critic1.state_dict(),
            'critic2_state_dict': self._critic2.state_dict(),
            'alpha_optim_state_dict': self._alpha_optim.state_dict(),
            'actor_optim_state_dict': self._actor_optim.state_dict(),
            'critic1_optim_state_dict': self._critic1_optim.state_dict(),
            'critic2_optim_state_dict': self._critic2_optim.state_dict(),
            'ob_norm_state_dict': self._ob_norm.state_dict(),
        }

    def load_state_dict(self, ckpt):
        self._log_alpha.data = torch.tensor(ckpt['log_alpha'],
                                            requires_grad=True,
                                            device=self._config.device)
        self._actor.load_state_dict(ckpt['actor_state_dict'])
        self._critic1.load_state_dict(ckpt['critic1_state_dict'])
        self._critic2.load_state_dict(ckpt['critic2_state_dict'])
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())
        self._ob_norm.load_state_dict(ckpt['ob_norm_state_dict'])
        self._network_cuda(self._config.device)

        self._alpha_optim.load_state_dict(ckpt['alpha_optim_state_dict'])
        self._actor_optim.load_state_dict(ckpt['actor_optim_state_dict'])
        self._critic1_optim.load_state_dict(ckpt['critic1_optim_state_dict'])
        self._critic2_optim.load_state_dict(ckpt['critic2_optim_state_dict'])
        optimizer_cuda(self._alpha_optim, self._config.device)
        optimizer_cuda(self._actor_optim, self._config.device)
        optimizer_cuda(self._critic1_optim, self._config.device)
        optimizer_cuda(self._critic2_optim, self._config.device)

    def _network_cuda(self, device):
        self._actor.to(device)
        self._critic1.to(device)
        self._critic2.to(device)
        self._critic1_target.to(device)
        self._critic2_target.to(device)

    def sync_networks(self):
        sync_networks(self._actor)
        sync_networks(self._critic1)
        sync_networks(self._critic2)

    def train(self):
        for _ in range(self._config.num_batches):
            transitions = self._buffer.sample(self._config.batch_size)
            train_info = self._update_network(transitions)
            self._soft_update_target_network(self._critic1_target,
                                             self._critic1,
                                             self._config.polyak)
            self._soft_update_target_network(self._critic2_target,
                                             self._critic2,
                                             self._config.polyak)

        train_info.update({
            'actor_grad_norm':
            np.mean(compute_gradient_norm(self._actor)),
            'actor_weight_norm':
            np.mean(compute_weight_norm(self._actor)),
            'critic1_grad_norm':
            compute_gradient_norm(self._critic1),
            'critic2_grad_norm':
            compute_gradient_norm(self._critic2),
            'critic1_weight_norm':
            compute_weight_norm(self._critic1),
            'critic2_weight_norm':
            compute_weight_norm(self._critic2),
        })
        return train_info

    def act_log(self, ob):
        return self._actor.act_log(ob)

    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o, o_next = transitions['ob'], transitions['ob_next']
        o = self.normalize(o)
        o_next = self.normalize(o_next)

        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])
        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)

        # update alpha
        actions_real, log_pi = self.act_log(o)
        alpha_loss = -(self._log_alpha *
                       (log_pi + self._target_entropy).detach()).mean()
        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()
        alpha = self._log_alpha.exp()

        # the actor loss
        entropy_loss = (alpha * log_pi).mean()
        actor_loss = -torch.min(self._critic1(o, actions_real),
                                self._critic2(o, actions_real)).mean()
        info['entropy_alpha'] = alpha.cpu().item()
        info['entropy_loss'] = entropy_loss.cpu().item()
        info['actor_loss'] = actor_loss.cpu().item()
        actor_loss += entropy_loss

        # calculate the target Q value function
        with torch.no_grad():
            actions_next, log_pi_next = self.act_log(o_next)
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1,
                                     q_next_value2) - alpha * log_pi_next
            target_q_value = rew * self._config.reward_scale + \
                (1 - done) * self._config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()
            ## clip the q value
            clip_return = 10 / (1 - self._config.discount_factor)
            target_q_value = torch.clamp(target_q_value, -clip_return,
                                         clip_return)

        # the q loss
        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)
        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info['min_target_q'] = target_q_value.min().cpu().item()
        info['target_q'] = target_q_value.mean().cpu().item()
        info['min_real1_q'] = real_q_value1.min().cpu().item()
        info['min_real2_q'] = real_q_value2.min().cpu().item()
        info['real1_q'] = real_q_value1.mean().cpu().item()
        info['real2_q'] = real_q_value2.mean().cpu().item()
        info['critic1_loss'] = critic1_loss.cpu().item()
        info['critic2_loss'] = critic2_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._actor.parameters(), self._config.max_grad_norm)
        sync_grads(self._actor)
        self._actor_optim.step()

        # update the critic
        self._critic1_optim.zero_grad()
        critic1_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._critic1.parameters(), self._config.max_grad_norm)
        sync_grads(self._critic1)
        self._critic1_optim.step()

        self._critic2_optim.zero_grad()
        critic2_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._critic2.parameters(), self._config.max_grad_norm)
        sync_grads(self._critic2)
        self._critic2_optim.step()

        # include info from policy
        info.update(self._actor.info)
        return mpi_average(info)
Exemple #13
0
class DQNAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, dqn):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        # build up networks
        self._dqn = dqn(config, ob_space, ac_space)
        self._network_cuda(config.device)

        self._dqn_optim = optim.Adam(self._dqn.parameters(),
                                     lr=config.lr_actor)
        sampler = RandomSampler()
        self._buffer = ReplayBuffer(config.buffer_size, sampler.sample_func,
                                    ob_space, ac_space)

    def _log_creation(self):
        logger.info("Creating a DQN agent")
        logger.info("The DQN has %d parameters".format(
            count_parameters(self._dqn)))

    def store_episode(self, rollouts):
        self._buffer.store_episode(rollouts)

    def store_sample(self, rollouts):
        self._buffer.store_sample(rollouts)

    def _network_cuda(self, device):
        self._dqn.to(device)

    def state_dict(self):
        return {
            'dqn_state_dict': self._dqn.state_dict(),
            'dqn_optim_state_dict': self._dqn_optim.state_dict(),
        }

    def load_state_dict(self, ckpt):
        self._dqn.load_state_dict(ckpt['dqn_state_dict'])

        self._network_cuda(self._config.device)
        self._dqn_optim.load_state_dict(ckpt['dqn_optim_state_dict'])
        optimizer_cuda(self._dqn_optim, self._config.device)

    def train(self):
        for _ in range(self._config.num_batches):
            transitions = self._buffer.sample(self._config.batch_size)
            train_info = self._update_network(transitions)

        return train_info

    def act_log(self, o):
        raise NotImplementedError

    def act(self, o):
        o = to_tensor(o, self._config.device)
        q_value = self._dqn(o)
        action = OrderedDict([('default', q_value.max(1)[1].item())])
        return action, None

    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o, o_next = transitions['ob'], transitions['ob_next']

        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])
        ac = ac['default'].to(torch.long)

        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)

        with torch.no_grad():
            q_next_values = self._dqn(o)
            q_next_value = q_next_values.max(1)[0]
            target_q_value = rew + \
                (1-done)  * self._config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()

        q_values = self._dqn(o)
        q_value = q_values.gather(1, ac[:, 0].unsqueeze(1)).squeeze(1)
        info['target_q'] = target_q_value.mean().cpu().item()
        info['real_q'] = q_value.mean().cpu().item()
        loss = (q_value - target_q_value).pow(2).mean()
        self._dqn_optim.zero_grad()
        loss.backward()
        self._dqn_optim.step()
        return info
Exemple #14
0
    def __init__(
        self,
        config,
        ob_space,
        ac_space,
        actor,
        critic,
        non_limited_idx=None,
        ref_joint_pos_indexes=None,
        joint_space=None,
        is_jnt_limited=None,
        jnt_indices=None,
    ):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        self._jnt_indices = jnt_indices
        self._ref_joint_pos_indexes = ref_joint_pos_indexes
        self._log_alpha = torch.tensor(np.log(config.alpha),
                                       requires_grad=True,
                                       device=config.device)
        self._alpha_optim = optim.Adam([self._log_alpha], lr=config.lr_actor)
        self._joint_space = joint_space
        self._is_jnt_limited = is_jnt_limited
        if joint_space is not None:
            self._jnt_minimum = joint_space["default"].low
            self._jnt_maximum = joint_space["default"].high

        # build up networks
        self._build_actor(actor)
        self._build_critic(critic)
        self._network_cuda(config.device)

        self._target_entropy = -action_size(self._actor._ac_space)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(),
                                         lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(),
                                         lr=config.lr_critic)

        sampler = RandomSampler()
        buffer_keys = ["ob", "ac", "meta_ac", "done", "rew"]
        if config.mopa or config.expand_ac_space:
            buffer_keys.append("intra_steps")
        self._buffer = ReplayBuffer(buffer_keys, config.buffer_size,
                                    sampler.sample_func)

        self._log_creation()

        self._planner = None
        self._is_planner_initialized = False
        if config.mopa:
            self._planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                is_simplified=config.is_simplified,
                simplified_duration=config.simplified_duration,
                range_=config.range,
            )
            self._simple_planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.simple_planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                goal_bias=1.0,
                is_simplified=config.simple_planner_simplified,
                simplified_duration=config.simple_planner_simplified_duration,
                range_=config.simple_planner_range,
            )
            self._omega = config.omega
Exemple #15
0
class SACAgent(BaseAgent):
    def __init__(
        self,
        config,
        ob_space,
        ac_space,
        actor,
        critic,
        non_limited_idx=None,
        ref_joint_pos_indexes=None,
        joint_space=None,
        is_jnt_limited=None,
        jnt_indices=None,
    ):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space
        self._jnt_indices = jnt_indices
        self._ref_joint_pos_indexes = ref_joint_pos_indexes
        self._log_alpha = torch.tensor(np.log(config.alpha),
                                       requires_grad=True,
                                       device=config.device)
        self._alpha_optim = optim.Adam([self._log_alpha], lr=config.lr_actor)
        self._joint_space = joint_space
        self._is_jnt_limited = is_jnt_limited
        if joint_space is not None:
            self._jnt_minimum = joint_space["default"].low
            self._jnt_maximum = joint_space["default"].high

        # build up networks
        self._build_actor(actor)
        self._build_critic(critic)
        self._network_cuda(config.device)

        self._target_entropy = -action_size(self._actor._ac_space)

        self._actor_optim = optim.Adam(self._actor.parameters(),
                                       lr=config.lr_actor)
        self._critic1_optim = optim.Adam(self._critic1.parameters(),
                                         lr=config.lr_critic)
        self._critic2_optim = optim.Adam(self._critic2.parameters(),
                                         lr=config.lr_critic)

        sampler = RandomSampler()
        buffer_keys = ["ob", "ac", "meta_ac", "done", "rew"]
        if config.mopa or config.expand_ac_space:
            buffer_keys.append("intra_steps")
        self._buffer = ReplayBuffer(buffer_keys, config.buffer_size,
                                    sampler.sample_func)

        self._log_creation()

        self._planner = None
        self._is_planner_initialized = False
        if config.mopa:
            self._planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                is_simplified=config.is_simplified,
                simplified_duration=config.simplified_duration,
                range_=config.range,
            )
            self._simple_planner = PlannerAgent(
                config,
                ac_space,
                non_limited_idx,
                planner_type=config.simple_planner_type,
                passive_joint_idx=config.passive_joint_idx,
                ignored_contacts=config.ignored_contact_geom_ids,
                goal_bias=1.0,
                is_simplified=config.simple_planner_simplified,
                simplified_duration=config.simple_planner_simplified_duration,
                range_=config.simple_planner_range,
            )
            self._omega = config.omega

    def _log_creation(self):
        if self._config.is_chef:
            logger.info("creating a sac agent")
            logger.info("the actor has %d parameters",
                        count_parameters(self._actor))
            logger.info("the critic1 has %d parameters",
                        count_parameters(self._critic1))
            logger.info("the critic2 has %d parameters",
                        count_parameters(self._critic2))

    def _build_actor(self, actor):
        self._actor = actor(
            self._config,
            self._ob_space,
            self._ac_space,
            self._config.tanh_policy,
        )

    def _build_critic(self, critic):
        config = self._config
        self._critic1 = critic(config, self._ob_space, self._ac_space)
        self._critic2 = critic(config, self._ob_space, self._ac_space)

        # build up target networks
        self._critic1_target = critic(config, self._ob_space, self._ac_space)
        self._critic2_target = critic(config, self._ob_space, self._ac_space)
        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())

    def store_episode(self, rollouts):
        self._buffer.store_episode(rollouts)

    def valid_action(self, ac):
        return np.all(ac["default"] >= -1.0) and np.all(ac["default"] <= 1.0)

    def is_planner_ac(self, ac):
        if np.any(
                ac["default"][:len(self._ref_joint_pos_indexes)] < -self._omega
        ) or np.any(ac["default"][:len(self._ref_joint_pos_indexes)] >
                    self._omega):
            return True
        return False

    def isValidState(self, state):
        return self._planner.isValidState(state)

    def convert2planner_displacement(self, ac, ac_scale):
        ac_space_type = self._config.ac_space_type
        action_range = self._config.action_range
        if ac_space_type == "normal":
            return ac * action_range
        elif ac_space_type == "piecewise":
            return np.where(
                np.abs(ac) < self._omega,
                ac / (self._omega / ac_scale),
                np.sign(ac) * (ac_scale + (action_range - ac_scale) *
                               ((np.abs(ac) - self._omega) /
                                (1 - self._omega))),
            )
        else:
            raise NotImplementedError

    def invert_displacement(self, displacement, ac_scale):
        ac_space_type = self._config.ac_space_type
        action_range = self._config.action_range
        if ac_space_type == "normal":
            return displacement / action_range
        elif ac_space_type == "piecewise":
            return np.where(
                np.abs(displacement) < ac_scale,
                displacement * (self._omega / ac_scale),
                np.sign(displacement) *
                ((np.abs(displacement) - ac_scale) /
                 ((action_range - ac_scale) / (1.0 - ac_scale)) /
                 ((1.0 - ac_scale) / (1.0 - self._omega)) + self._omega),
            )
        else:
            raise NotImplementedError

    # Calls motion planner to plan a path
    def plan(
        self,
        curr_qpos,
        target_qpos,
        ac_scale=None,
    ):

        curr_qpos = self.clip_qpos(curr_qpos)
        interpolation = True
        traj, success, valid, exact = self.simple_interpolate(
            curr_qpos, target_qpos, ac_scale)
        if not success:
            if not exact:
                traj, success, valid, exact = self._planner.plan(
                    curr_qpos, target_qpos, self._config.timelimit)
                interpolation = False
                if self._config.interpolation and success:
                    new_traj = []
                    start = curr_qpos
                    for i in range(len(traj)):
                        diff = traj[i] - start
                        if np.any(diff[:len(self._ref_joint_pos_indexes)] <
                                  -ac_scale) or np.any(
                                      diff[:len(self._ref_joint_pos_indexes)] >
                                      ac_scale):
                            inner_traj, _, _, _ = self.simple_interpolate(
                                start, traj[i], ac_scale, use_planner=True)
                            new_traj.extend(inner_traj)
                        else:
                            new_traj.append(traj[i])
                        start = traj[i]
                    traj = np.array(new_traj)

        return traj, success, interpolation, valid, exact

    def clip_qpos(self, curr_qpos):
        tmp_pos = curr_qpos.copy()
        if np.any(curr_qpos[self._is_jnt_limited[self._jnt_indices]] <
                  self._jnt_minimum[self._jnt_indices][self._is_jnt_limited[
                      self._jnt_indices]]) or np.any(
                          curr_qpos[self._is_jnt_limited[self._jnt_indices]] >
                          self._jnt_maximum[self._jnt_indices][
                              self._is_jnt_limited[self._jnt_indices]]):
            new_curr_qpos = np.clip(
                curr_qpos.copy(),
                self._jnt_minimum[self._jnt_indices] +
                self._config.joint_margin,
                self._jnt_maximum[self._jnt_indices] -
                self._config.joint_margin,
            )
            new_curr_qpos[np.invert(
                self._is_jnt_limited[self._jnt_indices])] = tmp_pos[np.invert(
                    self._is_jnt_limited[self._jnt_indices])]
            curr_qpos = new_curr_qpos
        return curr_qpos

    # interpolation function
    def simple_interpolate(self,
                           curr_qpos,
                           target_qpos,
                           ac_scale,
                           use_planner=False):
        success = True
        exact = True
        curr_qpos = self.clip_qpos(curr_qpos)

        traj = []
        min_action = self._ac_space["default"].low[0] * ac_scale * 0.8
        max_action = self._ac_space["default"].high[0] * ac_scale * 0.8
        assert max_action > min_action, "action space box is ill defined"
        assert (max_action > 0 and min_action < 0
                ), "action space MAY be ill defined. Check this assertion"

        diff = (target_qpos[:len(self._ref_joint_pos_indexes)] -
                curr_qpos[:len(self._ref_joint_pos_indexes)])
        out_of_bounds = np.where((diff > max_action) | (diff < min_action))[0]
        out_diff = diff[out_of_bounds]

        scales = np.where(out_diff > max_action, out_diff / max_action,
                          out_diff / min_action)
        if len(scales) == 0:
            scaling_factor = 1.0
        else:
            scaling_factor = max(max(scales), 1.0)
        scaled_ac = diff[:len(self._ref_joint_pos_indexes)] / scaling_factor

        valid = True
        interp_qpos = curr_qpos.copy()
        for i in range(int(scaling_factor)):
            interp_qpos[:len(self._ref_joint_pos_indexes)] += scaled_ac
            if not self._planner.isValidState(interp_qpos):
                valid = False
                break
            traj.append(interp_qpos.copy())

        if not valid and use_planner:
            traj, success, valid, exact = self._simple_planner.plan(
                curr_qpos, target_qpos, self._config.simple_planner_timelimit)
            if not success:
                traj, success, valid, exact = self._planner.plan(
                    curr_qpos, target_qpos, self._config.timelimit)
                if not success:
                    traj = [target_qpos]
                    success = False
                    exact = False
        else:
            if not valid:
                success = False
                exact = False
            traj.append(target_qpos)

        return np.array(traj), success, valid, exact

    def interpolate_ac(self, ac, ac_scale, diff):
        out_of_bounds = np.where((diff > ac_scale) | (diff < -ac_scale))[0]
        out_diff = diff[out_of_bounds]
        scales = np.where(out_diff > ac_scale, out_diff / ac_scale,
                          out_diff / (-ac_scale))
        if len(scales) == 0:
            scaling_factor = 1.0
        else:
            scaling_factor = max(max(scales), 1.0)
        scaled_ac = diff[:len(self._ref_joint_pos_indexes)] / scaling_factor
        actions = []
        for j in range(ceil(scaling_factor)):
            inter_ac = copy.deepcopy(ac)
            if j < int(scaling_factor):
                inter_ac["default"][
                    self._ref_joint_pos_indexes] = scaled_ac / ac_scale
            else:
                inter_ac["default"][
                    self.
                    _ref_joint_pos_indexes] -= scaled_ac * int(scaling_factor)

            actions.append(inter_ac)
        return actions

    def state_dict(self):
        return {
            "log_alpha": self._log_alpha.cpu().detach().numpy(),
            "actor_state_dict": self._actor.state_dict(),
            "critic1_state_dict": self._critic1.state_dict(),
            "critic2_state_dict": self._critic2.state_dict(),
            "alpha_optim_state_dict": self._alpha_optim.state_dict(),
            "actor_optim_state_dict": self._actor_optim.state_dict(),
            "critic1_optim_state_dict": self._critic1_optim.state_dict(),
            "critic2_optim_state_dict": self._critic2_optim.state_dict(),
        }

    def load_state_dict(self, ckpt):
        self._log_alpha.data = torch.tensor(ckpt["log_alpha"],
                                            requires_grad=True,
                                            device=self._config.device)
        self._actor.load_state_dict(ckpt["actor_state_dict"])
        self._critic1.load_state_dict(ckpt["critic1_state_dict"])
        self._critic2.load_state_dict(ckpt["critic2_state_dict"])

        self._critic1_target.load_state_dict(self._critic1.state_dict())
        self._critic2_target.load_state_dict(self._critic2.state_dict())

        self._network_cuda(self._config.device)

        self._alpha_optim.load_state_dict(ckpt["alpha_optim_state_dict"])
        self._actor_optim.load_state_dict(ckpt["actor_optim_state_dict"])
        self._critic1_optim.load_state_dict(ckpt["critic1_optim_state_dict"])
        self._critic2_optim.load_state_dict(ckpt["critic2_optim_state_dict"])

        optimizer_cuda(self._alpha_optim, self._config.device)
        optimizer_cuda(self._actor_optim, self._config.device)
        optimizer_cuda(self._critic1_optim, self._config.device)
        optimizer_cuda(self._critic2_optim, self._config.device)

    def _network_cuda(self, device):
        self._actor.to(device)
        self._critic1.to(device)
        self._critic2.to(device)
        self._critic1_target.to(device)
        self._critic2_target.to(device)

    def sync_networks(self):
        if self._config.is_mpi:
            sync_networks(self._actor)
            sync_networks(self._critic2)
            sync_networks(self._critic2)

    def train(self):
        for i in range(self._config.num_batches):
            transitions = self._buffer.sample(self._config.batch_size)
            train_info = self._update_network(transitions, i)
            self._soft_update_target_network(self._critic1_target,
                                             self._critic1,
                                             self._config.polyak)
            self._soft_update_target_network(self._critic2_target,
                                             self._critic2,
                                             self._config.polyak)
        return train_info

    def act_log(self, ob):
        return self._actor.act_log(ob)

    def _update_network(self, transitions, step=0):
        info = {}

        # pre-process observations
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o, o_next = transitions["ob"], transitions["ob_next"]
        bs = len(transitions["done"])
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions["ac"])

        if "intra_steps" in transitions.keys(
        ) and self._config.use_smdp_update:
            intra_steps = _to_tensor(transitions["intra_steps"])

        done = _to_tensor(transitions["done"]).reshape(bs, 1)
        rew = _to_tensor(transitions["rew"]).reshape(bs, 1)

        actions_real, log_pi = self.act_log(o)
        alpha_loss = -(self._log_alpha.exp() *
                       (log_pi + self._target_entropy).detach()).mean()

        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()
        alpha = self._log_alpha.exp()
        info["alpha_loss"] = alpha_loss.cpu().item()
        info["entropy_alpha"] = alpha.cpu().item()
        alpha = self._log_alpha.exp()

        # the actor loss
        entropy_loss = (alpha * log_pi).mean()
        actor_loss = -torch.min(self._critic1(o, actions_real),
                                self._critic2(o, actions_real)).mean()
        info["log_pi"] = log_pi.mean().cpu().item()
        info["entropy_loss"] = entropy_loss.cpu().item()
        info["actor_loss"] = actor_loss.cpu().item()
        actor_loss += entropy_loss

        # calculate the target Q value function
        with torch.no_grad():
            actions_next, log_pi_next = self.act_log(o_next)
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1,
                                     q_next_value2) - alpha * log_pi_next
            if self._config.use_smdp_update:
                target_q_value = (self._config.reward_scale * rew +
                                  (1 - done) *
                                  (self._config.discount_factor**
                                   (intra_steps + 1)) * q_next_value)
            else:
                target_q_value = (
                    self._config.reward_scale * rew +
                    (1 - done) * self._config.discount_factor * q_next_value)
            target_q_value = target_q_value.detach()

        # the q loss
        for k, space in self._ac_space.spaces.items():
            if isinstance(space, spaces.Discrete):
                ac[k] = (F.one_hot(ac[k].long(), action_size(
                    self._ac_space[k])).float().squeeze(1))
        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)
        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info["min_target_q"] = target_q_value.min().cpu().item()
        info["target_q"] = target_q_value.mean().cpu().item()
        info["min_real1_q"] = real_q_value1.min().cpu().item()
        info["min_real2_q"] = real_q_value2.min().cpu().item()
        info["real1_q"] = real_q_value1.mean().cpu().item()
        info["real2_q"] = real_q_value2.mean().cpu().item()
        info["critic1_loss"] = critic1_loss.cpu().item()
        info["critic2_loss"] = critic2_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        if self._config.is_mpi:
            sync_grads(self._actor)
        self._actor_optim.step()

        # update the critic
        self._critic1_optim.zero_grad()
        critic1_loss.backward()
        if self._config.is_mpi:
            sync_grads(self._critic1)
        self._critic1_optim.step()

        self._critic2_optim.zero_grad()
        critic2_loss.backward()
        if self._config.is_mpi:
            sync_grads(self, _critic2)
        self._critic2_optim.step()

        if self._config.is_mpi:
            return mpi_average(info)
        else:
            return info