Exemple #1
0
    def _create_losses(
        self,
        q1_streams: Dict[str, tf.Tensor],
        q2_streams: Dict[str, tf.Tensor],
        lr: tf.Tensor,
        max_step: int,
        stream_names: List[str],
        discrete: bool = False,
    ) -> None:
        """
        Creates training-specific Tensorflow ops for SAC models.
        :param q1_streams: Q1 streams from policy network
        :param q1_streams: Q2 streams from policy network
        :param lr: Learning rate
        :param max_step: Total number of training steps.
        :param stream_names: List of reward stream names.
        :param discrete: Whether or not to use discrete action losses.
        """

        if discrete:
            self.target_entropy = [
                self.discrete_target_entropy_scale *
                np.log(i).astype(np.float32) for i in self.act_size
            ]
            discrete_action_probs = tf.exp(self.policy.all_log_probs)
            per_action_entropy = discrete_action_probs * self.policy.all_log_probs
        else:
            self.target_entropy = (
                -1 * self.continuous_target_entropy_scale *
                np.prod(self.act_size[0]).astype(np.float32))

        self.rewards_holders = {}
        self.min_policy_qs = {}

        for name in stream_names:
            if discrete:
                _branched_mpq1 = ModelUtils.break_into_branches(
                    self.policy_network.q1_pheads[name] *
                    discrete_action_probs,
                    self.act_size,
                )
                branched_mpq1 = tf.stack([
                    tf.reduce_sum(_br, axis=1, keep_dims=True)
                    for _br in _branched_mpq1
                ])
                _q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0)

                _branched_mpq2 = ModelUtils.break_into_branches(
                    self.policy_network.q2_pheads[name] *
                    discrete_action_probs,
                    self.act_size,
                )
                branched_mpq2 = tf.stack([
                    tf.reduce_sum(_br, axis=1, keep_dims=True)
                    for _br in _branched_mpq2
                ])
                _q2_p_mean = tf.reduce_mean(branched_mpq2, axis=0)

                self.min_policy_qs[name] = tf.minimum(_q1_p_mean, _q2_p_mean)
            else:
                self.min_policy_qs[name] = tf.minimum(
                    self.policy_network.q1_pheads[name],
                    self.policy_network.q2_pheads[name],
                )

            rewards_holder = tf.placeholder(shape=[None],
                                            dtype=tf.float32,
                                            name=f"{name}_rewards")
            self.rewards_holders[name] = rewards_holder

        q1_losses = []
        q2_losses = []
        # Multiple q losses per stream
        expanded_dones = tf.expand_dims(self.dones_holder, axis=-1)
        for i, name in enumerate(stream_names):
            _expanded_rewards = tf.expand_dims(self.rewards_holders[name],
                                               axis=-1)

            q_backup = tf.stop_gradient(
                _expanded_rewards +
                (1.0 - self.use_dones_in_backup[name] * expanded_dones) *
                self.gammas[i] * self.target_network.value_heads[name])

            if discrete:
                # We need to break up the Q functions by branch, and update them individually.
                branched_q1_stream = ModelUtils.break_into_branches(
                    self.policy.selected_actions * q1_streams[name],
                    self.act_size)
                branched_q2_stream = ModelUtils.break_into_branches(
                    self.policy.selected_actions * q2_streams[name],
                    self.act_size)

                # Reduce each branch into scalar
                branched_q1_stream = [
                    tf.reduce_sum(_branch, axis=1, keep_dims=True)
                    for _branch in branched_q1_stream
                ]
                branched_q2_stream = [
                    tf.reduce_sum(_branch, axis=1, keep_dims=True)
                    for _branch in branched_q2_stream
                ]

                q1_stream = tf.reduce_mean(branched_q1_stream, axis=0)
                q2_stream = tf.reduce_mean(branched_q2_stream, axis=0)

            else:
                q1_stream = q1_streams[name]
                q2_stream = q2_streams[name]

            _q1_loss = 0.5 * tf.reduce_mean(
                tf.to_float(self.policy.mask) *
                tf.squared_difference(q_backup, q1_stream))

            _q2_loss = 0.5 * tf.reduce_mean(
                tf.to_float(self.policy.mask) *
                tf.squared_difference(q_backup, q2_stream))

            q1_losses.append(_q1_loss)
            q2_losses.append(_q2_loss)

        self.q1_loss = tf.reduce_mean(q1_losses)
        self.q2_loss = tf.reduce_mean(q2_losses)

        # Learn entropy coefficient
        if discrete:
            # Create a log_ent_coef for each branch
            self.log_ent_coef = tf.get_variable(
                "log_ent_coef",
                dtype=tf.float32,
                initializer=np.log([self.init_entcoef] *
                                   len(self.act_size)).astype(np.float32),
                trainable=True,
            )
        else:
            self.log_ent_coef = tf.get_variable(
                "log_ent_coef",
                dtype=tf.float32,
                initializer=np.log(self.init_entcoef).astype(np.float32),
                trainable=True,
            )

        self.ent_coef = tf.exp(self.log_ent_coef)
        if discrete:
            # We also have to do a different entropy and target_entropy per branch.
            branched_per_action_ent = ModelUtils.break_into_branches(
                per_action_entropy, self.act_size)
            branched_ent_sums = tf.stack(
                [
                    tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te for _lp,
                    _te in zip(branched_per_action_ent, self.target_entropy)
                ],
                axis=1,
            )
            self.entropy_loss = -tf.reduce_mean(
                tf.to_float(self.policy.mask) * tf.reduce_mean(
                    self.log_ent_coef *
                    tf.squeeze(tf.stop_gradient(branched_ent_sums), axis=2),
                    axis=1,
                ))

            # Same with policy loss, we have to do the loss per branch and average them,
            # so that larger branches don't get more weight.
            # The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q
            branched_q_term = ModelUtils.break_into_branches(
                discrete_action_probs * self.policy_network.q1_p,
                self.act_size)

            branched_policy_loss = tf.stack([
                tf.reduce_sum(self.ent_coef[i] * _lp - _qt,
                              axis=1,
                              keep_dims=True)
                for i, (_lp, _qt) in enumerate(
                    zip(branched_per_action_ent, branched_q_term))
            ])
            self.policy_loss = tf.reduce_mean(
                tf.to_float(self.policy.mask) *
                tf.squeeze(branched_policy_loss))

            # Do vbackup entropy bonus per branch as well.
            branched_ent_bonus = tf.stack([
                tf.reduce_sum(self.ent_coef[i] * _lp, axis=1, keep_dims=True)
                for i, _lp in enumerate(branched_per_action_ent)
            ])
            value_losses = []
            for name in stream_names:
                v_backup = tf.stop_gradient(
                    self.min_policy_qs[name] -
                    tf.reduce_mean(branched_ent_bonus, axis=0))
                value_losses.append(0.5 * tf.reduce_mean(
                    tf.to_float(self.policy.mask) * tf.squared_difference(
                        self.policy_network.value_heads[name], v_backup)))

        else:
            self.entropy_loss = -tf.reduce_mean(
                self.log_ent_coef * tf.to_float(self.policy.mask) *
                tf.stop_gradient(
                    tf.reduce_sum(
                        self.policy.all_log_probs + self.target_entropy,
                        axis=1,
                        keep_dims=True,
                    )))
            batch_policy_loss = tf.reduce_mean(
                self.ent_coef * self.policy.all_log_probs -
                self.policy_network.q1_p,
                axis=1,
            )
            self.policy_loss = tf.reduce_mean(
                tf.to_float(self.policy.mask) * batch_policy_loss)

            value_losses = []
            for name in stream_names:
                v_backup = tf.stop_gradient(
                    self.min_policy_qs[name] - tf.reduce_sum(
                        self.ent_coef * self.policy.all_log_probs, axis=1))
                value_losses.append(0.5 * tf.reduce_mean(
                    tf.to_float(self.policy.mask) * tf.squared_difference(
                        self.policy_network.value_heads[name], v_backup)))
        self.value_loss = tf.reduce_mean(value_losses)

        self.total_value_loss = self.q1_loss + self.q2_loss + self.value_loss

        self.entropy = self.policy_network.entropy
