コード例 #1
0
def class_to_graph(c, conversion_map):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.')

  class_globals = None
  for _, m in members:
    node, _ = function_to_graph(
        m,
        conversion_map=conversion_map,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    # TODO(mdan): Do not assume all members have the same view of globals.
    if class_globals is None:
      class_globals = six.get_function_globals(m)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_globals)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=converted_members.values(),
      decorator_list=[])

  return node, class_name
コード例 #2
0
def class_to_graph(c, conversion_map, param_value_hints):
  """Specialization of `object_to_graph` for classes."""
  converted_members = {}
  members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.')

  if 'self' in param_value_hints:
    raise ValueError('Hints may not be provided for reserved name "self".')
  param_value_hints['self'] = (c.__name__, c)

  class_globals = None
  for _, m in members:
    node, _ = function_to_graph(m, conversion_map, param_value_hints, c)
    # TODO(mdan): Do not assume all members have the same view of globals.
    if class_globals is None:
      class_globals = six.get_function_globals(m)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_globals)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=converted_members.values(),
      decorator_list=[])

  return node, class_name
コード例 #3
0
ファイル: conversion.py プロジェクト: zxhx/tensorflow
def class_to_graph(c, conversion_map):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        node, _, namespace = function_to_graph(
            m,
            conversion_map=conversion_map,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node
    namer = conversion_map.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)
    node = gast.ClassDef(class_name,
                         bases=[],
                         keywords=[],
                         body=list(converted_members.values()),
                         decorator_list=[])

    return node, class_name, class_namespace
コード例 #4
0
    def visit_ClassDef(self, node):
        new_node = gast.ClassDef(
            self._visit(node.name),
            self._visit(node.bases),
            [],  # keywords
            self._visit(node.body),
            self._visit(node.decorator_list),
        )

        ast.copy_location(new_node, node)
        return new_node
コード例 #5
0
    def visit_ClassDef(self, node):
        new_node = gast.ClassDef(
            self._visit(node.name),
            self._visit(node.bases),
            [],  # keywords
            self._visit(node.body),
            self._visit(node.decorator_list),
        )

        gast.copy_location(new_node, node)
        new_node.end_lineno = new_node.end_col_offset = None
        return new_node
コード例 #6
0
ファイル: conversion.py プロジェクト: wdma/tensorflow
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node[0]
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
コード例 #7
0
def convert_class_to_ast(c, program_ctx):
    """Specialization of `convert_entity_to_ast` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('cannot convert %s: no member methods' % c)

    # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
    # The assumption that one namespace suffices for all methods only holds if
    # all methods were defined in the same module.
    # If, instead, functions are imported from multiple modules and then spliced
    # into the class, then each function has its own globals and __future__
    # imports that need to stay separate.

    # For example, C's methods could both have `global x` statements referring to
    # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
    # from mod1 import f1
    # from mod2 import f2
    # class C(object):
    #   method1 = f1
    #   method2 = f2

    class_namespace = {}
    future_features = None
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        (node, ), _, entity_info = convert_func_to_ast(m,
                                                       program_ctx=program_ctx,
                                                       do_rename=False)
        class_namespace.update(entity_info.namespace)
        converted_members[m] = node

        # TODO(mdan): Similarly check the globals.
        if future_features is None:
            future_features = entity_info.future_features
        elif frozenset(future_features) ^ frozenset(
                entity_info.future_features):
            # Note: we can support this case if ever needed.
            raise ValueError(
                'cannot convert {}: if has methods built with mismatched future'
                ' features: {} and {}'.format(c, future_features,
                                              entity_info.future_features))
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    # TODO(mdan): Find a way better than forging this object.
    entity_info = transformer.EntityInfo(source_code=None,
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=class_namespace)

    return output_nodes, class_name, entity_info
コード例 #8
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        nodes, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            do_rename=False)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = nodes[0]
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace