def test_map_unmap(self):
        observations = []
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        opt = {}
        opt['no_cuda'] = True
        opt['history_tokens'] = 10000
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        dict = MockDict()

        shared = {'opt': opt, 'dict': dict}
        agent = TorchAgent(opt, shared)

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(text_vecs is not None, "Missing \'text_vecs\' field.")
        self.assertTrue(text_vecs.numpy().tolist() == [[1, 3, 5], [1, 3, 5]],
                        "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(label_vecs is not None, "Missing \'label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[1, 3, 5, 2], [1, 3, 5, 2]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(labels == ["Paint on a canvas.", "Paint on a canvas."],
                        "Doesn't return correct labels.")
        self.assertTrue(valid_inds == [0, 3],
                        "Returns incorrect indices of valid observations.")

        observations = []
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(label_vecs is not None, "Missing \'label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[1, 3, 5, 2], [1, 3, 5, 2]],
            "Incorrectly vectorized text field of obs_batch.")

        predictions = ["Oil on a canvas.", "Oil on a canvas."]
        expected_unmapped = [
            "Oil on a canvas.", None, None, "Oil on a canvas.", None, None
        ]
        self.assertTrue(
            agent.unmap_valid(predictions, valid_inds, 6) == expected_unmapped,
            "Unmapped predictions do not match expected results.")
示例#2
0
    def test_map_unmap(self):
        try:
            from parlai.core.torch_agent import TorchAgent
        except ImportError as e:
            if 'pytorch' in e.msg:
                print('Skipping TestTorchAgent.test_map_unmap, no pytorch.')
                return

        observations = []
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        opt = {}
        opt['no_cuda'] = True
        opt['truncate'] = 10000
        opt['history_dialog'] = 10
        opt['history_replies'] = 'label_else_model'
        mdict = MockDict()

        shared = {'opt': opt, 'dict': mdict}
        agent = TorchAgent(opt, shared)

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, text_lengths, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(text_vecs is not None, "Missing \'text_vecs\' field.")
        self.assertTrue(text_vecs.numpy().tolist() == [[7, 8, 9], [7, 8, 9]],
                        "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(text_lengths.numpy().tolist() == [3, 3],
                        "Incorrect text vector lengths returned.")
        self.assertTrue(label_vecs is not None, "Missing \'label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")
        self.assertTrue(labels == ["Paint on a canvas.", "Paint on a canvas."],
                        "Doesn't return correct labels.")
        self.assertTrue(valid_inds == [0, 3],
                        "Returns incorrect indices of valid observations.")

        observations = []
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})
        observations.append({
            "text": "What is a painting?",
            "eval_labels": ["Paint on a canvas."]
        })
        observations.append({})
        observations.append({})

        vec_observations = [agent.vectorize(obs) for obs in observations]

        mapped_valid = agent.map_valid(vec_observations)

        text_vecs, text_lengths, label_vecs, labels, valid_inds = mapped_valid

        self.assertTrue(label_vecs is not None,
                        "Missing \'eval_label_vec\' field.")
        self.assertTrue(
            label_vecs.numpy().tolist() == [[
                mdict.START_IDX, 7, 8, 9, mdict.END_IDX
            ], [mdict.START_IDX, 7, 8, 9, mdict.END_IDX]],
            "Incorrectly vectorized text field of obs_batch.")

        predictions = ["Oil on a canvas.", "Oil on a canvas."]
        expected_unmapped = [
            "Oil on a canvas.", None, None, "Oil on a canvas.", None, None
        ]
        self.assertTrue(
            agent.unmap_valid(predictions, valid_inds, 6) == expected_unmapped,
            "Unmapped predictions do not match expected results.")