def __init__(self, expert_margin=0.5, supervised_weight=1.0, double_q=True, dueling_q=True, huber_loss=False, n_step=1, shared_container_action_target=True, memory_spec=None, demo_memory_spec=None, demo_sample_ratio=0.2, store_last_memory_batch=False, store_last_q_table=False, **kwargs): # TODO Most of this is DQN duplicate but the way the loss function is instantiated, inheriting # from DQN does not work well. """ Args: expert_margin (float): The expert margin enforces a distance in Q-values between expert action and all other actions. supervised_weight (float): Indicates weight of the expert loss. double_q (bool): Whether to use the double DQN loss function (see [2]). dueling_q (bool): Whether to use a dueling layer in the ActionAdapter (see [3]). huber_loss (bool) : Whether to apply a Huber loss. (see [4]). n_step (Optional[int]): n-step adjustment to discounting. memory_spec (Optional[dict,Memory]): The spec for the Memory to use. demo_memory_spec (Optional[dict,Memory]): The spec for the Demo-Memory to use. store_last_memory_batch (bool): Whether to store the last pulled batch from the memory in `self.last_memory_batch` for debugging purposes. Default: False. store_last_q_table (bool): Whether to store the Q(s,a) values for the last received batch (memory or external) in `self.last_q_table` for debugging purposes. Default: False. """ # Fix action-adapter before passing it to the super constructor. policy_spec = kwargs.pop("policy_spec", dict()) # Use a DuelingPolicy (instead of a basic Policy) if option is set. if dueling_q is True: policy_spec["type"] = "dueling-policy" # Give us some default state-value nodes. if "units_state_value_stream" not in policy_spec: policy_spec["units_state_value_stream"] = 128 super(DQFDAgent, self).__init__( policy_spec=policy_spec, name=kwargs.pop("name", "dqfd-agent"), **kwargs ) # Assert that the synch interval is a multiple of the update_interval. if self.update_spec["sync_interval"] / self.update_spec["update_interval"] != \ self.update_spec["sync_interval"] // self.update_spec["update_interval"]: raise RLGraphError( "ERROR: sync_interval ({}) must be multiple of update_interval " "({})!".format(self.update_spec["sync_interval"], self.update_spec["update_interval"]) ) self.double_q = double_q self.dueling_q = dueling_q self.huber_loss = huber_loss self.demo_batch_size = int(demo_sample_ratio * self.update_spec['batch_size'] / (1.0 - demo_sample_ratio)) self.shared_container_action_target = shared_container_action_target # Debugging tools. self.store_last_memory_batch = store_last_memory_batch self.last_memory_batch = None self.store_last_q_table = store_last_q_table self.last_q_table = None # Extend input Space definitions to this Agent's specific API-methods. preprocessed_state_space = self.preprocessed_state_space.with_batch_rank() reward_space = FloatBox(add_batch_rank=True) terminal_space = BoolBox(add_batch_rank=True) weight_space = FloatBox(add_batch_rank=True) self.input_spaces.update(dict( actions=self.action_space.with_batch_rank(), policy_weights="variables:{}".format(self.policy.scope), time_step=int, use_exploration=bool, demo_batch_size=int, apply_demo_loss=bool, preprocessed_states=preprocessed_state_space, rewards=reward_space, terminals=terminal_space, next_states=preprocessed_state_space, preprocessed_next_states=preprocessed_state_space, importance_weights=weight_space )) # The merger to merge inputs into one record Dict going into the memory. self.merger = ContainerMerger("states", "actions", "rewards", "next_states", "terminals") # The replay memory. self.memory = Memory.from_spec(memory_spec) # Cannot have same default name. demo_memory_spec["scope"] = "demo-memory" self.demo_memory = Memory.from_spec(demo_memory_spec) # The splitter for splitting up the records from the memories. self.splitter = ContainerSplitter("states", "actions", "rewards", "terminals", "next_states") # Copy our Policy (target-net), make target-net synchronizable. self.target_policy = self.policy.copy(scope="target-policy", trainable=False) # Number of steps since the last target-net synching from the main policy. self.steps_since_target_net_sync = 0 use_importance_weights = isinstance(self.memory, PrioritizedReplay) self.loss_function = DQFDLossFunction( expert_margin=expert_margin, supervised_weight=supervised_weight, discount=self.discount, double_q=self.double_q, huber_loss=self.huber_loss, shared_container_action_target=shared_container_action_target, importance_weights=use_importance_weights, n_step=n_step ) # Add all our sub-components to the core. self.root_component.add_components( self.preprocessor, self.merger, self.memory, self.demo_memory, self.splitter, self.policy, self.target_policy, self.exploration, self.loss_function, self.optimizer ) # Define the Agent's (root-Component's) API. self.define_graph_api() if self.auto_build: self._build_graph([self.root_component], self.input_spaces, optimizer=self.optimizer, batch_size=self.update_spec["batch_size"]) self.graph_built = True
def __init__( self, state_space, action_space, discount=0.98, preprocessing_spec=None, network_spec=None, internal_states_space=None, policy_spec=None, exploration_spec=None, execution_spec=None, optimizer_spec=None, observe_spec=None, update_spec=None, summary_spec=None, saver_spec=None, auto_build=True, name="dqfd-agent", expert_margin=0.5, supervised_weight=1.0, double_q=True, dueling_q=True, huber_loss=False, n_step=1, shared_container_action_target=False, memory_spec=None, demo_memory_spec=None, demo_sample_ratio=0.2, ): """ Args: state_space (Union[dict,Space]): Spec dict for the state Space or a direct Space object. action_space (Union[dict,Space]): Spec dict for the action Space or a direct Space object. preprocessing_spec (Optional[list,PreprocessorStack]): The spec list for the different necessary states preprocessing steps or a PreprocessorStack object itself. discount (float): The discount factor (gamma). network_spec (Optional[list,NeuralNetwork]): Spec list for a NeuralNetwork Component or the NeuralNetwork object itself. internal_states_space (Optional[Union[dict,Space]]): Spec dict for the internal-states Space or a direct Space object for the Space(s) of the internal (RNN) states. policy_spec (Optional[dict]): An optional dict for further kwargs passing into the Policy c'tor. exploration_spec (Optional[dict]): The spec-dict to create the Exploration Component. execution_spec (Optional[dict,Execution]): The spec-dict specifying execution settings. optimizer_spec (Optional[dict,Optimizer]): The spec-dict to create the Optimizer for this Agent. observe_spec (Optional[dict]): Spec-dict to specify `Agent.observe()` settings. update_spec (Optional[dict]): Spec-dict to specify `Agent.update()` settings. summary_spec (Optional[dict]): Spec-dict to specify summary settings. saver_spec (Optional[dict]): Spec-dict to specify saver settings. auto_build (Optional[bool]): If True (default), immediately builds the graph using the agent's graph builder. If false, users must separately call agent.build(). Useful for debugging or analyzing components before building. name (str): Some name for this Agent object. expert_margin (float): The expert margin enforces a distance in Q-values between expert action and all other actions. supervised_weight (float): Indicates weight of the expert loss. double_q (bool): Whether to use the double DQN loss function (see [2]). dueling_q (bool): Whether to use a dueling layer in the ActionAdapter (see [3]). huber_loss (bool) : Whether to apply a Huber loss. (see [4]). n_step (Optional[int]): n-step adjustment to discounting. memory_spec (Optional[dict,Memory]): The spec for the Memory to use. demo_memory_spec (Optional[dict,Memory]): The spec for the Demo-Memory to use. """ # Fix action-adapter before passing it to the super constructor. # Use a DuelingPolicy (instead of a basic Policy) if option is set. if dueling_q is True: if policy_spec is None: policy_spec = {} policy_spec["type"] = "dueling-policy" # Give us some default state-value nodes. if "units_state_value_stream" not in policy_spec: policy_spec["units_state_value_stream"] = 128 super(DQFDAgent, self).__init__( state_space=state_space, action_space=action_space, discount=discount, preprocessing_spec=preprocessing_spec, network_spec=network_spec, internal_states_space=internal_states_space, policy_spec=policy_spec, exploration_spec=exploration_spec, execution_spec=execution_spec, optimizer_spec=optimizer_spec, observe_spec=observe_spec, update_spec=update_spec, summary_spec=summary_spec, saver_spec=saver_spec, auto_build=auto_build, name=name ) # Assert that the synch interval is a multiple of the update_interval. if self.update_spec["sync_interval"] / self.update_spec["update_interval"] != \ self.update_spec["sync_interval"] // self.update_spec["update_interval"]: raise RLGraphError( "ERROR: sync_interval ({}) must be multiple of update_interval " "({})!".format(self.update_spec["sync_interval"], self.update_spec["update_interval"]) ) self.double_q = double_q self.dueling_q = dueling_q self.huber_loss = huber_loss self.expert_margin = expert_margin self.batch_size = self.update_spec["batch_size"] self.default_margins = np.asarray([self.expert_margin] * self.batch_size) self.demo_batch_size = int(demo_sample_ratio * self.update_spec["batch_size"] / (1.0 - demo_sample_ratio)) self.demo_margins = np.asarray([self.expert_margin] * self.demo_batch_size) self.shared_container_action_target = shared_container_action_target # Extend input Space definitions to this Agent's specific API-methods. preprocessed_state_space = self.preprocessed_state_space.with_batch_rank() reward_space = FloatBox(add_batch_rank=True) terminal_space = BoolBox(add_batch_rank=True) weight_space = FloatBox(add_batch_rank=True) self.input_spaces.update(dict( actions=self.action_space.with_batch_rank(), policy_weights="variables:{}".format(self.policy.scope), time_step=int, use_exploration=bool, demo_batch_size=int, apply_demo_loss=bool, preprocessed_states=preprocessed_state_space, rewards=reward_space, terminals=terminal_space, expert_margins=FloatBox(add_batch_rank=True), next_states=preprocessed_state_space, preprocessed_next_states=preprocessed_state_space, importance_weights=weight_space )) # The merger to merge inputs into one record Dict going into the memory. self.merger = ContainerMerger("states", "actions", "rewards", "next_states", "terminals") # The replay memory. self.memory = Memory.from_spec(memory_spec) # Cannot have same default name. demo_memory_spec["scope"] = "demo-memory" self.demo_memory = Memory.from_spec(demo_memory_spec) # The splitter for splitting up the records from the memories. self.splitter = ContainerSplitter("states", "actions", "rewards", "terminals", "next_states") # Copy our Policy (target-net), make target-net synchronizable. self.target_policy = self.policy.copy(scope="target-policy", trainable=False) # Number of steps since the last target-net synching from the main policy. self.steps_since_target_net_sync = 0 self.use_importance_weights = isinstance(self.memory, PrioritizedReplay) self.loss_function = DQFDLossFunction( supervised_weight=supervised_weight, discount=self.discount, double_q=self.double_q, huber_loss=self.huber_loss, shared_container_action_target=shared_container_action_target, importance_weights=self.use_importance_weights, n_step=n_step ) # Add all our sub-components to the core. self.root_component.add_components( self.preprocessor, self.merger, self.memory, self.demo_memory, self.splitter, self.policy, self.target_policy, self.exploration, self.loss_function, self.optimizer ) # Define the Agent's (root-Component's) API. self.define_graph_api() if self.auto_build: self._build_graph([self.root_component], self.input_spaces, optimizer=self.optimizer, batch_size=self.update_spec["batch_size"]) self.graph_built = True