def testMapDefunWithInvalidInput(self):
        @function.Defun(dtypes.int32)
        def simple_fn(x):
            return x * 2

        c = constant_op.constant(2)
        with self.assertRaises(ValueError):
            # Fails at graph construction time for inputs with known shapes.
            r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
        p = array_ops.placeholder(dtypes.int32)
        r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
        with session.Session() as sess:
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(r, feed_dict={p: 0})
Exemplo n.º 2
0
  def testMapDefunWithInvalidInput(self):

    @function.Defun(dtypes.int32)
    def simple_fn(x):
      return x * 2

    c = constant_op.constant(2)
    with self.assertRaises(ValueError):
      # Fails at graph construction time for inputs with known shapes.
      r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
    p = array_ops.placeholder(dtypes.int32)
    r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
    with session.Session() as sess:
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(r, feed_dict={p: 0})
Exemplo n.º 3
0
    def testMapDefunPartialShapeInference(self):
        @function.Defun(dtypes.int32)
        def fn(x):
            return x

        elems = array_ops.placeholder(dtypes.int64, (None, 2))
        result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2, )])
        self.assertEqual(result[0].get_shape().as_list(), [None, 2])
Exemplo n.º 4
0
    def testMapDefunShapeInference(self):
        @function.Defun(dtypes.int32)
        def fn(x):
            return x

        nums = [[1, 2], [3, 4], [5, 6]]
        elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
        result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2, )])[0]
        self.assertEqual(result.get_shape(), (3, 2))
Exemplo n.º 5
0
  def testMapDefun_PartialShapeInference(self):

    @function.Defun(dtypes.int32)
    def fn(x):
      return x

    elems = array_ops.placeholder(dtypes.int64, (None, 2))
    result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
    self.assertEqual(result[0].get_shape().as_list(), [None, 2])
Exemplo n.º 6
0
    def testMapDefunMismatchedTypes(self):
        @function.Defun(dtypes.int32)
        def fn(x):
            return math_ops.cast(x, dtypes.float64)

        nums = [1, 2, 3, 4, 5, 6]
        elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
        r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(r)
Exemplo n.º 7
0
    def testMapDefunRaisesDefunError(self):
        @function.Defun(dtypes.int32)
        def fn(x):
            with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
                return array_ops.identity(x)

        elems = constant_op.constant([0, 0, 0, 37, 0])
        result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(result)
Exemplo n.º 8
0
  def testMapDefun_ShapeInference(self):

    @function.Defun(dtypes.int32)
    def fn(x):
      return x

    nums = [[1, 2], [3, 4], [5, 6]]
    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
    result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
    self.assertEqual(result.get_shape(), (3, 2))
    def testMapDefunWithWrongOutputShape(self):
        @function.Defun(dtypes.int32)
        def simple_fn(x):
            return x * 2 + 3

        nums = [[1, 2], [3, 4], [5, 6]]
        elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
        r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1, )])[0]
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(r)
Exemplo n.º 10
0
    def testMapDefunSimple(self):
        @function.Defun(dtypes.int32)
        def simple_fn(x):
            return x * 2 + 3

        nums = [[1, 2], [3, 4], [5, 6]]
        elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
        r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2, )])[0]
        expected = elems * 2 + 3
        self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
Exemplo n.º 11
0
    def testMapDefunMultipleOutputs(self):
        @function.Defun(dtypes.int32)
        def fn(x):
            return (x, math_ops.cast(x * 2 + 3, dtypes.float64))

        nums = [[1, 2], [3, 4], [5, 6]]
        elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
        r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64],
                                [(2, ), (2, )])
        expected = [elems, elems * 2 + 3]
        self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
Exemplo n.º 12
0
  def testMapDefunSimple(self):

    @function.Defun(dtypes.int32)
    def simple_fn(x):
      return x * 2 + 3

    nums = [[1, 2], [3, 4], [5, 6]]
    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
    r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
    expected = elems * 2 + 3
    self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
Exemplo n.º 13
0
  def testMapDefunWithWrongOutputShape(self):

    @function.Defun(dtypes.int32)
    def simple_fn(x):
      return x * 2 + 3

    nums = [[1, 2], [3, 4], [5, 6]]
    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
    r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(r)
