def __init__(self, encoder_major=None, encoder_minor=None, hparams=None): EncoderBase.__init__(self, hparams) encoder_major_hparams = utils.get_instance_kwargs( None, self._hparams.encoder_major_hparams) encoder_minor_hparams = utils.get_instance_kwargs( None, self._hparams.encoder_minor_hparams) if encoder_major is not None: self._encoder_major = encoder_major else: with tf.variable_scope(self.variable_scope.name): with tf.variable_scope('encoder_major'): self._encoder_major = utils.check_or_get_instance( self._hparams.encoder_major_type, encoder_major_hparams, ['texar.modules.encoders', 'texar.custom']) if encoder_minor is not None: self._encoder_minor = encoder_minor elif self._hparams.config_share: with tf.variable_scope(self.variable_scope.name): with tf.variable_scope('encoder_minor'): self._encoder_minor = utils.check_or_get_instance( self._hparams.encoder_major_type, encoder_major_hparams, ['texar.modules.encoders', 'texar.custom']) else: with tf.variable_scope(self.variable_scope.name): with tf.variable_scope('encoder_minor'): self._encoder_minor = utils.check_or_get_instance( self._hparams.encoder_minor_type, encoder_minor_hparams, ['texar.modules.encoders', 'texar.custom'])
def __init__(self, env_config, sess=None, policy=None, policy_kwargs=None, policy_caller_kwargs=None, learning_rate=None, hparams=None): EpisodicAgentBase.__init__(self, env_config, hparams) self._sess = sess self._lr = learning_rate self._discount_factor = self._hparams.discount_factor with tf.variable_scope(self.variable_scope): if policy is None: kwargs = utils.get_instance_kwargs( policy_kwargs, self._hparams.policy_hparams) policy = utils.check_or_get_instance( self._hparams.policy_type, kwargs, module_paths=['texar.modules', 'texar.custom']) self._policy = policy self._policy_caller_kwargs = policy_caller_kwargs or {} self._observs = [] self._actions = [] self._rewards = [] self._train_outputs = None self._build_graph()
def _build_network(self, network, kwargs): if network is not None: self._network = network else: kwargs = utils.get_instance_kwargs(kwargs, self._hparams.network_hparams) self._network = utils.check_or_get_instance( self._hparams.network_type, kwargs, module_paths=['texar.modules', 'texar.custom'])
def __init__(self, env_config, sess=None, actor=None, actor_kwargs=None, critic=None, critic_kwargs=None, hparams=None): EpisodicAgentBase.__init__(self, env_config=env_config, hparams=hparams) self._sess = sess self._num_actions = self._env_config.action_space.high - \ self._env_config.action_space.low with tf.variable_scope(self.variable_scope): if actor is None: kwargs = utils.get_instance_kwargs( actor_kwargs, self._hparams.actor_hparams) kwargs.update(dict(env_config=env_config, sess=sess)) actor = utils.get_instance( class_or_name=self._hparams.actor_type, kwargs=kwargs, module_paths=['texar.agents', 'texar.custom']) self._actor = actor if critic is None: kwargs = utils.get_instance_kwargs( critic_kwargs, self._hparams.critic_hparams) kwargs.update(dict(env_config=env_config, sess=sess)) critic = utils.get_instance( class_or_name=self._hparams.critic_type, kwargs=kwargs, module_paths=['texar.agents', 'texar.custom']) self._critic = critic if self._actor._discount_factor != self._critic._discount_factor: raise ValueError('discount_factor of the actor and the critic ' 'must be the same.') self._discount_factor = self._actor._discount_factor self._observs = [] self._actions = [] self._rewards = []
def __init__(self, env_config, sess=None, actor=None, actor_kwargs=None, critic=None, critic_kwargs=None, hparams=None): EpisodicAgentBase.__init__(self, env_config=env_config, hparams=hparams) self._sess = sess self._num_actions = self._env_config.action_space.high - \ self._env_config.action_space.low with tf.variable_scope(self.variable_scope): if actor is None: kwargs = utils.get_instance_kwargs(actor_kwargs, self._hparams.actor_kwargs) kwargs.update(dict(env_config=env_config, sess=sess)) actor = utils.get_instance( class_or_name=self._hparams.actor_type, kwargs=kwargs, module_paths=['texar.agents', 'texar.custom']) self.actor = actor if critic is None: kwargs = utils.get_instance_kwargs(critic_kwargs, self._hparams.critic_kwargs) kwargs.update(dict(env_config=env_config, sess=sess)) critic = utils.get_instance( class_or_name=self._hparams.critic_type, kwargs=kwargs, module_paths=['texar.agents', 'texar.custom']) self.critic = critic assert self.actor._discount_factor == self.critic._discount_factor self._discount_factor = self.actor._discount_factor
def __init__(self, env_config, sess=None, qnet=None, target=None, qnet_kwargs=None, qnet_caller_kwargs=None, replay_memory=None, replay_memory_kwargs=None, exploration=None, exploration_kwargs=None, hparams=None): EpisodicAgentBase.__init__(self, env_config, hparams) self._sess = sess self._cold_start_steps = self._hparams.cold_start_steps self._sample_batch_size = self._hparams.sample_batch_size self._update_period = self._hparams.update_period self._discount_factor = self._hparams.discount_factor self._target_update_strategy = self._hparams.target_update_strategy self._num_actions = self._env_config.action_space.high - \ self._env_config.action_space.low with tf.variable_scope(self.variable_scope): if qnet is None: kwargs = utils.get_instance_kwargs(qnet_kwargs, self._hparams.qnet_hparams) qnet = utils.check_or_get_instance( ins_or_class_or_name=self._hparams.qnet_type, kwargs=kwargs, module_paths=['texar.modules', 'texar.custom']) target = utils.check_or_get_instance( ins_or_class_or_name=self._hparams.qnet_type, kwargs=kwargs, module_paths=['texar.modules', 'texar.custom']) self._qnet = qnet self._target = target self._qnet_caller_kwargs = qnet_caller_kwargs or {} if replay_memory is None: kwargs = utils.get_instance_kwargs( replay_memory_kwargs, self._hparams.replay_memory_hparams) replay_memory = utils.check_or_get_instance( ins_or_class_or_name=self._hparams.replay_memory_type, kwargs=kwargs, module_paths=['texar.core', 'texar.custom']) self._replay_memory = replay_memory if exploration is None: kwargs = utils.get_instance_kwargs( exploration_kwargs, self._hparams.exploration_hparams) exploration = utils.check_or_get_instance( ins_or_class_or_name=self._hparams.exploration_type, kwargs=kwargs, module_paths=['texar.core', 'texar.custom']) self._exploration = exploration self._build_graph() self._observ = None self._action = None self._timestep = 0