Esempio n. 1
0
    def test_interactive_mode(self):
        """Test if conversation history is destroyed in MTurk mode."""
        # both manually setting bs to 1 and interactive mode true
        agent = get_agent(batchsize=1, interactive_mode=True)
        agent.observe(Message({'text': 'foo', 'episode_done': True}))
        response = agent.act()
        self.assertIn('Evaluating 0', response['text'],
                      'Incorrect output in single act()')
        shared = create_agent_from_shared(agent.share())
        shared.observe(Message({'text': 'bar', 'episode_done': True}))
        response = shared.act()
        self.assertIn('Evaluating 0', response['text'],
                      'Incorrect output in single act()')

        # now just bs 1
        agent = get_agent(batchsize=1, interactive_mode=False)
        agent.observe(Message({'text': 'foo', 'episode_done': True}))
        response = agent.act()
        self.assertIn('Evaluating 0', response['text'],
                      'Incorrect output in single act()')
        shared = create_agent_from_shared(agent.share())
        shared.observe(Message({'text': 'bar', 'episode_done': True}))
        response = shared.act()
        self.assertIn('Evaluating 0', response['text'],
                      'Incorrect output in single act()')

        # now just interactive
        shared = create_agent_from_shared(agent.share())
        agent.observe(Message({'text': 'foo', 'episode_done': True}))
        response = agent.act()
        self.assertIn('Evaluating 0', response['text'],
                      'Incorrect output in single act()')
        shared = create_agent_from_shared(agent.share())
        shared.observe(Message({'text': 'bar', 'episode_done': True}))
        response = shared.act()
        self.assertIn('Evaluating 0', response['text'],
                      'Incorrect output in single act()')

        # finally, the expected failure
        with self.assertRaises(RuntimeError):
            agent = get_agent(batchsize=16, interactive_mode=False)
            agent.observe(Message({'text': 'foo', 'episode_done': True}))
            response = agent.act()
            shared = create_agent_from_shared(agent.share())
            shared.observe(Message({'text': 'bar', 'episode_done': True}))
            response = shared.act()
