예제 #1
0
 def ensure_batch_predictions_are_consistent(self):
     self.model.eval()
     single_predictions = []
     for i, instance in enumerate(self.instances):
         dataset = Batch([instance])
         tensors = dataset.as_tensor_dict(dataset.get_padding_lengths())
         result = self.model(**tensors)
         single_predictions.append(result)
     full_dataset = Batch(self.instances)
     batch_tensors = full_dataset.as_tensor_dict(
         full_dataset.get_padding_lengths())
     batch_predictions = self.model(**batch_tensors)
     for i, instance_predictions in enumerate(single_predictions):
         for key, single_predicted in instance_predictions.items():
             tolerance = 1e-6
             if 'loss' in key:
                 # Loss is particularly unstable; we'll just be satisfied if everything else is
                 # close.
                 continue
             single_predicted = single_predicted[0]
             batch_predicted = batch_predictions[key][i]
             if isinstance(single_predicted, torch.Tensor):
                 if single_predicted.size() != batch_predicted.size():
                     slices = tuple(
                         slice(0, size) for size in single_predicted.size())
                     batch_predicted = batch_predicted[slices]
                 assert_allclose(single_predicted.data.numpy(),
                                 batch_predicted.data.numpy(),
                                 atol=tolerance,
                                 err_msg=key)
             else:
                 assert single_predicted == batch_predicted, key
예제 #2
0
    def test_as_tensor_dict(self):
        dataset = Batch(self.instances)
        dataset.index_instances(self.vocab)
        padding_lengths = dataset.get_padding_lengths()
        tensors = dataset.as_tensor_dict(padding_lengths)
        text1 = tensors["text1"]["tokens"].detach().cpu().numpy()
        text2 = tensors["text2"]["tokens"].detach().cpu().numpy()

        numpy.testing.assert_array_almost_equal(text1, numpy.array([[2, 3, 4, 5, 6],
                                                                    [1, 3, 4, 5, 6]]))
        numpy.testing.assert_array_almost_equal(text2, numpy.array([[2, 3, 4, 1, 5, 6],
                                                                    [2, 3, 1, 0, 0, 0]]))
예제 #3
0
 def test_padding_lengths_uses_max_instance_lengths(self):
     dataset = Batch(self.instances)
     dataset.index_instances(self.vocab)
     padding_lengths = dataset.get_padding_lengths()
     assert padding_lengths == {"text1": {"num_tokens": 5}, "text2": {"num_tokens": 6}}