Exemple #2
0
    def _create_dc_critic(self, h_size: int, num_layers: int,
                          vis_encode_type: EncoderType) -> None:
        """
        Creates Discrete control critic (value) network.
        :param h_size: Size of hidden linear layers.
        :param num_layers: Number of hidden linear layers.
        :param vis_encode_type: The type of visual encoder to use.
        """
        hidden_stream = ModelUtils.create_observation_streams(
            self.policy.visual_in,
            self.policy.processed_vector_in,
            1,
            h_size,
            num_layers,
            vis_encode_type,
        )[0]

        if self.policy.use_recurrent:
            hidden_value, memory_value_out = ModelUtils.create_recurrent_encoder(
                hidden_stream,
                self.memory_in,
                self.policy.sequence_length_ph,
                name="lstm_value",
            )
            self.memory_out = memory_value_out
        else:
            hidden_value = hidden_stream

        self.value_heads, self.value = ModelUtils.create_value_heads(
            self.stream_names, hidden_value)

        self.all_old_log_probs = tf.placeholder(
            shape=[None, sum(self.policy.act_size)],
            dtype=tf.float32,
            name="old_probabilities",
        )

        # Break old log log_probs into separate branches
        old_log_prob_branches = ModelUtils.break_into_branches(
            self.all_old_log_probs, self.policy.act_size)

        _, _, old_normalized_logits = ModelUtils.create_discrete_action_masking_layer(
            old_log_prob_branches, self.policy.action_masks,
            self.policy.act_size)

        action_idx = [0] + list(np.cumsum(self.policy.act_size))

        self.old_log_probs = tf.reduce_sum(
            (tf.stack(
                [
                    -tf.nn.softmax_cross_entropy_with_logits_v2(
                        labels=self.policy.
                        selected_actions[:, action_idx[i]:action_idx[i + 1]],
                        logits=old_normalized_logits[:, action_idx[i]:
                                                     action_idx[i + 1]],
                    ) for i in range(len(self.policy.act_size))
                ],
                axis=1,
            )),
            axis=1,
            keepdims=True,
        )