Exemplo n.º 1
0
    def __init__(self, obs_space, action_space, config):
        self.observation_space = obs_space
        self.action_space = action_space
        self.action_noise_std = config["action_noise_std"]
        self.preprocessor = ModelCatalog.get_preprocessor_for_space(
            self.observation_space)
        self.observation_filter = get_filter(config["observation_filter"],
                                             self.preprocessor.shape)

        self.single_threaded = config.get("single_threaded", False)
        self.sess = make_session(single_threaded=self.single_threaded)

        self.inputs = tf.placeholder(tf.float32,
                                     [None] + list(self.preprocessor.shape))

        # Policy network.
        dist_class, dist_dim = ModelCatalog.get_action_dist(
            self.action_space, config["model"], dist_type="deterministic")

        model = ModelCatalog.get_model({SampleBatch.CUR_OBS: self.inputs},
                                       self.observation_space,
                                       self.action_space, dist_dim,
                                       config["model"])
        dist = dist_class(model.outputs, model)
        self.sampler = dist.sample()

        self.variables = ray.experimental.tf_utils.TensorFlowVariables(
            model.outputs, self.sess)

        self.num_params = sum(
            np.prod(variable.shape.as_list())
            for _, variable in self.variables.variables.items())
        self.sess.run(tf.global_variables_initializer())
Exemplo n.º 2
0
    def __init__(self, obs_space, action_space, config):
        super().__init__(obs_space, action_space, config)
        self.action_noise_std = self.config["action_noise_std"]
        self.preprocessor = ModelCatalog.get_preprocessor_for_space(
            self.observation_space)
        self.observation_filter = get_filter(self.config["observation_filter"],
                                             self.preprocessor.shape)

        self.single_threaded = self.config.get("single_threaded", False)
        if self.config["framework"] == "tf":
            self.sess = make_session(single_threaded=self.single_threaded)

            # Set graph-level seed.
            if config.get("seed") is not None:
                with self.sess.as_default():
                    tf1.set_random_seed(config["seed"])

            self.inputs = tf1.placeholder(tf.float32, [None] +
                                          list(self.preprocessor.shape))
        else:
            if not tf1.executing_eagerly():
                tf1.enable_eager_execution()
            self.sess = self.inputs = None
            if config.get("seed") is not None:
                # Tf2.x.
                if config.get("framework") == "tf2":
                    tf.random.set_seed(config["seed"])
                # Tf-eager.
                elif tf1 and config.get("framework") == "tfe":
                    tf1.set_random_seed(config["seed"])

        # Policy network.
        self.dist_class, dist_dim = ModelCatalog.get_action_dist(
            self.action_space, self.config["model"], dist_type="deterministic")

        self.model = ModelCatalog.get_model_v2(
            obs_space=self.preprocessor.observation_space,
            action_space=self.action_space,
            num_outputs=dist_dim,
            model_config=self.config["model"],
        )

        self.sampler = None
        if self.sess:
            dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs})
            dist = self.dist_class(dist_inputs, self.model)
            self.sampler = dist.sample()
            self.variables = ray.experimental.tf_utils.TensorFlowVariables(
                dist_inputs, self.sess)
            self.sess.run(tf1.global_variables_initializer())
        else:
            self.variables = ray.experimental.tf_utils.TensorFlowVariables(
                [], None, self.model.variables())

        self.num_params = sum(
            np.prod(variable.shape.as_list())
            for _, variable in self.variables.variables.items())