示例#1
0
def get_agent(**kwargs):
    """Return opt-initialized agent.

    :param kwargs: any kwargs you want to set using parser.set_params(**kwargs)
    """
    if 'no_cuda' not in kwargs:
        kwargs['no_cuda'] = True
    from parlai.core.params import ParlaiParser
    parser = ParlaiParser()
    TorchAgent.add_cmdline_args(parser)
    parser.set_params(**kwargs)
    opt = parser.parse_args(print_args=False)
    return TorchAgent(opt)
示例#2
0
    def test_maintain_dialog_history(self):
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print(
                    'Skipping TestTorchAgent.test_maintain_dialog_history, no pytorch.'
                )
                return

        from parlai.core.params import ParlaiParser
        parser = ParlaiParser()
        TorchAgent.add_cmdline_args(parser)
        parser.set_params(no_cuda=True, truncate=5)
        opt = parser.parse_args(print_args=False)
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)

        observation = {
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."],
            "episode_done": False
        }

        agent.maintain_dialog_history(observation)

        self.assertTrue('dialog' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('episode_done' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('labels' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue(
            list(agent.history['dialog']) == [7, 8, 9],
            "Failed adding vectorized text to dialog.")
        self.assertTrue(not agent.history['episode_done'],
                        "Failed to properly store episode_done field.")
        self.assertTrue(agent.history['labels'] == observation['labels'],
                        "Failed saving labels.")

        observation['text_vec'] = agent.maintain_dialog_history(observation)
        print(agent.history['dialog'])
        self.assertTrue(
            list(agent.history['dialog']) == [8, 9, 7, 8, 9],
            "Failed adding vectorized text to dialog.")
示例#3
0
    def test_maintain_dialog_history(self):
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print(
                    'Skipping TestTorchAgent.test_maintain_dialog_history, no pytorch.'
                )
                return

        opt = {}
        opt['no_cuda'] = True
        opt['truncate'] = 5
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)

        observation = {
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."],
            "episode_done": False
        }

        agent.maintain_dialog_history(observation)

        self.assertTrue('dialog' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('episode_done' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('labels' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue(
            list(agent.history['dialog']) == [7, 8, 9],
            "Failed adding vectorized text to dialog.")
        self.assertTrue(not agent.history['episode_done'],
                        "Failed to properly store episode_done field.")
        self.assertTrue(agent.history['labels'] == observation['labels'],
                        "Failed saving labels.")

        observation['text_vec'] = agent.maintain_dialog_history(observation)
        self.assertTrue(
            list(agent.history['dialog']) == [8, 9, 7, 8, 9],
            "Failed adding vectorized text to dialog.")
    def test_vectorize(self):
        opt = {}
        opt['no_cuda'] = True
        opt['history_tokens'] = 10000
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        dict = MockDict()

        shared = {'opt': opt, 'dict': dict}
        agent = TorchAgent(opt, shared)
        observation = {}
        observation["text"] = "What does the dog do?"
        observation["labels"] = ["The dog jumps over the cat."]

        obs_vec = agent.vectorize(observation)
        self.assertTrue(
            'text_vec' in obs_vec,
            "Field \'text_vec\' missing from vectorized observation")
        self.assertTrue(obs_vec['text_vec'].numpy().tolist() == [1, 3, 5],
                        "Vectorized text is incorrect.")
        self.assertTrue(
            'labels_vec' in obs_vec,
            "Field \'labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['labels_vec'][0].numpy().tolist() == [
                1, 3, 5, dict.END_IDX
            ], "Vectorized label is incorrect.")

        observation = {}
        observation["text"] = "What does the dog do?"
        observation["eval_labels"] = ["The dog jumps over the cat."]

        obs_vec = agent.vectorize(observation)

        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'][0].numpy().tolist() == [
                1, 3, 5, dict.END_IDX
            ], "Vectorized label is incorrect.")
    def test_maintain_dialog_history(self):
        opt = {}
        opt['no_cuda'] = True
        opt['history_tokens'] = 5
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        dict = MockDict()

        shared = {'opt': opt, 'dict': dict}
        agent = TorchAgent(opt, shared)

        observation = {
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."],
            "episode_done": False
        }

        agent.maintain_dialog_history(observation)

        self.assertTrue('dialog' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('episode_done' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue('labels' in agent.history,
                        "Failed initializing self.history.")
        self.assertTrue(
            list(agent.history['dialog']) == [1, 3, 5],
            "Failed adding vectorized text to dialog.")
        self.assertTrue(not agent.history['episode_done'],
                        "Failed to properly store episode_done field.")
        self.assertTrue(agent.history['labels'] == observation['labels'],
                        "Failed saving labels.")

        observation['text_vec'] = agent.maintain_dialog_history(observation)
        self.assertTrue(
            list(agent.history['dialog']) == [3, 5, 1, 3, 5],
            "Failed adding vectorized text to dialog.")
示例#6
0
    def test_vectorize(self):
        """
        Goal of this test is to make sure that the vectorize function is
        actually adding a new field.
        """
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print('Skipping TestTorchAgent.test_vectorize, no pytorch.')
                return

        opt = {}
        opt['no_cuda'] = True
        opt['truncate'] = 10000
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)
        observation = {}
        observation["text"] = "What does the dog do?"
        observation["labels"] = ["The dog jumps over the cat."]

        # add start and end
        obs_vec = agent.vectorize(observation, add_start=True, add_end=True)
        self.assertTrue(
            'text_vec' in obs_vec,
            "Field 'text_vec' missing from vectorized observation")
        self.assertTrue(obs_vec['text_vec'].numpy().tolist() == [7, 8, 9],
                        "Vectorized text is incorrect.")
        self.assertTrue(
            'labels_vec' in obs_vec,
            "Field 'labels_vec' missing from vectorized observation")
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
        # no start, add end
        obs_vec = agent.vectorize(observation, add_start=False, add_end=True)
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [7, 8, 9, mdict.END_IDX],
            "Vectorized label is incorrect.")
        # add start, no end
        obs_vec = agent.vectorize(observation, add_start=True, add_end=False)
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9
            ], "Vectorized label is incorrect.")
        # no start, no end
        obs_vec = agent.vectorize(observation, add_start=False, add_end=False)
        self.assertTrue(obs_vec['labels_vec'].numpy().tolist() == [7, 8, 9],
                        "Vectorized label is incorrect.")

        observation = {}
        observation["text"] = "What does the dog do?"
        observation["eval_labels"] = ["The dog jumps over the cat."]

        # eval_labels
        obs_vec = agent.vectorize(observation)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
        # truncate
        obs_vec = agent.vectorize(observation, truncate=3)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")

        # truncate
        obs_vec = agent.vectorize(observation, truncate=10)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
