Example #1
0
    def test_parse_tensor(self):
        # Zero-rank tensor
        attr = attr_value.AttrValue()
        attr.tensor.version_number = 1
        attr.tensor.dtype = types.DataType.DT_INT32
        t = parse.parse_attr(attr)
        self.assertTrue(isinstance(t, mil_types.int32))
        self.assertEqual(0, t.val)

        # Non-zero rank
        attr = attr_value.AttrValue()
        attr.tensor.version_number = 1
        attr.tensor.dtype = types.DataType.DT_INT32
        shaped_attr = self._attr_with_shape([(1, "outer"), (2, "middle"),
                                             (3, "inner")])
        attr.tensor.tensor_shape.dim.extend(shaped_attr.shape.dim)
        attr.tensor.int_val.extend([55, 56, 57])

        t = parse.parse_attr(attr)
        self.assertEqual([55, 56, 57], t.val.tolist())
        self.assertEqual("tensor", mil_types.get_type_info(t).name)

        # Note that the result of t.get_primitive() is a function that returns a type
        # rather than an instance of that type as it is when the tensor has rank zero.
        self.assertTrue(isinstance(t.get_primitive()(), mil_types.int32))
        self.assertEqual((1, 2, 3), t.get_shape())
Example #2
0
        def compare(expected, lst, field_name):
            attr = attr_value.AttrValue()
            field = getattr(attr.list, field_name)
            field.extend(lst)

            actual = parse.parse_attr(attr)
            self.assertEqual(expected, actual)
Example #3
0
    def test_parse_scalar(self):
        def compare(expected, val, field_name):
            a = attr_value.AttrValue()
            setattr(a, field_name, val)
            actual = parse.parse_attr(a)
            self.assertEqual(expected, actual)

        compare("a String", b"a String", "s")
        compare(55, 55, "i")
        compare(True, True, "b")

        attr = attr_value.AttrValue()
        attr.f = 12.3
        self.assertAlmostEqual(12.3, parse.parse_attr(attr), places=2)
Example #4
0
 def compare(expected, tf_type):
     attr = attr_value.AttrValue()
     attr.type = tf_type
     self.assertEqual(expected, parse.parse_attr(attr))
Example #5
0
 def compare(expected, dims, unknown_rank=0):
     attr = self._attr_with_shape(dims, unknown_rank)
     actual = parse.parse_attr(attr)
     self.assertEqual(expected, actual)
Example #6
0
 def compare(expected, val, field_name):
     a = attr_value.AttrValue()
     setattr(a, field_name, val)
     actual = parse.parse_attr(a)
     self.assertEqual(expected, actual)