def test_side_effect_on_tensor(self):
        def test_fn(a):
            tf.Assert(a > 0, ['expected in throw'])
            return a

        node, ctx = self.prepare(test_fn, {})
        node = side_effect_guards.transform(node, ctx)

        self.assertEqual(len(node.body), 1)

        with self.compiled(node, {}, (control_flow_ops.Assert, )) as result:
            with self.cached_session() as sess:
                with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                             'expected in throw'):
                    sess.run(result.test_fn(constant_op.constant(-1)))
  def test_side_effect_on_tensor(self):

    def test_fn(a):
      tf.Assert(a > 0, ['expected in throw'])
      return a

    node, ctx = self.prepare(test_fn, {})
    node = side_effect_guards.transform(node, ctx)

    self.assertEqual(len(node.body), 1)

    with self.compiled(node, {}, control_flow_ops.Assert) as result:
      with self.cached_session() as sess:
        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                     'expected in throw'):
          sess.run(result.test_fn(constant_op.constant(-1)))
    def test_side_effect_on_used_variable(self):
        def test_fn(a):
            tf.assign(a, a + 1)
            return a + 1

        node, ctx = self.prepare(test_fn, {})
        node = side_effect_guards.transform(node, ctx)

        self.assertEqual(len(node.body), 1)

        with self.compiled(node, {}, (state_ops.assign, )) as result:
            with self.cached_session() as sess:
                v = variable_scope.get_variable('test', initializer=2)
                self.evaluate(v.initializer)
                self.evaluate(result.test_fn(v))
                # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
                # Right now it's 3 or 4 based on whether the read is synchronized.
                self.assertEqual(3, self.evaluate(v))
    def test_side_effect_on_return_only_variable(self):
        def test_fn(a):
            tf.assign(a, a + 1)
            return a

        node, ctx = self.prepare(test_fn, {})
        node = side_effect_guards.transform(node, ctx)

        self.assertEqual(len(node.body), 1)

        with self.compiled(node, {}, (state_ops.assign, )) as result:
            with self.cached_session() as sess:
                v = variable_scope.get_variable('test', initializer=2)
                self.evaluate(v.initializer)
                self.evaluate(result.test_fn(v))
                # TODO(mdan): Add support for this use case.
                # Right now the variable `a` is not conditioned on the `assign` because
                # there's no way to add control dependencies to a variable object.
                self.assertEqual(2, self.evaluate(v))
  def test_side_effect_on_used_variable(self):

    def test_fn(a):
      tf.assign(a, a + 1)
      return a + 1

    node, ctx = self.prepare(test_fn, {})
    node = side_effect_guards.transform(node, ctx)

    self.assertEqual(len(node.body), 1)

    with self.compiled(node, {}, state_ops.assign) as result:
      with self.cached_session() as sess:
        v = variable_scope.get_variable('test', initializer=2)
        self.evaluate(v.initializer)
        sess.run(result.test_fn(v))
        # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
        # Right now it's 3 or 4 based on whether the read is synchronized.
        self.assertEqual(3, self.evaluate(v))
    def test_multiline_block(self):
        def test_fn(a):
            tf.assign_add(a, 1)
            b = a + 1
            tf.assign_add(a, 1)
            b += 1
            return b

        node, ctx = self.prepare(test_fn, {})
        node = side_effect_guards.transform(node, ctx)

        self.assertEqual(len(node.body), 1)

        with self.compiled(node, {}, (state_ops.assign_add, )) as result:
            with self.cached_session() as sess:
                v = variable_scope.get_variable('test', initializer=2)
                self.evaluate(v.initializer)
                self.evaluate(result.test_fn(v))
                # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
                self.assertEqual(4, self.evaluate(v))
  def test_side_effect_on_return_only_variable(self):

    def test_fn(a):
      tf.assign(a, a + 1)
      return a

    node, ctx = self.prepare(test_fn, {})
    node = side_effect_guards.transform(node, ctx)

    self.assertEqual(len(node.body), 1)

    with self.compiled(node, {}, state_ops.assign) as result:
      with self.cached_session() as sess:
        v = variable_scope.get_variable('test', initializer=2)
        self.evaluate(v.initializer)
        sess.run(result.test_fn(v))
        # TODO(mdan): Add support for this use case.
        # Right now the variable `a` is not conditioned on the `assign` because
        # there's no way to add control dependencies to a variable object.
        self.assertEqual(2, self.evaluate(v))
  def test_multiline_nested_block(self):

    def test_fn(a):
      with tf.name_scope('foo'):
        tf.assign(a, a + 1)
        b = a + 1
      return b

    node, ctx = self.prepare(test_fn, {})
    node = side_effect_guards.transform(node, ctx)

    self.assertEqual(len(node.body[0].body), 1)

    with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
      with self.cached_session() as sess:
        v = variable_scope.get_variable('test', initializer=2)
        self.evaluate(v.initializer)
        sess.run(result.test_fn(v))
        # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
        self.assertEqual(3, self.evaluate(v))
    def test_multiline_nested_block(self):
        def test_fn(a):
            with tf.name_scope('foo'):
                tf.assign(a, a + 1)
                b = a + 1
            return b

        node, ctx = self.prepare(test_fn, {})
        node = side_effect_guards.transform(node, ctx)

        self.assertEqual(len(node.body[0].body), 1)

        with self.compiled(node, {}, state_ops.assign,
                           ops.name_scope) as result:
            with self.cached_session() as sess:
                v = variable_scope.get_variable('test', initializer=2)
                sess.run(v.initializer)
                sess.run(result.test_fn(v))
                # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
                self.assertEqual(3, sess.run(v))
  def test_multiline_block_unsafe(self):

    def test_fn(a):
      tf.assign(a, a + 1)
      b = a + 1
      tf.assign_add(a, 1)
      c = b + 1
      return c

    node, ctx = self.prepare(test_fn, {})
    node = side_effect_guards.transform(node, ctx)

    self.assertEqual(len(node.body), 1)

    with self.compiled(node, {}, state_ops.assign,
                       state_ops.assign_add) as result:
      with self.cached_session() as sess:
        v = variable_scope.get_variable('test', initializer=2)
        sess.run(v.initializer)
        sess.run(result.test_fn(v))
        # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
        self.assertEqual(4, sess.run(v))