def __init__(self, observation_spec, action_spec, tf_env, tf_context, step_cond_fn=cond_fn.env_transition, reset_episode_cond_fn=cond_fn.env_restart, reset_env_cond_fn=cond_fn.false_fn, every_n_steps=cond_fn.every_n_steps, metrics=None, **base_agent_kwargs): """Constructs a UVF agent. Args: observation_spec: A TensorSpec defining the observations. action_spec: A BoundedTensorSpec defining the actions. tf_env: A Tensorflow environment object. tf_context: A Context class. step_cond_fn: A function indicating whether to increment the num of steps. reset_episode_cond_fn: A function indicating whether to restart the episode, resampling the context. reset_env_cond_fn: A function indicating whether to perform a manual reset of the environment. metrics: A list of functions that evaluate metrics of the agent. **base_agent_kwargs: A dictionary of parameters for base RL Agent. Raises: ValueError: If 'dqda_clipping' is < 0. """ self._step_cond_fn = step_cond_fn self._reset_episode_cond_fn = reset_episode_cond_fn self._reset_env_cond_fn = reset_env_cond_fn self._every_n_steps = every_n_steps self.metrics = metrics # expose tf_context methods self.tf_context = tf_context(tf_env=tf_env) self.set_replay = self.tf_context.set_replay self.sample_contexts = self.tf_context.sample_contexts self.compute_rewards = self.tf_context.compute_rewards self.gamma_index = self.tf_context.gamma_index self.context_specs = self.tf_context.context_specs self.context_as_action_specs = self.tf_context.context_as_action_specs self.init_context_vars = self.tf_context.create_vars self.env_observation_spec = observation_spec[0] merged_observation_spec = ( uvf_utils.merge_specs((self.env_observation_spec, ) + self.context_specs), ) self._context_vars = dict() self._action_vars = dict() self.BASE_AGENT_CLASS.__init__( self, observation_spec=merged_observation_spec, action_spec=action_spec, **base_agent_kwargs)
def __init__(self, observation_spec, action_spec, tf_env, tf_context, step_cond_fn=cond_fn.env_transition, reset_episode_cond_fn=cond_fn.env_restart, reset_env_cond_fn=cond_fn.false_fn, metrics=None, **base_agent_kwargs): """Constructs a UVF agent. Args: observation_spec: A TensorSpec defining the observations. action_spec: A BoundedTensorSpec defining the actions. tf_env: A Tensorflow environment object. tf_context: A Context class. step_cond_fn: A function indicating whether to increment the num of steps. reset_episode_cond_fn: A function indicating whether to restart the episode, resampling the context. reset_env_cond_fn: A function indicating whether to perform a manual reset of the environment. metrics: A list of functions that evaluate metrics of the agent. **base_agent_kwargs: A dictionary of parameters for base RL Agent. Raises: ValueError: If 'dqda_clipping' is < 0. """ self._step_cond_fn = step_cond_fn self._reset_episode_cond_fn = reset_episode_cond_fn self._reset_env_cond_fn = reset_env_cond_fn self.metrics = metrics # expose tf_context methods self.tf_context = tf_context(tf_env=tf_env) self.set_replay = self.tf_context.set_replay self.sample_contexts = self.tf_context.sample_contexts self.compute_rewards = self.tf_context.compute_rewards self.gamma_index = self.tf_context.gamma_index self.context_specs = self.tf_context.context_specs self.context_as_action_specs = self.tf_context.context_as_action_specs self.init_context_vars = self.tf_context.create_vars self.env_observation_spec = observation_spec[0] merged_observation_spec = (uvf_utils.merge_specs( (self.env_observation_spec,) + self.context_specs),) self._context_vars = dict() self._action_vars = dict() self.BASE_AGENT_CLASS.__init__( self, observation_spec=merged_observation_spec, action_spec=action_spec, **base_agent_kwargs )
def __init__(self, observation_spec, action_spec, tf_env, tf_context, sub_context, step_cond_fn=cond_fn.env_transition, reset_episode_cond_fn=cond_fn.env_restart, reset_env_cond_fn=cond_fn.false_fn, reward_net=additional_networks.reward_net, completion_net=additional_networks.completion_net, metrics=None, actions_reg=0., k=2, **base_agent_kwargs): """Constructs a Meta agent. Args: observation_spec: A TensorSpec defining the observations. action_spec: A BoundedTensorSpec defining the actions. tf_env: A Tensorflow environment object. tf_context: A Context class. step_cond_fn: A function indicating whether to increment the num of steps. reset_episode_cond_fn: A function indicating whether to restart the episode, resampling the context. reset_env_cond_fn: A function indicating whether to perform a manual reset of the environment. metrics: A list of functions that evaluate metrics of the agent. **base_agent_kwargs: A dictionary of parameters for base RL Agent. Raises: ValueError: If 'dqda_clipping' is < 0. """ self.REWARD_NET_SCOPE = 'reward_net' self.COMPLETION_NET_SCOPE = 'completion_net' self._reward_net = tf.make_template(self.REWARD_NET_SCOPE, reward_net, create_scope_now_=True) self._completion_net = tf.make_template(self.COMPLETION_NET_SCOPE, completion_net, create_scope_now_=True) self._step_cond_fn = step_cond_fn self._reset_episode_cond_fn = reset_episode_cond_fn self._reset_env_cond_fn = reset_env_cond_fn self.metrics = metrics self._actions_reg = actions_reg self._k = k # expose tf_context methods self.tf_context = tf_context(tf_env=tf_env) self.sub_context = sub_context(tf_env=tf_env) self.set_replay = self.tf_context.set_replay self.sample_contexts = self.tf_context.sample_contexts self.compute_rewards = self.tf_context.compute_rewards self.gamma_index = self.tf_context.gamma_index self.context_specs = self.tf_context.context_specs self.context_as_action_specs = self.tf_context.context_as_action_specs self.sub_context_as_action_specs = self.sub_context.context_as_action_specs self.init_context_vars = self.tf_context.create_vars self.env_observation_spec = observation_spec[0] merged_observation_spec = ( uvf_utils.merge_specs((self.env_observation_spec, ) + self.context_specs), ) self._context_vars = dict() self._action_vars = dict() assert len(self.context_as_action_specs) == 1 self.BASE_AGENT_CLASS.__init__( self, observation_spec=merged_observation_spec, action_spec=self.sub_context_as_action_specs, **base_agent_kwargs)