示例#7
0
    def test_map_unmap(self):
        try:
            from parlai.core.torch_agent import TorchAgent, Output
        except ImportError as e:
            if 'pytorch' in e.msg:
                print('Skipping TestTorchAgent.test_map_unmap, no pytorch.')
                return

        observations = []
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        opt = {}
        opt['no_cuda'] = True
        opt['truncate'] = 10000
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)

        vec_observations = [agent.vectorize(obs) for obs in observations]

        batch = agent.batchify(vec_observations)

        self.assertTrue(batch.text_vec is not None,
                        "Missing 'text_vecs' field.")
        self.assertTrue(
            batch.text_vec.numpy().tolist() == [[7, 8, 9], [7, 8, 9]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(batch.label_vec is not None,
                        "Missing 'label_vec' field.")
        self.assertTrue(
            batch.label_vec.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(
            batch.labels == ["Paint on a canvas.", "Paint on a canvas."],
            "Doesn't return correct labels: " + str(batch.labels))
        true_i = [0, 3]
        self.assertTrue(
            all(batch.valid_indices[i] == true_i[i] for i in range(2)),
            "Returns incorrect indices of valid observations.")

        observations = []
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        vec_observations = [agent.vectorize(obs) for obs in observations]

        batch = agent.batchify(vec_observations)

        self.assertTrue(batch.label_vec is not None,
                        "Missing \'eval_label_vec\' field.")
        self.assertTrue(
            batch.label_vec.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")

        batch_reply = [{} for i in range(6)]
        predictions = ["Oil on a canvas.", "Oil on a canvas."]
        output = Output(predictions, None)
        expected_unmapped = batch_reply.copy()
        expected_unmapped[0]["text"] = "Oil on a canvas."
        expected_unmapped[3]["text"] = "Oil on a canvas."
        self.assertTrue(
            agent.match_batch(batch_reply, batch.valid_indices,
                              output) == expected_unmapped,
            "Unmapped predictions do not match expected results.")
示例#8
0
    def test_vectorize(self):
        """
        Make sure that the vectorize function is actually adding a new field.
        """
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print('Skipping TestTorchAgent.test_vectorize, no pytorch.')
                return

        from parlai.core.params import ParlaiParser
        parser = ParlaiParser()
        TorchAgent.add_cmdline_args(parser)
        parser.set_params(no_cuda=True)
        opt = parser.parse_args(print_args=False)
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)
        observation = {}
        observation["text"] = "What does the dog do?"
        observation["labels"] = ["The dog jumps over the cat."]

        # add start and end
        obs_vec = agent.vectorize(observation, add_start=True, add_end=True)
        self.assertTrue(
            'text_vec' in obs_vec,
            "Field 'text_vec' missing from vectorized observation")
        self.assertTrue(obs_vec['text_vec'].numpy().tolist() == [7, 8, 9],
                        "Vectorized text is incorrect.")
        self.assertTrue(
            'labels_vec' in obs_vec,
            "Field 'labels_vec' missing from vectorized observation")
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
        # no start, add end
        obs_vec = agent.vectorize(observation, add_start=False, add_end=True)
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [7, 8, 9, mdict.END_IDX],
            "Vectorized label is incorrect.")
        # add start, no end
        obs_vec = agent.vectorize(observation, add_start=True, add_end=False)
        self.assertTrue(
            obs_vec['labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9
            ], "Vectorized label is incorrect.")
        # no start, no end
        obs_vec = agent.vectorize(observation, add_start=False, add_end=False)
        self.assertTrue(obs_vec['labels_vec'].numpy().tolist() == [7, 8, 9],
                        "Vectorized label is incorrect.")

        observation = {}
        observation["text"] = "What does the dog do?"
        observation["eval_labels"] = ["The dog jumps over the cat."]

        # eval_labels
        obs_vec = agent.vectorize(observation)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
        # truncate
        obs_vec = agent.vectorize(observation, truncate=2)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7
            ], "Vectorized label is incorrect: " +
            str(obs_vec['eval_labels_vec']))

        # truncate
        obs_vec = agent.vectorize(observation, truncate=10)
        self.assertTrue(
            'eval_labels_vec' in obs_vec,
            "Field \'eval_labels_vec\' missing from vectorized observation")
        self.assertTrue(
            obs_vec['eval_labels_vec'].numpy().tolist() == [
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], "Vectorized label is incorrect.")
    def test_map_unmap(self):
        observations = []
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        opt = {}
        opt['no_cuda'] = True
        opt['history_tokens'] = 10000
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        dict = MockDict()

        shared = {'opt': opt, 'dict': dict}
        agent = TorchAgent(opt, shared)

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(text_vecs is not None, "Missing \'text_vecs\' field.")
        self.assertTrue(text_vecs.numpy().tolist() == [[1, 3, 5], [1, 3, 5]],
                        "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(label_vecs is not None, "Missing \'label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[1, 3, 5, 2], [1, 3, 5, 2]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(labels == ["Paint on a canvas.", "Paint on a canvas."],
                        "Doesn't return correct labels.")
        self.assertTrue(valid_inds == [0, 3],
                        "Returns incorrect indices of valid observations.")

        observations = []
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(label_vecs is not None, "Missing \'label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[1, 3, 5, 2], [1, 3, 5, 2]],
            "Incorrectly vectorized text field of obs_batch.")

        predictions = ["Oil on a canvas.", "Oil on a canvas."]
        expected_unmapped = [
            "Oil on a canvas.", None, None, "Oil on a canvas.", None, None
        ]
        self.assertTrue(
            agent.unmap_valid(predictions, valid_inds, 6) == expected_unmapped,
            "Unmapped predictions do not match expected results.")
