def to_shape_proto(shape: utils.Shape) -> feature_pb2.Shape: """Converts TFDS shape to Shape proto (-1 is used for unspecified dimensions).""" dimensions = [] for dimension in shape: if dimension is None or dimension < 0: dimensions.append(-1) else: dimensions.append(dimension) return feature_pb2.Shape(dimensions=dimensions)
def test_from_shape_proto_unspecified(): shape_proto = feature_pb2.Shape(dimensions=[28, 28, -1]) assert [28, 28, None] == feature.from_shape_proto(shape_proto)
def test_from_shape_proto_normal(): shape_proto = feature_pb2.Shape(dimensions=[28, 28, 1]) assert [28, 28, 1] == feature.from_shape_proto(shape_proto)
def test_from_shape_proto_single_zero_dimension(): shape_proto = feature_pb2.Shape(dimensions=[0]) assert [0] == feature.from_shape_proto(shape_proto)