Пример #1
0
    def __init__(self, expert_margin=0.5, supervised_weight=1.0, double_q=True, dueling_q=True,
                 huber_loss=False, n_step=1, shared_container_action_target=True,
                 memory_spec=None, demo_memory_spec=None,
                 demo_sample_ratio=0.2, store_last_memory_batch=False, store_last_q_table=False, **kwargs):
        # TODO Most of this is DQN duplicate but the way the loss function is instantiated, inheriting
        # from DQN does not work well.
        """
        Args:
            expert_margin (float): The expert margin enforces a distance in Q-values between expert action and
                all other actions.
            supervised_weight (float): Indicates weight of the expert loss.
            double_q (bool): Whether to use the double DQN loss function (see [2]).
            dueling_q (bool): Whether to use a dueling layer in the ActionAdapter  (see [3]).
            huber_loss (bool) : Whether to apply a Huber loss. (see [4]).
            n_step (Optional[int]): n-step adjustment to discounting.
            memory_spec (Optional[dict,Memory]): The spec for the Memory to use.
            demo_memory_spec (Optional[dict,Memory]): The spec for the Demo-Memory to use.
            store_last_memory_batch (bool): Whether to store the last pulled batch from the memory in
                `self.last_memory_batch` for debugging purposes.
                Default: False.
            store_last_q_table (bool): Whether to store the Q(s,a) values for the last received batch
                (memory or external) in `self.last_q_table` for debugging purposes.
                Default: False.
        """
        # Fix action-adapter before passing it to the super constructor.
        policy_spec = kwargs.pop("policy_spec", dict())
        # Use a DuelingPolicy (instead of a basic Policy) if option is set.
        if dueling_q is True:
            policy_spec["type"] = "dueling-policy"
            # Give us some default state-value nodes.
            if "units_state_value_stream" not in policy_spec:
                policy_spec["units_state_value_stream"] = 128
        super(DQFDAgent, self).__init__(
            policy_spec=policy_spec, name=kwargs.pop("name", "dqfd-agent"), **kwargs
        )
        # Assert that the synch interval is a multiple of the update_interval.
        if self.update_spec["sync_interval"] / self.update_spec["update_interval"] != \
                self.update_spec["sync_interval"] // self.update_spec["update_interval"]:
            raise RLGraphError(
                "ERROR: sync_interval ({}) must be multiple of update_interval "
                "({})!".format(self.update_spec["sync_interval"], self.update_spec["update_interval"])
            )

        self.double_q = double_q
        self.dueling_q = dueling_q
        self.huber_loss = huber_loss
        self.demo_batch_size = int(demo_sample_ratio * self.update_spec['batch_size'] / (1.0 - demo_sample_ratio))
        self.shared_container_action_target = shared_container_action_target

        # Debugging tools.
        self.store_last_memory_batch = store_last_memory_batch
        self.last_memory_batch = None
        self.store_last_q_table = store_last_q_table
        self.last_q_table = None

        # Extend input Space definitions to this Agent's specific API-methods.
        preprocessed_state_space = self.preprocessed_state_space.with_batch_rank()
        reward_space = FloatBox(add_batch_rank=True)
        terminal_space = BoolBox(add_batch_rank=True)
        weight_space = FloatBox(add_batch_rank=True)

        self.input_spaces.update(dict(
            actions=self.action_space.with_batch_rank(),
            policy_weights="variables:{}".format(self.policy.scope),
            time_step=int,
            use_exploration=bool,
            demo_batch_size=int,
            apply_demo_loss=bool,
            preprocessed_states=preprocessed_state_space,
            rewards=reward_space,
            terminals=terminal_space,
            next_states=preprocessed_state_space,
            preprocessed_next_states=preprocessed_state_space,
            importance_weights=weight_space
        ))

        # The merger to merge inputs into one record Dict going into the memory.
        self.merger = ContainerMerger("states", "actions", "rewards", "next_states", "terminals")

        # The replay memory.
        self.memory = Memory.from_spec(memory_spec)
        # Cannot have same default name.
        demo_memory_spec["scope"] = "demo-memory"
        self.demo_memory = Memory.from_spec(demo_memory_spec)

        # The splitter for splitting up the records from the memories.
        self.splitter = ContainerSplitter("states", "actions", "rewards", "terminals", "next_states")

        # Copy our Policy (target-net), make target-net synchronizable.
        self.target_policy = self.policy.copy(scope="target-policy", trainable=False)
        # Number of steps since the last target-net synching from the main policy.
        self.steps_since_target_net_sync = 0

        use_importance_weights = isinstance(self.memory, PrioritizedReplay)
        self.loss_function = DQFDLossFunction(
            expert_margin=expert_margin, supervised_weight=supervised_weight,
            discount=self.discount, double_q=self.double_q, huber_loss=self.huber_loss,
            shared_container_action_target=shared_container_action_target,
            importance_weights=use_importance_weights, n_step=n_step
        )

        # Add all our sub-components to the core.
        self.root_component.add_components(
            self.preprocessor, self.merger, self.memory, self.demo_memory, self.splitter, self.policy,
            self.target_policy, self.exploration, self.loss_function, self.optimizer
        )

        # Define the Agent's (root-Component's) API.
        self.define_graph_api()

        if self.auto_build:
            self._build_graph([self.root_component], self.input_spaces, optimizer=self.optimizer,
                              batch_size=self.update_spec["batch_size"])
            self.graph_built = True
