Esempio n. 1
0
    def get_preprocessor(cls, env_name, obs_shape, options=dict()):
        """Returns a suitable processor for the given environment.

        Args:
            env_name (str): The name of the environment.
            obs_shape (tuple): The shape of the env observation space.

        Returns:
            preprocessor (Preprocessor): Preprocessor for the env observations.
        """

        ATARI_OBS_SHAPE = (210, 160, 3)
        ATARI_RAM_OBS_SHAPE = (128, )

        for k in options.keys():
            if k not in MODEL_CONFIGS:
                raise Exception("Unknown config key `{}`, all keys: {}".format(
                    k, MODEL_CONFIGS))

        if env_name in cls._registered_preprocessor:
            return cls._registered_preprocessor[env_name](options)

        if obs_shape == ATARI_OBS_SHAPE:
            print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
            return AtariPixelPreprocessor(options)
        elif obs_shape == ATARI_RAM_OBS_SHAPE:
            print("Assuming Atari ram env, using AtariRamPreprocessor.")
            return AtariRamPreprocessor(options)

        print("Non-atari env, not using any observation preprocessor.")
        return NoPreprocessor(options)
Esempio n. 2
0
    def get_preprocessor(env_name, obs_shape):
        """Returns a suitable processor for the given environment.

        Args:
            env_name (str): The name of the environment.
            obs_shape (tuple): The shape of the env observation space.

        Returns:
            preprocessor (Preprocessor): Preprocessor for the env observations.
        """

        ATARI_OBS_SHAPE = (210, 160, 3)
        ATARI_RAM_OBS_SHAPE = (128,)

        if obs_shape == ATARI_OBS_SHAPE:
            print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
            return AtariPixelPreprocessor()
        elif obs_shape == ATARI_RAM_OBS_SHAPE:
            print("Assuming Atari ram env, using AtariRamPreprocessor.")
            return AtariRamPreprocessor()

        print("Non-atari env, not using any observation preprocessor.")
        return NoPreprocessor()
Esempio n. 3
0
    def __init__(self, env_id, config, logdir, is_remote):
        self.is_remote = is_remote
        if is_remote:
            os.environ["CUDA_VISIBLE_DEVICES"] = ""
            devices = ["/cpu:0"]
        else:
            devices = config["devices"]
        self.devices = devices
        self.config = config
        self.logdir = logdir
        # self.env = ModelCatalog.get_preprocessor_as_wrapper(
        #     registry, env_creator(config["env_config"]), config["model"])
        env = gym.make(env_id)
        preprocessor = AtariPixelPreprocessor(env.observation_space, config["model"])
        self.env = _RLlibPreprocessorWrapper(env, preprocessor)
        if is_remote:
            config_proto = tf.ConfigProto()
        else:
            config_proto = tf.ConfigProto(**config["tf_session_args"])
        self.sess = tf.Session(config=config_proto)
        if config["tf_debug_inf_or_nan"] and not is_remote:
            self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
            self.sess.add_tensor_filter(
                "has_inf_or_nan", tf_debug.has_inf_or_nan)

        # Defines the training inputs:
        # The coefficient of the KL penalty.
        self.kl_coeff = tf.placeholder(
            name="newkl", shape=(), dtype=tf.float32)

        self.e_kl_coeff = tf.placeholder(
            name="e_newkl", shape=(), dtype=tf.float32)

        # The input observations.
        self.observations = tf.placeholder(
            tf.float32, shape=(None,) + self.env.observation_space.shape)
        # Targets of the value function.
        self.value_targets = tf.placeholder(tf.float32, shape=(None,))
        # Advantage values in the policy gradient estimator.
        self.advantages = tf.placeholder(tf.float32, shape=(None,))

        # for explore
        self.e_value_targets = tf.placeholder(tf.float32, shape=(None,))
        self.e_advantages = tf.placeholder(tf.float32, shape=(None,))

        action_space = self.env.action_space
        self.actions = ModelCatalog.get_action_placeholder(action_space)

        self.e_actions = ModelCatalog.get_action_placeholder(action_space)
        self.distribution_class, self.logit_dim = ModelCatalog.get_action_dist(
            action_space)
        # Log probabilities from the policy before the policy update.
        self.prev_logits = tf.placeholder(
            tf.float32, shape=(None, self.logit_dim))
        # Value function predictions before the policy update.
        self.prev_vf_preds = tf.placeholder(tf.float32, shape=(None,))

        # for explore
        self.e_prev_logits = tf.placeholder(
            tf.float32, shape=(None, self.logit_dim))
        self.e_prev_vf_preds = tf.placeholder(tf.float32, shape=(None,))

        if is_remote:
            self.batch_size = config["rollout_batchsize"]
            self.per_device_batch_size = config["rollout_batchsize"]
        else:
            self.batch_size = int(
                config["sgd_batchsize"] / len(devices)) * len(devices)
            assert self.batch_size % len(devices) == 0
            self.per_device_batch_size = int(self.batch_size / len(devices))

        def build_loss(obs, vtargets, advs, acts, plog, pvf_preds,
                       e_vtargets, e_advs, e_plog, e_pvf_preds):
            return ProximalPolicyLoss(
                self.env.observation_space, self.env.action_space,
                obs, vtargets, advs, acts, plog, pvf_preds,
                e_vtargets, e_advs, e_plog, e_pvf_preds,
                self.logit_dim,
                self.kl_coeff, self.distribution_class, self.config,
                self.sess)

        self.par_opt = LocalSyncParallelOptimizer(
            tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
            self.devices,
            [self.observations, self.value_targets, self.advantages,
             self.actions, self.prev_logits, self.prev_vf_preds,
             self.e_value_targets, self.e_advantages, self.e_prev_logits, self.e_prev_vf_preds],
            self.per_device_batch_size,
            build_loss,
            self.logdir)

        # References to the model weights
        self.common_policy = self.par_opt.get_common_loss()
        self.variables = ray.experimental.TensorFlowVariables(
            self.common_policy.loss, self.sess)
        self.obs_filter = get_filter(
            config["observation_filter"], self.env.observation_space.shape)
        self.rew_filter = MeanStdFilter((), clip=5.0)
        self.e_rew_filter = MeanStdFilter((), clip=5.0)
        self.filters = {"obs_filter": self.obs_filter,
                        "rew_filter": self.rew_filter,
                        "e_rew_filter": self.e_rew_filter}
        self.sampler = SyncSampler(
            self.env, self.common_policy, self.obs_filter,
            self.config["horizon"], self.config["horizon"])
        self.init_op = tf.global_variables_initializer()