def testIsSequence(self): self.assertFalse(nest.is_sequence("1234")) self.assertFalse(nest.is_sequence([1, 3, [4, 5]])) self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) self.assertFalse(nest.is_sequence([])) self.assertFalse(nest.is_sequence(set([1, 2]))) ones = array_ops.ones([2, 3]) self.assertFalse(nest.is_sequence(ones)) self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) self.assertFalse(nest.is_sequence(np.ones((4, 5)))) self.assertTrue(nest.is_sequence({"foo": 1, "bar": 2})) self.assertFalse( nest.is_sequence(sparse_tensor.SparseTensorValue([[0]], [0], [1])))
def testIsSequence(self): self.assertFalse(nest.is_sequence("1234")) self.assertFalse(nest.is_sequence([1, 3, [4, 5]])) self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) self.assertFalse(nest.is_sequence([])) self.assertFalse(nest.is_sequence(set([1, 2]))) ones = array_ops.ones([2, 3]) self.assertFalse(nest.is_sequence(ones)) self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) self.assertFalse(nest.is_sequence(np.ones((4, 5)))) self.assertTrue(nest.is_sequence({"foo": 1, "bar": 2})) self.assertFalse( nest.is_sequence(sparse_tensor.SparseTensorValue([[0]], [0], [1])))
def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) if nest.is_sequence(nested_args): dataset = map_func(*nested_args) else: dataset = map_func(nested_args) if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`map_func` must return a `Dataset` object.") self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes return dataset.make_dataset_resource()
def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) if nest.is_sequence(nested_args): dataset = map_func(*nested_args) else: dataset = map_func(nested_args) if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`map_func` must return a `Dataset` object.") self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes return dataset._as_variant_tensor() # pylint: disable=protected-access