Пример #2
0
    def __init__(
        self,
        state_space,
        action_space,
        discount=0.98,
        preprocessing_spec=None,
        network_spec=None,
        internal_states_space=None,
        policy_spec=None,
        exploration_spec=None,
        execution_spec=None,
        optimizer_spec=None,
        observe_spec=None,
        update_spec=None,
        summary_spec=None,
        saver_spec=None,
        auto_build=True,
        name="dqfd-agent",
        expert_margin=0.5,
        supervised_weight=1.0,
        double_q=True,
        dueling_q=True,
        huber_loss=False,
        n_step=1,
        shared_container_action_target=False,
        memory_spec=None,
        demo_memory_spec=None,
        demo_sample_ratio=0.2,
    ):

        """
        Args:
            state_space (Union[dict,Space]): Spec dict for the state Space or a direct Space object.
            action_space (Union[dict,Space]): Spec dict for the action Space or a direct Space object.
            preprocessing_spec (Optional[list,PreprocessorStack]): The spec list for the different necessary states
                preprocessing steps or a PreprocessorStack object itself.
            discount (float): The discount factor (gamma).
            network_spec (Optional[list,NeuralNetwork]): Spec list for a NeuralNetwork Component or the NeuralNetwork
                object itself.
            internal_states_space (Optional[Union[dict,Space]]): Spec dict for the internal-states Space or a direct
                Space object for the Space(s) of the internal (RNN) states.
            policy_spec (Optional[dict]): An optional dict for further kwargs passing into the Policy c'tor.
            exploration_spec (Optional[dict]): The spec-dict to create the Exploration Component.
            execution_spec (Optional[dict,Execution]): The spec-dict specifying execution settings.
            optimizer_spec (Optional[dict,Optimizer]): The spec-dict to create the Optimizer for this Agent.
            observe_spec (Optional[dict]): Spec-dict to specify `Agent.observe()` settings.
            update_spec (Optional[dict]): Spec-dict to specify `Agent.update()` settings.
            summary_spec (Optional[dict]): Spec-dict to specify summary settings.
            saver_spec (Optional[dict]): Spec-dict to specify saver settings.
            auto_build (Optional[bool]): If True (default), immediately builds the graph using the agent's
                graph builder. If false, users must separately call agent.build(). Useful for debugging or analyzing
                components before building.
            name (str): Some name for this Agent object.
            expert_margin (float): The expert margin enforces a distance in Q-values between expert action and
                all other actions.
            supervised_weight (float): Indicates weight of the expert loss.
            double_q (bool): Whether to use the double DQN loss function (see [2]).
            dueling_q (bool): Whether to use a dueling layer in the ActionAdapter  (see [3]).
            huber_loss (bool) : Whether to apply a Huber loss. (see [4]).
            n_step (Optional[int]): n-step adjustment to discounting.
            memory_spec (Optional[dict,Memory]): The spec for the Memory to use.
            demo_memory_spec (Optional[dict,Memory]): The spec for the Demo-Memory to use.
        """
        # Fix action-adapter before passing it to the super constructor.
        # Use a DuelingPolicy (instead of a basic Policy) if option is set.
        if dueling_q is True:
            if policy_spec is None:
                policy_spec = {}
            policy_spec["type"] = "dueling-policy"
            # Give us some default state-value nodes.
            if "units_state_value_stream" not in policy_spec:
                policy_spec["units_state_value_stream"] = 128
        super(DQFDAgent, self).__init__(
            state_space=state_space,
            action_space=action_space,
            discount=discount,
            preprocessing_spec=preprocessing_spec,
            network_spec=network_spec,
            internal_states_space=internal_states_space,
            policy_spec=policy_spec,
            exploration_spec=exploration_spec,
            execution_spec=execution_spec,
            optimizer_spec=optimizer_spec,
            observe_spec=observe_spec,
            update_spec=update_spec,
            summary_spec=summary_spec,
            saver_spec=saver_spec,
            auto_build=auto_build,
            name=name
        )
        # Assert that the synch interval is a multiple of the update_interval.
        if self.update_spec["sync_interval"] / self.update_spec["update_interval"] != \
                self.update_spec["sync_interval"] // self.update_spec["update_interval"]:
            raise RLGraphError(
                "ERROR: sync_interval ({}) must be multiple of update_interval "
                "({})!".format(self.update_spec["sync_interval"], self.update_spec["update_interval"])
            )

        self.double_q = double_q
        self.dueling_q = dueling_q
        self.huber_loss = huber_loss
        self.expert_margin = expert_margin

        self.batch_size = self.update_spec["batch_size"]
        self.default_margins = np.asarray([self.expert_margin] * self.batch_size)

        self.demo_batch_size = int(demo_sample_ratio * self.update_spec["batch_size"] / (1.0 - demo_sample_ratio))
        self.demo_margins = np.asarray([self.expert_margin] * self.demo_batch_size)
        self.shared_container_action_target = shared_container_action_target

        # Extend input Space definitions to this Agent's specific API-methods.
        preprocessed_state_space = self.preprocessed_state_space.with_batch_rank()
        reward_space = FloatBox(add_batch_rank=True)
        terminal_space = BoolBox(add_batch_rank=True)
        weight_space = FloatBox(add_batch_rank=True)

        self.input_spaces.update(dict(
            actions=self.action_space.with_batch_rank(),
            policy_weights="variables:{}".format(self.policy.scope),
            time_step=int,
            use_exploration=bool,
            demo_batch_size=int,
            apply_demo_loss=bool,
            preprocessed_states=preprocessed_state_space,
            rewards=reward_space,
            terminals=terminal_space,
            expert_margins=FloatBox(add_batch_rank=True),
            next_states=preprocessed_state_space,
            preprocessed_next_states=preprocessed_state_space,
            importance_weights=weight_space
        ))

        # The merger to merge inputs into one record Dict going into the memory.
        self.merger = ContainerMerger("states", "actions", "rewards", "next_states", "terminals")

        # The replay memory.
        self.memory = Memory.from_spec(memory_spec)
        # Cannot have same default name.
        demo_memory_spec["scope"] = "demo-memory"
        self.demo_memory = Memory.from_spec(demo_memory_spec)

        # The splitter for splitting up the records from the memories.
        self.splitter = ContainerSplitter("states", "actions", "rewards", "terminals", "next_states")

        # Copy our Policy (target-net), make target-net synchronizable.
        self.target_policy = self.policy.copy(scope="target-policy", trainable=False)
        # Number of steps since the last target-net synching from the main policy.
        self.steps_since_target_net_sync = 0

        self.use_importance_weights = isinstance(self.memory, PrioritizedReplay)
        self.loss_function = DQFDLossFunction(
            supervised_weight=supervised_weight,
            discount=self.discount, double_q=self.double_q, huber_loss=self.huber_loss,
            shared_container_action_target=shared_container_action_target,
            importance_weights=self.use_importance_weights, n_step=n_step
        )

        # Add all our sub-components to the core.
        self.root_component.add_components(
            self.preprocessor, self.merger, self.memory, self.demo_memory, self.splitter, self.policy,
            self.target_policy, self.exploration, self.loss_function, self.optimizer
        )

        # Define the Agent's (root-Component's) API.
        self.define_graph_api()

        if self.auto_build:
            self._build_graph([self.root_component], self.input_spaces, optimizer=self.optimizer,
                              batch_size=self.update_spec["batch_size"])
            self.graph_built = True