from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \ choose_optimizer from ray.rllib.agents.ppo.appo_tf_policy import build_appo_model, \ postprocess_trajectory from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin, \ KLCoeffMixin from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ sequence_mask torch, nn = try_import_torch() logger = logging.getLogger(__name__) class PPOSurrogateLoss: """Loss used when V-trace is disabled. Arguments: prev_actions_logp: A float32 tensor of shape [T, B]. actions_logp: A float32 tensor of shape [T, B]. action_kl: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. valid_mask: A bool tensor of valid RNN input elements (#2992). advantages: A float32 tensor of shape [T, B].
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ convert_to_torch_tensor from ray.rllib.utils.tracking_dict import UsageTrackingDict from ray.rllib.utils.typing import ModelGradients, ModelWeights, \ TensorType, TrainerConfigDict torch, _ = try_import_torch() logger = logging.getLogger(__name__) @DeveloperAPI class TorchPolicy(Policy): """Template for a PyTorch policy and loss to use with RLlib. Attributes: observation_space (gym.Space): observation space of the policy. action_space (gym.Space): action space of the policy. config (dict): config of the policy. model (TorchModel): Torch model instance. dist_class (type): Torch action distribution class. """
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.env.multi_agent_env import ENV_STATE from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import _unpack_obs from ray.rllib.env.constants import GROUP_REWARDS from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.annotations import override # Torch must be installed. torch, nn = try_import_torch(error=True) logger = logging.getLogger(__name__) class QMixLoss(nn.Module): def __init__(self, model, target_model, mixer, target_mixer, n_agents, n_actions, double_q=True, gamma=0.99): nn.Module.__init__(self)