Beispiel #1
0
    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o, o_next = transitions['ob'], transitions['ob_next']

        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])
        ac = ac['default'].to(torch.long)

        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)

        with torch.no_grad():
            q_next_values = self._dqn(o)
            q_next_value = q_next_values.max(1)[0]
            target_q_value = rew + \
                (1-done)  * self._config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()

        q_values = self._dqn(o)
        q_value = q_values.gather(1, ac[:, 0].unsqueeze(1)).squeeze(1)
        info['target_q'] = target_q_value.mean().cpu().item()
        info['real_q'] = q_value.mean().cpu().item()
        loss = (q_value - target_q_value).pow(2).mean()
        self._dqn_optim.zero_grad()
        loss.backward()
        self._dqn_optim.step()
        return info
Beispiel #2
0
    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o, o_next = transitions['ob'], transitions['ob_next']

        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])

        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)
        # update alpha
        critic_info = self._update_critic(o, ac, rew, o_next, done)
        info.update(critic_info)
        actor_alpha_info = self._update_actor_and_alpha(o)
        info.update(actor_alpha_info)

        if self._config.policy == 'cnn' and self._config.unsup_algo == 'curl':

            cpc_kwargs = transitions['cpc_kwargs']
            o_anchor = _to_tensor(cpc_kwargs['ob_anchor'])
            o_pos = _to_tensor(cpc_kwargs['ob_pos'])
            cpc_info = self._update_cpc(o_anchor, o_pos, cpc_kwargs)
            info.update(cpc_info)

        return info
Beispiel #3
0
    def act(self, ob, is_train=True, return_log_prob=False):
        ob = to_tensor(ob, self._config.device)
        means, stds = self.forward(ob, self._deterministic)

        dists = OrderedDict()
        for k, space in self._ac_space.spaces.items():
            if isinstance(space, spaces.Box):
                if self._deterministic:
                    stds[k] = torch.zeros_like(means[k])
                dists[k] = FixedNormal(means[k], stds[k])
            else:
                if self._config.meta_algo == 'sac' or self._config.algo == 'sac':
                    dists[k] = FixedGumbelSoftmax(torch.tensor(
                        self._config.temperature),
                                                  logits=means[k])
                else:
                    dists[k] = FixedCategorical(logits=means[k])

        actions = OrderedDict()
        mixed_dist = MixedDistribution(dists)
        if not is_train or self._deterministic:
            activations = mixed_dist.mode()
        else:
            activations = mixed_dist.sample()

        if return_log_prob:
            log_probs = mixed_dist.log_probs(activations)

        for k, space in self._ac_space.spaces.items():
            z = activations[k]
            if self._tanh and isinstance(space, spaces.Box):
                # action_scale = to_tensor((self._ac_space[k].high), self._config.device).detach()
                # action = torch.tanh(z) * action_scale
                action = torch.tanh(z)
                if return_log_prob:
                    # follow the Appendix C. Enforcing Action Bounds
                    # log_det_jacobian = 2 * (np.log(2.) - z - F.softplus(-2. * z)).sum(dim=1, keepdim=True)
                    log_det_jacobian = 2 * (np.log(2.) - z - F.softplus(
                        -2. * z)).sum(dim=-1, keepdim=True)
                    # log_det_jacobian = torch.log((1-torch.tanh(z).pow(2))+1e-6).sum(dim=1, keepdim=True)
                    log_probs[k] = log_probs[k] - log_det_jacobian
            else:
                action = z
            if action.shape[0] == 1:
                actions[k] = action.detach().cpu().numpy().squeeze(0)
            else:
                actions[k] = action.detach().cpu().numpy()

        if return_log_prob:
            log_probs_ = torch.cat(list(log_probs.values()),
                                   -1).sum(-1, keepdim=True)
            # if log_probs_.min() < -100:
            #     print('sampling an action with a probability of 1e-100')
            #     import ipdb; ipdb.set_trace()

            log_probs_ = log_probs_.detach().cpu().numpy().squeeze(0)
            return actions, activations, log_probs_
        else:
            return actions, activations
Beispiel #4
0
 def act(self, ob, is_train=True):
     ob = to_tensor(ob, self._config.device)
     ac, activation = self._actor.act(ob, is_train=is_train)
     if is_train:
         for k, space in self._ac_space.spaces.items():
             if isinstance(space, spaces.Box):
                 ac[k] += self._config.noise_scale*np.random.randn(len(ac[k]))
                 ac[k] = np.clip(ac[k], self._ac_space[k].low, self._ac_space[k].high)
     return ac, activation
