Exemple #1
0
    def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma,
                         sess):
        """Emulates SAC loss functions for tf and torch."""
        # ks:
        # 0=log_alpha
        # 1=target log-alpha (not used)

        # 2=action hidden bias
        # 3=action hidden kernel
        # 4=action out bias
        # 5=action out kernel

        # 6=Q hidden bias
        # 7=Q hidden kernel
        # 8=Q out bias
        # 9=Q out kernel

        # 14=target Q hidden bias
        # 15=target Q hidden kernel
        # 16=target Q out bias
        # 17=target Q out kernel
        alpha = np.exp(log_alpha)
        cls = TorchSquashedGaussian if fw == "torch" else SquashedGaussian
        model_out_t = train_batch[SampleBatch.CUR_OBS]
        model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
        target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]

        # get_policy_output
        action_dist_t = cls(
            fc(
                relu(
                    fc(model_out_t,
                       weights[ks[3]],
                       weights[ks[2]],
                       framework=fw)), weights[ks[5]], weights[ks[4]]), None)
        policy_t = action_dist_t.deterministic_sample()
        log_pis_t = action_dist_t.logp(policy_t)
        if sess:
            log_pis_t = sess.run(log_pis_t)
            policy_t = sess.run(policy_t)
        log_pis_t = np.expand_dims(log_pis_t, -1)

        # Get policy output for t+1.
        action_dist_tp1 = cls(
            fc(
                relu(
                    fc(model_out_tp1,
                       weights[ks[3]],
                       weights[ks[2]],
                       framework=fw)), weights[ks[5]], weights[ks[4]]), None)
        policy_tp1 = action_dist_tp1.deterministic_sample()
        log_pis_tp1 = action_dist_tp1.logp(policy_tp1)
        if sess:
            log_pis_tp1 = sess.run(log_pis_tp1)
            policy_tp1 = sess.run(policy_tp1)
        log_pis_tp1 = np.expand_dims(log_pis_tp1, -1)

        # Q-values for the actually selected actions.
        # get_q_values
        q_t = fc(relu(
            fc(np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]],
                              -1),
               weights[ks[7]],
               weights[ks[6]],
               framework=fw)),
                 weights[ks[9]],
                 weights[ks[8]],
                 framework=fw)

        # Q-values for current policy in given current state.
        # get_q_values
        q_t_det_policy = fc(relu(
            fc(np.concatenate([model_out_t, policy_t], -1),
               weights[ks[7]],
               weights[ks[6]],
               framework=fw)),
                            weights[ks[9]],
                            weights[ks[8]],
                            framework=fw)

        # Target q network evaluation.
        # target_model.get_q_values
        q_tp1 = fc(relu(
            fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
               weights[ks[15]],
               weights[ks[14]],
               framework=fw)),
                   weights[ks[17]],
                   weights[ks[16]],
                   framework=fw)

        q_t_selected = np.squeeze(q_t, axis=-1)
        q_tp1 -= alpha * log_pis_tp1
        q_tp1_best = np.squeeze(q_tp1, axis=-1)
        dones = train_batch[SampleBatch.DONES]
        rewards = train_batch[SampleBatch.REWARDS]
        if fw == "torch":
            dones = dones.float().numpy()
            rewards = rewards.numpy()
        q_tp1_best_masked = (1.0 - dones) * q_tp1_best
        q_t_selected_target = rewards + gamma * q_tp1_best_masked
        base_td_error = np.abs(q_t_selected - q_t_selected_target)
        td_error = base_td_error
        critic_loss = [
            0.5 * np.mean(np.power(q_t_selected_target - q_t_selected, 2.0))
        ]
        target_entropy = -np.prod((1, ))
        alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))
        actor_loss = np.mean(alpha * log_pis_t - q_t_det_policy)

        return critic_loss, actor_loss, alpha_loss, td_error
Exemple #2
0
    def _ddpg_loss_helper(self, train_batch, weights, ks, fw, gamma,
                          huber_threshold, l2_reg, sess):
        """Emulates DDPG loss functions for tf and torch."""
        model_out_t = train_batch[SampleBatch.CUR_OBS]
        target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
        # get_policy_output
        policy_t = sigmoid(2.0 * fc(
            relu(fc(model_out_t, weights[ks[1]], weights[ks[0]],
                    framework=fw)), weights[ks[5]], weights[ks[4]]))
        # Get policy output for t+1 (target model).
        policy_tp1 = sigmoid(2.0 * fc(
            relu(
                fc(target_model_out_tp1,
                   weights[ks[3]],
                   weights[ks[2]],
                   framework=fw)), weights[ks[7]], weights[ks[6]]))
        # Assume no smooth target policy.
        policy_tp1_smoothed = policy_tp1

        # Q-values for the actually selected actions.
        # get_q_values
        q_t = fc(relu(
            fc(np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]],
                              -1),
               weights[ks[9]],
               weights[ks[8]],
               framework=fw)),
                 weights[ks[11]],
                 weights[ks[10]],
                 framework=fw)
        twin_q_t = fc(relu(
            fc(np.concatenate([model_out_t, train_batch[SampleBatch.ACTIONS]],
                              -1),
               weights[ks[13]],
               weights[ks[12]],
               framework=fw)),
                      weights[ks[15]],
                      weights[ks[14]],
                      framework=fw)

        # Q-values for current policy in given current state.
        # get_q_values
        q_t_det_policy = fc(relu(
            fc(np.concatenate([model_out_t, policy_t], -1),
               weights[ks[9]],
               weights[ks[8]],
               framework=fw)),
                            weights[ks[11]],
                            weights[ks[10]],
                            framework=fw)

        # Target q network evaluation.
        # target_model.get_q_values
        q_tp1 = fc(relu(
            fc(np.concatenate([target_model_out_tp1, policy_tp1_smoothed], -1),
               weights[ks[17]],
               weights[ks[16]],
               framework=fw)),
                   weights[ks[19]],
                   weights[ks[18]],
                   framework=fw)
        twin_q_tp1 = fc(relu(
            fc(np.concatenate([target_model_out_tp1, policy_tp1_smoothed], -1),
               weights[ks[21]],
               weights[ks[20]],
               framework=fw)),
                        weights[ks[23]],
                        weights[ks[22]],
                        framework=fw)

        q_t_selected = np.squeeze(q_t, axis=-1)
        twin_q_t_selected = np.squeeze(twin_q_t, axis=-1)
        q_tp1 = np.minimum(q_tp1, twin_q_tp1)
        q_tp1_best = np.squeeze(q_tp1, axis=-1)

        dones = train_batch[SampleBatch.DONES]
        rewards = train_batch[SampleBatch.REWARDS]
        if fw == "torch":
            dones = dones.float().numpy()
            rewards = rewards.numpy()

        q_tp1_best_masked = (1.0 - dones) * q_tp1_best
        q_t_selected_target = rewards + gamma * q_tp1_best_masked

        td_error = q_t_selected - q_t_selected_target
        twin_td_error = twin_q_t_selected - q_t_selected_target
        td_error = td_error + twin_td_error
        errors = huber_loss(td_error, huber_threshold) + \
            huber_loss(twin_td_error, huber_threshold)

        critic_loss = np.mean(errors)
        actor_loss = -np.mean(q_t_det_policy)
        # Add l2-regularization if required.
        for name, var in weights.items():
            if re.match("default_policy/actor_(hidden_0|out)/kernel", name):
                actor_loss += (l2_reg * l2_loss(var))
            elif re.match("default_policy/sequential(_1)?/\\w+/kernel", name):
                critic_loss += (l2_reg * l2_loss(var))

        return critic_loss, actor_loss, td_error