示例#1
0
 def test_copy_task_max_length(self):
     """Test that sequence data env respects max_length."""
     env = data_envs.SequenceDataEnv(data_envs.copy_stream(10, n=1),
                                     16,
                                     max_length=2)
     obs = env.reset()
     for _ in range(10):
         obs, reward, done, _ = env.step(0)
         self.assertEqual(reward, 0.0)
         self.assertEqual(done, False)
     self.assertEqual(obs, 1)  # produces EOS
     obs, reward, done, _ = env.step(7)
     self.assertEqual(obs, 7)  # repeats action
     self.assertEqual(reward, 0.0)
     self.assertEqual(done, False)
     obs, reward, done, _ = env.step(8)
     self.assertEqual(obs, 8)  # repeats action
     self.assertEqual(reward, 0.0)
     self.assertEqual(done, False)
     obs, reward, done, _ = env.step(9)
     self.assertEqual(done, True)  # exceeded max_length, stop
     self.assertEqual(obs, 1)  # produce EOS on done
     obs, reward, done, _ = env.step(10)
     self.assertEqual(done, True)  # continue producing done = True
     self.assertEqual(obs, 1)  # continue producing EOS
示例#2
0
    def test_copy_task_short_sequence_correct_actions(self):
        """Test sequence data env on the copying task, correct replies.

    With input (x1, x2) this tests for the following sequence of
    (observations, rewards, dones, actions):
    x1                 = env.reset()
    x2,   0.0,  F, _   = env.step(ignored_action)
    eos,  0.0,  F, _   = env.step(ignored_action)
    x1,   0.0,  F, _   = env.step(x1)
    x2,   0.0,  F, _   = env.step(x2)
    eos,  1.0,  T, _   = env.step(eos)
    """
        env = data_envs.SequenceDataEnv(data_envs.copy_stream(2, n=1), 16)
        x1 = env.reset()
        x2, r0, d0, _ = env.step(0)
        self.assertEqual(r0, 0.0)
        self.assertEqual(d0, False)
        eos, r1, d1, _ = env.step(0)
        self.assertEqual(eos, 1)
        self.assertEqual(r1, 0.0)
        self.assertEqual(d1, False)
        y1, r2, d2, _ = env.step(x1)
        self.assertEqual(y1, x1)
        self.assertEqual(r2, 0.0)
        self.assertEqual(d2, False)
        y2, r3, d3, _ = env.step(x2)
        self.assertEqual(y2, x2)
        self.assertEqual(r3, 0.0)
        self.assertEqual(d3, False)
        eos2, r4, d4, _ = env.step(1)
        self.assertEqual(eos2, 1)
        self.assertEqual(r4, 1.0)
        self.assertEqual(d4, True)
示例#3
0
    def test_copy_task_longer_sequnece_mixed_actions(self):
        """Test sequence data env on the copying task, mixed replies.

    With input (x1, x2) and (y1, y2) this tests for the following sequence of
    (observations, rewards, dones, actions):
    x1                 = env.reset()
    x2,   0.0,  F, _   = env.step(ignored_action)
    eos,  0.0,  F, _   = env.step(ignored_action)
    x1,   0.0,  F, _   = env.step(x1)
    x2+1, 0.0,  F, _   = env.step(x2+1)
    y1,   0,5,  F, _   = env.step(eos)
    y2,   0.0,  F, _   = env.step(ignored_action)
    eos,  0.0,  F, _   = env.step(ignored_action)
    y1+1  0.0,  F, _   = env.step(y1+1)
    y2+1, 0.0,  F, _   = env.step(y2+1)
    eos,  0.0,  T, _   = env.step(eos)
    """
        env = data_envs.SequenceDataEnv(data_envs.copy_stream(2, n=2), 16)
        x1 = env.reset()
        x2, _, _, _ = env.step(0)
        eos, _, _, _ = env.step(0)
        _, _, _, _ = env.step(x1)
        _, _, _, _ = env.step(x2 + 1)  # incorrect
        y1, r1, d1, _ = env.step(1)
        self.assertEqual(r1, 0.5)
        self.assertEqual(d1, False)
        y2, _, _, _ = env.step(0)
        eos, _, _, _ = env.step(0)
        _, _, _, _ = env.step(y1 + 1)  # incorrect
        _, _, _, _ = env.step(y2 + 1)  # incorrect
        eos, r2, d2, _ = env.step(1)
        self.assertEqual(eos, 1)
        self.assertEqual(r2, 0.0)
        self.assertEqual(d2, True)
示例#4
0
    def test_number_of_active_masks(self):
        """Test that we have the correct number of control and discount masks."""
        n_input_seqs = 3
        n_output_seqs = 2
        input_len = 4
        output_len = 5

        def data_stream():
            i = 2 * np.ones(input_len)
            o = np.zeros(output_len)
            while True:
                yield (i, o, i, o, i)  # 3 input, 2 output sequences.

        env = data_envs.SequenceDataEnv(data_stream(),
                                        16,
                                        max_length=output_len)
        env.reset()

        n_discount = 0
        n_control = 0
        n_steps = 0
        done = False
        while not done:
            (_, _, done, info) = env.step(action=0)
            n_discount += info['discount_mask']
            n_control += info['control_mask']
            n_steps += 1

        # One discount_mask=1 per output sequence.
        self.assertEqual(n_discount, n_output_seqs)
        # One control_mask=1 per output token, including EOS, because it's also
        # controlled by the agent.
        self.assertEqual(n_control, (output_len + 1) * n_output_seqs)
        # One control_mask=0 per input token, excluding EOS, because when the env
        # emits it, control transfers to the agent immediately.
        self.assertEqual(n_steps - n_control, input_len * n_input_seqs)
示例#5
0
 def test_copy_task_action_observation_space(self):
     """Test that sequence data env returns correct action/observation space."""
     env = data_envs.SequenceDataEnv(data_envs.copy_stream(2, n=1), 16)
     self.assertEqual(env.action_space.n, 16)
     self.assertEqual(env.observation_space.n, 16)