Beispiel #5
0
    def _update_network(self, transitions, step=0):
        config = self._config
        info = {}

        o, o_next = transitions['ob'], transitions['ob_next']
        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])

        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)

        ## Actor loss
        actions_real, _ = self.act_log(o)
        actor_loss = -self._critic(o, actions_real).mean()
        info['actor_loss'] = actor_loss.cpu().item()

        ## Critic loss
        with torch.no_grad():
            actions_next, _ = self.target_act_log(o_next)
            q_next_value = self._critic_target(o_next, actions_next)
            target_q_value = rew + (1.-done) * config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()

        real_q_value = self._critic(o, ac)

        critic_loss = 0.5 * (target_q_value - real_q_value).pow(2).mean()

        info['min_target_q'] = target_q_value.min().cpu().item()
        info['target_q'] = target_q_value.mean().cpu().item()
        info['min_real1_q'] = real_q_value.min().cpu().item()
        info['real_q'] = real_q_value.mean().cpu().item()
        info['critic_loss'] = critic_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        self._actor_optim.step()

        # update the critics
        self._critic_optim.zero_grad()
        critic_loss.backward()
        self._critic_optim.step()

        return info
Beispiel #6
0
    def act(self, ob, is_train=True, return_log_prob=False):
        ob = to_tensor(ob, self._config.device)
        self._ob = ob
        means, stds = self.forward(ob)

        dists = OrderedDict()
        for k in self._ac_space.keys():
            if self._ac_space.is_continuous(k):
                dists[k] = FixedNormal(means[k], stds[k])
            else:
                dists[k] = FixedCategorical(logits=means[k])

        actions = OrderedDict()
        mixed_dist = MixedDistribution(dists)
        if not is_train:
            activations = mixed_dist.mode()
        else:
            activations = mixed_dist.sample()

        if return_log_prob:
            log_probs = mixed_dist.log_probs(activations)

        for k in self._ac_space.keys():
            z = activations[k]
            if self._tanh and self._ac_space.is_continuous(k):
                action = torch.tanh(z)
                if return_log_prob:
                    # follow the Appendix C. Enforcing Action Bounds
                    log_det_jacobian = 2 * (np.log(2.) - z - F.softplus(-2. * z)).sum(dim=-1, keepdim=True)
                    log_probs[k] = log_probs[k] - log_det_jacobian
            else:
                action = z

            actions[k] = action.detach().cpu().numpy().squeeze(0)
            activations[k] = z.detach().cpu().numpy().squeeze(0)

        if return_log_prob:
            log_probs_ = torch.cat(list(log_probs.values()), -1).sum(-1, keepdim=True)
            if log_probs_.min() < -100:
                print('sampling an action with a probability of 1e-100')
                import ipdb; ipdb.set_trace()

            log_probs_ = log_probs_.detach().cpu().numpy().squeeze(0)
            return actions, activations, log_probs_
        else:
            return actions, activations
Beispiel #7
0
    def act_log_debug(self, ob, activations=None):
        means, stds = self.forward(ob)

        dists = OrderedDict()
        actions = OrderedDict()
        for k, space in self._ac_space.spaces.items():
            if isinstance(space, spaces.Box):
                dists[k] = FixedNormal(means[k], stds[k])
            else:
                dists[k] = FixedCategorical(logits=means[k])

        mixed_dist = MixedDistribution(dists)

        activations_ = mixed_dist.rsample(
        ) if activations is None else activations
        log_probs = mixed_dist.log_probs(activations_)

        for k, space in self._ac_space.spaces.items():
            z = activations_[k]
            if self._tanh and isinstance(space, spaces.Box):
                action = torch.tanh(z) * to_tensor(
                    (self._ac_space[k].high), self._config.device)
                # follow the Appendix C. Enforcing Action Bounds
                log_det_jacobian = 2 * (np.log(2.) - z - F.softplus(
                    -2. * z)).sum(dim=-1, keepdim=True)
                log_probs[k] = log_probs[k] - log_det_jacobian
            else:
                action = z

            actions[k] = action

        ents = mixed_dist.entropy()
        #print(torch.cat(list(log_probs.values()), -1))
        log_probs_ = torch.cat(list(log_probs.values()), -1).sum(-1,
                                                                 keepdim=True)
        if log_probs_.min() < -100:
            print(ob)
            print(log_probs_.min())
            import ipdb
            ipdb.set_trace()
        if activations is None:
            return actions, log_probs_
        else:
            return log_probs_, ents, log_probs, means, stds
