Пример #1
0
def test_cartpole_policy_model():
    flat_action_space = flat_structured_space(_mock_action_spaces_dict())
    distribution_mapper = DistributionMapper(action_space=flat_action_space,
                                             distribution_mapper_config={})

    action_logits_shapes = {step_key: {action_head: distribution_mapper.required_logits_shape(action_head)
                                       for action_head in _mock_action_spaces_dict()[step_key].spaces.keys()}
                            for step_key in _mock_action_spaces_dict().keys()}

    obs_shapes = observation_spaces_to_in_shapes(_mock_observation_spaces_dict())

    policy = CustomComplexPolicyNet(obs_shapes[0], action_logits_shapes[0], non_lin='torch.nn.ReLU',
                                    hidden_units=[128])

    critic = CustomComplexCriticNet(obs_shapes[0], non_lin='torch.nn.ReLU',
                                    hidden_units=[128])

    obs_np = _mock_observation_spaces_dict()[0].sample()
    obs = {k: torch.from_numpy(v) for k, v in obs_np.items()}

    actions = policy(obs)
    values = critic(obs)

    assert 'action_move' in actions
    assert 'action_use' in actions
    assert 'value' in values
Пример #2
0
def test_distribution_mapper():
    """ distribution test """

    # action space
    act_space = spaces.Dict(
        spaces={
            "selection":
            spaces.Discrete(10),
            "order":
            spaces.MultiBinary(15),
            "scale_input":
            spaces.Box(shape=(5, ), low=0, high=100, dtype=np.float64),
            "order_by_weight":
            spaces.Box(shape=(5, ), low=0, high=100, dtype=np.float64)
        })

    # default config
    config = [{
        "action_space":
        spaces.Box,
        "distribution":
        "maze.distributions.squashed_gaussian.SquashedGaussianProbabilityDistribution"
    }, {
        "action_head":
        "order_by_weight",
        "distribution":
        "maze.distributions.beta.BetaProbabilityDistribution"
    }]

    # initialize distribution mapper
    distribution_mapper = DistributionMapper(action_space=act_space,
                                             distribution_mapper_config=config)
    repr(distribution_mapper)

    # assign action heads to registered distributions
    logits_dict = dict()
    for action_head in act_space.spaces.keys():
        logits_shape = distribution_mapper.required_logits_shape(action_head)

        logits_tensor = torch.from_numpy(np.random.randn(*logits_shape))
        torch_dist = distribution_mapper.action_head_distribution(
            action_head=action_head, logits=logits_tensor, temperature=1.0)
        logits_dict[action_head] = logits_tensor

        # check if distributions are correctly assigned
        if action_head == "selection":
            assert isinstance(torch_dist, CategoricalProbabilityDistribution)
        elif action_head == "order":
            assert isinstance(torch_dist, BernoulliProbabilityDistribution)
        elif action_head == "scale_input":
            assert isinstance(torch_dist,
                              SquashedGaussianProbabilityDistribution)
        elif action_head == "order_by_weight":
            assert isinstance(torch_dist, BetaProbabilityDistribution)

    # test dictionary distribution mapping
    dict_dist = distribution_mapper.logits_dict_to_distribution(
        logits_dict=logits_dict, temperature=1.0)
    assert isinstance(dict_dist, DictProbabilityDistribution)
Пример #3
0
def main(n_epochs) -> None:
    """Trains the cart pole environment with the ES implementation.
    """

    env = GymMazeEnv(env="CartPole-v0")
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})

    obs_shapes = observation_spaces_to_in_shapes(env.observation_spaces_dict)
    action_shapes = {
        step_key: {
            action_head: distribution_mapper.required_logits_shape(action_head)
            for action_head in env.action_spaces_dict[step_key].spaces.keys()
        }
        for step_key in env.action_spaces_dict.keys()
    }

    # initialize policies
    policies = [
        PolicyNet(obs_shapes=obs_shapes[0],
                  action_logits_shapes=action_shapes[0],
                  non_lin=nn.SELU)
    ]

    # initialize optimizer
    policy = TorchPolicy(networks=list_to_dict(policies),
                         distribution_mapper=distribution_mapper,
                         device="cpu")

    shared_noise = SharedNoiseTable(count=1_000_000)

    algorithm_config = ESAlgorithmConfig(n_rollouts_per_update=100,
                                         n_timesteps_per_update=0,
                                         max_steps=0,
                                         optimizer=Adam(step_size=0.01),
                                         l2_penalty=0.005,
                                         noise_stddev=0.02,
                                         n_epochs=n_epochs,
                                         policy_wrapper=None)

    trainer = ESTrainer(algorithm_config=algorithm_config,
                        torch_policy=policy,
                        shared_noise=shared_noise,
                        normalization_stats=None)

    setup_logging(job_config=None)

    maze_rng = np.random.RandomState(None)

    # run with pseudo-distribution, without worker processes
    trainer.train(ESDummyDistributedRollouts(
        env=env,
        n_eval_rollouts=10,
        shared_noise=shared_noise,
        agent_instance_seed=MazeSeeding.generate_seed_from_random_state(
            maze_rng)),
                  model_selection=None)
Пример #4
0
def log_probs_from_logits_and_actions_and_spaces(
        policy_logits: List[TorchActionType],
        actions: List[TorchActionType],
        distribution_mapper: DistributionMapper) \
        -> Tuple[List[TorchActionType], List[DictProbabilityDistribution]]:
    """Computes action log-probs from policy logits, actions and acton_spaces.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size and
    NUM_ACTIONS refers to the number of actions.

    :param policy_logits: A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors
        of un-normalized log-probabilities (shape list[dict[str,[T, B, NUM_ACTIONS]]])
    :param actions: An list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors
        (list[dict[str,[T, B]]])
    :param distribution_mapper: A distribution mapper providing a mapping of action heads to distributions.

    :return: A list (w.r.t. the substeps of the env) of dicts (w.r.t. the actions) of tensors of shape [T, B]
        corresponding to the sampling log probability of the chosen action w.r.t. the policy.
        And a list (w.r.t. the substeps of the env) of DictProbability distributions corresponding to the step-action-
        distributions.
    """
    log_probs = list()
    step_action_dists = list()
    for step_policy_logits, step_actions in zip(policy_logits, actions):
        step_action_dist = distribution_mapper.logits_dict_to_distribution(
            logits_dict=step_policy_logits, temperature=1.0)
        log_probs.append(step_action_dist.log_prob(step_actions))
        step_action_dists.append(step_action_dist)
    return log_probs, step_action_dists
