Example #1
0
    def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        """Initializes a TorchMultiActionDistribution object.

        Args:
            inputs (torch.Tensor): A single tensor of shape [BATCH, size].
            model (ModelV2): The ModelV2 object used to produce inputs for this
                distribution.
            child_distributions (any[torch.Tensor]): Any struct
                that contains the child distribution classes to use to
                instantiate the child distributions from `inputs`. This could
                be an already flattened list or a struct according to
                `action_space`.
            input_lens (any[int]): A flat list or a nested struct of input
                split lengths used to split `inputs`.
            action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
                and possibly nested action space.
        """
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)
        super().__init__(inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = tree.flatten(input_lens)
        flat_child_distributions = tree.flatten(child_distributions)
        split_inputs = torch.split(inputs, input_lens, dim=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), flat_child_distributions,
            list(split_inputs))
Example #2
0
    def __init__(self, obs_space, action_space, config):
        self.observation_space = obs_space
        self.action_space = action_space
        self.action_space_struct = get_base_struct_from_space(action_space)
        self.action_noise_std = config["action_noise_std"]
        self.preprocessor = ModelCatalog.get_preprocessor_for_space(obs_space)
        self.observation_filter = get_filter(config["observation_filter"],
                                             self.preprocessor.shape)
        self.single_threaded = config.get("single_threaded", False)
        self.sess = make_session(single_threaded=self.single_threaded)
        self.inputs = tf.placeholder(tf.float32,
                                     [None] + list(self.preprocessor.shape))

        # Policy network.
        dist_class, dist_dim = ModelCatalog.get_action_dist(
            self.action_space, config["model"], dist_type="deterministic")
        self.model = ModelCatalog.get_model_v2(
            obs_space=self.preprocessor.observation_space,
            action_space=action_space,
            num_outputs=dist_dim,
            model_config=config["model"])
        dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
        dist = dist_class(dist_inputs, self.model)
        self.sampler = dist.sample()

        self.variables = ray.experimental.tf_utils.TensorFlowVariables(
            dist_inputs, self.sess)

        self.num_params = sum(
            np.prod(variable.shape.as_list())
            for _, variable in self.variables.variables.items())
        self.sess.run(tf.global_variables_initializer())
    def __init__(self, inputs, model, *, child_distributions, input_lens,
                 action_space):
        ActionDistribution.__init__(self, inputs, model)

        self.action_space_struct = get_base_struct_from_space(action_space)

        input_lens = np.array(input_lens, dtype=np.int32)
        split_inputs = tf.split(inputs, input_lens, axis=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), child_distributions,
            split_inputs)
Example #4
0
    def __init__(self, action_space, *, model, framework, **kwargs):
        """Initialize a Random Exploration object.

        Args:
            action_space (Space): The gym action space used by the environment.
            framework (Optional[str]): One of None, "tf", "torch".
        """
        super().__init__(action_space=action_space,
                         model=model,
                         framework=framework,
                         **kwargs)

        self.action_space_struct = get_base_struct_from_space(
            self.action_space)
Example #5
0
    def __init__(self, observation_space, action_space, config):
        """Initialize the graph.

        This is the standard constructor for policies. The policy
        class you pass into RolloutWorker will be constructed with
        these arguments.

        Args:
            observation_space (gym.Space): Observation space of the policy.
            action_space (gym.Space): Action space of the policy.
            config (dict): Policy-specific configuration data.
        """
        self.observation_space = observation_space
        self.action_space = action_space
        self.action_space_struct = get_base_struct_from_space(action_space)
        self.config = config
        # The global timestep, broadcast down from time to time from the
        # driver.
        self.global_timestep = 0
        # The action distribution class to use for action sampling, if any.
        # Child classes may set this.
        self.dist_class = None