Beispiel #8
0
 def act(self, ob, is_train=True, return_stds=False):
     ob = to_tensor(ob, self._config.device)
     if return_stds:
         ac, activation, stds = self._actor.act(ob,
                                                is_train=is_train,
                                                return_stds=return_stds)
     else:
         ac, activation = self._actor.act(ob, is_train=is_train)
     if is_train:
         for k, space in self._ac_space.spaces.items():
             if isinstance(space, spaces.Box):
                 ac[k] += np.random.normal(0,
                                           self._config.action_noise,
                                           size=len(ac[k]))
                 ac[k] = np.clip(ac[k], self._ac_space[k].low,
                                 self._ac_space[k].high)
     if return_stds:
         return ac, activation, stds
     else:
         return ac, activation
Beispiel #9
0
    def _update_network(self, transitions, step=0):
        config = self._config
        info = {}

        o, o_next = transitions["ob"], transitions["ob_next"]
        bs = len(transitions["done"])
        _to_tensor = lambda x: to_tensor(x, config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions["ac"])

        done = _to_tensor(transitions["done"]).reshape(bs, 1)
        rew = _to_tensor(transitions["rew"]).reshape(bs, 1)

        ## Actor loss
        actions_real, _ = self.act_log(o)
        actor_loss = -self._critic1(o, actions_real).mean()
        info["actor_loss"] = actor_loss.cpu().item()

        ## Critic loss
        with torch.no_grad():
            actions_next, _ = self.target_act_log(o_next)
            for k, space in self._ac_space.spaces.items():
                if isinstance(space, spaces.Box):
                    epsilon = (torch.randn_like(actions_next[k]) *
                               self._config.target_noise)
                    epsilon = torch.clamp(epsilon, -config.noise_clip,
                                          config.noise_clip)
                    actions_next[k] += epsilon
                    actions_next[k].clamp(-1.0, 1.0)
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1, q_next_value2)
            target_q_value = (
                rew * self._config.reward_scale +
                (1.0 - done) * config.discount_factor * q_next_value)
            target_q_value = target_q_value.detach()

        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)

        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info["min_target_q"] = target_q_value.min().cpu().item()
        info["target_q"] = target_q_value.mean().cpu().item()
        info["min_real1_q"] = real_q_value1.min().cpu().item()
        info["min_real2_q"] = real_q_value2.min().cpu().item()
        info["real1_q"] = real_q_value1.mean().cpu().item()
        info["rea2_q"] = real_q_value2.mean().cpu().item()
        info["critic1_loss"] = critic1_loss.cpu().item()
        info["critic2_loss"] = critic2_loss.cpu().item()

        if self._update_steps % self._config.actor_update_freq == 0:
            # update the actor
            self._actor_optim.zero_grad()
            actor_loss.backward()
            self._actor_optim.step()

        # update the critics
        self._critic1_optim.zero_grad()
        critic1_loss.backward()
        self._critic1_optim.step()

        self._critic2_optim.zero_grad()
        critic2_loss.backward()
        self._critic2_optim.step()
        self._update_steps += 1

        return info
Beispiel #10
0
 def value(self, ob):
     return self._critic(to_tensor(
         ob, self._config.device)).detach().cpu().numpy()
Beispiel #11
0
    def _update_network(self, transitions, step=0):
        config = self._config
        info = {}

        o, o_next = transitions['ob'], transitions['ob_next']
        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])

        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)

        ## Actor loss
        actions_real, _ = self.act_log(o)
        actor_loss = -self._critic1(o, actions_real).mean()
        info['actor_loss'] = actor_loss.cpu().item()

        ## Critic loss
        with torch.no_grad():
            actions_next, _ = self.target_act_log(o_next)
            for k, space in self._ac_space.spaces.items():
                if isinstance(space, spaces.Box):
                    epsilon = torch.randn_like(
                        actions_next[k]) * self._config.noise_scale
                    epsilon = torch.clamp(epsilon, -config.noise_clip,
                                          config.noise_clip)
                    actions_next[k] += epsilon
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1, q_next_value2)
            target_q_value = rew + (
                1. - done) * config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()

        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)

        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info['min_target_q'] = target_q_value.min().cpu().item()
        info['target_q'] = target_q_value.mean().cpu().item()
        info['min_real1_q'] = real_q_value1.min().cpu().item()
        info['min_real2_q'] = real_q_value2.min().cpu().item()
        info['real1_q'] = real_q_value1.mean().cpu().item()
        info['rea2_q'] = real_q_value2.mean().cpu().item()
        info['critic1_loss'] = critic1_loss.cpu().item()
        info['critic2_loss'] = critic2_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        self._actor_optim.step()

        # update the critics
        self._critic1_optim.zero_grad()
        critic1_loss.backward()
        self._critic1_optim.step()

        self._critic2_optim.zero_grad()
        critic2_loss.backward()
        self._critic2_optim.step()

        return info
