コード例 #1
0
 def __init__(self):
     space = Box(0.0, 1.0, shape=(84, 84, 3), dtype=np.float32)
     super().__init__(space)
コード例 #2
0
    def __init__(self,
                 obj_low=None,
                 obj_high=None,
                 reward_type='hand_and_obj_distance',
                 indicator_threshold=0.06,
                 obj_init_positions=((0, 0.6, 0.02), ),
                 random_init=False,
                 fix_goal=False,
                 fixed_goal=(0.15, 0.6, 0.055, -0.15, 0.6),
                 goal_low=None,
                 goal_high=None,
                 reset_free=False,
                 hide_goal_markers=False,
                 oracle_reset_prob=0.0,
                 presampled_goals=None,
                 num_goals_presampled=10,
                 p_obj_in_hand=.75,
                 **kwargs):
        self.quick_init(locals())
        MultitaskEnv.__init__(self)
        SawyerXYZEnv.__init__(self, model_name=self.model_name, **kwargs)
        if obj_low is None:
            obj_low = self.hand_low
        if obj_high is None:
            obj_high = self.hand_high
        self.obj_low = obj_low
        self.obj_high = obj_high
        if goal_low is None:
            goal_low = np.hstack((self.hand_low, obj_low))
        if goal_high is None:
            goal_high = np.hstack((self.hand_high, obj_high))

        self.reward_type = reward_type
        self.random_init = random_init
        self.p_obj_in_hand = p_obj_in_hand
        self.indicator_threshold = indicator_threshold

        self.obj_init_z = obj_init_positions[0][2]
        self.obj_init_positions = np.array(obj_init_positions)
        self.last_obj_pos = self.obj_init_positions[0]

        self.fix_goal = fix_goal
        self.fixed_goal = np.array(fixed_goal)
        self._state_goal = None
        self.reset_free = reset_free
        self.oracle_reset_prob = oracle_reset_prob

        self.hide_goal_markers = hide_goal_markers

        self.action_space = Box(np.array([-1, -1, -1, -1]),
                                np.array([1, 1, 1, 1]),
                                dtype=np.float32)
        self.hand_and_obj_space = Box(np.hstack((self.hand_low, obj_low)),
                                      np.hstack((self.hand_high, obj_high)),
                                      dtype=np.float32)
        self.hand_space = Box(self.hand_low, self.hand_high, dtype=np.float32)
        self.gripper_and_hand_and_obj_space = Box(np.hstack(
            ([0.0], self.hand_low, obj_low)),
                                                  np.hstack(
                                                      ([0.04], self.hand_high,
                                                       obj_high)),
                                                  dtype=np.float32)

        self.observation_space = Dict([
            ('observation', self.gripper_and_hand_and_obj_space),
            ('desired_goal', self.hand_and_obj_space),
            ('achieved_goal', self.hand_and_obj_space),
            ('state_observation', self.gripper_and_hand_and_obj_space),
            ('state_desired_goal', self.hand_and_obj_space),
            ('state_achieved_goal', self.hand_and_obj_space),
            ('proprio_observation', self.hand_space),
            ('proprio_desired_goal', self.hand_space),
            ('proprio_achieved_goal', self.hand_space),
        ])
        self.hand_reset_pos = np.array([0, .6, .2])

        if presampled_goals is not None:
            self._presampled_goals = presampled_goals
            self.num_goals_presampled = len(
                list(self._presampled_goals.values)[0])
        else:
            # presampled_goals will be created when sample_goal is first called
            self._presampled_goals = None
            self.num_goals_presampled = num_goals_presampled
コード例 #3
0
ファイル: action_masking.py プロジェクト: wuisawesome/ray
    return args


if __name__ == "__main__":
    args = get_cli_args()

    ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)

    # main part: configure the ActionMaskEnv and ActionMaskModel
    config = {
        # random env with 100 discrete actions and 5x [-1,1] observations
        # some actions are declared invalid and lead to errors
        "env": ActionMaskEnv,
        "env_config": {
            "action_space": Discrete(100),
            "observation_space": Box(-1.0, 1.0, (5,)),
        },
        # the ActionMaskModel retrieves the invalid actions and avoids them
        "model": {
            "custom_model": ActionMaskModel
            if args.framework != "torch"
            else TorchActionMaskModel,
            # disable action masking according to CLI
            "custom_model_config": {"no_masking": args.no_masking},
        },
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
        "framework": args.framework,
        # Run with tracing enabled for tfe/tf2?
        "eager_tracing": args.eager_tracing,
    }
コード例 #4
0
    def __init__(self,
                 random_init=True,
                 obs_type='plain',
                 goal_low=None,
                 goal_high=None,
                 rotMode='fixed',
                 **kwargs):
        self.quick_init(locals())
        hand_low = (-0.5, 0.40, 0.05)
        hand_high = (0.5, 1, 0.5)
        obj_low = (-0.05, 0.85, 0.12)
        obj_high = (0.05, 0.9, 0.12)
        SawyerXYZEnv.__init__(self,
                              frame_skip=5,
                              action_scale=1. / 100,
                              hand_low=hand_low,
                              hand_high=hand_high,
                              model_name=self.model_name,
                              **kwargs)

        self.init_config = {
            'obj_init_pos': np.array([0., 0.9, 0.12], dtype=np.float32),
            'hand_init_pos': np.array([0, 0.6, 0.2], dtype=np.float32),
        }
        self.goal = np.array([0, 0.84, 0.12])
        self.obj_init_pos = self.init_config['obj_init_pos']
        self.hand_init_pos = self.init_config['hand_init_pos']

        assert obs_type in OBS_TYPE
        self.obs_type = obs_type

        if goal_low is None:
            goal_low = self.hand_low

        if goal_high is None:
            goal_high = self.hand_high

        self.random_init = random_init
        self.max_path_length = 150
        self.rotMode = rotMode

        if rotMode == 'fixed':
            self.action_space = Box(
                np.array([-1, -1, -1, -1]),
                np.array([1, 1, 1, 1]),
            )
        elif rotMode == 'rotz':
            self.action_rot_scale = 1. / 50
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi, -1]),
                np.array([1, 1, 1, np.pi, 1]),
            )
        elif rotMode == 'quat':
            self.action_space = Box(
                np.array([-1, -1, -1, 0, -1, -1, -1, -1]),
                np.array([1, 1, 1, 2 * np.pi, 1, 1, 1, 1]),
            )
        else:
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi / 2, -np.pi / 2, 0, -1]),
                np.array([1, 1, 1, np.pi / 2, np.pi / 2, np.pi * 2, 1]),
            )
        self.obj_and_goal_space = Box(
            np.array(obj_low),
            np.array(obj_high),
        )
        self.goal_space = Box(np.array(goal_low), np.array(goal_high))
        if self.obs_type == 'plain':
            self.observation_space = Box(
                np.hstack((
                    self.hand_low,
                    obj_low,
                )),
                np.hstack((
                    self.hand_high,
                    obj_high,
                )),
            )
        elif self.obs_type == 'with_goal':
            self.observation_space = Box(
                np.hstack((self.hand_low, obj_low, goal_low)),
                np.hstack((self.hand_high, obj_high, goal_high)),
            )
        else:
            raise NotImplementedError
        self.reset()
コード例 #5
0
    def __init__(self, env, active_handles, names, map_size, max_cycles,
                 reward_range, minimap_mode, extra_features):
        self.map_size = map_size
        self.max_cycles = max_cycles
        self.minimap_mode = minimap_mode
        self.extra_features = extra_features
        self.env = env
        self.handles = active_handles
        self._all_handles = self.env.get_handles()
        env.reset()
        self.generate_map()
        self.team_sizes = team_sizes = [
            env.get_num(handle) for handle in self.handles
        ]
        self.agents = [
            f"{names[j]}_{i}" for j in range(len(team_sizes))
            for i in range(team_sizes[j])
        ]
        self.possible_agents = self.agents[:]

        num_actions = [
            env.get_action_space(handle)[0] for handle in self.handles
        ]
        action_spaces_list = [
            Discrete(num_actions[j]) for j in range(len(team_sizes))
            for i in range(team_sizes[j])
        ]
        # may change depending on environment config? Not sure.
        team_obs_shapes = self._calc_obs_shapes()
        state_shape = self._calc_state_shape()
        observation_space_list = [
            Box(low=0., high=2., shape=team_obs_shapes[j], dtype=np.float32)
            for j in range(len(team_sizes)) for i in range(team_sizes[j])
        ]

        self.state_space = Box(low=0.,
                               high=2.,
                               shape=state_shape,
                               dtype=np.float32)
        reward_low, reward_high = reward_range

        if extra_features:
            for space in observation_space_list:
                idx = space.shape[2] - 3 if minimap_mode else space.shape[2] - 1
                space.low[:, :, idx] = reward_low
                space.high[:, :, idx] = reward_high
            idx_state = self.state_space.shape[
                2] - 3 if minimap_mode else self.state_space.shape[2] - 1
            self.state_space.low[:, :, idx_state] = reward_low
            self.state_space.high[:, :, idx_state] = reward_high

        self.action_spaces = {
            agent: space
            for agent, space in zip(self.agents, action_spaces_list)
        }
        self.observation_spaces = {
            agent: space
            for agent, space in zip(self.agents, observation_space_list)
        }

        self._zero_obs = {
            agent: np.zeros_like(space.low)
            for agent, space in self.observation_spaces.items()
        }
        self.base_state = np.zeros(self.state_space.shape)
        walls = self.env._get_walls_info()
        wall_x, wall_y = zip(*walls)
        self.base_state[wall_x, wall_y, 0] = 1
        self._renderer = None
        self.frames = 0
