Example #1
0
    def _backward(self) -> None:
        """Trainer backward pass updating network parameters"""

        # Calculate the gradients and update the networks
        for agent in self._agents:
            agent_key = self.agent_net_keys[agent]
            # Get trainable variables.
            trainable_variables = self._q_networks[agent_key].trainable_variables

            # Compute gradients.
            gradients = self.tape.gradient(self.loss, trainable_variables)

            # Clip gradients.
            gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0]

            # Apply gradients.
            self._optimizers[agent_key].apply(gradients, trainable_variables)

        # Delete the tape manually because of the persistent=True flag.
        train_utils.safe_del(self, "tape")
Example #2
0
    def _backward(self) -> None:
        log_current_timestep = self._log_step()
        for agent in self._agents:
            agent_key = self.agent_net_keys[agent]

            # Update agent networks
            variables = [*self._q_networks[agent_key].trainable_variables]
            gradients = self.tape.gradient(self.loss, variables)
            gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0]
            self._optimizers[agent_key].apply(gradients, variables)

            if log_current_timestep:
                if self.log_weights:
                    self._log_weights(label="Policy", agent=agent, weights=variables)
                if self.log_gradients:
                    self._log_gradients(
                        label="Policy",
                        agent=agent,
                        variables_names=[vars.name for vars in variables],
                        gradients=gradients,
                    )

        # Update mixing network
        variables = self.get_mixing_trainable_vars()
        gradients = self.tape.gradient(self.loss, variables)

        gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0]
        self._optimizer.apply(gradients, variables)

        if log_current_timestep:
            if self.log_weights:
                self._log_weights(label="Mixing", agent=agent, weights=variables)
            if self.log_gradients:
                self._log_gradients(
                    label="Mixing",
                    agent=agent,
                    variables_names=[vars.name for vars in variables],
                    gradients=gradients,
                )

        train_utils.safe_del(self, "tape")
Example #3
0
    def _backward(self) -> None:
        """Trainer backward pass updating network parameters"""

        for agent in self._agents:
            agent_key = self.agent_net_keys[agent]
            # Update agent networks
            variables = self._q_networks[agent_key].trainable_variables
            gradients = self.tape.gradient(self.loss, variables)
            gradients = tf.clip_by_global_norm(gradients,
                                               self._max_gradient_norm)[0]
            self._optimizers[agent_key].apply(gradients, variables)

        # Update mixing network
        variables = self.get_mixing_trainable_vars("mixing")
        gradients = self.tape.gradient(self.loss, variables)

        gradients = tf.clip_by_global_norm(gradients,
                                           self._max_gradient_norm)[0]
        self._optimizer.apply(gradients, variables)

        train_utils.safe_del(self, "tape")
Example #4
0
    def _backward(self) -> None:
        """Trainer backward pass updating network parameters"""

        # Calculate the gradients and update the networks
        policy_losses = self.policy_losses
        critic_losses = self.critic_losses
        tape = self.tape

        for agent in self._agents:
            # Get agent_key.
            agent_key = agent.split("_")[0] if self._shared_weights else agent

            # Get trainable variables.
            policy_variables = (
                self._observation_networks[agent_key].trainable_variables +
                self._policy_networks[agent_key].trainable_variables)
            critic_variables = self._critic_networks[
                agent_key].trainable_variables

            # Get gradients.
            policy_gradients = tape.gradient(policy_losses[agent],
                                             policy_variables)
            critic_gradients = tape.gradient(critic_losses[agent],
                                             critic_variables)

            # Optionally apply clipping.
            critic_grads, critic_norm = tf.clip_by_global_norm(
                critic_gradients, self._max_gradient_norm)
            policy_grads, policy_norm = tf.clip_by_global_norm(
                policy_gradients, self._max_gradient_norm)

            # Apply gradients.
            self._critic_optimizers[agent_key].apply(critic_grads,
                                                     critic_variables)
            self._policy_optimizers[agent_key].apply(policy_grads,
                                                     policy_variables)

        train_utils.safe_del(self, "tape")
Example #5
0
    def _backward(self) -> None:
        policy_losses = self.policy_losses
        tape = self.tape
        log_current_timestep = self._log_step()

        # Calculate the gradients and update the networks
        for agent in self._agents:
            agent_key = self.agent_net_keys[agent]

            # Get trainable variables.
            policy_variables = (
                self._observation_networks[agent_key].trainable_variables
                + self._policy_networks[agent_key].trainable_variables
            )

            # Compute gradients.
            policy_gradients = tape.gradient(policy_losses[agent], policy_variables)

            # Maybe clip gradients.
            if self._clipping:
                policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.0)[0]

            # Apply gradients.
            self._policy_optimizers[agent_key].apply(policy_gradients, policy_variables)

            if log_current_timestep:
                if self.log_weights:
                    self._log_weights(
                        label="Policy", agent=agent, weights=policy_variables
                    )
                if self.log_gradients:
                    self._log_gradients(
                        label="Policy",
                        agent=agent,
                        variables_names=[vars.name for vars in policy_variables],
                        gradients=policy_gradients,
                    )
        train_utils.safe_del(self, "tape")
Example #6
0
    def _backward(self) -> None:
        # Calculate the gradients and update the networks
        policy_losses = self.policy_losses
        critic_losses = self.critic_losses
        tape = self.tape
        log_current_timestep = self._log_step()

        for agent in self._agents:
            agent_key = self.agent_net_keys[agent]

            # Get trainable variables.
            policy_variables = (
                self._observation_networks[agent_key].trainable_variables
                + self._policy_networks[agent_key].trainable_variables
            )
            critic_variables = (
                # In this agent, the critic loss trains the observation network.
                self._observation_networks[agent_key].trainable_variables
                + self._critic_networks[agent_key].trainable_variables
            )

            # Compute gradients.
            policy_gradients = tape.gradient(policy_losses[agent], policy_variables)
            critic_gradients = tape.gradient(critic_losses[agent], critic_variables)

            # Maybe clip gradients.
            policy_gradients = tf.clip_by_global_norm(
                policy_gradients, self._max_gradient_norm
            )[0]
            critic_gradients = tf.clip_by_global_norm(
                critic_gradients, self._max_gradient_norm
            )[0]

            # Apply gradients.
            self._policy_optimizers[agent_key].apply(policy_gradients, policy_variables)
            self._critic_optimizers[agent_key].apply(critic_gradients, critic_variables)

            if log_current_timestep:
                if self.log_weights:
                    self._log_weights(
                        label="Policy", agent=agent, weights=policy_variables
                    )
                    self._log_weights(
                        label="Critic", agent=agent, weights=critic_variables
                    )

                if self.log_gradients:
                    self._log_gradients(
                        label="Policy",
                        agent=agent,
                        variables_names=[vars.name for vars in policy_variables],
                        gradients=policy_gradients,
                    )
                    self._log_gradients(
                        label="Critic",
                        agent=agent,
                        variables_names=[vars.name for vars in critic_variables],
                        gradients=critic_gradients,
                    )

        train_utils.safe_del(self, "tape")