Beispiel #12
0
    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o, o_next = transitions['ob'], transitions['ob_next']
        o = self.normalize(o)
        o_next = self.normalize(o_next)

        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions['ac'])
        done = _to_tensor(transitions['done']).reshape(bs, 1)
        rew = _to_tensor(transitions['rew']).reshape(bs, 1)

        # update alpha
        actions_real, log_pi = self.act_log(o)
        alpha_loss = -(self._log_alpha *
                       (log_pi + self._target_entropy).detach()).mean()
        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()
        alpha = self._log_alpha.exp()

        # the actor loss
        entropy_loss = (alpha * log_pi).mean()
        actor_loss = -torch.min(self._critic1(o, actions_real),
                                self._critic2(o, actions_real)).mean()
        info['entropy_alpha'] = alpha.cpu().item()
        info['entropy_loss'] = entropy_loss.cpu().item()
        info['actor_loss'] = actor_loss.cpu().item()
        actor_loss += entropy_loss

        # calculate the target Q value function
        with torch.no_grad():
            actions_next, log_pi_next = self.act_log(o_next)
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1,
                                     q_next_value2) - alpha * log_pi_next
            target_q_value = rew * self._config.reward_scale + \
                (1 - done) * self._config.discount_factor * q_next_value
            target_q_value = target_q_value.detach()
            ## clip the q value
            clip_return = 10 / (1 - self._config.discount_factor)
            target_q_value = torch.clamp(target_q_value, -clip_return,
                                         clip_return)

        # the q loss
        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)
        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info['min_target_q'] = target_q_value.min().cpu().item()
        info['target_q'] = target_q_value.mean().cpu().item()
        info['min_real1_q'] = real_q_value1.min().cpu().item()
        info['min_real2_q'] = real_q_value2.min().cpu().item()
        info['real1_q'] = real_q_value1.mean().cpu().item()
        info['real2_q'] = real_q_value2.mean().cpu().item()
        info['critic1_loss'] = critic1_loss.cpu().item()
        info['critic2_loss'] = critic2_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._actor.parameters(), self._config.max_grad_norm)
        sync_grads(self._actor)
        self._actor_optim.step()

        # update the critic
        self._critic1_optim.zero_grad()
        critic1_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._critic1.parameters(), self._config.max_grad_norm)
        sync_grads(self._critic1)
        self._critic1_optim.step()

        self._critic2_optim.zero_grad()
        critic2_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._critic2.parameters(), self._config.max_grad_norm)
        sync_grads(self._critic2)
        self._critic2_optim.step()

        # include info from policy
        info.update(self._actor.info)
        return mpi_average(info)
Beispiel #13
0
 def act(self, o):
     o = to_tensor(o, self._config.device)
     q_value = self._dqn(o)
     action = OrderedDict([('default', q_value.max(1)[1].item())])
     return action, None
Beispiel #14
0
    def _update_network(self, transitions):
        info = {}

        # pre-process observations
        o = transitions['ob']
        o = self.normalize(o)

        bs = len(transitions['done'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        ac = _to_tensor(transitions['ac'])
        a_z = _to_tensor(transitions['ac_before_activation'])
        ret = _to_tensor(transitions['ret']).reshape(bs, 1)
        adv = _to_tensor(transitions['adv']).reshape(bs, 1)

        log_pi, ent = self._actor.act_log(o, a_z, z=z)
        old_log_pi, _ = self._old_actor.act_log(o, a_z, z=z)
        if old_log_pi.min() < -100:
            import ipdb; ipdb.set_trace()

        # the actor loss
        entropy_loss = self._config.entropy_loss_coeff * ent.mean()
        ratio = torch.exp(log_pi - old_log_pi)
        surr1 = ratio * adv
        surr2 = torch.clamp(ratio, 1.0 - self._config.clip_param,
                            1.0 + self._config.clip_param) * adv
        actor_loss = -torch.min(surr1, surr2).mean()

        if not np.isfinite(ratio.cpu().detach()).all() or not np.isfinite(adv.cpu().detach()).all():
            import ipdb; ipdb.set_trace()
        info['entropy_loss'] = entropy_loss.cpu().item()
        info['actor_loss'] = actor_loss.cpu().item()
        actor_loss += entropy_loss

        custom_loss = self._actor.custom_loss()
        if custom_loss is not None:
            actor_loss += custom_loss * self._config.custom_loss_weight
            info['custom_loss'] = custom_loss.cpu().item()

        # the q loss
        value_pred = self._critic(o)
        value_loss = self._config.value_loss_coeff * (ret - value_pred).pow(2).mean()

        info['value_target'] = ret.mean().cpu().item()
        info['value_predicted'] = value_pred.mean().cpu().item()
        info['value_loss'] = value_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._actor.parameters(), self._config.max_grad_norm)
        sync_grads(self._actor)
        self._actor_optim.step()

        # update the critic
        self._critic_optim.zero_grad()
        value_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self._critic1.parameters(), self._config.max_grad_norm)
        sync_grads(self._critic)
        self._critic_optim.step()

        # include info from policy
        info.update(self._actor.info)

        return mpi_average(info)