Пример #5
0
def _policy(env: GymMazeEnv):
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})
    policies = {
        0:
        FlattenConcatPolicyNet({'observation': (4, )}, {'action': (2, )},
                               hidden_units=[16],
                               non_lin=nn.Tanh)
    }
    critics = {
        0:
        FlattenConcatStateValueNet({'observation': (4, )},
                                   hidden_units=[16],
                                   non_lin=nn.Tanh)
    }

    policy = TorchPolicy(networks=policies,
                         distribution_mapper=distribution_mapper,
                         device="cpu")

    critic = TorchSharedStateCritic(
        networks=critics,
        obs_spaces_dict=env.observation_spaces_dict,
        device="cpu",
        stack_observations=False)

    return TorchActorCritic(policy=policy, critic=critic, device="cpu")
Пример #6
0
def train_function(n_epochs: int, distributed_env_cls) -> A2C:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """

    # initialize distributed env
    envs = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)])

    # initialize the env and enable statistics collection
    eval_env = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)],
                                   logging_prefix='eval')

    # init distribution mapper
    env = GymMazeEnv(env="CartPole-v0")
    distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={})

    # initialize policies
    policies = {0: FlattenConcatPolicyNet({'observation': (4,)}, {'action': (2,)}, hidden_units=[16], non_lin=nn.Tanh)}

    # initialize critic
    critics = {0: FlattenConcatStateValueNet({'observation': (4,)}, hidden_units=[16], non_lin=nn.Tanh)}

    # algorithm configuration
    algorithm_config = A2CAlgorithmConfig(
        n_epochs=n_epochs,
        epoch_length=2,
        patience=10,
        critic_burn_in_epochs=0,
        n_rollout_steps=20,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.0,
        max_grad_norm=0.0,
        device="cpu",
        rollout_evaluator=RolloutEvaluator(eval_env=eval_env, n_episodes=1, model_selection=None, deterministic=True)
    )

    # initialize actor critic model
    model = TorchActorCritic(
        policy=TorchPolicy(networks=policies, distribution_mapper=distribution_mapper, device=algorithm_config.device),
        critic=TorchSharedStateCritic(networks=critics, obs_spaces_dict=env.observation_spaces_dict,
                                      device=algorithm_config.device,
                                      stack_observations=False),
        device=algorithm_config.device)

    a2c = A2C(rollout_generator=RolloutGenerator(envs),
              algorithm_config=algorithm_config,
              evaluator=algorithm_config.rollout_evaluator,
              model=model,
              model_selection=None)

    # train agent
    a2c.train()

    return a2c
Пример #7
0
def test_dummy_model_with_dummy_network():
    """
    Unit test for the DummyStructuredEnvironment
    """
    maze_env = build_dummy_maze_env()

    # init the distribution_mapper with the flat action space
    distribution_mapper_config = [{
        "action_space":
        spaces.Box,
        "distribution":
        "maze.distributions.squashed_gaussian.SquashedGaussianProbabilityDistribution"
    }]
    distribution_mapper = DistributionMapper(
        action_space=maze_env.action_space,
        distribution_mapper_config=distribution_mapper_config)

    obs_shapes = observation_spaces_to_in_shapes(
        maze_env.observation_spaces_dict)

    dummy_actor = DummyPolicyNet(
        obs_shapes=obs_shapes[0],
        action_logits_shapes={
            key: distribution_mapper.required_logits_shape(key)
            for key in maze_env.action_space.spaces.keys()
        },
        non_lin=nn.Tanh)

    dummy_critic = DummyValueNet(obs_shapes=obs_shapes[0], non_lin=nn.Tanh)

    obs_np = maze_env.reset()
    obs = {k: torch.from_numpy(v) for k, v in obs_np.items()}

    for i in range(100):
        logits_dict = dummy_actor(obs)
        prob_dist = distribution_mapper.logits_dict_to_distribution(
            logits_dict=logits_dict, temperature=1.0)
        sampled_actions = prob_dist.sample()

        obs_np, _, _, _ = maze_env.step(sampled_actions)
        obs = {k: torch.from_numpy(v) for k, v in obs_np.items()}

        _ = dummy_critic(obs)
    maze_env.close()
Пример #8
0
    def __init__(self, action_spaces_dict: Dict[StepKeyType, gym.spaces.Dict],
                 observation_spaces_dict: Dict[StepKeyType, gym.spaces.Dict],
                 agent_counts_dict: Dict[StepKeyType, int],
                 distribution_mapper_config: ConfigType):
        self.action_spaces_dict = action_spaces_dict
        self.observation_spaces_dict = observation_spaces_dict
        self.agent_counts_dict = agent_counts_dict

        # initialize DistributionMapper
        flat_action_space = flat_structured_space(action_spaces_dict)
        self._distribution_mapper = DistributionMapper(
            action_space=flat_action_space,
            distribution_mapper_config=distribution_mapper_config)
Пример #9
0
    def required_model_output_shape(action_space: gym.spaces.Dict,
                                    model_config: Dict) -> int:
        """Returns the required logits shape (network output shape) for a given action head.

        :param action_space: The action space of the env.
        :param model_config: The rllib model config.
        :return: The number of the flattened output.
        """
        # Retrieve the distribution_mapper_config from the model config
        method_distribution_mapper_config = \
            model_config['custom_model_config']['maze_model_composer_config']['distribution_mapper_config']
        # Build the distribution mapper
        method_distribution_mapper = DistributionMapper(
            action_space,
            distribution_mapper_config=method_distribution_mapper_config)
        # Compute the flattened number of logits
        num_outputs = sum([
            np.prod(
                method_distribution_mapper.required_logits_shape(action_head))
            for action_head in method_distribution_mapper.action_space.spaces
        ])
        return num_outputs
Пример #10
0
def test_cartpole_policy_model():
    env = GymMazeEnv(env='CartPole-v0')
    observation_spaces_dict = env.observation_spaces_dict
    action_spaces_dict = env.action_spaces_dict

    flat_action_space = flat_structured_space(action_spaces_dict)
    distribution_mapper = DistributionMapper(action_space=flat_action_space,
                                             distribution_mapper_config={})

    action_logits_shapes = {
        step_key: {
            action_head: distribution_mapper.required_logits_shape(action_head)
            for action_head in action_spaces_dict[step_key].spaces.keys()
        }
        for step_key in action_spaces_dict.keys()
    }

    obs_shapes = observation_spaces_to_in_shapes(observation_spaces_dict)

    policy = CustomPlainCartpolePolicyNet(obs_shapes[0],
                                          action_logits_shapes[0],
                                          hidden_layer_0=16,
                                          hidden_layer_1=32,
                                          use_bias=True)

    critic = CustomPlainCartpoleCriticNet(obs_shapes[0],
                                          hidden_layer_0=16,
                                          hidden_layer_1=32,
                                          use_bias=True)

    obs_np = env.reset()
    obs = {k: torch.from_numpy(v) for k, v in obs_np.items()}

    actions = policy(obs)
    values = critic(obs)

    assert 'action' in actions
    assert 'value' in values
Пример #11
0
def _log_probs_from_logits_and_actions(batch_size):
    """Tests log_probs_from_logits_and_actions."""
    seq_len = 7
    num_actions = 3

    action_space = gym.spaces.Dict(
        {'action1': gym.spaces.Discrete(num_actions)})

    policy_logits = convert_to_torch(
        _shaped_arange(seq_len, batch_size, num_actions) + 10,
        cast=None,
        device=None,
        in_place='try')
    actions = convert_to_torch(np.random.randint(0,
                                                 num_actions,
                                                 size=(seq_len, batch_size),
                                                 dtype=np.int32),
                               cast=None,
                               device=None,
                               in_place='try')

    distribution_mapper = DistributionMapper(action_space=action_space,
                                             distribution_mapper_config={})

    action_log_probs_tensor, _ = impala_vtrace.log_probs_from_logits_and_actions_and_spaces(
        policy_logits=[{
            'action1': policy_logits
        }],
        actions=[{
            'action1': actions
        }],
        distribution_mapper=distribution_mapper)
    action_log_probs_tensor = action_log_probs_tensor[0]['action1']
    # Ground Truth
    # Using broadcasting to create a mask that indexes action logits
    action_index_mask = np.array(actions[..., None]) == np.arange(num_actions)

    def index_with_mask(array, mask):
        return array[mask].reshape(*array.shape[:-1])

    # Note: Normally log(softmax) is not a good idea because it's not
    # numerically stable. However, in this test we have well-behaved values.
    ground_truth_v = index_with_mask(np.log(_softmax(np.array(policy_logits))),
                                     action_index_mask)

    assert np.allclose(ground_truth_v, action_log_probs_tensor)
Пример #12
0
def train_setup(
        n_epochs: int,
        policy_wrapper=None) -> Tuple[TorchPolicy, StructuredEnv, ESTrainer]:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """

    # initialize distributed env
    env = GymMazeEnv(env="CartPole-v0")

    # initialize distribution mapper
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})

    # initialize policies
    policies = {
        0:
        FlattenConcatPolicyNet({'observation': (4, )}, {'action': (2, )},
                               hidden_units=[16],
                               non_lin=nn.Tanh)
    }

    # initialize optimizer
    policy = TorchPolicy(networks=policies,
                         distribution_mapper=distribution_mapper,
                         device="cpu")

    # reduce the noise table size to speed up testing
    shared_noise = SharedNoiseTable(count=1_000_000)

    algorithm_config = ESAlgorithmConfig(n_rollouts_per_update=100,
                                         n_timesteps_per_update=0,
                                         max_steps=0,
                                         optimizer=Adam(step_size=0.01),
                                         l2_penalty=0.005,
                                         noise_stddev=0.02,
                                         n_epochs=n_epochs,
                                         policy_wrapper=policy_wrapper)

    # train agent
    trainer = ESTrainer(algorithm_config=algorithm_config,
                        shared_noise=shared_noise,
                        torch_policy=policy,
                        normalization_stats=None)

    return policy, env, trainer
