Example #1
0
 def test_use_reply(self):
     """
     Check that self-observe is correctly acting on labels.
     """
     # default is hybrid label-model, which uses the label if it's available, and
     # otherwise the label
     # first check if there is a label available
     agent = get_agent()
     obs = Message({'text': 'Call', 'labels': ['Response'], 'episode_done': False})
     agent.observe(obs)
     _ = agent.act()
     self.assertEqual(agent.history.get_history_str(), 'Call\nResponse')
     # check if there is no label
     agent.reset()
     obs = Message({'text': 'Call', 'episode_done': False})
     agent.observe(obs)
     _ = agent.act()
     self.assertEqual(
         agent.history.get_history_str(), 'Call\nEvaluating 0 (responding to Call)!'
     )
     # now some of the other possible values of --use-reply
     # --use-reply model. even if there is a label, we should see the model's out
     agent = get_agent(use_reply='model')
     obs = Message({'text': 'Call', 'labels': ['Response'], 'episode_done': False})
     agent.observe(obs)
     _ = agent.act()
     self.assertEqual(agent.history.get_history_str(), 'Call\nTraining 0!')
     # --use-reply none doesn't hear itself
     agent = get_agent(use_reply='none')
     obs = Message({'text': 'Call', 'labels': ['Response'], 'episode_done': False})
     agent.observe(obs)
     agent.act()
     self.assertEqual(agent.history.get_history_str(), 'Call')
Example #2
0
    def test_interactive_mode(self):
        """
        Test if conversation history is destroyed in MTurk mode.
        """
        # both manually setting bs to 1 and interactive mode true
        agent = get_agent(batchsize=1, interactive_mode=True)
        agent.observe(Message({'text': 'foo', 'episode_done': True}))
        response = agent.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )
        shared = create_agent_from_shared(agent.share())
        shared.observe(Message({'text': 'bar', 'episode_done': True}))
        response = shared.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )

        # now just bs 1
        agent = get_agent(batchsize=1, interactive_mode=False)
        agent.observe(Message({'text': 'foo', 'episode_done': True}))
        response = agent.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )
        shared = create_agent_from_shared(agent.share())
        shared.observe(Message({'text': 'bar', 'episode_done': True}))
        response = shared.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )

        # now just interactive
        shared = create_agent_from_shared(agent.share())
        agent.observe(Message({'text': 'foo', 'episode_done': True}))
        response = agent.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )
        shared = create_agent_from_shared(agent.share())
        shared.observe(Message({'text': 'bar', 'episode_done': True}))
        response = shared.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )

        # finally, actively attempt to sabotage
        agent = get_agent(batchsize=16, interactive_mode=False)
        agent.observe(Message({'text': 'foo', 'episode_done': True}))
        response = agent.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )
        shared = create_agent_from_shared(agent.share())
        shared.observe(Message({'text': 'bar', 'episode_done': True}))
        response = shared.act()
        self.assertIn(
            'Evaluating 0', response['text'], 'Incorrect output in single act()'
        )
Example #3
0
 def test_respond(self):
     """
     Tests respond() in the base Agent class, where the agent provides a string
     response to a single message.
     """
     agent = get_agent()
     message = Message({
         'text': "It's only a flesh wound.",
         'labels': ['Yield!'],
         'episode_done': True,
     })
     response = agent.respond(message)
     self.assertEqual(response, 'Training 0!')
     message = Message({
         'text': "It's only a flesh wound.",
         'eval_labels': ['Yield!'],
         'episode_done': True,
     })
     response = agent.respond(message)
     self.assertIn('Evaluating 0', response)
