コード例 #1
0
ファイル: appo_torch_policy.py プロジェクト: zhuohan123/ray
from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
    choose_optimizer
from ray.rllib.agents.ppo.appo_tf_policy import build_appo_model, \
    postprocess_trajectory
from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin, \
    KLCoeffMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
    sequence_mask

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


class PPOSurrogateLoss:
    """Loss used when V-trace is disabled.

    Arguments:
        prev_actions_logp: A float32 tensor of shape [T, B].
        actions_logp: A float32 tensor of shape [T, B].
        action_kl: A float32 tensor of shape [T, B].
        actions_entropy: A float32 tensor of shape [T, B].
        values: A float32 tensor of shape [T, B].
        valid_mask: A bool tensor of valid RNN input elements (#2992).
        advantages: A float32 tensor of shape [T, B].
コード例 #2
0
ファイル: torch_policy.py プロジェクト: daanklijn/ray
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
    convert_to_torch_tensor
from ray.rllib.utils.tracking_dict import UsageTrackingDict
from ray.rllib.utils.typing import ModelGradients, ModelWeights, \
    TensorType, TrainerConfigDict

torch, _ = try_import_torch()

logger = logging.getLogger(__name__)


@DeveloperAPI
class TorchPolicy(Policy):
    """Template for a PyTorch policy and loss to use with RLlib.

    Attributes:
        observation_space (gym.Space): observation space of the policy.
        action_space (gym.Space): action space of the policy.
        config (dict): config of the policy.
        model (TorchModel): Torch model instance.
        dist_class (type): Torch action distribution class.
    """
コード例 #3
0
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel, _get_size
from ray.rllib.env.multi_agent_env import ENV_STATE
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import _unpack_obs
from ray.rllib.env.constants import GROUP_REWARDS
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.annotations import override

# Torch must be installed.
torch, nn = try_import_torch(error=True)

logger = logging.getLogger(__name__)


class QMixLoss(nn.Module):
    def __init__(self,
                 model,
                 target_model,
                 mixer,
                 target_mixer,
                 n_agents,
                 n_actions,
                 double_q=True,
                 gamma=0.99):
        nn.Module.__init__(self)