Exemplo n.º 1
0
    def test_unbound_instance_cache(self):
        class TestClass(object):
            def method(self):
                pass

        c = cache.UnboundInstanceCache()

        o1 = TestClass()
        dummy = object()

        c[o1.method][1] = dummy

        self.assertTrue(c.has(o1.method, 1))
        self.assertFalse(c.has(o1.method, 2))
        self.assertIs(c[o1.method][1], dummy)
        self.assertEqual(len(c), 1)

        o2 = TestClass()

        self.assertTrue(c.has(o2.method, 1))
        self.assertIs(c[o2.method][1], dummy)
        self.assertEqual(len(c), 1)
Exemplo n.º 2
0
from __future__ import division
from __future__ import print_function

import functools
import inspect
import sys
import unittest

from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.eager import function
from tensorflow.python.util import tf_inspect

_ALLOWLIST_CACHE = cache.UnboundInstanceCache()


def _is_of_known_loaded_module(f, module_name):
    mod = sys.modules.get(module_name, None)
    if mod is None:
        return False
    if any(v is not None for v in mod.__dict__.values() if f is v):
        return True
    return False


def _is_known_loaded_type(f, module_name, entity_name):
    """Tests whether the function or method is an instance of a known type."""
    if (module_name not in sys.modules
            or not hasattr(sys.modules[module_name], entity_name)):
Exemplo n.º 3
0
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        return node


_TRANSPILER = AutoGraphTranspiler()
_WHITELIST_CACHE = cache.UnboundInstanceCache()

custom_vars = None


# TODO(mdan): Superfluous function, remove.
# TODO(mdan): Put these extra fields inside __autograph_info__.
def convert(entity, program_ctx):
    """Applies AutoGraph to entity."""

    if not hasattr(entity, '__code__'):
        raise ValueError('Cannot apply autograph to a function that doesn\'t '
                         'expose a __code__ object. If this is a @tf.function,'
                         ' try passing f.python_function instead.')

    _create_custom_vars(program_ctx)