def testMapDefunWithInvalidInput(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 c = constant_op.constant(2) with self.assertRaises(ValueError): # Fails at graph construction time for inputs with known shapes. r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] p = array_ops.placeholder(dtypes.int32) r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] with session.Session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(r, feed_dict={p: 0})
def testMapDefunWithInvalidInput(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 c = constant_op.constant(2) with self.assertRaises(ValueError): # Fails at graph construction time for inputs with known shapes. r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] p = array_ops.placeholder(dtypes.int32) r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] with session.Session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(r, feed_dict={p: 0})
def testMapDefunPartialShapeInference(self): @function.Defun(dtypes.int32) def fn(x): return x elems = array_ops.placeholder(dtypes.int64, (None, 2)) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2, )]) self.assertEqual(result[0].get_shape().as_list(), [None, 2])
def testMapDefunShapeInference(self): @function.Defun(dtypes.int32) def fn(x): return x nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2, )])[0] self.assertEqual(result.get_shape(), (3, 2))
def testMapDefun_PartialShapeInference(self): @function.Defun(dtypes.int32) def fn(x): return x elems = array_ops.placeholder(dtypes.int64, (None, 2)) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) self.assertEqual(result[0].get_shape().as_list(), [None, 2])
def testMapDefunMismatchedTypes(self): @function.Defun(dtypes.int32) def fn(x): return math_ops.cast(x, dtypes.float64) nums = [1, 2, 3, 4, 5, 6] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunRaisesDefunError(self): @function.Defun(dtypes.int32) def fn(x): with ops.control_dependencies([check_ops.assert_equal(x, 0)]): return array_ops.identity(x) elems = constant_op.constant([0, 0, 0, 37, 0]) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result)
def testMapDefun_ShapeInference(self): @function.Defun(dtypes.int32) def fn(x): return x nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] self.assertEqual(result.get_shape(), (3, 2))
def testMapDefunWithWrongOutputShape(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1, )])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunSimple(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2, )])[0] expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunMultipleOutputs(self): @function.Defun(dtypes.int32) def fn(x): return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2, ), (2, )]) expected = [elems, elems * 2 + 3] self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunSimple(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunWithWrongOutputShape(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunMismatchedTypes(self): @function.Defun(dtypes.int32) def fn(x): return math_ops.cast(x, dtypes.float64) nums = [1, 2, 3, 4, 5, 6] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] with self.assertRaises(errors.InvalidArgumentError): self.evaluate(r)
def testMapDefunWithDifferentOutputShapeEachRun(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 elems = array_ops.placeholder(dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] with session.Session() as sess: self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) self.assertAllEqual(sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
def testMapDefunRaisesDefunError(self): @function.Defun(dtypes.int32) def fn(x): with ops.control_dependencies([check_ops.assert_equal(x, 0)]): return array_ops.identity(x) elems = constant_op.constant([0, 0, 0, 37, 0]) result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result)
def testMapDefunWithDifferentOutputShapeEachRun(self): @function.Defun(dtypes.int32) def simple_fn(x): return x * 2 + 3 elems = array_ops.placeholder(dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] with session.Session() as sess: self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) self.assertAllEqual( sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
def testMapDefunMultipleOutputs(self): @function.Defun(dtypes.int32) def fn(x): return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,), (2,)]) expected = [elems, elems * 2 + 3] self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunReduceDim(self): # Tests where the output has a different rank from the input @function.Defun(dtypes.int32) def fn(x): return array_ops.gather(x, 0) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] expected = constant_op.constant([1, 3, 5]) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunReduceDim(self): # Tests where the output has a different rank from the input @function.Defun(dtypes.int32) def fn(x): return array_ops.gather(x, 0) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] expected = constant_op.constant([1, 3, 5]) self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
def testMapDefunCancelledCorrectly(self): @function.Defun(dtypes.int64) def defun(x): # x has leading dimension 5, this will raise an error return array_ops.gather(x, 10) c = array_ops.tile( array_ops.expand_dims( constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), [100, 1]) map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] with self.assertRaisesRegexp(errors.InvalidArgumentError, r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op)
def testMapDefunCancelledCorrectly(self): @function.Defun(dtypes.int64) def defun(x): # x has leading dimension 5, this will raise an error return array_ops.gather(x, 10) c = array_ops.tile( array_ops.expand_dims( constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), [100, 1]) map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] with self.assertRaisesRegexp(errors.InvalidArgumentError, r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op)
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self): @function.Defun(dtypes.int32, dtypes.int32) def fn(x, y): return x, y elems1 = array_ops.placeholder(dtypes.int32) elems2 = array_ops.placeholder(dtypes.int32) result = map_defun.map_defun(fn, [elems1, elems2], [dtypes.int32, dtypes.int32], [(), ()]) with self.cached_session() as sess: with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "All inputs must have the same dimension 0."): sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self): @function.Defun(dtypes.int32, dtypes.int32) def fn(x, y): return x, y elems1 = array_ops.placeholder(dtypes.int32) elems2 = array_ops.placeholder(dtypes.int32) result = map_defun.map_defun(fn, [elems1, elems2], [dtypes.int32, dtypes.int32], [(), ()]) with self.test_session() as sess: with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "All inputs must have the same dimension 0."): sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
def testMapDefunWithUnspecifiedOutputShape(self): @function.Defun(dtypes.int32) def simple_fn(x): res = x * 2 + 3 return (res, res + 1, res + 2) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32, dtypes.int32, dtypes.int32], [None, (None, ), (2, )]) expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
def testMapDefunWithUnspecifiedOutputShape(self): @function.Defun(dtypes.int32) def simple_fn(x): res = x * 2 + 3 return (res, res + 1, res + 2) nums = [[1, 2], [3, 4], [5, 6]] elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32, dtypes.int32, dtypes.int32], [None, (None,), (2,)]) expected = elems * 2 + 3 self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
def testMapDefunWithParentCancellation(self): # Checks that a cancellation of the parent graph is threaded through to # MapDefunOp correctly. @function.Defun(dtypes.int32) def simple_fn(x): del x queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) # Blocking return queue.dequeue_many(5) c = constant_op.constant([1, 2, 3, 4, 5]) map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] with self.cached_session() as sess: thread = self.checkedThread( self._assert_op_cancelled, args=(sess, map_defun_op)) thread.start() time.sleep(0.1) sess.close() thread.join()
def testMapDefunWithParentCancellation(self): # Checks that a cancellation of the parent graph is threaded through to # MapDefunOp correctly. @function.Defun(dtypes.int32) def simple_fn(x): del x queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) # Blocking return queue.dequeue_many(5) c = constant_op.constant([1, 2, 3, 4, 5]) map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] with self.cached_session() as sess: thread = self.checkedThread( self._assert_op_cancelled, args=(sess, map_defun_op)) thread.start() time.sleep(0.1) sess.close() thread.join()
def benchmarkDefunVsMapFn(self): """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" @function.Defun(dtypes.int32) def defun(x): return array_ops.identity(x) def map_fn(x): return array_ops.identity(x) base = math_ops.range(100) for input_size in [10, 100, 1000, 10000]: num_iters = 100000 // input_size map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) map_fn_op = functional_ops.map_fn(map_fn, base) self._run( map_defun_op, "benchmarkMapDefun_size_%d" % input_size, num_iters=num_iters) self._run( map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
def benchmarkDefunVsMapFn(self): """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" @function.Defun(dtypes.int32) def defun(x): return array_ops.identity(x) def map_fn(x): return array_ops.identity(x) base = math_ops.range(100) for input_size in [10, 100, 1000, 10000]: num_iters = 100000 // input_size map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) map_fn_op = functional_ops.map_fn(map_fn, base) self._run( map_defun_op, "benchmarkMapDefun_size_%d" % input_size, num_iters=num_iters) self._run( map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)