Пример #13
0
def _get_cartpole_setup_components(
) -> Tuple[CustomModelComposer, ProbabilisticPolicyComposer,
           SharedStateCriticComposer, TorchPolicy, TorchActorCritic]:
    """
    Returns various instantiated components for environment CartPole-v0.
    :return: Various components cartpole setting.
    """

    env = GymMazeEnv(env=gym.make("CartPole-v0"))
    observation_space = env.observation_space
    action_space = env.action_space

    policy_net = FlattenConcatPolicyNet({'observation': (4, )},
                                        {'action': (2, )},
                                        hidden_units=[16],
                                        non_lin=nn.Tanh)
    maze_wrapped_policy_net = TorchModelBlock(
        in_keys='observation',
        out_keys='action',
        in_shapes=observation_space.spaces['observation'].shape,
        in_num_dims=[2],
        out_num_dims=2,
        net=policy_net)

    policy_networks = {0: maze_wrapped_policy_net}

    # Policy Distribution
    # ^^^^^^^^^^^^^^^^^^^
    distribution_mapper = DistributionMapper(action_space=action_space,
                                             distribution_mapper_config={})

    # Instantiating the Policy
    # ^^^^^^^^^^^^^^^^^^^^^^^^
    torch_policy = TorchPolicy(networks=policy_networks,
                               distribution_mapper=distribution_mapper,
                               device='cpu')

    policy_composer = ProbabilisticPolicyComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        distribution_mapper=distribution_mapper,
        networks=[{
            '_target_':
            'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet',
            'non_lin': 'torch.nn.Tanh',
            'hidden_units': [222, 222]
        }],
        substeps_with_separate_agent_nets=[],
        agent_counts_dict={0: 1})

    # Value Function Setup
    # --------------------

    # Value Network
    # ^^^^^^^^^^^^^
    value_net = FlattenConcatStateValueNet({'observation': (4, )},
                                           hidden_units=[16],
                                           non_lin=nn.Tanh)
    maze_wrapped_value_net = TorchModelBlock(
        in_keys='observation',
        out_keys='value',
        in_shapes=observation_space.spaces['observation'].shape,
        in_num_dims=[2],
        out_num_dims=2,
        net=value_net)

    value_networks = {0: maze_wrapped_value_net}

    # Instantiate the Value Function
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    torch_critic = TorchSharedStateCritic(
        networks=value_networks,
        obs_spaces_dict=env.observation_spaces_dict,
        device='cpu',
        stack_observations=True)

    # Critic composer.
    critic_composer = SharedStateCriticComposer(
        observation_spaces_dict=env.observation_spaces_dict,
        agent_counts_dict={0: 1},
        networks=value_networks,
        stack_observations=True)

    # Initializing the ActorCritic Model.
    # -----------------------------------
    actor_critic_model = TorchActorCritic(policy=torch_policy,
                                          critic=torch_critic,
                                          device='cpu')

    model_composer = CustomModelComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        distribution_mapper_config={},
        policy=policy_composer,
        critic=None,
        agent_counts_dict={0: 1})

    return model_composer, policy_composer, critic_composer, torch_policy, actor_critic_model
