def testToFromComponents(self, shape, fields, field_specs): struct = StructuredTensor.from_fields(fields, shape) spec = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec( row_partitions=[], static_inner_shape=shape, dtype=dtypes.int64), _fields=field_specs) actual_components = spec._to_components(struct) rt_reconstructed = spec._from_components(actual_components) self.assertAllEqual(struct, rt_reconstructed)
def testConstruction(self): spec1_fields = dict(a=T_1_2_3_4) spec1 = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec( row_partitions=[], static_inner_shape=tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes.int64), _fields=spec1_fields) self.assertEqual(spec1._shape, (1, 2, 3)) self.assertEqual(spec1._field_specs, spec1_fields) spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1) spec2 = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec( row_partitions=[], static_inner_shape=tensor_shape.TensorShape([1, 2]), dtype=dtypes.int64), _fields=spec2_fields) self.assertEqual(spec2._shape, (1, 2)) self.assertEqual(spec2._field_specs, spec2_fields)
def testValueType(self): spec1 = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec( row_partitions=[], static_inner_shape=[1, 2], dtype=dtypes.int64), _fields=dict(a=T_1_2)) self.assertEqual(spec1.value_type, StructuredTensor)