示例#10
0
    def test_map_unmap(self):
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print('Skipping TestTorchAgent.test_map_unmap, no pytorch.')
                return

        observations = []
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        opt = {}
        opt['no_cuda'] = True
        opt['truncate'] = 10000
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, text_lengths, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(text_vecs is not None, "Missing \'text_vecs\' field.")
        self.assertTrue(text_vecs.numpy().tolist() == [[7, 8, 9], [7, 8, 9]],
                        "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(text_lengths.numpy().tolist() == [3, 3],
                        "Incorrect text vector lengths returned.")
        self.assertTrue(label_vecs is not None, "Missing \'label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(labels == ["Paint on a canvas.", "Paint on a canvas."],
                        "Doesn't return correct labels.")
        self.assertTrue(valid_inds == [0, 3],
                        "Returns incorrect indices of valid observations.")

        observations = []
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, text_lengths, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(label_vecs is not None,
                        "Missing \'eval_label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")

        predictions = ["Oil on a canvas.", "Oil on a canvas."]
        expected_unmapped = [
            "Oil on a canvas.", None, None, "Oil on a canvas.", None, None
        ]
        self.assertTrue(
            agent.unmap_valid(predictions, valid_inds, 6) == expected_unmapped,
            "Unmapped predictions do not match expected results.")