コード例 #6
0
ファイル: test_sac.py プロジェクト: zseymour/ray
 def __init__(self, config):
     self.action_space = Box(0.0, 1.0, (1, ))
     self.observation_space = Box(0.0, 1.0, (1, ))
     self.max_steps = config.get("max_steps", 100)
     self.state = None
     self.steps = None
コード例 #7
0
 def __init__(self, env, training=True):
     super().__init__(env)
     H, W, C = self.env.render('rgb_array').shape
     self.observation_space = Box(0, 255, (H, W, C), dtype=np.uint8)
コード例 #8
0
import numpy as np
import unittest

import ray
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as FCNetV2
from ray.rllib.models.tf.visionnet import VisionNetwork as VisionNetV2
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVisionNetV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNetV2
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.test_utils import framework_iterator

ACTION_SPACES_TO_TEST = {
    "discrete": Discrete(5),
    "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
    "vector2": Box(-1.0, 1.0, (5, ), dtype=np.float32),
    "int_actions": Box(0, 3, (2, 3), dtype=np.int32),
    "multidiscrete": MultiDiscrete([1, 2, 3, 4]),
    "tuple": Tuple(
        [Discrete(2),
         Discrete(3),
         Box(-1.0, 1.0, (5, ), dtype=np.float32)]),
    "dict": Dict({
        "action_choice": Discrete(3),
        "parameters": Box(-1.0, 1.0, (1, ), dtype=np.float32),
        "yet_another_nested_dict": Dict({
            "a": Tuple([Discrete(2), Discrete(3)])
        })
    }),
}
コード例 #9
0
ファイル: particle_1d.py プロジェクト: maxiaoba/rlkit
 def observation_space(self):
     return Box(-np.ones(2 * self.agent_num),
                np.ones(2 * self.agent_num),
                dtype=np.float32)
 def __init__(self, env):
     super().__init__(env)
     assert isinstance(env.observation_space, Discrete)
     self.observation_space = Box(0.0, 1.0, (env.observation_space.n,), dtype=np.float32)