Beispiel #15
0
    def _update_network(self, transitions, step=0):
        info = {}

        # pre-process observations
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o, o_next = transitions["ob"], transitions["ob_next"]
        bs = len(transitions["done"])
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions["ac"])

        if "intra_steps" in transitions.keys(
        ) and self._config.use_smdp_update:
            intra_steps = _to_tensor(transitions["intra_steps"])

        done = _to_tensor(transitions["done"]).reshape(bs, 1)
        rew = _to_tensor(transitions["rew"]).reshape(bs, 1)

        actions_real, log_pi = self.act_log(o)
        alpha_loss = -(self._log_alpha.exp() *
                       (log_pi + self._target_entropy).detach()).mean()

        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()
        alpha = self._log_alpha.exp()
        info["alpha_loss"] = alpha_loss.cpu().item()
        info["entropy_alpha"] = alpha.cpu().item()
        alpha = self._log_alpha.exp()

        # the actor loss
        entropy_loss = (alpha * log_pi).mean()
        actor_loss = -torch.min(self._critic1(o, actions_real),
                                self._critic2(o, actions_real)).mean()
        info["log_pi"] = log_pi.mean().cpu().item()
        info["entropy_loss"] = entropy_loss.cpu().item()
        info["actor_loss"] = actor_loss.cpu().item()
        actor_loss += entropy_loss

        # calculate the target Q value function
        with torch.no_grad():
            actions_next, log_pi_next = self.act_log(o_next)
            q_next_value1 = self._critic1_target(o_next, actions_next)
            q_next_value2 = self._critic2_target(o_next, actions_next)
            q_next_value = torch.min(q_next_value1,
                                     q_next_value2) - alpha * log_pi_next
            if self._config.use_smdp_update:
                target_q_value = (self._config.reward_scale * rew +
                                  (1 - done) *
                                  (self._config.discount_factor**
                                   (intra_steps + 1)) * q_next_value)
            else:
                target_q_value = (
                    self._config.reward_scale * rew +
                    (1 - done) * self._config.discount_factor * q_next_value)
            target_q_value = target_q_value.detach()

        # the q loss
        for k, space in self._ac_space.spaces.items():
            if isinstance(space, spaces.Discrete):
                ac[k] = (F.one_hot(ac[k].long(), action_size(
                    self._ac_space[k])).float().squeeze(1))
        real_q_value1 = self._critic1(o, ac)
        real_q_value2 = self._critic2(o, ac)
        critic1_loss = 0.5 * (target_q_value - real_q_value1).pow(2).mean()
        critic2_loss = 0.5 * (target_q_value - real_q_value2).pow(2).mean()

        info["min_target_q"] = target_q_value.min().cpu().item()
        info["target_q"] = target_q_value.mean().cpu().item()
        info["min_real1_q"] = real_q_value1.min().cpu().item()
        info["min_real2_q"] = real_q_value2.min().cpu().item()
        info["real1_q"] = real_q_value1.mean().cpu().item()
        info["real2_q"] = real_q_value2.mean().cpu().item()
        info["critic1_loss"] = critic1_loss.cpu().item()
        info["critic2_loss"] = critic2_loss.cpu().item()

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        if self._config.is_mpi:
            sync_grads(self._actor)
        self._actor_optim.step()

        # update the critic
        self._critic1_optim.zero_grad()
        critic1_loss.backward()
        if self._config.is_mpi:
            sync_grads(self._critic1)
        self._critic1_optim.step()

        self._critic2_optim.zero_grad()
        critic2_loss.backward()
        if self._config.is_mpi:
            sync_grads(self, _critic2)
        self._critic2_optim.step()

        if self._config.is_mpi:
            return mpi_average(info)
        else:
            return info