Example #4
0
    def test_mturk_racehistory(self):
        """
        Emulate a setting where batch_act misappropriately handles mturk.
        """
        agent = get_agent(batchsize=16, interactive_mode=True, echo=True)
        share1 = create_agent_from_shared(agent.share())

        share1.observe(Message({'text': 'thread1-msg1', 'episode_done': False}))
        share2 = create_agent_from_shared(agent.share())
        share2.observe(Message({'text': 'thread2-msg1', 'episode_done': False}))
        share1.act()
        share2.act()

        share1.observe(Message({'text': 'thread1-msg2', 'episode_done': False}))
        share2.observe(Message({'text': 'thread2-msg2', 'episode_done': False}))
        share2.act()
        share1.act()

        share2.observe(Message({'text': 'thread2-msg3', 'episode_done': False}))
        share1.observe(Message({'text': 'thread1-msg3', 'episode_done': False}))

        self.assertNotIn('thread1-msg1', share2.history.get_history_str())
        self.assertNotIn('thread2-msg1', share1.history.get_history_str())
        self.assertNotIn('thread1-msg2', share2.history.get_history_str())
        self.assertNotIn('thread2-msg2', share1.history.get_history_str())
Example #5
0
    def test_batch_act(self):
        """
        Make sure batch act calls the right step.
        """
        agent = get_agent()

        obs_labs = [
            Message({
                'text': "It's only a flesh wound.",
                'labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        obs_labs_vecs = []
        for o in obs_labs:
            agent.history.reset()
            agent.history.update_history(o)
            obs_labs_vecs.append(agent.vectorize(o, agent.history))
        reply = agent.batch_act(obs_labs_vecs)
        for i in range(len(obs_labs_vecs)):
            self.assertEqual(reply[i]['text'], 'Training {}!'.format(i))

        obs_elabs = [
            Message({
                'text': "It's only a flesh wound.",
                'eval_labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'eval_labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'eval_labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        obs_elabs_vecs = []
        for o in obs_elabs:
            agent.history.reset()
            agent.history.update_history(o)
            obs_elabs_vecs.append(agent.vectorize(o, agent.history))
        reply = agent.batch_act(obs_elabs_vecs)
        for i in range(len(obs_elabs_vecs)):
            self.assertIn('Evaluating {}'.format(i), reply[i]['text'])
Example #6
0
    def test_batch_respond(self):
        """
        Tests batch_respond() in the base Agent class, where the agent provides a batch
        response to a batch of messages.
        """
        agent = get_agent()

        obs_labs = [
            Message({
                'text': "It's only a flesh wound.",
                'labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        response = agent.batch_respond(obs_labs)
        for i, resp in enumerate(response):
            self.assertEqual(resp, 'Training {}!'.format(i))

        obs_elabs = [
            Message({
                'text': "It's only a flesh wound.",
                'eval_labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'eval_labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'eval_labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        response = agent.batch_respond(obs_elabs)
        for i, resp in enumerate(response):
            self.assertIn('Evaluating {}'.format(i), resp)
Example #7
0
    def test_batchify(self):
        """
        Make sure the batchify function sets up the right fields.
        """
        agent = get_agent(rank_candidates=True)
        obs_labs = [
            Message({
                'text': 'It\'s only a flesh wound.',
                'labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        obs_elabs = [
            Message({
                'text': 'It\'s only a flesh wound.',
                'eval_labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'eval_labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'eval_labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        for obs_batch in (obs_labs, obs_elabs):
            lab_key = 'labels' if 'labels' in obs_batch[0] else 'eval_labels'

            # nothing has been vectorized yet so should be empty
            batch = agent.batchify(obs_batch)
            self.assertIsNone(batch.text_vec)
            self.assertIsNone(batch.text_lengths)
            self.assertIsNone(batch.label_vec)
            self.assertIsNone(batch.label_lengths)
            self.assertIsNone(batch.labels)
            self.assertIsNone(batch.valid_indices)
            self.assertIsNone(batch.candidates)
            self.assertIsNone(batch.candidate_vecs)
            self.assertIsNone(batch.image)

            obs_vecs = []
            for o in obs_batch:
                agent.history.reset()
                agent.history.update_history(o)
                obs_vecs.append(
                    agent.vectorize(o,
                                    agent.history,
                                    add_start=False,
                                    add_end=False))

            # is_valid should map to nothing
            def is_valid(obs):
                return False

            agent.is_valid = is_valid

            batch = agent.batchify(obs_batch)
            self.assertIsNone(batch.text_vec)
            self.assertIsNone(batch.text_lengths)
            self.assertIsNone(batch.label_vec)
            self.assertIsNone(batch.label_lengths)
            self.assertIsNone(batch.labels)
            self.assertIsNone(batch.valid_indices)
            self.assertIsNone(batch.candidates)
            self.assertIsNone(batch.candidate_vecs)
            self.assertIsNone(batch.image)

            # is_valid should check for text_vec
            def is_valid(obs):
                return 'text_vec' in obs

            agent.is_valid = is_valid

            batch = agent.batchify(obs_vecs)
            # which fields were filled vs should be empty?
            self.assertIsNotNone(batch.text_vec)
            self.assertIsNotNone(batch.text_lengths)
            self.assertIsNotNone(batch.label_vec)
            self.assertIsNotNone(batch.label_lengths)
            self.assertIsNotNone(batch.labels)
            self.assertIsNotNone(batch.valid_indices)
            self.assertIsNone(batch.candidates)
            self.assertIsNone(batch.candidate_vecs)
            self.assertIsNone(batch.image)

            # contents of certain fields:
            self.assertEqual(
                batch.text_vec.tolist(),
                [[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6], [1, 2, 0, 0, 0, 0]],
            )
            self.assertEqual(batch.text_lengths, [5, 6, 2])
            self.assertEqual(
                batch.label_vec.tolist(),
                [[1, 0, 0, 0, 0], [1, 2, 3, 4, 5], [1, 2, 0, 0, 0]],
            )
            self.assertEqual(batch.label_lengths, [1, 5, 2])
            self.assertEqual(batch.labels, [o[lab_key][0] for o in obs_batch])
            self.assertEqual(list(batch.valid_indices), [0, 1, 2])

            # now sort the batch, make sure fields are in sorted order
            batch = agent.batchify(obs_vecs, sort=True)
            self.assertEqual(
                batch.text_vec.tolist(),
                [[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 0], [1, 2, 0, 0, 0, 0]],
            )
            self.assertEqual(batch.text_lengths, [6, 5, 2])
            self.assertEqual(
                batch.label_vec.tolist(),
                [[1, 2, 3, 4, 5], [1, 0, 0, 0, 0], [1, 2, 0, 0, 0]],
            )
            self.assertEqual(batch.label_lengths, [5, 1, 2])
            labs = [o[lab_key][0] for o in obs_batch]
            self.assertEqual(batch.labels, [labs[i] for i in [1, 0, 2]])
            self.assertEqual(list(batch.valid_indices), [1, 0, 2])

            # now sort just on ys
            new_vecs = [vecs.copy() for vecs in obs_vecs]
            for vec in new_vecs:
                vec.pop('text')
                vec.pop('text_vec')

            def is_valid(obs):
                return 'labels_vec' in obs or 'eval_labels_vec' in obs

            agent.is_valid = is_valid

            batch = agent.batchify(new_vecs, sort=True)
            self.assertIsNone(batch.text_vec)
            self.assertIsNone(batch.text_lengths)
            self.assertIsNotNone(batch.label_vec)
            self.assertIsNotNone(batch.label_lengths)
            self.assertEqual(
                batch.label_vec.tolist(),
                [[1, 2, 3, 4, 5], [1, 2, 0, 0, 0], [1, 0, 0, 0, 0]],
            )
            self.assertEqual(batch.label_lengths, [5, 2, 1])
            labs = [o[lab_key][0] for o in new_vecs]
            self.assertEqual(batch.labels, [labs[i] for i in [1, 2, 0]])
            self.assertEqual(list(batch.valid_indices), [1, 2, 0])

            # test is_valid
            def is_valid(obs):
                return 'text_vec' in obs and len(obs['text_vec']) < 3

            agent.is_valid = is_valid

            batch = agent.batchify(obs_vecs)
            self.assertEqual(batch.text_vec.tolist(), [[1, 2]])
            self.assertEqual(batch.text_lengths, [2])
            self.assertEqual(batch.label_vec.tolist(), [[1, 2]])
            self.assertEqual(batch.label_lengths, [2])
            self.assertEqual(batch.labels, obs_batch[2][lab_key])
            self.assertEqual(list(batch.valid_indices), [2])

        agent.history.reset()
        obs_cands = [
            agent.vectorize(Message({'label_candidates': ['A', 'B', 'C']}),
                            agent.history),
            agent.vectorize(
                Message({'label_candidates': ['1', '2', '5', '3', 'Sir']}),
                agent.history,
            ),
            agent.vectorize(Message({'label_candidates': ['Do', 'Re', 'Mi']}),
                            agent.history),
            agent.vectorize(
                Message({'label_candidates': ['Fa', 'So', 'La', 'Ti']}),
                agent.history),
        ]

        # is_valid should check for label candidates vecs
        def is_valid(obs):
            return 'label_candidates_vecs' in obs

        agent.is_valid = is_valid

        batch = agent.batchify(obs_cands)
        self.assertTrue(agent.rank_candidates, 'Agent not set up to rank.')
        self.assertIsNone(batch.text_vec)
        self.assertIsNone(batch.text_lengths)
        self.assertIsNone(batch.label_vec)
        self.assertIsNone(batch.label_lengths)
        self.assertIsNone(batch.labels)
        self.assertIsNotNone(batch.valid_indices)
        self.assertIsNotNone(batch.candidates)
        self.assertIsNotNone(batch.candidate_vecs)
        self.assertEqual(list(batch.valid_indices), [0, 1, 2, 3])
        self.assertEqual(batch.candidates,
                         [o['label_candidates'] for o in obs_cands])
        self.assertEqual(len(batch.candidate_vecs), len(obs_cands))
        for i, cs in enumerate(batch.candidate_vecs):
            self.assertEqual(len(cs), len(obs_cands[i]['label_candidates']))
Example #8
0
    def test_vectorize(self):
        """
        Test the vectorization of observations.

        Make sure they do not recompute results, and respect the different param
        options.
        """
        agent = get_agent()
        obs_labs = Message({
            'text': 'No. Try not.',
            'labels': ['Do.', 'Do not.'],
            'episode_done': True
        })
        obs_elabs = Message({
            'text': 'No. Try not.',
            'eval_labels': ['Do.', 'Do not.'],
            'episode_done': True,
        })

        for obs in (obs_labs, obs_elabs):
            lab_key = 'labels' if 'labels' in obs else 'eval_labels'
            lab_vec = lab_key + '_vec'
            lab_chc = lab_key + '_choice'

            inp = obs.copy()
            # test add_start=True, add_end=True
            agent.history.reset()
            agent.history.update_history(inp)
            out = agent.vectorize(inp,
                                  agent.history,
                                  add_start=True,
                                  add_end=True)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertEqual(out[lab_vec][1].item(), 1)
            self.assertEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test add_start=True, add_end=False
            inp = obs.copy()
            out = agent.vectorize(inp,
                                  agent.history,
                                  add_start=True,
                                  add_end=False)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertNotEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test add_start=False, add_end=True
            inp = obs.copy()
            out = agent.vectorize(inp,
                                  agent.history,
                                  add_start=False,
                                  add_end=True)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertNotEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test add_start=False, add_end=False
            inp = obs.copy()
            out = agent.vectorize(inp,
                                  agent.history,
                                  add_start=False,
                                  add_end=False)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertNotEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertNotEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test caching of tensors
            out_again = agent.vectorize(out, agent.history)
            # should have cached result from before
            self.assertIs(out['text_vec'], out_again['text_vec'])
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # next: should truncate cached result
            prev_vec = out['text_vec']
            out_again = agent.vectorize(out, agent.history, text_truncate=1)
            self.assertIsNot(prev_vec, out_again['text_vec'])
            self.assertEqual(out['text_vec'].tolist(), [3])

        # test split_lines
        agent = get_agent(split_lines=True)
        obs = Message({
            'text': 'Hello.\nMy name is Inogo Montoya.\n'
            'You killed my father.\nPrepare to die.',
            'episode_done': True,
        })
        agent.history.update_history(obs)
        vecs = agent.history.get_history_vec_list()
        self.assertEqual(vecs, [[1], [1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3]])

        # check cache
        out_again = agent.vectorize(obs, agent.history)
        vecs = agent.history.get_history_vec_list()
        self.assertEqual(vecs, [[1], [1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3]])