def test_insert_demos(self):
        """
        Tests inserting into the demo memory.
        """
        env = OpenAIGymEnv.from_spec(self.env_spec)

        agent_config = config_from_path("configs/dqfd_agent_for_cartpole.json")
        agent = DQFDAgent.from_spec(
            agent_config,
            state_space=env.state_space,
            action_space=env.action_space
        )
        terminals = BoolBox(add_batch_rank=True)
        rewards = FloatBox(add_batch_rank=True)

        # Observe a single data point.
        agent.observe_demos(
            preprocessed_states=agent.preprocessed_state_space.with_batch_rank().sample(1),
            actions=env.action_space.with_batch_rank().sample(1),
            rewards=rewards.sample(1),
            next_states=agent.preprocessed_state_space.with_batch_rank().sample(1),
            terminals=terminals.sample(1),
        )

        # Observe a batch of demos.
        agent.observe_demos(
            preprocessed_states=agent.preprocessed_state_space.sample(10),
            actions=env.action_space.sample(10),
            rewards=FloatBox().sample(10),
            terminals=terminals.sample(10),
            next_states=agent.preprocessed_state_space.sample(10)
        )
    def test_update_online(self):
        """
        Tests if joint updates from demo and online memory work.
        """
        env = OpenAIGymEnv.from_spec(self.env_spec)
        agent_config = config_from_path("configs/dqfd_agent_for_cartpole.json")
        agent = DQFDAgent.from_spec(
            agent_config,
            state_space=env.state_space,
            action_space=env.action_space
        )
        terminals = BoolBox(add_batch_rank=True)

        # Observe a batch of demos.
        agent.observe_demos(
            preprocessed_states=agent.preprocessed_state_space.sample(32),
            actions=env.action_space.sample(32),
            rewards=FloatBox().sample(32),
            terminals=terminals.sample(32),
            next_states=agent.preprocessed_state_space.sample(32)
        )

        # Observe a batch of online data.
        agent._observe_graph(
            preprocessed_states=agent.preprocessed_state_space.sample(32),
            actions=env.action_space.sample(32),
            rewards=FloatBox().sample(32),
            internals=[],
            terminals=terminals.sample(32),
            next_states=agent.preprocessed_state_space.sample(32)
        )
        # Call update.
        agent.update()
예제 #3
0
    def test_demos_with_container_actions(self):
        # Tests if dqfd can fit a set of states to a set of actions.
        vocab_size = 100
        embed_dim = 128
        # ID/state space.
        state_space = IntBox(vocab_size, shape=(10, ))
        # Container action space.
        actions_space = {}
        num_outputs = 3
        for i in range(3):
            actions_space['action_{}'.format(i)] = IntBox(low=0,
                                                          high=num_outputs)
        actions_space = Dict(actions_space)

        agent_config = config_from_path("configs/dqfd_container.json")
        agent_config["network_spec"] = [
            dict(type="embedding", embed_dim=embed_dim, vocab_size=vocab_size),
            dict(type="reshape", flatten=True),
            dict(type="dense",
                 units=embed_dim,
                 activation="relu",
                 scope="dense_1")
        ]
        agent = DQFDAgent.from_spec(agent_config,
                                    state_space=state_space,
                                    action_space=actions_space)
        terminals = BoolBox(add_batch_rank=True)
        rewards = FloatBox(add_batch_rank=True)

        # Create a set of demos.
        demo_states = agent.preprocessed_state_space.with_batch_rank().sample(
            20)
        demo_actions = actions_space.with_batch_rank().sample(20)
        demo_rewards = rewards.sample(20, fill_value=1.0)
        demo_next_states = agent.preprocessed_state_space.with_batch_rank(
        ).sample(20)
        demo_terminals = terminals.sample(20, fill_value=False)

        # Insert.
        agent.observe_demos(
            preprocessed_states=demo_states,
            actions=demo_actions,
            rewards=demo_rewards,
            next_states=demo_next_states,
            terminals=demo_terminals,
        )

        # Fit demos.
        agent.update_from_demos(num_updates=5000, batch_size=20)

        # Evaluate demos:
        agent_actions = agent.get_action(demo_states,
                                         apply_preprocessing=False,
                                         use_exploration=False)
        recursive_assert_almost_equal(agent_actions, demo_actions)
    def test_update_from_demos(self):
        """
        Tests the separate API method to update from demos.
        """
        env = OpenAIGymEnv.from_spec(self.env_spec)
        agent_config = config_from_path("configs/dqfd_agent_for_cartpole.json")
        agent = DQFDAgent.from_spec(agent_config,
                                    state_space=env.state_space,
                                    action_space=env.action_space)
        terminals = BoolBox(add_batch_rank=True)
        rewards = FloatBox(add_batch_rank=True)
        state_1 = agent.preprocessed_state_space.with_batch_rank().sample(1)
        action_1 = [1]
        state_2 = agent.preprocessed_state_space.with_batch_rank().sample(1)
        action_2 = [0]

        # Insert two states with fixed actions and a few random examples.
        for _ in range(10):
            # State with correct action
            agent.observe_demos(
                preprocessed_states=state_1,
                actions=action_1,
                rewards=rewards.sample(1),
                next_states=agent.preprocessed_state_space.with_batch_rank().
                sample(1),
                terminals=terminals.sample(1),
            )
            agent.observe_demos(
                preprocessed_states=state_2,
                actions=action_2,
                rewards=rewards.sample(1),
                next_states=agent.preprocessed_state_space.with_batch_rank().
                sample(1),
                terminals=terminals.sample(1),
            )

        # Update.
        agent.update_from_demos(num_updates=100, batch_size=8)

        # Test if fixed states and actions map.
        action = agent.get_action(states=state_1,
                                  apply_preprocessing=False,
                                  use_exploration=False)
        self.assertEqual(action, action_1)

        action = agent.get_action(states=state_2,
                                  apply_preprocessing=False,
                                  use_exploration=False)
        self.assertEqual(action, action_2)
