def testFlatTensorSpecs(self): # Note that the batchable tensor list encoding for a StructuredTensor # contains a separate tensor for each leaf field. # In this example, _flat_tensor_specs in class StructuredTensorSpec is # called three times and it returns results with length 2, 3 and 11 # for "g", "h" and `struct` respectively. fields = self._lambda_for_fields() rank = 4 if callable(fields): fields = fields( ) # deferred construction: fields may include tensors. struct = StructuredTensor.from_fields_and_rank(fields, rank) spec = type_spec.type_spec_from_value(struct) flat_specs = spec._flat_tensor_specs self.assertEqual( flat_specs, [ # a , b tensor_spec.TensorSpec( shape=(1, 2, 3, 1), dtype=dtypes.float64, name=None), tensor_spec.TensorSpec( shape=(1, 2, 3, 1, 5), dtype=dtypes.float64, name=None), # c, d, e, f tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), # g tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), # h tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None), tensor_spec.TensorSpec( shape=None, dtype=dtypes.variant, name=None) ])
def testFulTypesForFlatTensors(self): # Note that the batchable tensor list encoding for a StructuredTensor # contains a separate tensor for each leaf field. # In this example, _flat_tensor_specs in class StructuredTensorSpec is # called three times and it returns results with length 2, 3 and 11 # for "g", "h" and `struct` respectively. fields = self._lambda_for_fields() rank = 4 if callable(fields): fields = fields( ) # deferred construction: fields may include tensors. struct = StructuredTensor.from_fields_and_rank(fields, rank) spec = type_spec.type_spec_from_value(struct) flat_specs = spec._flat_tensor_specs fulltype = fulltypes_for_flat_tensors(spec) expected_ft_list = [ # a, b full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET), full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET), # c, d, e, f full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UINT8) ]), full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT) ]), full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT) ]), full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT) ]), # g full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32) ]), full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32) ]), # h full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32) ]), full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32) ]), full_type_pb2.FullTypeDef( type_id=full_type_pb2.TFT_RAGGED, args=[ full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32) ]), ] self.assertEqual(len(expected_ft_list), len(flat_specs)) self.assertEqual(fulltype, expected_ft_list)