Exemplo n.º 14
0
  def testMapDefunMismatchedTypes(self):

    @function.Defun(dtypes.int32)
    def fn(x):
      return math_ops.cast(x, dtypes.float64)

    nums = [1, 2, 3, 4, 5, 6]
    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
    r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(r)
Exemplo n.º 15
0
    def testMapDefunWithDifferentOutputShapeEachRun(self):
        @function.Defun(dtypes.int32)
        def simple_fn(x):
            return x * 2 + 3

        elems = array_ops.placeholder(dtypes.int32, name="data")
        r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
        with session.Session() as sess:
            self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
            self.assertAllEqual(sess.run(r, feed_dict={elems: [[0], [1]]}),
                                [[3], [5]])
Exemplo n.º 16
0
  def testMapDefunRaisesDefunError(self):

    @function.Defun(dtypes.int32)
    def fn(x):
      with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
        return array_ops.identity(x)

    elems = constant_op.constant([0, 0, 0, 37, 0])
    result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(result)
Exemplo n.º 17
0
  def testMapDefunWithDifferentOutputShapeEachRun(self):

    @function.Defun(dtypes.int32)
    def simple_fn(x):
      return x * 2 + 3

    elems = array_ops.placeholder(dtypes.int32, name="data")
    r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
    with session.Session() as sess:
      self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
      self.assertAllEqual(
          sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
Exemplo n.º 18
0
  def testMapDefunMultipleOutputs(self):

    @function.Defun(dtypes.int32)
    def fn(x):
      return (x, math_ops.cast(x * 2 + 3, dtypes.float64))

    nums = [[1, 2], [3, 4], [5, 6]]
    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
    r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
                                                                          (2,)])
    expected = [elems, elems * 2 + 3]
    self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
Exemplo n.º 19
0
    def testMapDefunReduceDim(self):
        # Tests where the output has a different rank from the input

        @function.Defun(dtypes.int32)
        def fn(x):
            return array_ops.gather(x, 0)

        nums = [[1, 2], [3, 4], [5, 6]]
        elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
        r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
        expected = constant_op.constant([1, 3, 5])
        self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
Exemplo n.º 20
0
  def testMapDefunReduceDim(self):
    # Tests where the output has a different rank from the input

    @function.Defun(dtypes.int32)
    def fn(x):
      return array_ops.gather(x, 0)

    nums = [[1, 2], [3, 4], [5, 6]]
    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
    r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
    expected = constant_op.constant([1, 3, 5])
    self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
Exemplo n.º 21
0
    def testMapDefunCancelledCorrectly(self):
        @function.Defun(dtypes.int64)
        def defun(x):
            # x has leading dimension 5, this will raise an error
            return array_ops.gather(x, 10)

        c = array_ops.tile(
            array_ops.expand_dims(
                constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
            [100, 1])
        map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     r"indices = 10 is not in \[0, 5\)"):
            self.evaluate(map_defun_op)
Exemplo n.º 22
0
  def testMapDefunCancelledCorrectly(self):

    @function.Defun(dtypes.int64)
    def defun(x):
      # x has leading dimension 5, this will raise an error
      return array_ops.gather(x, 10)

    c = array_ops.tile(
        array_ops.expand_dims(
            constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
        [100, 1])
    map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 r"indices = 10 is not in \[0, 5\)"):
      self.evaluate(map_defun_op)
Exemplo n.º 23
0
  def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):

    @function.Defun(dtypes.int32, dtypes.int32)
    def fn(x, y):
      return x, y

    elems1 = array_ops.placeholder(dtypes.int32)
    elems2 = array_ops.placeholder(dtypes.int32)
    result = map_defun.map_defun(fn, [elems1, elems2],
                                 [dtypes.int32, dtypes.int32], [(), ()])
    with self.cached_session() as sess:
      with self.assertRaisesWithPredicateMatch(
          errors.InvalidArgumentError,
          "All inputs must have the same dimension 0."):
        sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