예제 #5
0
    def test_container_actions(self):
        # Test container actions with embedding.

        vocab_size = 100
        embed_dim = 128
        # ID/state space.
        state_space = IntBox(vocab_size, shape=(10, ))
        # Container action space.
        actions_space = {}
        num_outputs = 3
        for i in range(3):
            actions_space['action_{}'.format(i)] = IntBox(low=0,
                                                          high=num_outputs)
        actions_space = Dict(actions_space)

        agent_config = config_from_path("configs/dqfd_container.json")
        agent_config["network_spec"] = [
            dict(type="embedding", embed_dim=embed_dim, vocab_size=vocab_size),
            dict(type="reshape", flatten=True),
            dict(type="dense",
                 units=embed_dim,
                 activation="relu",
                 scope="dense_1")
        ]
        agent = DQFDAgent.from_spec(agent_config,
                                    state_space=state_space,
                                    action_space=actions_space)
        terminals = BoolBox(add_batch_rank=True)
        rewards = FloatBox(add_batch_rank=True)

        agent.observe_demos(
            preprocessed_states=agent.preprocessed_state_space.with_batch_rank(
            ).sample(1),
            actions=actions_space.with_batch_rank().sample(1),
            rewards=rewards.sample(1),
            next_states=agent.preprocessed_state_space.with_batch_rank().
            sample(1),
            terminals=terminals.sample(1),
        )
    def test_custom_margin_demos_with_container_actions(self):
        # Tests if using different margins per sample works.
        # Same state, but different
        vocab_size = 100
        embed_dim = 8
        # ID/state space.
        state_space = IntBox(vocab_size, shape=(10,))
        # Container action space.
        actions_space = {}
        num_outputs = 3
        for i in range(3):
            actions_space['action_{}'.format(i)] = IntBox(
                low=0,
                high=num_outputs
            )
        actions_space = Dict(actions_space)

        agent_config = config_from_path("configs/dqfd_container.json")
        agent_config["network_spec"] = [
            dict(type="embedding", embed_dim=embed_dim, vocab_size=vocab_size),
            dict(type="reshape", flatten=True),
            dict(type="dense", units=embed_dim, activation="relu", scope="dense_1")
        ]
        agent = DQFDAgent.from_spec(
            agent_config,
            state_space=state_space,
            action_space=actions_space
        )
        terminals = BoolBox(add_batch_rank=True)
        rewards = FloatBox(add_batch_rank=True)

        # Create a set of demos.
        demo_states = agent.preprocessed_state_space.with_batch_rank().sample(2)
        # Same state.
        demo_states[1] = demo_states[0]
        demo_actions = actions_space.with_batch_rank().sample(2)

        for name, action in actions_space.items():
            demo_actions[name][0] = 0
            demo_actions[name][1] = 1

        demo_rewards = rewards.sample(2, fill_value=.0)
        # One action has positive reward, one negative
        demo_rewards[0] = 0
        demo_rewards[1] = 0

        # One action is encouraged, one is discouraged.
        margins = np.asarray([0.5, -0.5])

        demo_next_states = agent.preprocessed_state_space.with_batch_rank().sample(2)
        demo_terminals = terminals.sample(2, fill_value=False)

        # When using margins, need to use external batch.
        batch = dict(
            states=demo_states,
            actions=demo_actions,
            rewards=demo_rewards,
            next_states=demo_next_states,
            importance_weights=np.ones_like(demo_rewards),
            terminals=demo_terminals,
        )
        # Fit demos with custom margins.
        for _ in range(10000):
            agent.update(batch=batch, update_from_demos=False, apply_demo_loss_to_batch=True, expert_margins=margins)

        # Evaluate demos for the state -> should have action with positive reward.
        agent_actions = agent.get_action(np.array([demo_states[0]]), apply_preprocessing=False, use_exploration=False)
        print("learned action = ", agent_actions)