示例#1
0
    def test_iterable_with_strings(self):
        """Test that strings are not treated as a sequence"""
        arr1 = np.array([0.4, 0.1])
        arr2 = np.array([1])

        res = list(np.extract_tensors([arr1, ["abc", [arr2]]]))

        assert len(res) == 2
        assert res[0] is arr1
        assert res[1] is arr2
示例#2
0
    def test_iterable_with_unpatched_numpy_arrays(self):
        """Test that the extraction ignores unpatched numpy arrays"""
        arr1 = np.array([0.4, 0.1])
        arr2 = np.array([1])

        res = list(np.extract_tensors([arr1, [onp.array([1, 2]), [arr2]]]))

        assert len(res) == 2
        assert res[0] is arr1
        assert res[1] is arr2
示例#3
0
 def test_empty_terable(self):
     """Test that an empty iterable returns nothing"""
     res = list(np.extract_tensors([]))
     assert res == []