def testMapDefunWithInvalidInput(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 c = constant_op.constant(2) with self.assertRaises(ValueError): # Fails at graph construction time for inputs with known shapes. r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] p = array_ops.placeholder(dtypes.int32) r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] with session.Session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(r, feed_dict={p: 0})
def testMapDefunWithInvalidInput(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 c = constant_op.constant(2) with self.assertRaises(ValueError): # Fails at graph construction time for inputs with known shapes. r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] p = array_ops.placeholder(dtypes.int32) r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] with session.Session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(r, feed_dict={p: 0})
def testMapDefunWithVariantTensor(self): @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.variant)]) def fn(x): return x st = sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) serialized = array_ops.stack([serialized, serialized]) map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.variant], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue(indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertValuesEqual(expected, actual)
def testMapDefunWithVariantTensorAsCaptured(self): st = sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): del x return serialized x = constant_op.constant([0, 0]) map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue(indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertSparseValuesEqual(expected, actual)
def benchmark_defun_vs_map_fn(self): """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def defun(x): return array_ops.identity(x) def fn(x): return array_ops.identity(x) base = math_ops.range(10000) for input_size in [10, 100, 1000, 10000]: num_iters = 10000 // input_size map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) map_fn_op = map_fn.map_fn(fn, base) self._run( op=map_defun_op, name="with_defun_size_%d" % input_size, num_iters=num_iters, benchmark_id=1) self._run( op=map_fn_op, name="without_defun_size_%d" % input_size, num_iters=num_iters, benchmark_id=2)
def testMapDefunPartialShapeInference(self): @function.Defun(dtypes.int32) def fn(x): return x elems = array_ops.placeholder(dtypes.int64, (None, 2)) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2, )]) self.assertEqual(result[0].get_shape().as_list(), [None, 2])
def testMapDefunShapeInference(self): @function.Defun(dtypes.int32) def fn(x): return x nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2, )])[0] self.assertEqual(result.get_shape(), (3, 2))
def testMapDefunPartialShapeInference(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def fn(x): return x elems = array_ops.placeholder(dtypes.int64, (None, 2)) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) self.assertEqual(result[0].get_shape().as_list(), [None, 2])
def testMapDefunSimple(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2, )])[0] expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunRaisesDefunError(self): @function.Defun(dtypes.int32) def fn(x): with ops.control_dependencies([check_ops.assert_equal(x, 0)]): return array_ops.identity(x) elems = constant_op.constant([0, 0, 0, 37, 0]) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result)
def testMapDefunMismatchedTypes(self): @function.Defun(dtypes.int32) def fn(x): return math_ops.cast(x, dtypes.float64) nums = [1, 2, 3, 4, 5, 6] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunShapeInference(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def fn(x): return x nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] self.assertEqual(result.get_shape(), (3, 2))
def testMapDefunWithWrongOutputShape(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1, )])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunWithDifferentOutputShapeEachRun(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 elems = array_ops.placeholder(dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] with session.Session() as sess: self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) self.assertAllEqual(sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
def testMapDefunMultipleOutputs(self): @function.Defun(dtypes.int32) def fn(x): return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2, ), (2, )]) expected = [elems, elems * 2 + 3] self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunWithCapturedInputs(self): c = constant_op.constant(2) @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): return x + c x = constant_op.constant([1, 2, 3, 4]) map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0] expected = x + c self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
def testMapDefunWithWrongOutputShape(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunWithCapturedInputs(self): c = constant_op.constant(2) @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): return x + c x = constant_op.constant([1, 2, 3, 4]) map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0] expected = x + c self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
def testMapDefunRaisesDefunError(self): @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): with ops.control_dependencies([check_ops.assert_equal(x, 0)]): return array_ops.identity(x) elems = constant_op.constant([0, 0, 0, 37, 0]) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result)
def testMapDefunMismatchedTypes(self): @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): return math_ops.cast(x, dtypes.float64) nums = [1, 2, 3, 4, 5, 6] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunSimple(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunReduceDim(self): # Tests where the output has a different rank from the input @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def fn(x): return array_ops.gather(x, 0) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] expected = constant_op.constant([1, 3, 5]) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunMultipleOutputs(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def fn(x): return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,), (2,)]) expected = [elems, elems * 2 + 3] self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testNoIntraOpLimit(self): @function.defun( input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2, )], max_intra_op_parallelism=0)[0] expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunWithDifferentOutputShapeEachRun(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 elems = array_ops.placeholder(dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] with session.Session() as sess: self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) self.assertAllEqual( sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
def testMapDefunReduceDim(self): # Tests where the output has a different rank from the input @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def fn(x): return array_ops.gather(x, 0) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] expected = constant_op.constant([1, 3, 5]) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunCancelledCorrectly(self): @function.Defun(dtypes.int64) def defun(x): # x has leading dimension 5, this will raise an error return array_ops.gather(x, 10) c = array_ops.tile( array_ops.expand_dims( constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), [100, 1]) map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] with self.assertRaisesRegexp(errors.InvalidArgumentError, r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op)
def testMapDefunCancelledCorrectly(self): @function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)]) def defun(x): # x has leading dimension 5, this will raise an error return array_ops.gather(x, 10) c = array_ops.tile( array_ops.expand_dims( constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), [100, 1]) map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] with self.assertRaisesRegexp(errors.InvalidArgumentError, r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op)
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self): @function.Defun(dtypes.int32, dtypes.int32) def fn(x, y): return x, y elems1 = array_ops.placeholder(dtypes.int32) elems2 = array_ops.placeholder(dtypes.int32) result = map_defun.map_defun(fn, [elems1, elems2], [dtypes.int32, dtypes.int32], [(), ()]) with self.cached_session() as sess: with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "All inputs must have the same dimension 0."): sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
def testMapDefunWithUnspecifiedOutputShape(self): @function.Defun(dtypes.int32) def simple_fn(x): res = x * 2 + 3 return (res, res + 1, res + 2) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32, dtypes.int32, dtypes.int32], [None, (None, ), (2, )]) expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self): @function.Defun(dtypes.int32, dtypes.int32) def fn(x, y): return x, y elems1 = array_ops.placeholder(dtypes.int32) elems2 = array_ops.placeholder(dtypes.int32) result = map_defun.map_defun(fn, [elems1, elems2], [dtypes.int32, dtypes.int32], [(), ()]) with self.cached_session() as sess: with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "All inputs must have the same dimension 0."): sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
def testMapDefunWithUnspecifiedOutputShape(self): @function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)]) def simple_fn(x): res = x * 2 + 3 return (res, res + 1, res + 2) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32, dtypes.int32, dtypes.int32], [None, (None,), (2,)]) expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
def testMapDefunWithVariantTensorAsCaptured(self): st = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant) @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def fn(x): del x return serialized x = constant_op.constant([0, 0]) map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertSparseValuesEqual(expected, actual)
def testMapDefunWithParentCancellation(self): # Checks that a cancellation of the parent graph is threaded through to # MapDefunOp correctly. @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def simple_fn(x): del x queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) # Blocking return queue.dequeue_many(5) c = constant_op.constant([1, 2, 3, 4, 5]) map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] with self.cached_session() as sess: thread = self.checkedThread( self._assert_op_cancelled, args=(sess, map_defun_op)) thread.start() time.sleep(0.2) sess.close() thread.join()
def testMapDefunWithStrTensor(self): @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def fn(x): return x st = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string) serialized = array_ops.stack([serialized, serialized]) map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string], [None])[0] deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32) expected = sparse_tensor.SparseTensorValue( indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]], values=[1, 2, 1, 2], dense_shape=[2, 3, 4]) actual = self.evaluate(deserialized) self.assertSparseValuesEqual(expected, actual)
def benchmarkDefunVsMapFn(self): """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def defun(x): return array_ops.identity(x) def fn(x): return array_ops.identity(x) base = math_ops.range(100) for input_size in [10, 100, 1000, 10000]: num_iters = 100000 // input_size map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) map_fn_op = map_fn.map_fn(fn, base) self._run( map_defun_op, "with_defun_size_%d" % input_size, num_iters=num_iters) self._run( map_fn_op, "without_defun_size_%d" % input_size, num_iters=num_iters)
def testMapDefunWithParentCancellation(self): # Checks that a cancellation of the parent graph is threaded through to # MapDefunOp correctly. @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) def simple_fn(x): del x queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) # Blocking return queue.dequeue_many(5) c = constant_op.constant([1, 2, 3, 4, 5]) map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] with self.cached_session() as sess: thread = self.checkedThread( self._assert_op_cancelled, args=(sess, map_defun_op)) thread.start() time.sleep(0.2) sess.close() thread.join()
def benchmarkDefunVsMapFn(self): """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" @function.Defun(dtypes.int32) def defun(x): return array_ops.identity(x) def map_fn(x): return array_ops.identity(x) base = math_ops.range(100) for input_size in [10, 100, 1000, 10000]: num_iters = 100000 // input_size map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) map_fn_op = functional_ops.map_fn(map_fn, base) self._run( map_defun_op, "benchmarkMapDefun_size_%d" % input_size, num_iters=num_iters) self._run( map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
def benchmarkDefunVsMapFn(self): """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" @function.Defun(dtypes.int32) def defun(x): return array_ops.identity(x) def map_fn(x): return array_ops.identity(x) base = math_ops.range(100) for input_size in [10, 100, 1000, 10000]: num_iters = 100000 // input_size map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) map_fn_op = functional_ops.map_fn(map_fn, base) self._run(map_defun_op, "benchmarkMapDefun_size_%d" % input_size, num_iters=num_iters) self._run(map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)