コード例 #1
0
 def testToFromComponents(self, shape, fields, field_specs):
     struct = StructuredTensor.from_fields(fields, shape)
     spec = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec(
         row_partitions=[], static_inner_shape=shape, dtype=dtypes.int64),
                                  _fields=field_specs)
     actual_components = spec._to_components(struct)
     rt_reconstructed = spec._from_components(actual_components)
     self.assertAllEqual(struct, rt_reconstructed)
コード例 #2
0
    def testConstruction(self):
        spec1_fields = dict(a=T_1_2_3_4)
        spec1 = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec(
            row_partitions=[],
            static_inner_shape=tensor_shape.TensorShape([1, 2, 3]),
            dtype=dtypes.int64),
                                      _fields=spec1_fields)
        self.assertEqual(spec1._shape, (1, 2, 3))
        self.assertEqual(spec1._field_specs, spec1_fields)

        spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
        spec2 = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec(
            row_partitions=[],
            static_inner_shape=tensor_shape.TensorShape([1, 2]),
            dtype=dtypes.int64),
                                      _fields=spec2_fields)
        self.assertEqual(spec2._shape, (1, 2))
        self.assertEqual(spec2._field_specs, spec2_fields)
コード例 #3
0
 def testValueType(self):
     spec1 = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec(
         row_partitions=[], static_inner_shape=[1, 2], dtype=dtypes.int64),
                                   _fields=dict(a=T_1_2))
     self.assertEqual(spec1.value_type, StructuredTensor)