コード例 #11
0
"""

import sys

sys.path.append('../rt_erg_lib')

###########################################
# basic function test

import numpy as np
import numpy.random as npr
from basis import Basis
from gym.spaces import Box

# define the exploration space as gym.Box
explr_space = Box(np.array([0.0, 0.0]), np.array([1.0, 1.0]), dtype=np.float32)
# define the basis object
basis = Basis(explr_space=explr_space, num_basis=5)
# simulate/randomize a trajectory
xt = [explr_space.sample() for _ in range(10)]
# print indices for all basis functions
print('indices for all basis functions: ')
print(basis.k)  # amount is square of num_basis
# test basis function, the input is a pose
print(basis.fk(xt[0]))
# test derivative of basis function wrt a pose
print(basis.dfk(xt[0]))
# hk, even computed in the source code, is not
# used in the end, so we temporarily ignore it

###########################################
コード例 #12
0
class SubepisodedReferenceGenerator(ReferenceGenerator):
    """
    Base Class for Reference Generators, which change its parameters in certain ranges after a random number of
    time steps and can pre-calculate their references in these "sub episodes".
    """

    reference_space = Box(-1, 1, shape=(1, ))
    _reference = None

    def __init__(self,
                 reference_state='omega',
                 episode_lengths=(500, 2000),
                 limit_margin=None,
                 *_,
                 **__):
        """
        Args:
            reference_state(str): Name of the state that this reference generator is referencing.
            episode_lengths(Tuple(int,int)): Minimum and maximum length of a sub episode.
            limit_margin(Tuple(float,float)/float/None):
                Factor, how close the references should get to the limits.
                If a tuple is passed, then the lower[0] and upper[1] margin might differ.
                If a float is passed, both margins are equal.
                If None(default), the limit margin equals (nominal values/limits).
                In general, the limit margin should not exceed (-1, 1)
        """
        super().__init__()
        self._limit_margin = limit_margin
        self._reference_value = 0.0
        self._reference_state = reference_state.lower()
        self._episode_len_range = episode_lengths
        self._current_episode_length = int(
            self._get_current_value(episode_lengths))
        self._k = 0

    def set_modules(self, physical_system):
        super().set_modules(physical_system)
        self._referenced_states = set_state_array({
            self._reference_state: 1
        }, physical_system.state_names).astype(bool)
        rs = self._referenced_states
        ps = physical_system
        if self._limit_margin is None:
            upper_margin = (ps.nominal_state[rs] /
                            ps.limits[rs])[0] * ps.state_space.high[rs]
            lower_margin = (ps.nominal_state[rs] /
                            ps.limits[rs])[0] * ps.state_space.low[rs]
            self._limit_margin = lower_margin[0], upper_margin[0]
        elif type(self._limit_margin) in [float, int]:
            upper_margin = self._limit_margin * ps.state_space.high[rs]
            lower_margin = self._limit_margin * ps.state_space.low[rs]
            self._limit_margin = lower_margin[0], upper_margin[0]
        elif type(self._limit_margin) is tuple:
            lower_margin = self._limit_margin[0] * ps.state_space.low[rs]
            upper_margin = self._limit_margin[1] * ps.state_space.high[rs]
            self._limit_margin = lower_margin[0], upper_margin[0]
        else:
            raise Exception('Unknown type for the limit margin.')
        self.reference_space = Box(lower_margin[0],
                                   upper_margin[0],
                                   shape=(1, ))

    def reset(self, initial_state=None, initial_reference=None):
        """
        The references are reset. If an initial reference is passed, this value will be the first reference value of
        the next episode. Otherwise it will be 0.

        Args:
            initial_state(ndarray(float)): The initial state of the physical system.
            initial_reference(ndarray(float)): (Optional) The first reference value.

        Returns:
             initial_reference(ndarray(float)): initial reference array.
             initial_reference_observation(element of reference_space): An initial observation of the next reference.
             trajectory(None): No initial trajectory is passed.
        """
        if initial_reference is not None:
            self._reference_value = initial_reference[
                self._referenced_states][0]
        else:
            self._reference_value = 0.0
        self._current_episode_length = -1
        return super().reset(initial_state)

    def get_reference(self, *_, **__):
        reference = np.zeros_like(self._referenced_states, dtype=float)
        reference[self._referenced_states] = self._reference_value
        return reference

    def get_reference_observation(self, *_, **__):
        if self._k >= self._current_episode_length:
            self._k = 0
            self._current_episode_length = int(
                self._get_current_value(self._episode_len_range))
            self._reset_reference()
        self._reference_value = self._reference[self._k]
        self._k += 1
        return np.array([self._reference_value])

    def _reset_reference(self):
        """
        Subclasses implement in this method its generation of the references for the next self._current_episode_length
        time steps and write it into self._reference.
        """
        raise NotImplementedError

    @staticmethod
    def _get_current_value(value_range):
        """
        Return a uniform distributed value for the next sub episode.

        If float or int is passed this value will be returned. Otherwise a uniform distributed value
        between value_range[0] and value_range[1] is returned.
        """
        if type(value_range) in [int, float]:
            return value_range
        elif type(value_range) in [list, tuple, np.ndarray]:
            return (value_range[1] -
                    value_range[0]) * np.random.rand() + value_range[0]
コード例 #13
0
    def init_observation_space(self):
        arm = dict()
        arm['joint_pos_list'] = Box(low=-np.inf,
                                    high=np.inf,
                                    shape=(self.arm_dof_num, 3),
                                    dtype=np.float32)
        arm['jonit_vel_list'] = Box(low=-np.inf,
                                    high=np.inf,
                                    shape=(self.arm_dot_num, 3),
                                    dtype=np.float32)
        arm['link_glb_pos_list'] = Box(low=-np.inf,
                                       high=np.inf,
                                       shape=(self.arm_link_num, 3),
                                       dtype=np.float32)
        arm['link_glb_orn_list'] = Box(low=-np.inf,
                                       high=np.inf,
                                       shape=(self.arm_link_num, 4),
                                       dtype=np.float32)
        arm['link_loc_pos_list'] = Box(low=-np.inf,
                                       high=np.inf,
                                       shape=(self.arm_link_num, 3),
                                       dtype=np.float32)
        arm['link_loc_orn_list'] = Box(low=-np.inf,
                                       high=np.inf,
                                       shape=(self.arm_link_num, 4),
                                       dtype=np.float32)

        gripper = dict()
        griiper['joint_pos_list'] = Box(low=-np.inf,
                                        high=np.inf,
                                        shape=(self.gripper_dof_num, 3),
                                        dtype=np.float32)
        griiper['jonit_vel_list'] = Box(low=-np.inf,
                                        high=np.inf,
                                        shape=(self.gripper_dot_num, 3),
                                        dtype=np.float32)
        griiper['link_glb_pos_list'] = Box(low=-np.inf,
                                           high=np.inf,
                                           shape=(self.gripper_link_num, 3),
                                           dtype=np.float32)
        griiper['link_glb_orn_list'] = Box(low=-np.inf,
                                           high=np.inf,
                                           shape=(self.gripper_link_num, 4),
                                           dtype=np.float32)
        griiper['link_loc_pos_list'] = Box(low=-np.inf,
                                           high=np.inf,
                                           shape=(self.gripper_link_num, 3),
                                           dtype=np.float32)
        griiper['link_loc_orn_list'] = Box(low=-np.inf,
                                           high=np.inf,
                                           shape=(self.gripper_link_num, 4),
                                           dtype=np.float32)

        super().observation['arm'] = arm
        super().observation['gripper'] = gripper
コード例 #14
0
ファイル: test_preprocessors.py プロジェクト: wuisawesome/ray
 def __init__(self):
     self.observation_space = Tuple(
         [Discrete(5),
          Box(0, 5, shape=(3, ), dtype=np.float32)])
コード例 #15
0
ファイル: tfnet_2.py プロジェクト: Woitoxx/bomberman-rl
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        # TODO: (sven) Support Dicts as well.

        self.original_space = obs_space.original_space if \
            hasattr(obs_space, "original_space") else obs_space
        assert isinstance(self.original_space, (Tuple)), \
            "`obs_space.original_space` must be Tuple!"

        super().__init__(self.original_space, action_space, num_outputs,
                         model_config, name)
        self.new_obs_space = obs_space
        # Build the CNN(s) given obs_space's image components.
        self.cnns = {}
        self.one_hot = {}
        self.flatten = {}
        concat_size_p, concat_size_v = 0, 0
        for i, component in enumerate(self.original_space):
            # Image space.
            if len(component.shape) == 3:
                config = {
                    "conv_filters": model_config.get(
                        "conv_filters"),
                    "conv_activation": model_config.get("conv_activation"),
                    "post_fcnet_hiddens": [],
                }
                cnn = CustomVisionNetwork(component, action_space, None, config, "cnn_{}".format(i))
                '''
                cnn = ModelCatalog.get_model_v2(
                    component,
                    action_space,
                    num_outputs=None,
                    model_config=config,
                    framework="tf",
                    name="cnn_{}".format(i))
                '''
                cnn.base_model.summary()
                concat_size_p += cnn.num_outputs_p
                concat_size_v += cnn.num_outputs_v
                self.cnns[i] = cnn
            # Discrete inputs -> One-hot encode.
            elif isinstance(component, Discrete):
                self.one_hot[i] = True
                concat_size_p += component.n
                concat_size_v += component.n
            # TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
            # Everything else (1D Box).
            else:
                self.flatten[i] = int(np.product(component.shape))
                concat_size_p += self.flatten[i]
                concat_size_v += self.flatten[i]

        # Optional post-concat FC-stack.
        post_fc_stack_config = {
            "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
            "fcnet_activation": model_config.get("post_fcnet_activation",
                                                 "relu"),
            "vf_share_layers": 'True'
        }
        self.post_fc_stack = ModelCatalog.get_model_v2(
            Box(float("-inf"),
                float("inf"),
                shape=(concat_size_p,),
                dtype=np.float32),
            self.action_space,
            None,
            post_fc_stack_config,
            framework="tf",
            name="post_fc_stack")

        self.post_fc_stack_vf = ModelCatalog.get_model_v2(
            Box(float("-inf"),
                float("inf"),
                shape=(concat_size_v,),
                dtype=np.float32),
            self.action_space,
            None,
            post_fc_stack_config,
            framework="tf",
            name="post_fc_stack_vf")
        self.post_fc_stack.base_model.summary()
        self.post_fc_stack_vf.base_model.summary()

        # Actions and value heads.
        self.logits_and_value_model = None
        self._value_out = None
        if num_outputs:
            # Action-distribution head.
            p_layer = tf.keras.layers.Input(
                (self.post_fc_stack.num_outputs,))
            v_layer = tf.keras.layers.Input(
                (self.post_fc_stack_vf.num_outputs,))
            logits_layer = tf.keras.layers.Dense(
                num_outputs,
                activation=tf.keras.activations.linear,
                name="logits")(p_layer)

            # Create the value branch model.
            value_layer = tf.keras.layers.Dense(
                1,
                name="value_out",
                activation=tf.keras.activations.tanh,
                kernel_initializer=normc_initializer(0.01))(v_layer)
            self.logits_model = tf.keras.models.Model(
                p_layer, [logits_layer])
            self.value_model = tf.keras.models.Model(
                v_layer, [value_layer]
            )
            self.logits_model.summary()
            self.value_model.summary()
        else:
            self.num_outputs = self.post_fc_stack.num_outputs
コード例 #16
0
ファイル: particle_1d.py プロジェクト: maxiaoba/rlkit
 def action_space(self):
     return Box(-np.ones(1), np.ones(1), dtype=np.float32)
コード例 #17
0
    def __init__(
            self,
            random_init=False,
            task_types=['pick_place', 'reach', 'push'],
            task_type='pick_place',
            obs_type='plain',
            goal_low=(-0.1, 0.8, 0.05),
            goal_high=(0.1, 0.9, 0.3),
            liftThresh = 0.04,
            sampleMode='equal',
            rewMode = 'orig',
            rotMode='fixed',#'fixed',
            **kwargs
    ):
        self.quick_init(locals())

        hand_low=(-0.5, 0.40, 0.05)
        hand_high=(0.5, 1, 0.5)
        obj_low=(-0.1, 0.6, 0.02)
        obj_high=(0.1, 0.7, 0.02)

        SawyerXYZEnv.__init__(
            self,
            frame_skip=5,
            action_scale=1./100,
            hand_low=hand_low,
            hand_high=hand_high,
            model_name=self.model_name,
            **kwargs
        )
        self.task_type = task_type
        self.init_config = {
            'obj_init_angle': .3,
            'obj_init_pos': np.array([0, 0.6, 0.02]),
            'hand_init_pos': np.array([0, .6, .2]),
        }
        # we only do one task from [pick_place, reach, push]
        # per instance of SawyerReachPushPickPlaceEnv.
        # Please only set task_type from constructor.
        if self.task_type == 'pick_place':
            self.goal = np.array([0.1, 0.8, 0.2])
        elif self.task_type == 'reach':
            self.goal = np.array([-0.1, 0.8, 0.2])
        elif self.task_type == 'push':
            self.goal = np.array([0.1, 0.8, 0.02])
        else:
            raise NotImplementedError
        self.obj_init_angle = self.init_config['obj_init_angle']
        self.obj_init_pos = self.init_config['obj_init_pos']
        self.hand_init_pos = self.init_config['hand_init_pos']

        assert obs_type in OBS_TYPE
        self.obs_type = obs_type

        if goal_low is None:
            goal_low = self.hand_low
        
        if goal_high is None:
            goal_high = self.hand_high

        self.random_init = random_init
        self.liftThresh = liftThresh
        self.max_path_length = 150
        self.rewMode = rewMode
        self.rotMode = rotMode
        self.sampleMode = sampleMode
        self.task_types = task_types
        if rotMode == 'fixed':
            self.action_space = Box(
                np.array([-1, -1, -1, -1]),
                np.array([1, 1, 1, 1]),
            )
        elif rotMode == 'rotz':
            self.action_rot_scale = 1./50
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi, -1]),
                np.array([1, 1, 1, np.pi, 1]),
            )
        elif rotMode == 'quat':
            self.action_space = Box(
                np.array([-1, -1, -1, 0, -1, -1, -1, -1]),
                np.array([1, 1, 1, 2*np.pi, 1, 1, 1, 1]),
            )
        else:
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi/2, -np.pi/2, 0, -1]),
                np.array([1, 1, 1, np.pi/2, np.pi/2, np.pi*2, 1]),
            )
        self.obj_and_goal_space = Box(
            np.hstack((obj_low, goal_low)),
            np.hstack((obj_high, goal_high)),
        )
        self.goal_space = Box(np.array(goal_low), np.array(goal_high))
        if self.obs_type == 'plain':
            self.observation_space = Box(
                np.hstack((self.hand_low, obj_low,)),
                np.hstack((self.hand_high, obj_high,)),
            )
        elif self.obs_type == 'with_goal':
            self.observation_space = Box(
                np.hstack((self.hand_low, obj_low, goal_low)),
                np.hstack((self.hand_high, obj_high, goal_high)),
            )
        else:
            raise NotImplementedError('If you want to use an observation\
                with_obs_idx, please discretize the goal space after instantiate an environment.')
        self.num_resets = 0
        self.reset()
コード例 #18
0
    def __init__(self,
                 random_init=False,
                 obs_type='plain',
                 goal_low=(0., 0.85, 0.05),
                 goal_high=(0.3, 0.9, 0.05),
                 liftThresh=0.09,
                 rotMode='fixed',
                 rewMode='orig',
                 **kwargs):
        self.quick_init(locals())
        hand_low = (-0.5, 0.40, 0.05)
        hand_high = (0.5, 1, 0.5)
        obj_low = (-0.1, 0.5, 0.02)
        obj_high = (0.1, 0.6, 0.02)
        SawyerXYZEnv.__init__(self,
                              frame_skip=5,
                              action_scale=1. / 100,
                              hand_low=hand_low,
                              hand_high=hand_high,
                              model_name=self.model_name,
                              **kwargs)
        # TODO should we put this to goal instead of initial config?
        self.init_config = {
            'hammer_init_pos': np.array([0, 0.6, 0.02]),
            'hand_init_pos': np.array([0, 0.6, 0.2]),
        }
        self.goal = self.init_config['hammer_init_pos']  # TODO: check this
        self.hammer_init_pos = self.init_config['hammer_init_pos']
        self.hand_init_pos = self.init_config['hand_init_pos']

        if goal_low is None:
            goal_low = self.hand_low

        if goal_high is None:
            goal_high = self.hand_high

        assert obs_type in OBS_TYPE
        self.obs_type = obs_type

        self.random_init = random_init
        self.liftThresh = liftThresh
        self.max_path_length = 200
        self.rewMode = rewMode
        self.rotMode = rotMode
        if rotMode == 'fixed':
            self.action_space = Box(
                np.array([-1, -1, -1, -1]),
                np.array([1, 1, 1, 1]),
            )
        elif rotMode == 'rotz':
            self.action_rot_scale = 1. / 50
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi, -1]),
                np.array([1, 1, 1, np.pi, 1]),
            )
        elif rotMode == 'quat':
            self.action_space = Box(
                np.array([-1, -1, -1, 0, -1, -1, -1, -1]),
                np.array([1, 1, 1, 2 * np.pi, 1, 1, 1, 1]),
            )
        else:
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi / 2, -np.pi / 2, 0, -1]),
                np.array([1, 1, 1, np.pi / 2, np.pi / 2, np.pi * 2, 1]),
            )
        self.obj_and_goal_space = Box(
            np.hstack((obj_low, goal_low)),
            np.hstack((obj_high, goal_high)),
        )
        self.goal_space = Box(np.array(goal_low), np.array(goal_high))
        if self.obs_type == 'plain':
            self.observation_space = Box(
                np.hstack((
                    self.hand_low,
                    obj_low,
                )),
                np.hstack((
                    self.hand_high,
                    obj_high,
                )),
            )
        elif self.obs_type == 'with_goal':
            self.observation_space = Box(
                np.hstack((self.hand_low, obj_low, goal_low)),
                np.hstack((self.hand_high, obj_high, goal_high)),
            )
        else:
            raise NotImplementedError
        self.reset()
コード例 #19
0
class CityFlowEnvRay(MultiAgentEnv):
    """
    multi intersection cityflow environment, for the Ray framework
    """
    observation_space = Box(0.0 * np.ones((29, )), 150 * np.ones((29, )))
    action_space = Discrete(8)  # num of agents
    config = json.load(
        open(
            '/home/skylark/PycharmRemote/Gamma-Reward-Perfect/config/global_config.json'
        ))
    cityflow_config = json.load(open(config['cityflow_config_file']))
    roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
    config["lane_phase_info"] = parse_roadnet(roadnetFile)
    intersection_id = list(config['lane_phase_info'].keys())

    def __init__(self, env_config):
        config = json.load(
            open(
                '/home/skylark/PycharmRemote/Gamma-Reward-Perfect/config/global_config.json'
            ))
        cityflow_config = json.load(open(config['cityflow_config_file']))
        self.roadnetFile = cityflow_config['dir'] + cityflow_config[
            'roadnetFile']
        self.record_dir = '/home/skylark/PycharmRemote/Gamma-Reward-Perfect/record/' + env_config[
            "Name"]
        if not os.path.exists(self.record_dir):
            os.makedirs(self.record_dir)

        self.dic_traffic_env_conf = {
            'ADJACENCY_BY_CONNECTION_OR_GEO': False,
            'TOP_K_ADJACENCY': 5
        }

        self.dic_lane_waiting_vehicle_count_previous_step = {}
        self.dic_vehicle_speed_previous_step = {}
        self.dic_vehicle_distance_previous_step = {}
        self.dic_lane_vehicle_previous_step = {}
        self.traffic_light_node_dict = self._adjacency_extraction()
        self.dic_lane_vehicle_current_step = {}
        self.dic_lane_waiting_vehicle_count_current_step = {}
        self.dic_vehicle_speed_current_step = {}
        self.dic_vehicle_distance_current_step = {}
        self.list_lane_vehicle_previous_step = {}
        self.list_lane_vehicle_current_step = {}
        self.dic_vehicle_arrive_leave_time = {}

        config["lane_phase_info"] = parse_roadnet(self.roadnetFile)
        intersection_id = list(config['lane_phase_info'].keys())
        self.Gamma_Reward = env_config["Gamma_Reward"]
        self.threshold = env_config["threshold"]
        self.min_action_time = env_config["MIN_ACTION_TIME"]
        self.road_sort = env_config['road_sort']

        self.eng = cityflow.Engine(config['cityflow_config_file'],
                                   thread_num=config["thread_num"])
        # self.eng = config["eng"][0]
        self.num_step = 3600
        self.intersection_id = intersection_id  # list, [intersection_id, ...]
        self.num_agents = len(self.intersection_id)
        self.state_size = None
        self.lane_phase_info = config["lane_phase_info"]  # "intersection_1_1"

        # self.score = []
        # self.score_file = './utils/score_' + str(datetime.datetime.now()) + '.csv'
        self.current_phase = {}
        self.current_phase_time = {}
        self.start_lane = {}
        self.end_lane = {}
        self.phase_list = {}
        self.phase_startLane_mapping = {}
        self.intersection_lane_mapping = {}  # {id_:[lanes]}
        self.empty = {}
        self.dic_num_id_inter = {}

        num_id = 0
        for id_ in self.intersection_id:
            self.start_lane[id_] = self.lane_phase_info[id_]['start_lane']
            self.end_lane[id_] = self.lane_phase_info[id_]['end_lane']
            self.phase_startLane_mapping[id_] = self.lane_phase_info[id_][
                "phase_startLane_mapping"]

            self.phase_list[id_] = self.lane_phase_info[id_]["phase"]
            self.current_phase[id_] = self.phase_list[id_][0]
            self.current_phase_time[id_] = 0

            self.dic_lane_vehicle_current_step[id_] = {}
            self.dic_lane_waiting_vehicle_count_current_step[id_] = {}
            self.dic_vehicle_arrive_leave_time[id_] = {}
            self.empty[id_] = {}
            self.dic_num_id_inter[num_id] = id_
            num_id += 1

        self.reset_count = 0
        self.get_state()  # set self.state_size
        self.num_actions = len(self.phase_list[self.intersection_id[0]])

        self.count = 0
        self.congestion_count = 0
        # self.done = False
        self.congestion = False
        self.iteration_count = []

        self.reset()

    def reset(self):
        print("\033[1;34m=================================\033[0m")
        print("\033[1;34mreset_count: {0}, iteration: {1}\033[0m".format(
            self.reset_count, self.count))
        # self.iteration_count.append(self.count)
        # if self.reset_count >= 102:
        #     df = pd.DataFrame(self.iteration_count)
        #     df.to_csv(os.path.join(self.record_dir, 'iteration_count.csv'))

        if not operator.eq(self.dic_vehicle_arrive_leave_time, self.empty):
            path_to_log = self.record_dir + '/train_results/episode_{0}/'.format(
                self.reset_count)
            if not os.path.exists(path_to_log):
                os.makedirs(path_to_log)
            self.log(path_to_log)
            print("Log is saved !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            print(path_to_log)

        self.eng.reset()
        self.dic_vehicle_arrive_leave_time = copy.deepcopy(self.empty)
        self.dic_lane_vehicle_current_step = copy.deepcopy(self.empty)
        self.dic_lane_waiting_vehicle_count_current_step = copy.deepcopy(
            self.empty)
        # self.traffic_light_node_dict = self._adjacency_extraction()
        # self.done = False
        self.congestion = False
        self.count = 0
        self.congestion_count = 0
        self.reset_count += 1
        return {
            id_: np.zeros((self.state_size, ))
            for id_ in self.intersection_id
        }

    def step(self, action):
        """
        Calculate the state in time
        """
        step_start_time = time.time()
        for i in range(self.min_action_time):
            self._inner_step(action)

        state = self.get_state()
        reward = self.get_raw_reward()

        # 判断是否已经出现拥堵
        self.congestion = self.compute_congestion()
        self.done = {id_: False for id_ in self.intersection_id}
        self.done['__all__'] = False
        # if self.count >= self.num_step:
        #     self.done = {id_: True for id_ in self.intersection_id}
        #     self.done['__all__'] = True
        # if self.count == 3600:
        #     self.reset()
        return state, reward, self.done, {}

    def _inner_step(self, action):
        self.update_previous_measurements()

        for id_, a in action.items():  # intersection_id, corresponding action
            if self.current_phase[id_] == self.phase_list[id_][a]:
                self.current_phase_time[id_] += 1
            else:
                self.current_phase[id_] = self.phase_list[id_][a]
                self.current_phase_time[id_] = 1
            self.eng.set_tl_phase(
                id_, self.current_phase[id_])  # set phase of traffic light

        self.eng.next_step()
        self.count += 1

        # print(self.count)

        self.system_states = {
            "get_lane_vehicles":
            self.eng.get_lane_vehicles(),
            "get_lane_waiting_vehicle_count":
            self.eng.get_lane_waiting_vehicle_count(),
            "get_vehicle_speed":
            None,
            "get_vehicle_distance":
            None
        }
        for id_ in self.intersection_id:
            self.update_current_measurements_map(id_, self.system_states)

    def update_previous_measurements(self):
        self.dic_lane_vehicle_previous_step = copy.deepcopy(
            self.dic_lane_vehicle_current_step)
        self.dic_lane_waiting_vehicle_count_previous_step = copy.deepcopy(
            self.dic_lane_waiting_vehicle_count_current_step)
        self.dic_vehicle_speed_previous_step = copy.deepcopy(
            self.dic_vehicle_speed_current_step)
        self.dic_vehicle_distance_previous_step = copy.deepcopy(
            self.dic_vehicle_distance_current_step)

    def update_current_measurements_map(self, id_, simulator_state):
        ## need change, debug in seeing format
        def _change_lane_vehicle_dic_to_list(dic_lane_vehicle):
            list_lane_vehicle = []
            for value in dic_lane_vehicle.values():
                list_lane_vehicle.extend(value)
            return list_lane_vehicle

        for lane in self.lane_phase_info[id_]['start_lane']:
            self.dic_lane_vehicle_current_step[id_][lane] = simulator_state[
                "get_lane_vehicles"][lane]

            self.dic_lane_waiting_vehicle_count_current_step[id_][lane] = \
                simulator_state["get_lane_waiting_vehicle_count"][
                    lane]

        for lane in self.lane_phase_info[id_]['end_lane']:
            self.dic_lane_vehicle_current_step[id_][lane] = simulator_state[
                "get_lane_vehicles"][lane]
            self.dic_lane_waiting_vehicle_count_current_step[id_][lane] = \
                simulator_state["get_lane_waiting_vehicle_count"][
                    lane]

        self.dic_vehicle_speed_current_step[id_] = simulator_state[
            'get_vehicle_speed']
        self.dic_vehicle_distance_current_step[id_] = simulator_state[
            'get_vehicle_distance']

        # get vehicle list
        self.list_lane_vehicle_current_step[
            id_] = _change_lane_vehicle_dic_to_list(
                self.dic_lane_vehicle_current_step[id_])
        self.list_lane_vehicle_previous_step[
            id_] = _change_lane_vehicle_dic_to_list(
                self.dic_lane_vehicle_previous_step[id_])

        list_vehicle_new_arrive = list(
            set(self.list_lane_vehicle_current_step[id_]) -
            set(self.list_lane_vehicle_previous_step[id_]))
        # if id_ == 'intersection_6_1':
        #     print('list_lane_vehicle_current_step: ' + str(self.list_lane_vehicle_current_step))
        #     print('list_lane_vehicle_previous_step: ' + str(self.list_lane_vehicle_previous_step))

        list_vehicle_new_left = list(
            set(self.list_lane_vehicle_previous_step[id_]) -
            set(self.list_lane_vehicle_current_step[id_]))
        list_vehicle_new_left_entering_lane_by_lane = self._update_leave_entering_approach_vehicle(
            id_)
        list_vehicle_new_left_entering_lane = []
        for l in list_vehicle_new_left_entering_lane_by_lane:
            list_vehicle_new_left_entering_lane += l
        # print('list_vehicle_new_arrive' + str(list_vehicle_new_arrive))
        # print('list_vehicle_new_left_entering_lane' + str(list_vehicle_new_left_entering_lane))

        # update vehicle arrive and left time
        self._update_arrive_time(id_, list_vehicle_new_arrive)
        self._update_left_time(id_, list_vehicle_new_left_entering_lane)

        # update vehicle minimum speed in history, # to be implemented
        # self._update_vehicle_min_speed()

        # update feature
        # self._update_feature_map(simulator_state)

    def compute_congestion(self):
        index = False
        intersection_info = {}
        for id_ in self.intersection_id:
            intersection_info[id_] = self.intersection_info(id_)
        congestion = {id_: False for id_ in self.intersection_id}
        for id_ in self.intersection_id:
            if np.max(
                    list(intersection_info[id_]
                         ["start_lane_waiting_vehicle_count"].values())
            ) > self.threshold:
                congestion[id_] = True
                index = True
        return index

    def get_state(self):
        state = {id_: self._get_state(id_) for id_ in self.intersection_id}
        # self.score.append(self.get_score())
        # if self.reset_count >= 100:
        #     '''
        #     TODO: save fn may be a bug
        #     '''
        # score = pd.DataFrame(self.score)
        # score.to_csv(self.score_file)
        return state

    def _get_state(self, id_):
        state = self.intersection_info(id_)

        if self.Gamma_Reward:
            #### dw ####
            keys = state['end_lane_vehicle_count'].keys()
            start_index = id_.find('_')
            s0 = 'road' + id_[start_index:start_index + 4] + '_0'  # To East
            s1 = 'road' + id_[start_index:start_index + 4] + '_1'  # To North
            s2 = 'road' + id_[start_index:start_index + 4] + '_2'  # To West
            s3 = 'road' + id_[start_index:start_index + 4] + '_3'  # To South

            num_w_e = 0
            num_e_w = 0
            num_n_s = 0
            num_s_n = 0

            for i in keys:
                if i.startswith(s0):
                    num_w_e += state['end_lane_vehicle_count'][i]
                elif i.startswith(s1):
                    num_n_s += state['end_lane_vehicle_count'][i]
                elif i.startswith(s2):
                    num_e_w += state['end_lane_vehicle_count'][i]
                elif i.startswith(s3):
                    num_s_n += state['end_lane_vehicle_count'][i]

            end_lane_dict = {
                s0: num_w_e,
                s1: num_n_s,
                s2: num_e_w,
                s3: num_s_n
            }
            end_lane_sorted_keys = sorted(end_lane_dict.keys())

            state_dict_waiting = state['start_lane_waiting_vehicle_count']
            state_dict = state['start_lane_vehicle_count']

            # 12-dim start lanes car + 12-dim start lanes waiting car number + current phase + 4-dim end_lanes
            # waiting car number
            return_state = [state_dict[key] for key in self.road_sort[id_]] + [
                state_dict_waiting[key] for key in self.road_sort[id_]
            ] + [state['current_phase']
                 ] + [end_lane_dict[key] for key in end_lane_sorted_keys]
            # return_state = [state_dict[key] for key in sorted_keys] + [state['current_phase']] + \
            #                [0, 0, 0, 0]
        else:
            state_dict = state['start_lane_waiting_vehicle_count']
            return_state = [state_dict[key] for key in self.road_sort[id_]] + [0] * 12 + [state['current_phase']] + \
                           [0, 0, 0, 0]

        return self.preprocess_state(return_state)

    def intersection_info(self, id_):
        """
        info of intersection 'id_'
        """
        state = {}

        get_lane_vehicle_count = self.eng.get_lane_vehicle_count()
        get_lane_waiting_vehicle_count = self.eng.get_lane_waiting_vehicle_count(
        )
        # get_lane_vehicles = self.eng.get_lane_vehicles()
        # get_vehicle_speed = self.eng.get_vehicle_speed()

        state['start_lane_vehicle_count'] = {
            lane: get_lane_vehicle_count[lane]
            for lane in self.start_lane[id_]
        }
        state['end_lane_vehicle_count'] = {
            lane: get_lane_vehicle_count[lane]
            for lane in self.end_lane[id_]
        }

        # state['lane_vehicle_count'] = state['start_lane_vehicle_count'].copy()
        # state['lane_vehicle_count'].update(state['end_lane_vehicle_count'].items())
        state['start_lane_waiting_vehicle_count'] = {
            lane: get_lane_waiting_vehicle_count[lane]
            for lane in self.start_lane[id_]
        }
        # state['end_lane_waiting_vehicle_count'] = {lane: get_lane_waiting_vehicle_count[lane] for lane in
        #                                            self.end_lane[id_]}
        #
        # state['start_lane_vehicles'] = {lane: get_lane_vehicles[lane] for lane in self.start_lane[id_]}
        # state['end_lane_vehicles'] = {lane: get_lane_vehicles[lane] for lane in self.end_lane[id_]}
        #
        # state['start_lane_speed'] = {
        #     lane: np.sum(list(map(lambda vehicle: get_vehicle_speed[vehicle], get_lane_vehicles[lane]))) / (
        #             get_lane_vehicle_count[lane] + 1e-5) for lane in
        #     self.start_lane[id_]}  # compute start lane mean speed
        # state['end_lane_speed'] = {
        #     lane: np.sum(list(map(lambda vehicle: get_vehicle_speed[vehicle], get_lane_vehicles[lane]))) / (
        #             get_lane_vehicle_count[lane] + 1e-5) for lane in
        #     self.end_lane[id_]}  # compute end lane mean speed

        state['current_phase'] = self.current_phase[id_]
        # state['current_phase_time'] = self.current_phase_time[id_]

        state['adjacency_matrix'] = self.traffic_light_node_dict[id_][
            'adjacency_row']

        return state

    def preprocess_state(self, state):
        return_state = np.array(state)
        if self.state_size is None:
            self.state_size = len(return_state.flatten())
        return_state = np.reshape(np.array(return_state),
                                  [1, self.state_size]).flatten()
        return return_state

    def get_raw_reward(self):
        reward = {
            id_: self._get_raw_reward(id_)
            for id_ in self.intersection_id
        }
        # mean_global_sum = np.mean(list(reward.values()))

        return reward

    def _get_raw_reward(self, id_):
        # every agent/intersection's reward
        state = self.intersection_info(id_)
        # r = max(list(state['start_lane_vehicle_count'].values()))
        r = max(list(state['start_lane_waiting_vehicle_count'].values()))
        return -r

    # def get_score(self):
    #     score = {id_: self._get_score(id_) for id_ in self.intersection_id}
    #     score = self.dict_Avg(score)
    #     return score

    # def _get_score(self, id_):
    #     state = self.intersection_info(id_)
    #     start_lane_speed = state['start_lane_speed']
    #     end_lane_speed = state['end_lane_speed']
    #     score = (self.dict_Avg(start_lane_speed) + self.dict_Avg(end_lane_speed)) / 2
    #     # score = (1 / (1 + np.exp(-1 * x))) / self.num_step
    #     return score

    def sigmoid(self, x):
        return 1 / (1 + math.exp(-x))

    def dict_Avg(self, Dict):
        Len = len(Dict)  # 取字典中键值对的个数
        Sum = sum(Dict.values())  # 取字典中键对应值的总和
        Avg = Sum / Len
        return Avg

    def _adjacency_extraction(self):
        traffic_light_node_dict = {}
        file = self.roadnetFile
        with open('{0}'.format(file)) as json_data:
            net = json.load(json_data)
            # print(net)

            # build the info dict for intersections
            for inter in net['intersections']:
                if not inter['virtual']:
                    traffic_light_node_dict[inter['id']] = {
                        'location': {
                            'x': float(inter['point']['x']),
                            'y': float(inter['point']['y'])
                        },
                        "total_inter_num": None,
                        'adjacency_row': None,
                        "inter_id_to_index": None,
                        "neighbor_ENWS": None,
                        "entering_lane_ENWS": None,
                    }

            top_k = self.dic_traffic_env_conf["TOP_K_ADJACENCY"]
            total_inter_num = len(traffic_light_node_dict.keys())
            inter_id_to_index = {}

            edge_id_dict = {}
            for road in net['roads']:
                if road['id'] not in edge_id_dict.keys():
                    edge_id_dict[road['id']] = {}
                edge_id_dict[road['id']]['from'] = road['startIntersection']
                edge_id_dict[road['id']]['to'] = road['endIntersection']
                edge_id_dict[road['id']]['num_of_lane'] = len(road['lanes'])
                edge_id_dict[road['id']]['length'] = np.sqrt(
                    np.square(pd.DataFrame(road['points'])).sum(axis=1)).sum()

            index = 0
            for i in traffic_light_node_dict.keys():
                inter_id_to_index[i] = index
                index += 1

            for i in traffic_light_node_dict.keys():
                traffic_light_node_dict[i][
                    'inter_id_to_index'] = inter_id_to_index
                traffic_light_node_dict[i]['neighbor_ENWS'] = []
                traffic_light_node_dict[i]['entering_lane_ENWS'] = {
                    "lane_ids": [],
                    "lane_length": []
                }
                for j in range(4):
                    road_id = i.replace("intersection", "road") + "_" + str(j)
                    neighboring_node = edge_id_dict[road_id]['to']
                    # calculate the neighboring intersections
                    if neighboring_node not in traffic_light_node_dict.keys(
                    ):  # virtual node
                        traffic_light_node_dict[i]['neighbor_ENWS'].append(
                            None)
                    else:
                        traffic_light_node_dict[i]['neighbor_ENWS'].append(
                            neighboring_node)
                    # calculate the entering lanes ENWS
                    for key, value in edge_id_dict.items():
                        if value['from'] == neighboring_node and value[
                                'to'] == i:
                            neighboring_road = key

                            neighboring_lanes = []
                            for k in range(value['num_of_lane']):
                                neighboring_lanes.append(neighboring_road +
                                                         "_{0}".format(k))

                            traffic_light_node_dict[i]['entering_lane_ENWS'][
                                'lane_ids'].append(neighboring_lanes)
                            traffic_light_node_dict[i]['entering_lane_ENWS'][
                                'lane_length'].append(value['length'])

            for i in traffic_light_node_dict.keys():
                location_1 = traffic_light_node_dict[i]['location']

                # TODO return with Top K results
                if not self.dic_traffic_env_conf[
                        'ADJACENCY_BY_CONNECTION_OR_GEO']:  # use geo-distance
                    row = np.array([0] * total_inter_num)
                    # row = np.zeros((self.dic_traffic_env_conf["NUM_ROW"],self.dic_traffic_env_conf["NUM_col"]))
                    for j in traffic_light_node_dict.keys():
                        location_2 = traffic_light_node_dict[j]['location']
                        dist = self._cal_distance(location_1, location_2)
                        row[inter_id_to_index[j]] = dist
                    if len(row) == top_k:
                        adjacency_row_unsorted = np.argpartition(
                            row, -1)[:top_k].tolist()
                    elif len(row) > top_k:
                        adjacency_row_unsorted = np.argpartition(
                            row, top_k)[:top_k].tolist()
                    else:
                        adjacency_row_unsorted = [
                            k for k in range(total_inter_num)
                        ]
                    adjacency_row_unsorted.remove(inter_id_to_index[i])
                    traffic_light_node_dict[i]['adjacency_row'] = [
                        inter_id_to_index[i]
                    ] + adjacency_row_unsorted
                else:  # use connection infomation
                    traffic_light_node_dict[i]['adjacency_row'] = [
                        inter_id_to_index[i]
                    ]
                    for j in traffic_light_node_dict[i][
                            'neighbor_ENWS']:  ## TODO
                        if j is not None:
                            traffic_light_node_dict[i]['adjacency_row'].append(
                                inter_id_to_index[j])
                        else:
                            traffic_light_node_dict[i]['adjacency_row'].append(
                                -1)

                traffic_light_node_dict[i]['total_inter_num'] = total_inter_num

            path_to_save = os.path.join(self.record_dir,
                                        'traffic_light_node_dict.conf')
            with open(path_to_save, 'w') as f:
                f.write(str(traffic_light_node_dict))
                print("\033[1;33mSaved traffic_light_node_dict\033[0m")

        return traffic_light_node_dict

    def _update_leave_entering_approach_vehicle(self, id_):

        list_entering_lane_vehicle_left = []

        # update vehicles leaving entering lane
        if not self.dic_lane_vehicle_previous_step[id_]:
            for lane in self.lane_phase_info[id_]['start_lane']:
                list_entering_lane_vehicle_left.append([])
        else:
            last_step_vehicle_id_list = []
            current_step_vehilce_id_list = []
            for lane in self.lane_phase_info[id_]['start_lane']:
                last_step_vehicle_id_list.extend(
                    self.dic_lane_vehicle_previous_step[id_][lane])
                current_step_vehilce_id_list.extend(
                    self.dic_lane_vehicle_current_step[id_][lane])

            list_entering_lane_vehicle_left.append(
                list(
                    set(last_step_vehicle_id_list) -
                    set(current_step_vehilce_id_list)))

        return list_entering_lane_vehicle_left

    def _update_arrive_time(self, id_, list_vehicle_arrive):

        ts = self.get_current_time()
        # get dic vehicle enter leave time
        for vehicle in list_vehicle_arrive:
            if vehicle not in self.dic_vehicle_arrive_leave_time[id_]:
                self.dic_vehicle_arrive_leave_time[id_][vehicle] = \
                    {"enter_time": ts, "leave_time": np.nan}
            else:
                # print("vehicle: %s already exists in entering lane!"%vehicle)
                # sys.exit(-1)
                pass

    def _update_left_time(self, id_, list_vehicle_left):

        ts = self.get_current_time()
        # update the time for vehicle to leave entering lane
        for vehicle in list_vehicle_left:
            try:
                self.dic_vehicle_arrive_leave_time[id_][vehicle][
                    "leave_time"] = ts
            except KeyError:
                print("vehicle not recorded when entering")
                sys.exit(-1)

    @staticmethod
    def _cal_distance(loc_dict1, loc_dict2):
        a = np.array((loc_dict1['x'], loc_dict1['y']))
        b = np.array((loc_dict2['x'], loc_dict2['y']))
        return np.sqrt(np.sum((a - b)**2))

    def get_current_time(self):
        return self.eng.get_current_time()

    def log(self, path_to_log):
        for id_ in self.intersection_id:
            # print("log for ", id_)
            path_to_log_file = os.path.join(path_to_log,
                                            "vehicle_{0}.csv".format(id_))
            dic_vehicle = self.get_dic_vehicle_arrive_leave_time(id_)
            df = pd.DataFrame.from_dict(dic_vehicle, orient='index')
            df.to_csv(path_to_log_file, na_rep="nan")

            # path_to_log_file = os.path.join(self.path_to_log, "inter_{0}.pkl".format(inter_ind))
            # f = open(path_to_log_file, "wb")
            #
            # # Use pickle to pack data flow into
            # pickle.dump(self.list_inter_log[inter_ind], f)
            # f.close()

    def get_dic_vehicle_arrive_leave_time(self, id_):
        return self.dic_vehicle_arrive_leave_time[id_]
コード例 #20
0
ファイル: test_rollout.py プロジェクト: yiranwang52/ray
def learn_test_multi_agent_plus_rollout(algo):
    for fw in framework_iterator(frameworks=("tf", "torch")):
        tmp_dir = os.popen("mktemp -d").read()[:-1]
        if not os.path.exists(tmp_dir):
            # Last resort: Resolve via underlying tempdir (and cut tmp_.
            tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
            if not os.path.exists(tmp_dir):
                sys.exit(1)

        print("Saving results to {}".format(tmp_dir))

        rllib_dir = str(Path(__file__).parent.parent.absolute())
        print("RLlib dir = {}\nexists={}".format(rllib_dir,
                                                 os.path.exists(rllib_dir)))

        def policy_fn(agent_id, episode, **kwargs):
            return "pol{}".format(agent_id)

        observation_space = Box(float("-inf"), float("inf"), (4, ))
        action_space = Discrete(2)

        config = {
            "num_gpus": 0,
            "num_workers": 1,
            "evaluation_config": {
                "explore": False
            },
            "framework": fw,
            "env": MultiAgentCartPole,
            "multiagent": {
                "policies": {
                    "pol0": (None, observation_space, action_space, {}),
                    "pol1": (None, observation_space, action_space, {}),
                },
                "policy_mapping_fn": policy_fn,
            },
        }
        stop = {"episode_reward_mean": 150.0}
        tune.run(
            algo,
            config=config,
            stop=stop,
            checkpoint_freq=1,
            checkpoint_at_end=True,
            local_dir=tmp_dir,
            verbose=1)

        # Find last checkpoint and use that for the rollout.
        checkpoint_path = os.popen("ls {}/PPO/*/checkpoint_*/"
                                   "checkpoint-*".format(tmp_dir)).read()[:-1]
        checkpoint_paths = checkpoint_path.split("\n")
        assert len(checkpoint_paths) > 0
        checkpoints = [
            cp for cp in checkpoint_paths
            if re.match(r"^.+checkpoint-\d+$", cp)
        ]
        # Sort by number and pick last (which should be the best checkpoint).
        last_checkpoint = sorted(
            checkpoints,
            key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1]
        assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
        if not os.path.exists(last_checkpoint):
            sys.exit(1)
        print("Best checkpoint={} (exists)".format(last_checkpoint))

        ray.shutdown()

        # Test rolling out n steps.
        result = os.popen(
            "python {}/rollout.py --run={} "
            "--steps=400 "
            "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
                rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
        if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
            sys.exit(1)
        print("Rollout output exists -> Checking reward ...")
        episodes = result.split("\n")
        mean_reward = 0.0
        num_episodes = 0
        for ep in episodes:
            mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
            if mo:
                mean_reward += float(mo.group(1))
                num_episodes += 1
        mean_reward /= num_episodes
        print("Rollout's mean episode reward={}".format(mean_reward))
        assert mean_reward >= 150.0

        # Cleanup.
        os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
コード例 #21
0
ファイル: test_flatten.py プロジェクト: zhuzhenping/gym
        self.observation_space = observation_space

    def reset(self,
              *,
              seed: Optional[int] = None,
              options: Optional[dict] = None):
        super().reset(seed=seed)
        self.observation = self.observation_space.sample()
        return self.observation


OBSERVATION_SPACES = (
    (
        Dict(
            OrderedDict([
                ("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)),
                ("key2", Box(shape=(), low=1, high=1, dtype=np.float32)),
                ("key3", Box(shape=(2, ), low=2, high=2, dtype=np.float32)),
            ])),
        True,
    ),
    (
        Dict(
            OrderedDict([
                ("key2", Box(shape=(), low=0, high=0, dtype=np.float32)),
                ("key3", Box(shape=(2, ), low=1, high=1, dtype=np.float32)),
                ("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)),
            ])),
        True,
    ),
    (
コード例 #22
0
ファイル: envs.py プロジェクト: thejose5/rl-project
 def __init__(self, env, width, height):
     assert isinstance(env.observation_space, Box)
     super(ResizeWrapper, self).__init__(env)
     self.width = width
     self.height = height
     self.observation_space = Box(0, 255, (height, width) + env.observation_space.low.shape[2:])
コード例 #23
0
    def __init__(
            self,
            obj_low=None,
            obj_high=None,
            random_init=False,
            tasks=[{
                'goal': np.array([0.1, 0.8, 0.2]),
                'obj_init_pos': np.array([0, 0.6, 0.02]),
                'obj_init_angle': 0.3
            }],
            goal_low=None,
            goal_high=None,
            hand_init_pos=(0, 0.6, 0.2),
            liftThresh=0.04,
            rewMode='orig',
            rotMode='rotz',  #'fixed',
            **kwargs):

        hand_low = (-0.5, 0.40, 0.05)
        hand_high = (0.5, 1, 0.5)
        obj_low = (-0.5, 0.40, 0.05)
        obj_high = (0.5, 1, 0.5)
        SawyerXYZEnv.__init__(self,
                              frame_skip=5,
                              action_scale=1. / 100,
                              hand_low=hand_low,
                              hand_high=hand_high,
                              model_name=self.model_name,
                              **kwargs)
        if obj_low is None:
            obj_low = self.hand_low

        if goal_low is None:
            goal_low = self.hand_low

        if obj_high is None:
            obj_high = self.hand_high

        if goal_high is None:
            goal_high = self.hand_high

        self.random_init = random_init
        self.liftThresh = liftThresh
        self.max_path_length = 200  #150
        self.tasks = tasks
        self.num_tasks = len(tasks)
        self.rewMode = rewMode
        self.rotMode = rotMode
        self.hand_init_pos = np.array(hand_init_pos)
        if rotMode == 'fixed':
            self.action_space = Box(
                np.array([-1, -1, -1, -1]),
                np.array([1, 1, 1, 1]),
            )
        elif rotMode == 'rotz':
            self.action_rot_scale = 1. / 50
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi, -1]),
                np.array([1, 1, 1, np.pi, 1]),
            )
        elif rotMode == 'quat':
            self.action_space = Box(
                np.array([-1, -1, -1, 0, -1, -1, -1, -1]),
                np.array([1, 1, 1, 2 * np.pi, 1, 1, 1, 1]),
            )
        else:
            self.action_space = Box(
                np.array([-1, -1, -1, -np.pi / 2, -np.pi / 2, 0, -1]),
                np.array([1, 1, 1, np.pi / 2, np.pi / 2, np.pi * 2, 1]),
            )
        self.hand_and_obj_space = Box(
            np.hstack((self.hand_low, obj_low)),
            np.hstack((self.hand_high, obj_high)),
        )
        self.goal_space = Box(goal_low, goal_high)
        self.observation_space = Box(
            np.hstack((self.hand_low, obj_low, obj_low)),
            np.hstack((self.hand_high, obj_high, obj_high)),
        )
コード例 #24
0
ファイル: env_config.py プロジェクト: xuyanbo03/ModelArts-Lab
import numpy as np
from gym.spaces import Discrete, Box

action_space = Discrete(8)
observation_space = Box(-np.inf, np.inf, shape=(240, 320, 1), dtype=np.float32)
コード例 #25
0
ファイル: sawyer_door.py プロジェクト: shikharbahl/multiworld
    def __init__(self,
                 goal_low=None,
                 goal_high=None,
                 action_reward_scale=0,
                 reward_type='angle_difference',
                 indicator_threshold=(.02, .03),
                 fix_goal=False,
                 fixed_goal=(0, .45, .12, -.25),
                 reset_free=False,
                 fixed_hand_z=0.12,
                 hand_low=(-0.25, 0.3, .12),
                 hand_high=(0.25, 0.6, .12),
                 target_pos_scale=1,
                 target_angle_scale=1,
                 min_angle=-1.5708,
                 max_angle=0,
                 xml_path='sawyer_xyz/sawyer_door_pull.xml',
                 **sawyer_xyz_kwargs):
        self.quick_init(locals())
        self.model_name = get_asset_full_path(xml_path)
        SawyerXYZEnv.__init__(self,
                              self.model_name,
                              hand_low=hand_low,
                              hand_high=hand_high,
                              **sawyer_xyz_kwargs)
        MultitaskEnv.__init__(self)

        self.reward_type = reward_type
        self.indicator_threshold = indicator_threshold

        self.fix_goal = fix_goal
        self.fixed_goal = np.array(fixed_goal)
        self._state_goal = None
        self.fixed_hand_z = fixed_hand_z

        self.action_space = Box(np.array([-1, -1]),
                                np.array([1, 1]),
                                dtype=np.float32)
        self.state_space = Box(
            np.concatenate((hand_low, [min_angle])),
            np.concatenate((hand_high, [max_angle])),
            dtype=np.float32,
        )
        if goal_low is None:
            goal_low = self.state_space.low
        if goal_high is None:
            goal_high = self.state_space.high
        self.goal_space = Box(
            np.array(goal_low),
            np.array(goal_high),
            dtype=np.float32,
        )
        self.observation_space = Dict([
            ('observation', self.state_space),
            ('desired_goal', self.goal_space),
            ('achieved_goal', self.state_space),
            ('state_observation', self.state_space),
            ('state_desired_goal', self.goal_space),
            ('state_achieved_goal', self.state_space),
        ])
        self.action_reward_scale = action_reward_scale
        self.target_pos_scale = target_pos_scale
        self.target_angle_scale = target_angle_scale
        self.reset_free = reset_free
        self.door_angle_idx = self.model.get_joint_qpos_addr('doorjoint')
        self.reset()
コード例 #26
0
        ('/rl/on_pedestrian', 'std_msgs.msg.Bool'),
        ('/rl/obs_factor', 'std_msgs.msg.Float32'),
    ],
    defs_action=[('/autoDrive_KeyboardMode', 'std_msgs.msg.Char')],
    rate_action=10.0,
    window_sizes={'obs': 2, 'reward': 3},
    buffer_sizes={'obs': 2, 'reward': 3},
    func_compile_obs=func_compile_obs,
    func_compile_reward=func_compile_reward,
    func_compile_action=func_compile_action,
    step_delay_target=0.5,
    is_dummy_action=False)


# TODO: define these Gym related params insode DrivingSimulatorEnv
env.observation_space = Box(low=0, high=255, shape=(350, 350, 3))
env.reward_range = (-np.inf, np.inf)
env.metadata = {}
env.action_space = Discrete(len(ACTIONS))
env = FrameStack(env, 3)

n_interactive = 0
n_skip = 1
n_additional_learn = 4
n_ep = 0  # last ep in the last run, if restart use 0
n_test = 10  # num of episode per test run (no exploration)
state_shape = (350, 350, 9)
tf.app.flags.DEFINE_string("logdir",
                           "/home/pirate03/PycharmProjects/hobotrl/playground/initialD/imitaion_learning/fnet_rename_learn_q_no_skip_tmp_debug",
                           """save tmp model""")
tf.app.flags.DEFINE_string("savedir",
コード例 #27
0
ファイル: misc.py プロジェクト: xiangxud/h-baselines
def get_manager_ac_space(ob_space, relative_goals, env_name, use_fingerprints,
                         fingerprint_dim):
    """Compute the action space for the Manager.

    If the fingerprint terms are being appended onto the observations, this
    should be removed from the action space.

    Parameters
    ----------
    ob_space : gym.spaces.*
        the observation space of the environment
    relative_goals : bool
        specifies whether the goal issued by the Manager is meant to be a
        relative or absolute goal, i.e. specific state or change in state
    env_name : str
        the name of the environment. Used for special cases to assign the
        Manager action space to only ego observations in the observation space.
    use_fingerprints : bool
        specifies whether to add a time-dependent fingerprint to the
        observations
    fingerprint_dim : tuple of int
        the shape of the fingerprint elements, if they are being used

    Returns
    -------
    gym.spaces.Box
        the action space of the Manager policy
    """
    if env_name in [
            "AntMaze", "AntPush", "AntFall", "AntGather", "AntFourRooms"
    ]:
        manager_ac_space = Box(
            low=np.array([
                -10, -10, -0.5, -1, -1, -1, -1, -0.5, -0.3, -0.5, -0.3, -0.5,
                -0.3, -0.5, -0.3
            ]),
            high=np.array([
                10, 10, 0.5, 1, 1, 1, 1, 0.5, 0.3, 0.5, 0.3, 0.5, 0.3, 0.5, 0.3
            ]),
            dtype=np.float32,
        )
    elif env_name == "UR5":
        manager_ac_space = Box(
            low=np.array([-2 * np.pi, -2 * np.pi, -2 * np.pi, -4, -4, -4]),
            high=np.array([2 * np.pi, 2 * np.pi, 2 * np.pi, 4, 4, 4]),
            dtype=np.float32,
        )
    elif env_name == "Pendulum":
        manager_ac_space = Box(low=np.array([-np.pi, -15]),
                               high=np.array([np.pi, 15]),
                               dtype=np.float32)
    elif env_name in ["ring0", "ring1"]:
        if relative_goals:
            manager_ac_space = Box(-.5, .5, shape=(1, ), dtype=np.float32)
        else:
            manager_ac_space = Box(0, 1, shape=(1, ), dtype=np.float32)
    elif env_name == "figureeight0":
        if relative_goals:
            manager_ac_space = Box(-.5, .5, shape=(1, ), dtype=np.float32)
        else:
            manager_ac_space = Box(0, 1, shape=(1, ), dtype=np.float32)
    elif env_name == "figureeight1":
        if relative_goals:
            manager_ac_space = Box(-.5, .5, shape=(7, ), dtype=np.float32)
        else:
            manager_ac_space = Box(0, 1, shape=(7, ), dtype=np.float32)
    elif env_name == "figureeight2":
        if relative_goals:
            manager_ac_space = Box(-.5, .5, shape=(14, ), dtype=np.float32)
        else:
            manager_ac_space = Box(0, 1, shape=(14, ), dtype=np.float32)
    elif env_name == "merge0":
        if relative_goals:
            manager_ac_space = Box(-.5, .5, shape=(5, ), dtype=np.float32)
        else:
            manager_ac_space = Box(0, 1, shape=(5, ), dtype=np.float32)
    elif env_name == "merge1":
        if relative_goals:
            manager_ac_space = Box(-.5, .5, shape=(13, ), dtype=np.float32)
        else:
            manager_ac_space = Box(0, 1, shape=(13, ), dtype=np.float32)
    elif env_name == "merge2":
        if relative_goals:
            manager_ac_space = Box(-.5, .5, shape=(17, ), dtype=np.float32)
        else:
            manager_ac_space = Box(0, 1, shape=(17, ), dtype=np.float32)
    elif env_name == "PD-Biped3D-HLC-Soccer-v1":
        manager_ac_space = Box(
            low=np.array(
                [0, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -1, -2]),
            high=np.array([1.5, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2]),
            dtype=np.float32)
    else:
        if use_fingerprints:
            low = np.array(ob_space.low)[:-fingerprint_dim[0]]
            high = ob_space.high[:-fingerprint_dim[0]]
            manager_ac_space = Box(low=low, high=high, dtype=np.float32)
        else:
            manager_ac_space = ob_space

    return manager_ac_space
コード例 #28
0
 def __init__(self, env):
     super().__init__(env)
     self.n = None
     if isinstance(self.env.observation_space, Discrete):
         self.n = self.env.observation_space.n
         self.observation_space = Box(0, 1, (self.n,))
コード例 #29
0
    def __init__(self, random_init=False, task_type='pick_place'):
        liftThresh = 0.04
        goal_low = (-0.1, 0.8, 0.05)
        goal_high = (0.1, 0.9, 0.3)
        hand_low = (-0.5, 0.40, 0.05)
        hand_high = (0.5, 1, 0.5)
        obj_low = (-0.1, 0.6, 0.02)
        obj_high = (0.1, 0.7, 0.02)

        self.task_types = ['pick_place', 'reach', 'push']

        super().__init__(
            self.model_name,
            hand_low=hand_low,
            hand_high=hand_high,
        )

        self.task_type = task_type
        self.init_config = {
            'obj_init_angle': .3,
            'obj_init_pos': np.array([0, 0.6, 0.02]),
            'hand_init_pos': np.array([0, .6, .2]),
        }

        # we only do one task from [pick_place, reach, push]
        # per instance of SawyerReachPushPickPlaceEnv.
        # Please only set task_type from constructor.
        if self.task_type == 'pick_place':
            self.goal = np.array([0.1, 0.8, 0.2])
        elif self.task_type == 'reach':
            self.goal = np.array([-0.1, 0.8, 0.2])
        elif self.task_type == 'push':
            self.goal = np.array([0.1, 0.8, 0.02])
        else:
            raise NotImplementedError

        self.obj_init_angle = self.init_config['obj_init_angle']
        self.obj_init_pos = self.init_config['obj_init_pos']
        self.hand_init_pos = self.init_config['hand_init_pos']

        self.random_init = random_init
        self.liftThresh = liftThresh
        self.max_path_length = 150

        self.action_space = Box(
            np.array([-1, -1, -1, -1]),
            np.array([1, 1, 1, 1]),
        )

        self.obj_and_goal_space = Box(
            np.hstack((obj_low, goal_low)),
            np.hstack((obj_high, goal_high)),
        )
        self.goal_space = Box(np.array(goal_low), np.array(goal_high))

        self.observation_space = Box(
            np.hstack((
                self.hand_low,
                obj_low,
            )),
            np.hstack((
                self.hand_high,
                obj_high,
            )),
        )

        self.num_resets = 0
        self.reset()
コード例 #30
0
class TestC51Network(TestBaseNetwork):
    __test__ = True

    network = C51Network

    list_work = [[Discrete(3), Discrete(1)], [Discrete(3),
                                              Discrete(3)],
                 [Discrete(10), Discrete(50)],
                 [MultiDiscrete([3]), MultiDiscrete([1])],
                 [MultiDiscrete([3, 3]),
                  MultiDiscrete([3, 3])],
                 [MultiDiscrete([4, 4, 4]),
                  MultiDiscrete([50, 4, 4])],
                 [
                     MultiDiscrete([[100, 3], [3, 5]]),
                     MultiDiscrete([[100, 3], [3, 5]])
                 ],
                 [
                     MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]]),
                     MultiDiscrete([[[100, 3], [3, 5]], [[100, 3], [3, 5]]])
                 ]]

    list_fail = [
        [None, None],
        ["dedrfe", "qdzq"],
        [1215.4154, 157.48],
        ["zdzd", (Discrete(1))],
        [Discrete(1), "zdzd"],
        ["zdzd", (1, 4, 7)],
        [(1, 4, 7), "zdzd"],
        [152, 485],
        [MultiBinary(1), MultiBinary(1)],
        [MultiBinary(3), MultiBinary(3)],
        # [MultiBinary([3, 2]), MultiBinary([3, 2])], # Don't work yet because gym don't implemented this
        [Box(low=0, high=10, shape=[1]),
         Box(low=0, high=10, shape=[1])],
        [Box(low=0, high=10, shape=[2, 2]),
         Box(low=0, high=10, shape=[2, 2])],
        [
            Box(low=0, high=10, shape=[2, 2, 2]),
            Box(low=0, high=10, shape=[2, 2, 2])
        ],
        [
            Tuple([Discrete(1), MultiDiscrete([1, 1])]),
            Tuple([Discrete(1), MultiDiscrete([1, 1])])
        ],
        [
            Dict({
                "first": Discrete(1),
                "second": MultiDiscrete([1, 1])
            }),
            Dict({
                "first": Discrete(1),
                "second": MultiDiscrete([1, 1])
            })
        ]
    ]

    def test_init(self):
        for ob, ac in self.list_fail:
            with pytest.raises(TypeError):
                self.network(observation_space=ob, action_space=ac)

        for ob, ac in self.list_work:
            self.network(observation_space=ob, action_space=ac)

    def test_forward(self):
        for ob, ac in self.list_work:
            network = self.network(observation_space=ob, action_space=ac)
            network.forward(torch.rand((1, flatdim(ob))))

    def test_str_(self):
        for ob, ac in self.list_work:
            network = self.network(observation_space=ob, action_space=ac)

            assert 'C51Network-' + str(ob) + "-" + str(ac) == network.__str__()