Пример #14
0
def main(n_epochs: int) -> None:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """

    # initialize distributed env
    envs = SequentialVectorEnv(
        [lambda: GymMazeEnv(env="CartPole-v0") for _ in range(8)],
        logging_prefix="train")

    # initialize the env and enable statistics collection
    eval_env = SequentialVectorEnv(
        [lambda: GymMazeEnv(env="CartPole-v0") for _ in range(8)],
        logging_prefix="eval")

    # init distribution mapper
    env = GymMazeEnv(env="CartPole-v0")

    # init default distribution mapper
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})

    # initialize policies
    policies = {
        0: PolicyNet({'observation': (4, )}, {'action': (2, )},
                     non_lin=nn.Tanh)
    }

    # initialize critic
    critics = {0: ValueNet({'observation': (4, )})}

    # initialize optimizer
    algorithm_config = A2CAlgorithmConfig(n_epochs=n_epochs,
                                          epoch_length=10,
                                          patience=10,
                                          critic_burn_in_epochs=0,
                                          n_rollout_steps=20,
                                          lr=0.0005,
                                          gamma=0.98,
                                          gae_lambda=1.0,
                                          policy_loss_coef=1.0,
                                          value_loss_coef=0.5,
                                          entropy_coef=0.0,
                                          max_grad_norm=0.0,
                                          device="cpu",
                                          rollout_evaluator=RolloutEvaluator(
                                              eval_env=eval_env,
                                              n_episodes=1,
                                              model_selection=None,
                                              deterministic=True))

    # initialize actor critic model
    model = TorchActorCritic(policy=TorchPolicy(
        networks=policies,
        distribution_mapper=distribution_mapper,
        device=algorithm_config.device),
                             critic=TorchSharedStateCritic(
                                 networks=critics,
                                 obs_spaces_dict=env.observation_spaces_dict,
                                 device=algorithm_config.device,
                                 stack_observations=False),
                             device=algorithm_config.device)

    a2c = A2C(rollout_generator=RolloutGenerator(envs),
              evaluator=algorithm_config.rollout_evaluator,
              algorithm_config=algorithm_config,
              model=model,
              model_selection=None)

    setup_logging(job_config=None)

    # train agent
    a2c.train()

    # final evaluation run
    print("Final Evaluation Run:")
    a2c.evaluate()
Пример #15
0
def train_function(n_epochs: int, epoch_length: int, deterministic_eval: bool,
                   eval_repeats: int, distributed_env_cls,
                   split_rollouts_into_transitions: bool) -> SAC:
    """Implements the lunar lander continuous env and performs tests on it w.r.t. the sac trainer.
    """

    # initialize distributed env
    env_factory = lambda: GymMazeEnv(env="LunarLanderContinuous-v2")

    # initialize the env and enable statistics collection
    eval_env = distributed_env_cls([env_factory for _ in range(2)],
                                   logging_prefix='eval')

    env = env_factory()
    # init distribution mapper
    distribution_mapper = DistributionMapper(
        action_space=env.action_space,
        distribution_mapper_config=[{
            'action_space':
            'gym.spaces.Box',
            'distribution':
            'maze.distributions.squashed_gaussian.SquashedGaussianProbabilityDistribution'
        }])

    action_shapes = {
        step_key: {
            action_head:
            tuple(distribution_mapper.required_logits_shape(action_head))
            for action_head in env.action_spaces_dict[step_key].spaces.keys()
        }
        for step_key in env.action_spaces_dict.keys()
    }

    obs_shapes = observation_spaces_to_in_shapes(env.observation_spaces_dict)
    # initialize policies
    policies = {
        ii: PolicyNet(obs_shapes=obs_shapes[ii],
                      action_logits_shapes=action_shapes[ii],
                      non_lin=nn.Tanh)
        for ii in obs_shapes.keys()
    }

    for key, value in env.action_spaces_dict.items():
        for act_key, act_space in value.spaces.items():
            obs_shapes[key][act_key] = act_space.sample().shape
    # initialize critic
    critics = {
        ii: QCriticNetContinuous(obs_shapes[ii],
                                 non_lin=nn.Tanh,
                                 action_spaces_dict=env.action_spaces_dict)
        for ii in obs_shapes.keys()
    }

    # initialize optimizer
    algorithm_config = SACAlgorithmConfig(
        n_rollout_steps=5,
        lr=0.001,
        entropy_coef=0.2,
        gamma=0.99,
        max_grad_norm=0.5,
        batch_size=100,
        num_actors=2,
        tau=0.005,
        target_update_interval=1,
        entropy_tuning=False,
        device='cpu',
        replay_buffer_size=10000,
        initial_buffer_size=100,
        initial_sampling_policy={
            '_target_': 'maze.core.agent.random_policy.RandomPolicy'
        },
        rollouts_per_iteration=1,
        split_rollouts_into_transitions=split_rollouts_into_transitions,
        entropy_coef_lr=0.0007,
        num_batches_per_iter=1,
        n_epochs=n_epochs,
        epoch_length=epoch_length,
        rollout_evaluator=RolloutEvaluator(eval_env=eval_env,
                                           n_episodes=eval_repeats,
                                           model_selection=None,
                                           deterministic=deterministic_eval),
        patience=50,
        target_entropy_multiplier=1.0)

    actor_policy = TorchPolicy(networks=policies,
                               distribution_mapper=distribution_mapper,
                               device='cpu')

    replay_buffer = UniformReplayBuffer(
        buffer_size=algorithm_config.replay_buffer_size, seed=1234)
    SACRunner.init_replay_buffer(
        replay_buffer=replay_buffer,
        initial_sampling_policy=algorithm_config.initial_sampling_policy,
        initial_buffer_size=algorithm_config.initial_buffer_size,
        replay_buffer_seed=1234,
        split_rollouts_into_transitions=split_rollouts_into_transitions,
        n_rollout_steps=algorithm_config.n_rollout_steps,
        env_factory=env_factory)
    distributed_actors = DummyDistributedWorkersWithBuffer(
        env_factory=env_factory,
        worker_policy=actor_policy,
        n_rollout_steps=algorithm_config.n_rollout_steps,
        n_workers=algorithm_config.num_actors,
        batch_size=algorithm_config.batch_size,
        rollouts_per_iteration=algorithm_config.rollouts_per_iteration,
        split_rollouts_into_transitions=split_rollouts_into_transitions,
        env_instance_seeds=list(range(algorithm_config.num_actors)),
        replay_buffer=replay_buffer)

    critics_policy = TorchStepStateActionCritic(
        networks=critics,
        num_policies=1,
        device='cpu',
        only_discrete_spaces={0: False},
        action_spaces_dict=env.action_spaces_dict)

    learner_model = TorchActorCritic(policy=actor_policy,
                                     critic=critics_policy,
                                     device='cpu')

    # initialize trainer
    sac = SAC(learner_model=learner_model,
              distributed_actors=distributed_actors,
              algorithm_config=algorithm_config,
              evaluator=algorithm_config.rollout_evaluator,
              model_selection=None)

    # train agent
    sac.train(n_epochs=algorithm_config.n_epochs)

    return sac
Пример #16
0
def test_concepts_and_structures_run_context_overview():
    """
    Tests snippets in docs/source/concepts_and_structure/run_context_overview.rst.
    """

    # Default overrides for faster tests. Shouldn't change functionality.
    ac_overrides = {"runner.concurrency": 1}
    es_overrides = {"algorithm.n_epochs": 1, "algorithm.n_rollouts_per_update": 1}

    # Training
    # --------

    rc = RunContext(
        algorithm="a2c",
        overrides={"env.name": "CartPole-v0", **ac_overrides},
        model="vector_obs",
        critic="template_state",
        runner="dev",
        configuration="test"
    )
    rc.train(n_epochs=1)

    alg_config = A2CAlgorithmConfig(
        n_epochs=1,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(
            eval_env=SequentialVectorEnv([lambda: GymMazeEnv("CartPole-v0")]),
            n_episodes=1,
            model_selection=None,
            deterministic=True
        )
    )

    rc = RunContext(
        algorithm=alg_config,
        overrides={"env.name": "CartPole-v0", **ac_overrides},
        model="vector_obs",
        critic="template_state",
        runner="dev",
        configuration="test"
    )
    rc.train(n_epochs=1)

    rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'), overrides=es_overrides, runner="dev", configuration="test")
    rc.train(n_epochs=1)

    policy_composer_config = {
        '_target_': 'maze.perception.models.policies.ProbabilisticPolicyComposer',
        'networks': [{
            '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet',
            'non_lin': 'torch.nn.Tanh',
            'hidden_units': [256, 256]
        }],
        "substeps_with_separate_agent_nets": [],
        "agent_counts_dict": {0: 1}
    }
    rc = RunContext(
        overrides={"model.policy": policy_composer_config, **es_overrides}, runner="dev", configuration="test"
    )
    rc.train(n_epochs=1)

    env = GymMazeEnv('CartPole-v0')
    policy_composer = ProbabilisticPolicyComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        distribution_mapper=DistributionMapper(action_space=env.action_space, distribution_mapper_config={}),
        networks=[{
            '_target_': 'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet',
            'non_lin': 'torch.nn.Tanh',
            'hidden_units': [222, 222]
        }],
        substeps_with_separate_agent_nets=[],
        agent_counts_dict={0: 1}
    )
    rc = RunContext(overrides={"model.policy": policy_composer, **es_overrides}, runner="dev", configuration="test")
    rc.train(n_epochs=1)

    rc = RunContext(algorithm=alg_config, overrides=ac_overrides, runner="dev", configuration="test")
    rc.train(n_epochs=1)
    rc.train()

    # Rollout
    # -------

    obs = env.reset()
    for i in range(10):
        action = rc.compute_action(obs)
        obs, rewards, dones, info = env.step(action)

    # Evaluation
    # ----------

    env.reset()
    evaluator = RolloutEvaluator(
        # Environment has to be have statistics logging capabilities for RolloutEvaluator.
        eval_env=LogStatsWrapper.wrap(env, logging_prefix="eval"),
        n_episodes=1,
        model_selection=None
    )
    evaluator.evaluate(rc.policy)
Пример #17
0
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features=obs_shapes[OBSERVATION_NAME][0],
                      out_features=16), nn.Tanh(),
            nn.Linear(in_features=16,
                      out_features=action_logits_shapes[ACTION_NAME][0]))

    def forward(self, in_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """ forward pass. """
        return {ACTION_NAME: self.net(in_dict[OBSERVATION_NAME])}


# init default distribution mapper
distribution_mapper = DistributionMapper(
    action_space=spaces.Dict(spaces={ACTION_NAME: spaces.Discrete(2)}),
    distribution_mapper_config={})

# request required action logits shape and init a policy net
logits_shape = distribution_mapper.required_logits_shape(ACTION_NAME)
policy_net = PolicyNet(obs_shapes={OBSERVATION_NAME: (4, )},
                       action_logits_shapes={ACTION_NAME: logits_shape})

# compute action logits (here from random input)
logits_dict = policy_net({OBSERVATION_NAME: torch.randn(4)})

# init action sampling distribution from model output
dist = distribution_mapper.logits_dict_to_distribution(logits_dict,
                                                       temperature=1.0)

# sample action (e.g., {my_action: 1})
def test_examples_part1():
    """
    Tests snippets in maze/docs/source/concepts_and_structure/run_context_overview.rst.
    Adds some performance-specific configuration that should not influence snippets' functionality.
    Split for runtime reasons.
    """

    a2c_overrides = {"runner.concurrency": 1}
    es_overrides = {
        "algorithm.n_epochs": 1,
        "algorithm.n_rollouts_per_update": 1
    }
    env_factory = lambda: GymMazeEnv('CartPole-v0')
    alg_config = _get_alg_config("CartPole-v0", "dev")

    # ------------------------------------------------------------------

    rc = RunContext(algorithm="a2c",
                    overrides={
                        "env.name": "CartPole-v0",
                        **a2c_overrides
                    },
                    model="vector_obs",
                    critic="template_state",
                    runner="dev",
                    configuration="test")
    rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    rc = RunContext(algorithm=alg_config,
                    overrides={
                        "env.name": "CartPole-v0",
                        **a2c_overrides
                    },
                    model="vector_obs",
                    critic="template_state",
                    runner="dev",
                    configuration="test")
    rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'),
                    overrides=es_overrides,
                    runner="dev",
                    configuration="test")
    rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    policy_composer_config = {
        '_target_':
        'maze.perception.models.policies.ProbabilisticPolicyComposer',
        'networks': [{
            '_target_':
            'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet',
            'non_lin': 'torch.nn.Tanh',
            'hidden_units': [256, 256]
        }],
        "substeps_with_separate_agent_nets": [],
        "agent_counts_dict": {
            0: 1
        }
    }
    rc = RunContext(overrides={
        "model.policy": policy_composer_config,
        **es_overrides
    },
                    runner="dev",
                    configuration="test")
    rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    env = env_factory()
    policy_composer = ProbabilisticPolicyComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        distribution_mapper=DistributionMapper(action_space=env.action_space,
                                               distribution_mapper_config={}),
        networks=[{
            '_target_':
            'maze.perception.models.built_in.flatten_concat.FlattenConcatPolicyNet',
            'non_lin': 'torch.nn.Tanh',
            'hidden_units': [222, 222]
        }],
        substeps_with_separate_agent_nets=[],
        agent_counts_dict={0: 1})
    rc = RunContext(overrides={
        "model.policy": policy_composer,
        **es_overrides
    },
                    runner="dev",
                    configuration="test")
    rc.train(n_epochs=1)
Пример #19
0
def train(n_epochs: int) -> int:
    """
    Trains agent in pure Python.

    :param n_epochs: Number of epochs to train.

    :return: 0 if successful.

    """

    # Environment setup
    # -----------------

    env = cartpole_env_factory()

    # Algorithm setup
    # ---------------

    algorithm_config = A2CAlgorithmConfig(
        n_epochs=5,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(
            eval_env=SequentialVectorEnv([cartpole_env_factory]),
            n_episodes=1,
            model_selection=None,
            deterministic=True
        )
    )

    # Custom model setup
    # ------------------

    # Policy customization
    # ^^^^^^^^^^^^^^^^^^^^

    # Policy network.
    policy_net = CartpolePolicyNet(
        obs_shapes={'observation': env.observation_space.spaces['observation'].shape},
        action_logit_shapes={'action': (env.action_space.spaces['action'].n,)}
    )
    policy_networks = [policy_net]

    # Policy distribution.
    distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={})

    # Policy composer.
    policy_composer = ProbabilisticPolicyComposer(
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        # Derive distribution from environment's action space.
        distribution_mapper=distribution_mapper,
        networks=policy_networks,
        # We have only one agent and network, thus this is an empty list.
        substeps_with_separate_agent_nets=[],
        # We have only one step and one agent.
        agent_counts_dict={0: 1}
    )

    # Critic customization
    # ^^^^^^^^^^^^^^^^^^^^

    # Value networks.
    value_networks = {
        0: TorchModelBlock(
            in_keys='observation', out_keys='value',
            in_shapes=env.observation_space.spaces['observation'].shape,
            in_num_dims=[2],
            out_num_dims=2,
            net=CartpoleValueNet({'observation': env.observation_space.spaces['observation'].shape})
        )
    }

    # Critic composer.
    critic_composer = SharedStateCriticComposer(
        observation_spaces_dict=env.observation_spaces_dict,
        agent_counts_dict={0: 1},
        networks=value_networks,
        stack_observations=True
    )

    # Training
    # ^^^^^^^^

    rc = run_context.RunContext(
        env=cartpole_env_factory,
        algorithm=algorithm_config,
        policy=policy_composer,
        critic=critic_composer,
        runner="dev"
    )
    rc.train(n_epochs=n_epochs)

    # Distributed training
    # ^^^^^^^^^^^^^^^^^^^^

    algorithm_config.rollout_evaluator.eval_env = SubprocVectorEnv([cartpole_env_factory])
    rc = run_context.RunContext(
        env=cartpole_env_factory,
        algorithm=algorithm_config,
        policy=policy_composer,
        critic=critic_composer,
        runner="local"
    )
    rc.train(n_epochs=n_epochs)

    # Evaluation
    # ^^^^^^^^^^

    print("-----------------")
    evaluator = RolloutEvaluator(
        eval_env=LogStatsWrapper.wrap(cartpole_env_factory(), logging_prefix="eval"),
        n_episodes=1,
        model_selection=None
    )
    evaluator.evaluate(rc.policy)

    return 0
Пример #20
0
def _vtrace_from_logits(batch_size):
    """Tests V-trace calculated from logits."""
    seq_len = 5
    num_actions = 3
    clip_rho_threshold = None  # No clipping.
    clip_pg_rho_threshold = None  # No clipping.

    # Intentionally leaving shapes unspecified to test if V-trace can
    # deal with that.

    values = {
        'behaviour_policy_logits': [{
            'action1':
            convert_to_torch(_shaped_arange(seq_len, batch_size, num_actions),
                             device=None,
                             cast=None,
                             in_place='try')
        }],
        'target_policy_logits': [{
            'action1':
            convert_to_torch(_shaped_arange(seq_len, batch_size, num_actions),
                             device=None,
                             cast=None,
                             in_place='try')
        }],
        'actions': [{
            'action1':
            convert_to_torch(np.random.randint(0,
                                               num_actions - 1,
                                               size=(seq_len, batch_size)),
                             device=None,
                             cast=None,
                             in_place='try')
        }],
        'discounts':
        convert_to_torch(
            np.array(  # T, B where B_i: [0.9 / (i+1)] * T
                [[0.9 / (b + 1) for b in range(batch_size)]
                 for _ in range(seq_len)]),
            device=None,
            cast=None,
            in_place='try'),
        'rewards':
        convert_to_torch(_shaped_arange(seq_len, batch_size),
                         device=None,
                         cast=None,
                         in_place='try'),
        'values': [
            convert_to_torch(_shaped_arange(seq_len, batch_size) / batch_size,
                             device=None,
                             cast=None,
                             in_place='try')
        ],
        'bootstrap_value': [
            convert_to_torch(_shaped_arange(batch_size) + 1.0,
                             device=None,
                             cast=None,
                             in_place='try')
        ],
    }
    action_space = {
        0: gym.spaces.Dict({'action1': gym.spaces.Discrete(num_actions)})
    }
    # initialize distribution mapper
    distribution_mapper = DistributionMapper(action_space=action_space[0],
                                             distribution_mapper_config={})

    from_logits_output = impala_vtrace.from_logits(
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold,
        device=None,
        distribution_mapper=distribution_mapper,
        **values)

    target_log_probs, _ = impala_vtrace.log_probs_from_logits_and_actions_and_spaces(
        values['target_policy_logits'],
        values['actions'],
        distribution_mapper=distribution_mapper)
    behaviour_log_probs, _ = impala_vtrace.log_probs_from_logits_and_actions_and_spaces(
        values['behaviour_policy_logits'],
        values['actions'],
        distribution_mapper=distribution_mapper)
    log_rhos = impala_vtrace.get_log_rhos(target_log_probs,
                                          behaviour_log_probs)
    ground_truth_log_rhos, ground_truth_behaviour_action_log_probs, ground_truth_target_action_log_probs = \
        log_rhos, behaviour_log_probs, target_log_probs

    # Calculate V-trace using the ground truth logits.
    from_iw = impala_vtrace.from_importance_weights(
        log_rhos=ground_truth_log_rhos[0],
        discounts=values['discounts'],
        rewards=values['rewards'],
        values=values['values'][0],
        bootstrap_value=values['bootstrap_value'][0],
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold)

    assert np.allclose(from_iw.vs, from_logits_output.vs[0])
    assert np.allclose(from_iw.pg_advantages,
                       from_logits_output.pg_advantages[0])
    assert np.allclose(
        ground_truth_behaviour_action_log_probs[0]['action1'],
        from_logits_output.behaviour_action_log_probs[0]['action1'])
    assert np.allclose(
        ground_truth_target_action_log_probs[0]['action1'],
        from_logits_output.target_action_log_probs[0]['action1'])
    assert np.allclose(ground_truth_log_rhos[0],
                       from_logits_output.log_rhos[0])
Пример #21
0
def train(n_epochs):
    # Instantiate one environment. This will be used for convenient access to observation
    # and action spaces.
    env = cartpole_env_factory()
    observation_space = env.observation_space
    action_space = env.action_space

    # Policy Setup
    # ------------

    # Policy Network
    # ^^^^^^^^^^^^^^
    # Instantiate policy with the correct shapes of observation and action spaces.
    policy_net = CartpolePolicyNet(
        obs_shapes={'observation': observation_space.spaces['observation'].shape},
        action_logit_shapes={'action': (action_space.spaces['action'].n,)})

    maze_wrapped_policy_net = TorchModelBlock(
        in_keys='observation', out_keys='action',
        in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
        out_num_dims=2, net=policy_net)

    policy_networks = {0: maze_wrapped_policy_net}

    # Policy Distribution
    # ^^^^^^^^^^^^^^^^^^^
    distribution_mapper = DistributionMapper(
        action_space=action_space,
        distribution_mapper_config={})

    # Optionally, you can specify a different distribution with the distribution_mapper_config argument. Using a
    # Categorical distribution for a discrete action space would be done via
    distribution_mapper = DistributionMapper(
        action_space=action_space,
        distribution_mapper_config=[{
            "action_space": gym.spaces.Discrete,
            "distribution": "maze.distributions.categorical.CategoricalProbabilityDistribution"}])

    # Instantiating the Policy
    # ^^^^^^^^^^^^^^^^^^^^^^^^
    torch_policy = TorchPolicy(networks=policy_networks, distribution_mapper=distribution_mapper, device='cpu')

    # Value Function Setup
    # --------------------

    # Value Network
    # ^^^^^^^^^^^^^
    value_net = CartpoleValueNet(obs_shapes={'observation': observation_space.spaces['observation'].shape})

    maze_wrapped_value_net = TorchModelBlock(
        in_keys='observation', out_keys='value',
        in_shapes=observation_space.spaces['observation'].shape, in_num_dims=[2],
        out_num_dims=2, net=value_net)

    value_networks = {0: maze_wrapped_value_net}

    # Instantiate the Value Function
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    torch_critic = TorchSharedStateCritic(networks=value_networks, obs_spaces_dict=env.observation_spaces_dict,
                                          device='cpu', stack_observations=False)

    # Initializing the ActorCritic Model.
    # -----------------------------------
    actor_critic_model = TorchActorCritic(policy=torch_policy, critic=torch_critic, device='cpu')

    # Instantiating the Trainer
    # =========================

    algorithm_config = A2CAlgorithmConfig(
        n_epochs=n_epochs,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(
            eval_env=SequentialVectorEnv([cartpole_env_factory]),
            n_episodes=1,
            model_selection=None,
            deterministic=True
        )
    )

    # Distributed Environments
    # ------------------------
    # In order to use the distributed trainers, the previously created env factory is supplied to one of Maze's
    # distribution classes:
    train_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="train")
    eval_envs = SequentialVectorEnv([cartpole_env_factory for _ in range(2)], logging_prefix="eval")

    # Initialize best model selection.
    model_selection = BestModelSelection(dump_file="params.pt", model=actor_critic_model)

    a2c_trainer = A2C(rollout_generator=RolloutGenerator(train_envs),
                      evaluator=algorithm_config.rollout_evaluator,
                      algorithm_config=algorithm_config,
                      model=actor_critic_model,
                      model_selection=model_selection)

    # Train the Agent
    # ===============
    # Before starting the training, we will enable logging by calling
    log_dir = '.'
    setup_logging(job_config=None, log_dir=log_dir)

    # Now, we can train the agent.
    a2c_trainer.train()

    return 0
def perform_test_maze_rllib_action_distribution(batch_dim: int):
    """ distribution test """
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    # action space
    act_space = spaces.Dict(spaces=dict(
        sorted({
            "selection":
            spaces.Discrete(10),
            "scale_input":
            spaces.Box(shape=(5, ), low=0, high=100, dtype=np.float64),
            "order_by_weight":
            spaces.Box(shape=(5, ), low=0, high=100, dtype=np.float64)
        }.items())))

    # default config
    config = [{
        "action_space":
        spaces.Box,
        "distribution":
        "maze.distributions.squashed_gaussian.SquashedGaussianProbabilityDistribution"
    }, {
        "action_head":
        "order_by_weight",
        "distribution":
        "maze.distributions.beta.BetaProbabilityDistribution"
    }]

    # initialize distribution mapper
    distribution_mapper = DistributionMapper(action_space=act_space,
                                             distribution_mapper_config=config)

    num_outputs = sum([
        np.prod(distribution_mapper.required_logits_shape(action_head))
        for action_head in distribution_mapper.action_space.spaces
    ])
    model_config = {
        'custom_model_config': {
            'maze_model_composer_config': {
                'distribution_mapper_config': config
            }
        }
    }
    assert num_outputs == MazeRLlibActionDistribution.required_model_output_shape(
        act_space, model_config)

    # assign action heads to registered distributions
    logits_dict = dict()
    for action_head in act_space.spaces.keys():

        logits_shape = distribution_mapper.required_logits_shape(action_head)
        if batch_dim > 0:
            logits_shape = (batch_dim, *logits_shape)

        logits_tensor = torch.from_numpy(np.random.randn(*logits_shape))
        logits_dict[action_head] = logits_tensor

    flat_input = torch.cat([tt for tt in logits_dict.values()], dim=-1)
    if batch_dim == 0:
        flat_input = flat_input.unsqueeze(0)
    fake_model = FakeRLLibModel(distribution_mapper)
    rllib_dist = MazeRLlibActionDistribution(flat_input,
                                             fake_model,
                                             temperature=0.5)

    # test dictionary distribution mapping
    maze_dist = distribution_mapper.logits_dict_to_distribution(
        logits_dict=logits_dict, temperature=0.5)

    for action_head in act_space.spaces.keys():
        maze_distribution = maze_dist.distribution_dict[action_head]
        maze_rllib_distribution = rllib_dist.maze_dist.distribution_dict[
            action_head]
        if hasattr(maze_distribution, 'logits'):
            assert torch.allclose(maze_distribution.logits,
                                  maze_rllib_distribution.logits)
        if hasattr(maze_distribution, 'low'):
            assert torch.allclose(maze_distribution.low,
                                  maze_rllib_distribution.low)
            assert torch.allclose(maze_distribution.high,
                                  maze_rllib_distribution.high)

    test_action_maze = maze_dist.sample()
    test_action_rllib = rllib_dist.sample()

    for action_head in act_space.spaces.keys():
        assert test_action_maze[action_head].shape == test_action_rllib[
            action_head].shape[int(batch_dim == 0):]

    maze_action = maze_dist.deterministic_sample()
    rllib_action = rllib_dist.deterministic_sample()

    for action_head in act_space.spaces.keys():
        assert torch.all(maze_action[action_head] == rllib_action[action_head])

    maze_action = convert_to_torch(maze_action,
                                   device=None,
                                   cast=torch.float64,
                                   in_place=True)
    rllib_action = convert_to_torch(rllib_action,
                                    device=None,
                                    cast=torch.float64,
                                    in_place=True)

    # This un-sqeeze is preformed by rllib before passing an action to log p
    for action_head in act_space.spaces.keys():
        if len(rllib_action[action_head].shape) == 0:
            rllib_action[action_head] = rllib_action[action_head].unsqueeze(0)

    logp_maze_dict = maze_dist.log_prob(maze_action)
    action_concat = torch.cat(
        [v.unsqueeze(-1) for v in logp_maze_dict.values()], dim=-1)
    logp_maze = torch.sum(action_concat, dim=-1)

    logp_rllib = rllib_dist.logp(rllib_action)
    if batch_dim == 0:
        logp_rllib = logp_rllib[0]

    assert torch.equal(logp_maze, logp_rllib)

    logp_rllib_2 = rllib_dist.sampled_action_logp()
    if batch_dim == 0:
        logp_rllib_2 = logp_rllib_2[0]

    assert torch.equal(logp_maze, logp_rllib_2)

    maze_entropy = maze_dist.entropy()
    rllib_entropy = rllib_dist.entropy()
    if batch_dim == 0:
        rllib_entropy = rllib_entropy[0]

    assert torch.equal(maze_entropy, rllib_entropy)

    logits_dict2 = dict()
    for action_head in act_space.spaces.keys():
        logits_shape = distribution_mapper.required_logits_shape(action_head)
        if batch_dim > 0:
            logits_shape = (batch_dim, *logits_shape)

        logits_tensor = torch.from_numpy(np.random.randn(*logits_shape))
        logits_dict2[action_head] = logits_tensor

    flat_input = torch.cat([tt for tt in logits_dict2.values()], dim=-1)
    if batch_dim == 0:
        flat_input = flat_input.unsqueeze(0)
    fake_model = FakeRLLibModel(distribution_mapper)
    rllib_dist_2 = MazeRLlibActionDistribution(flat_input,
                                               fake_model,
                                               temperature=0.5)

    # test dictionary distribution mapping
    maze_dist_2 = distribution_mapper.logits_dict_to_distribution(
        logits_dict=logits_dict2, temperature=0.5)

    maze_kl = maze_dist.kl(maze_dist_2)
    rllib_kl = rllib_dist.kl(rllib_dist_2)
    if batch_dim == 0:
        rllib_kl = rllib_kl[0]

    assert torch.equal(maze_kl, rllib_kl)