Esempio n. 2
0
    def test_batch_act(self):
        """Make sure batch act calls the right step."""
        agent = get_agent()

        obs_labs = [
            Message({
                'text': "It's only a flesh wound.",
                'labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        obs_labs_vecs = []
        for o in obs_labs:
            agent.history.reset()
            agent.history.update_history(o)
            obs_labs_vecs.append(agent.vectorize(o, agent.history))
        reply = agent.batch_act(obs_labs_vecs)
        for i in range(len(obs_labs_vecs)):
            self.assertEqual(reply[i]['text'], 'Training {}!'.format(i))

        obs_elabs = [
            Message({
                'text': "It's only a flesh wound.",
                'eval_labels': ['Yield!'],
                'episode_done': True,
            }),
            Message({
                'text': 'The needs of the many outweigh...',
                'eval_labels': ['The needs of the few.'],
                'episode_done': True,
            }),
            Message({
                'text': 'Hello there.',
                'eval_labels': ['General Kenobi.'],
                'episode_done': True,
            }),
        ]
        obs_elabs_vecs = []
        for o in obs_elabs:
            agent.history.reset()
            agent.history.update_history(o)
            obs_elabs_vecs.append(agent.vectorize(o, agent.history))
        reply = agent.batch_act(obs_elabs_vecs)
        for i in range(len(obs_elabs_vecs)):
            self.assertIn('Evaluating {}'.format(i), reply[i]['text'])
Esempio n. 3
0
    def test_mturk_racehistory(self):
        """Emulate a setting where batch_act misappropriately handles mturk."""
        agent = get_agent(batchsize=16, interactive_mode=True, echo=True)
        share1 = create_agent_from_shared(agent.share())

        share1.observe(Message({
            'text': 'thread1-msg1',
            'episode_done': False
        }))
        share2 = create_agent_from_shared(agent.share())
        share2.observe(Message({
            'text': 'thread2-msg1',
            'episode_done': False
        }))
        share1.act()
        share2.act()

        share1.observe(Message({
            'text': 'thread1-msg2',
            'episode_done': False
        }))
        share2.observe(Message({
            'text': 'thread2-msg2',
            'episode_done': False
        }))
        share2.act()
        share1.act()

        share2.observe(Message({
            'text': 'thread2-msg3',
            'episode_done': False
        }))
        share1.observe(Message({
            'text': 'thread1-msg3',
            'episode_done': False
        }))

        self.assertNotIn('thread1-msg1', share2.history.get_history_str())
        self.assertNotIn('thread2-msg1', share1.history.get_history_str())
        self.assertNotIn('thread1-msg2', share2.history.get_history_str())
        self.assertNotIn('thread2-msg2', share1.history.get_history_str())
Esempio n. 4
0
    def test_batchify(self):
        """Make sure the batchify function sets up the right fields."""
        agent = get_agent(rank_candidates=True)
        obs_labs = [
            Message(
                {
                    'text': 'It\'s only a flesh wound.',
                    'labels': ['Yield!'],
                    'episode_done': True,
                }
            ),
            Message(
                {
                    'text': 'The needs of the many outweigh...',
                    'labels': ['The needs of the few.'],
                    'episode_done': True,
                }
            ),
            Message(
                {
                    'text': 'Hello there.',
                    'labels': ['General Kenobi.'],
                    'episode_done': True,
                }
            ),
        ]
        obs_elabs = [
            Message(
                {
                    'text': 'It\'s only a flesh wound.',
                    'eval_labels': ['Yield!'],
                    'episode_done': True,
                }
            ),
            Message(
                {
                    'text': 'The needs of the many outweigh...',
                    'eval_labels': ['The needs of the few.'],
                    'episode_done': True,
                }
            ),
            Message(
                {
                    'text': 'Hello there.',
                    'eval_labels': ['General Kenobi.'],
                    'episode_done': True,
                }
            ),
        ]
        for obs_batch in (obs_labs, obs_elabs):
            lab_key = 'labels' if 'labels' in obs_batch[0] else 'eval_labels'

            # nothing has been vectorized yet so should be empty
            batch = agent.batchify(obs_batch)
            self.assertIsNone(batch.text_vec)
            self.assertIsNone(batch.text_lengths)
            self.assertIsNone(batch.label_vec)
            self.assertIsNone(batch.label_lengths)
            self.assertIsNone(batch.labels)
            self.assertIsNone(batch.valid_indices)
            self.assertIsNone(batch.candidates)
            self.assertIsNone(batch.candidate_vecs)
            self.assertIsNone(batch.image)

            obs_vecs = []
            for o in obs_batch:
                agent.history.reset()
                agent.history.update_history(o)
                obs_vecs.append(
                    agent.vectorize(o, agent.history, add_start=False, add_end=False)
                )

            # is_valid should map to nothing
            def is_valid(obs):
                return False

            agent.is_valid = is_valid

            batch = agent.batchify(obs_batch)
            self.assertIsNone(batch.text_vec)
            self.assertIsNone(batch.text_lengths)
            self.assertIsNone(batch.label_vec)
            self.assertIsNone(batch.label_lengths)
            self.assertIsNone(batch.labels)
            self.assertIsNone(batch.valid_indices)
            self.assertIsNone(batch.candidates)
            self.assertIsNone(batch.candidate_vecs)
            self.assertIsNone(batch.image)

            # is_valid should check for text_vec
            def is_valid(obs):
                return 'text_vec' in obs

            agent.is_valid = is_valid

            batch = agent.batchify(obs_vecs)
            # which fields were filled vs should be empty?
            self.assertIsNotNone(batch.text_vec)
            self.assertIsNotNone(batch.text_lengths)
            self.assertIsNotNone(batch.label_vec)
            self.assertIsNotNone(batch.label_lengths)
            self.assertIsNotNone(batch.labels)
            self.assertIsNotNone(batch.valid_indices)
            self.assertIsNone(batch.candidates)
            self.assertIsNone(batch.candidate_vecs)
            self.assertIsNone(batch.image)

            # contents of certain fields:
            self.assertEqual(
                batch.text_vec.tolist(),
                [[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6], [1, 2, 0, 0, 0, 0]],
            )
            self.assertEqual(batch.text_lengths, [5, 6, 2])
            self.assertEqual(
                batch.label_vec.tolist(),
                [[1, 0, 0, 0, 0], [1, 2, 3, 4, 5], [1, 2, 0, 0, 0]],
            )
            self.assertEqual(batch.label_lengths, [1, 5, 2])
            self.assertEqual(batch.labels, [o[lab_key][0] for o in obs_batch])
            self.assertEqual(list(batch.valid_indices), [0, 1, 2])

            # now sort the batch, make sure fields are in sorted order
            batch = agent.batchify(obs_vecs, sort=True)
            self.assertEqual(
                batch.text_vec.tolist(),
                [[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 0], [1, 2, 0, 0, 0, 0]],
            )
            self.assertEqual(batch.text_lengths, [6, 5, 2])
            self.assertEqual(
                batch.label_vec.tolist(),
                [[1, 2, 3, 4, 5], [1, 0, 0, 0, 0], [1, 2, 0, 0, 0]],
            )
            self.assertEqual(batch.label_lengths, [5, 1, 2])
            labs = [o[lab_key][0] for o in obs_batch]
            self.assertEqual(batch.labels, [labs[i] for i in [1, 0, 2]])
            self.assertEqual(list(batch.valid_indices), [1, 0, 2])

            # now sort just on ys
            new_vecs = [vecs.copy() for vecs in obs_vecs]
            for vec in new_vecs:
                vec.pop('text')
                vec.pop('text_vec')

            def is_valid(obs):
                return 'labels_vec' in obs or 'eval_labels_vec' in obs

            agent.is_valid = is_valid

            batch = agent.batchify(new_vecs, sort=True)
            self.assertIsNone(batch.text_vec)
            self.assertIsNone(batch.text_lengths)
            self.assertIsNotNone(batch.label_vec)
            self.assertIsNotNone(batch.label_lengths)
            self.assertEqual(
                batch.label_vec.tolist(),
                [[1, 2, 3, 4, 5], [1, 2, 0, 0, 0], [1, 0, 0, 0, 0]],
            )
            self.assertEqual(batch.label_lengths, [5, 2, 1])
            labs = [o[lab_key][0] for o in new_vecs]
            self.assertEqual(batch.labels, [labs[i] for i in [1, 2, 0]])
            self.assertEqual(list(batch.valid_indices), [1, 2, 0])

            # test is_valid
            def is_valid(obs):
                return 'text_vec' in obs and len(obs['text_vec']) < 3

            agent.is_valid = is_valid

            batch = agent.batchify(obs_vecs)
            self.assertEqual(batch.text_vec.tolist(), [[1, 2]])
            self.assertEqual(batch.text_lengths, [2])
            self.assertEqual(batch.label_vec.tolist(), [[1, 2]])
            self.assertEqual(batch.label_lengths, [2])
            self.assertEqual(batch.labels, obs_batch[2][lab_key])
            self.assertEqual(list(batch.valid_indices), [2])

        agent.history.reset()
        obs_cands = [
            agent.vectorize(
                Message({'label_candidates': ['A', 'B', 'C']}), agent.history
            ),
            agent.vectorize(
                Message({'label_candidates': ['1', '2', '5', '3', 'Sir']}),
                agent.history,
            ),
            agent.vectorize(
                Message({'label_candidates': ['Do', 'Re', 'Mi']}), agent.history
            ),
            agent.vectorize(
                Message({'label_candidates': ['Fa', 'So', 'La', 'Ti']}), agent.history
            ),
        ]

        # is_valid should check for label candidates vecs
        def is_valid(obs):
            return 'label_candidates_vecs' in obs

        agent.is_valid = is_valid

        batch = agent.batchify(obs_cands)
        self.assertTrue(agent.rank_candidates, 'Agent not set up to rank.')
        self.assertIsNone(batch.text_vec)
        self.assertIsNone(batch.text_lengths)
        self.assertIsNone(batch.label_vec)
        self.assertIsNone(batch.label_lengths)
        self.assertIsNone(batch.labels)
        self.assertIsNotNone(batch.valid_indices)
        self.assertIsNotNone(batch.candidates)
        self.assertIsNotNone(batch.candidate_vecs)
        self.assertEqual(list(batch.valid_indices), [0, 1, 2, 3])
        self.assertEqual(batch.candidates, [o['label_candidates'] for o in obs_cands])
        self.assertEqual(len(batch.candidate_vecs), len(obs_cands))
        for i, cs in enumerate(batch.candidate_vecs):
            self.assertEqual(len(cs), len(obs_cands[i]['label_candidates']))
Esempio n. 5
0
    def test_vectorize(self):
        """Test the vectorization of observations.

        Make sure they do not recompute results, and respect the different
        param options.
        """
        agent = get_agent()
        obs_labs = Message(
            {'text': 'No. Try not.', 'labels': ['Do.', 'Do not.'], 'episode_done': True}
        )
        obs_elabs = Message(
            {
                'text': 'No. Try not.',
                'eval_labels': ['Do.', 'Do not.'],
                'episode_done': True,
            }
        )

        for obs in (obs_labs, obs_elabs):
            lab_key = 'labels' if 'labels' in obs else 'eval_labels'
            lab_vec = lab_key + '_vec'
            lab_chc = lab_key + '_choice'

            inp = obs.copy()
            # test add_start=True, add_end=True
            agent.history.reset()
            agent.history.update_history(inp)
            out = agent.vectorize(inp, agent.history, add_start=True, add_end=True)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertEqual(out[lab_vec][1].item(), 1)
            self.assertEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test add_start=True, add_end=False
            inp = obs.copy()
            out = agent.vectorize(inp, agent.history, add_start=True, add_end=False)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertNotEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test add_start=False, add_end=True
            inp = obs.copy()
            out = agent.vectorize(inp, agent.history, add_start=False, add_end=True)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertNotEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test add_start=False, add_end=False
            inp = obs.copy()
            out = agent.vectorize(inp, agent.history, add_start=False, add_end=False)
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # note that label could be either label above
            self.assertNotEqual(out[lab_vec][0].item(), MockDict.BEG_IDX)
            self.assertNotEqual(out[lab_vec][-1].item(), MockDict.END_IDX)
            self.assertEqual(out[lab_chc][:2], 'Do')

            # test caching of tensors
            out_again = agent.vectorize(out, agent.history)
            # should have cached result from before
            self.assertIs(out['text_vec'], out_again['text_vec'])
            self.assertEqual(out['text_vec'].tolist(), [1, 2, 3])
            # next: should truncate cached result
            prev_vec = out['text_vec']
            out_again = agent.vectorize(out, agent.history, text_truncate=1)
            self.assertIsNot(prev_vec, out_again['text_vec'])
            self.assertEqual(out['text_vec'].tolist(), [3])

        # test split_lines
        agent = get_agent(split_lines=True)
        obs = Message(
            {
                'text': 'Hello.\nMy name is Inogo Montoya.\n'
                'You killed my father.\nPrepare to die.',
                'episode_done': True,
            }
        )
        agent.history.update_history(obs)
        vecs = agent.history.get_history_vec_list()
        self.assertEqual(vecs, [[1], [1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3]])

        # check cache
        out_again = agent.vectorize(obs, agent.history)
        vecs = agent.history.get_history_vec_list()
        self.assertEqual(vecs, [[1], [1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3]])