Exemplo n.º 24
0
  def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self):

    @function.Defun(dtypes.int32, dtypes.int32)
    def fn(x, y):
      return x, y

    elems1 = array_ops.placeholder(dtypes.int32)
    elems2 = array_ops.placeholder(dtypes.int32)
    result = map_defun.map_defun(fn, [elems1, elems2],
                                 [dtypes.int32, dtypes.int32], [(), ()])
    with self.test_session() as sess:
      with self.assertRaisesWithPredicateMatch(
          errors.InvalidArgumentError,
          "All inputs must have the same dimension 0."):
        sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
Exemplo n.º 25
0
    def testMapDefunWithUnspecifiedOutputShape(self):
        @function.Defun(dtypes.int32)
        def simple_fn(x):
            res = x * 2 + 3
            return (res, res + 1, res + 2)

        nums = [[1, 2], [3, 4], [5, 6]]
        elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
        r = map_defun.map_defun(simple_fn, [elems],
                                [dtypes.int32, dtypes.int32, dtypes.int32],
                                [None, (None, ), (2, )])
        expected = elems * 2 + 3
        self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
        self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
        self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
Exemplo n.º 26
0
  def testMapDefunWithUnspecifiedOutputShape(self):

    @function.Defun(dtypes.int32)
    def simple_fn(x):
      res = x * 2 + 3
      return (res, res + 1, res + 2)

    nums = [[1, 2], [3, 4], [5, 6]]
    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
    r = map_defun.map_defun(simple_fn, [elems],
                            [dtypes.int32, dtypes.int32, dtypes.int32],
                            [None, (None,), (2,)])
    expected = elems * 2 + 3
    self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
    self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
    self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
Exemplo n.º 27
0
  def testMapDefunWithParentCancellation(self):
    # Checks that a cancellation of the parent graph is threaded through to
    # MapDefunOp correctly.
    @function.Defun(dtypes.int32)
    def simple_fn(x):
      del x
      queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
      # Blocking
      return queue.dequeue_many(5)

    c = constant_op.constant([1, 2, 3, 4, 5])
    map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]

    with self.cached_session() as sess:
      thread = self.checkedThread(
          self._assert_op_cancelled, args=(sess, map_defun_op))
      thread.start()
      time.sleep(0.1)
      sess.close()
      thread.join()
Exemplo n.º 28
0
  def testMapDefunWithParentCancellation(self):
    # Checks that a cancellation of the parent graph is threaded through to
    # MapDefunOp correctly.
    @function.Defun(dtypes.int32)
    def simple_fn(x):
      del x
      queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
      # Blocking
      return queue.dequeue_many(5)

    c = constant_op.constant([1, 2, 3, 4, 5])
    map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]

    with self.cached_session() as sess:
      thread = self.checkedThread(
          self._assert_op_cancelled, args=(sess, map_defun_op))
      thread.start()
      time.sleep(0.1)
      sess.close()
      thread.join()
Exemplo n.º 29
0
  def benchmarkDefunVsMapFn(self):
    """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""

    @function.Defun(dtypes.int32)
    def defun(x):
      return array_ops.identity(x)

    def map_fn(x):
      return array_ops.identity(x)

    base = math_ops.range(100)
    for input_size in [10, 100, 1000, 10000]:
      num_iters = 100000 // input_size
      map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
      map_fn_op = functional_ops.map_fn(map_fn, base)

      self._run(
          map_defun_op,
          "benchmarkMapDefun_size_%d" % input_size,
          num_iters=num_iters)
      self._run(
          map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
Exemplo n.º 30
0
  def benchmarkDefunVsMapFn(self):
    """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""

    @function.Defun(dtypes.int32)
    def defun(x):
      return array_ops.identity(x)

    def map_fn(x):
      return array_ops.identity(x)

    base = math_ops.range(100)
    for input_size in [10, 100, 1000, 10000]:
      num_iters = 100000 // input_size
      map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
      map_fn_op = functional_ops.map_fn(map_fn, base)

      self._run(
          map_defun_op,
          "benchmarkMapDefun_size_%d" % input_size,
          num_iters=num_iters)
      self._run(
          map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)