def test_parse_tensor(self): # Zero-rank tensor attr = attr_value.AttrValue() attr.tensor.version_number = 1 attr.tensor.dtype = types.DataType.DT_INT32 t = parse.parse_attr(attr) self.assertTrue(isinstance(t, mil_types.int32)) self.assertEqual(0, t.val) # Non-zero rank attr = attr_value.AttrValue() attr.tensor.version_number = 1 attr.tensor.dtype = types.DataType.DT_INT32 shaped_attr = self._attr_with_shape([(1, "outer"), (2, "middle"), (3, "inner")]) attr.tensor.tensor_shape.dim.extend(shaped_attr.shape.dim) attr.tensor.int_val.extend([55, 56, 57]) t = parse.parse_attr(attr) self.assertEqual([55, 56, 57], t.val.tolist()) self.assertEqual("tensor", mil_types.get_type_info(t).name) # Note that the result of t.get_primitive() is a function that returns a type # rather than an instance of that type as it is when the tensor has rank zero. self.assertTrue(isinstance(t.get_primitive()(), mil_types.int32)) self.assertEqual((1, 2, 3), t.get_shape())
def compare(expected, lst, field_name): attr = attr_value.AttrValue() field = getattr(attr.list, field_name) field.extend(lst) actual = parse.parse_attr(attr) self.assertEqual(expected, actual)
def test_parse_scalar(self): def compare(expected, val, field_name): a = attr_value.AttrValue() setattr(a, field_name, val) actual = parse.parse_attr(a) self.assertEqual(expected, actual) compare("a String", b"a String", "s") compare(55, 55, "i") compare(True, True, "b") attr = attr_value.AttrValue() attr.f = 12.3 self.assertAlmostEqual(12.3, parse.parse_attr(attr), places=2)
def compare(expected, tf_type): attr = attr_value.AttrValue() attr.type = tf_type self.assertEqual(expected, parse.parse_attr(attr))
def compare(expected, dims, unknown_rank=0): attr = self._attr_with_shape(dims, unknown_rank) actual = parse.parse_attr(attr) self.assertEqual(expected, actual)
def compare(expected, val, field_name): a = attr_value.AttrValue() setattr(a, field_name, val) actual = parse.parse_attr(a) self.assertEqual(expected, actual)