Пример #1
0
def _registered_kl(type_a, type_b):
  """Get the KL function registered for classes a and b."""
  hierarchy_a = tf_inspect.getmro(type_a)
  hierarchy_b = tf_inspect.getmro(type_b)
  dist_to_children = None
  kl_fn = None
  for mro_to_a, parent_a in enumerate(hierarchy_a):
    for mro_to_b, parent_b in enumerate(hierarchy_b):
      candidate_dist = mro_to_a + mro_to_b
      candidate_kl_fn = _DIVERGENCES.get((parent_a, parent_b), None)
      if not kl_fn or (candidate_kl_fn and candidate_dist < dist_to_children):
        dist_to_children = candidate_dist
        kl_fn = candidate_kl_fn
  return kl_fn
def _SanitizedMRO(obj):
  """Get a list of superclasses with minimal amount of non-TF classes.

  Based on many parameters like python version, OS, protobuf implementation
  or changes in google core libraries the list of superclasses of a class
  can change. We only return the first non-TF class to be robust to non API
  affecting changes. The Method Resolution Order returned by `tf_inspect.getmro`
  is still maintained in the return value.

  Args:
    obj: A python routine for us the create the sanitized arspec of.

  Returns:
    list of strings, string representation of the class names.
  """
  return_list = []
  for cls in tf_inspect.getmro(obj):
    if cls.__name__ == '_NewClass':
      # Ignore class created by @deprecated_alias decorator.
      continue
    str_repr = _NormalizeType(str(cls))
    return_list.append(str_repr)
    if 'tensorflow' not in str_repr:
      break

    # Hack - tensorflow.test.StubOutForTesting may or may not be type <object>
    # depending on the environment. To avoid inconsistency, break after we add
    # StubOutForTesting to the return_list.
    if 'StubOutForTesting' in str_repr:
      break

  return return_list
Пример #3
0
def _SanitizedMRO(obj):
    """Get a list of superclasses with minimal amount of non-TF classes.

  Based on many parameters like python version, OS, protobuf implementation
  or changes in google core libraries the list of superclasses of a class
  can change. We only return the first non-TF class to be robust to non API
  affecting changes. The Method Resolution Order returned by `tf_inspect.getmro`
  is still maintained in the return value.

  Args:
    obj: A python routine for us the create the sanitized arspec of.

  Returns:
    list of strings, string representation of the class names.
  """
    return_list = []
    for cls in tf_inspect.getmro(obj):
        if cls.__name__ == '_NewClass':
            # Ignore class created by @deprecated_alias decorator.
            continue
        str_repr = _NormalizeType(str(cls))
        return_list.append(str_repr)
        if 'tensorflow' not in str_repr:
            break

        # Hack - tensorflow.test.StubOutForTesting may or may not be type <object>
        # depending on the environment. To avoid inconsistency, break after we add
        # StubOutForTesting to the return_list.
        if 'StubOutForTesting' in str_repr:
            break

    return return_list
Пример #4
0
def is_supported_type_for_deprecation(symbol):
    # Exclude Exception subclasses since users should be able to
    # "except" the same type of exception that was "raised" (i.e. we
    # shouldn't wrap it with deprecation alias).
    # Also, exclude subclasses of namedtuples for now.
    return tf_inspect.isfunction(symbol) or (
        tf_inspect.isclass(symbol) and not issubclass(symbol, Exception)
        and tuple not in tf_inspect.getmro(symbol))
Пример #5
0
def is_supported_type_for_deprecation(symbol):
  # Exclude Exception subclasses since users should be able to
  # "except" the same type of exception that was "raised" (i.e. we
  # shouldn't wrap it with deprecation alias).
  # Also, exclude subclasses of namedtuples for now.
  return tf_inspect.isfunction(symbol) or (
      tf_inspect.isclass(symbol) and not issubclass(symbol, Exception)
      and tuple not in tf_inspect.getmro(symbol))
Пример #6
0
def getdefiningclass(m, owner_class):
  """Resolves the class (e.g. one of the superclasses) that defined a method."""
  m = six.get_unbound_function(m)
  last_defining = owner_class
  for superclass in tf_inspect.getmro(owner_class):
    if hasattr(superclass, m.__name__):
      superclass_m = getattr(superclass, m.__name__)
      if six.get_unbound_function(superclass_m) == m:
        last_defining = superclass
  return last_defining
Пример #7
0
def _is_test_class(obj):
    """Check if arbitrary object is a test class (not a test object!).

  Args:
    obj: An arbitrary object from within a module.

  Returns:
    True iff obj is a test class inheriting at some point from a module
    named "TestCase". This is because we write tests using different underlying
    test libraries.
  """
    return (tf_inspect.isclass(obj)
            and 'TestCase' in (p.__name__ for p in tf_inspect.getmro(obj)))
def _registered_cholesky(type_a):
    """Get the Cholesky function registered for class a."""
    hierarchy_a = tf_inspect.getmro(type_a)
    distance_to_children = None
    cholesky_fn = None
    for mro_to_a, parent_a in enumerate(hierarchy_a):
        candidate_dist = mro_to_a
        candidate_cholesky_fn = _CHOLESKY_DECOMPS.get(parent_a, None)
        if not cholesky_fn or (candidate_cholesky_fn
                               and candidate_dist < distance_to_children):
            distance_to_children = candidate_dist
            cholesky_fn = candidate_cholesky_fn
    return cholesky_fn
Пример #9
0
def _registered_function(type_list, registry):
  """Given a list of classes, finds the most specific function registered."""
  enumerated_hierarchies = [enumerate(tf_inspect.getmro(t)) for t in type_list]
  # Get all possible combinations of hierarchies.
  cls_combinations = list(itertools.product(*enumerated_hierarchies))

  def hierarchy_distance(cls_combination):
    candidate_distance = sum(c[0] for c in cls_combination)
    if tuple(c[1] for c in cls_combination) in registry:
      return candidate_distance
    return 10000

  registered_combination = min(cls_combinations, key=hierarchy_distance)
  return registry.get(tuple(r[1] for r in registered_combination), None)
Пример #10
0
def _registered_function(type_list, registry):
  """Given a list of classes, finds the most specific function registered."""
  enumerated_hierarchies = [enumerate(tf_inspect.getmro(t)) for t in type_list]
  # Get all possible combinations of hierarchies.
  cls_combinations = list(itertools.product(*enumerated_hierarchies))

  def hierarchy_distance(cls_combination):
    candidate_distance = sum(c[0] for c in cls_combination)
    if tuple(c[1] for c in cls_combination) in registry:
      return candidate_distance
    return 10000

  registered_combination = min(cls_combinations, key=hierarchy_distance)
  return registry.get(tuple(r[1] for r in registered_combination), None)
Пример #11
0
def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
    """Wraps an optimizer with a LossScaleOptimizer."""

    if isinstance(opt,
                  loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
        raise ValueError('"opt" must not already be an instance of a '
                         'MixedPrecisionLossScaleOptimizer. '
                         '`enable_mixed_precision_graph_rewrite` will '
                         'automatically wrap the optimizer with a '
                         'MixedPrecisionLossScaleOptimizer.')
    # To avoid a circular dependency, we cannot depend on tf.keras. Because
    # LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check
    # the class name.
    if opt.__class__.__name__ == 'LossScaleOptimizer':
        raise ValueError('"opt" must not already be an instance of a '
                         'LossScaleOptimizer. '
                         '`enable_mixed_precision_graph_rewrite` will '
                         'automatically wrap the optimizer with a '
                         'LossScaleOptimizer.')

    if isinstance(opt, optimizer.Optimizer):
        # For convenience, we allow the V2 version of this function to wrap the V1
        # optimizer, even though we do not document this.
        return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(
            opt, loss_scale)

    # Because we cannot depend on tf.keras, we see if `opt` is an instance of the
    # Keras OptimizerV2 class by checking the subclass names.
    base_classes = tf_inspect.getmro(opt.__class__)
    base_class_names = [cls.__name__ for cls in base_classes]
    is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names

    if is_loss_scale_optimizer_v2:
        # Because we cannot depend on tf.keras, we cannot unconditionally do this
        # import. But since `opt` is a Keras OptimizerV2, we know keras is
        # importable, so it is safe to do this import. (Technically, it's possible
        # to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this
        # is not done in practice).
        from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2  # pylint: disable=g-import-not-at-top
        return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)

    if use_v1_behavior:
        raise ValueError(
            '"opt" must be an instance of a tf.train.Optimizer or a '
            'tf.keras.optimizers.Optimizer, but got: %s' % opt)
    else:
        raise ValueError('"opt" must be an instance of a '
                         'tf.keras.optimizers.Optimizer, but got: %s' % opt)
Пример #12
0
def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
  """Wraps an optimizer with a LossScaleOptimizer."""

  if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
    raise ValueError('"opt" must not already be an instance of a '
                     'MixedPrecisionLossScaleOptimizer. '
                     '`enable_mixed_precision_graph_rewrite` will '
                     'automatically wrap the optimizer with a '
                     'MixedPrecisionLossScaleOptimizer.')
  # To avoid a circular dependency, we cannot depend on tf.keras. Because
  # LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check
  # the class name.
  if opt.__class__.__name__ == 'LossScaleOptimizer':
    raise ValueError('"opt" must not already be an instance of a '
                     'LossScaleOptimizer. '
                     '`enable_mixed_precision_graph_rewrite` will '
                     'automatically wrap the optimizer with a '
                     'LossScaleOptimizer.')

  if isinstance(opt, optimizer.Optimizer):
    # For convenience, we allow the V2 version of this function to wrap the V1
    # optimizer, even though we do not document this.
    return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(opt,
                                                                    loss_scale)

  # Because we cannot depend on tf.keras, we see if `opt` is an instance of the
  # Keras OptimizerV2 class by checking the subclass names.
  base_classes = tf_inspect.getmro(opt.__class__)
  base_class_names = [cls.__name__ for cls in base_classes]
  is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names

  if is_loss_scale_optimizer_v2:
    # Because we cannot depend on tf.keras, we cannot unconditionally do this
    # import. But since `opt` is a Keras OptimizerV2, we know keras is
    # importable, so it is safe to do this import. (Technically, it's possible
    # to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this
    # is not done in practice).
    from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2  # pylint: disable=g-import-not-at-top
    return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)

  if use_v1_behavior:
    raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
                     'tf.keras.optimizers.Optimizer, but got: %s' % opt)
  else:
    raise ValueError('"opt" must be an instance of a '
                     'tf.keras.optimizers.Optimizer, but got: %s' % opt)
Пример #13
0
def _SanitizedMRO(obj):
    """Get a list of superclasses with minimal amount of non-TF classes.

  Based on many parameters like python version, OS, protobuf implementation
  or changes in google core libraries the list of superclasses of a class
  can change. We only return the first non-TF class to be robust to non API
  affecting changes. The Method Resolution Order returned by `tf_inspect.getmro`
  is still maintained in the return value.

  Args:
    obj: A python routine for us the create the sanitized arspec of.

  Returns:
    list of strings, string representation of the class names.
  """
    return_list = []
    for cls in tf_inspect.getmro(obj):
        str_repr = str(cls)
        return_list.append(str_repr)
        if 'tensorflow' not in str_repr:
            break

    return return_list
Пример #14
0
def _get_defining_class(py_class, name):
    for cls in tf_inspect.getmro(py_class):
        if name in cls.__dict__:
            return cls
    return None
Пример #15
0
    def SmartSet(self, obj, attr_name, new_attr):
        """Replace obj.attr_name with new_attr.

    This method is smart and works at the module, class, and instance level
    while preserving proper inheritance. It will not stub out C types however
    unless that has been explicitly allowed by the type.

    This method supports the case where attr_name is a staticmethod or a
    classmethod of obj.

    Notes:
      - If obj is an instance, then it is its class that will actually be
        stubbed. Note that the method Set() does not do that: if obj is
        an instance, it (and not its class) will be stubbed.
      - The stubbing is using the builtin getattr and setattr. So, the __get__
        and __set__ will be called when stubbing (TODO: A better idea would
        probably be to manipulate obj.__dict__ instead of getattr() and
        setattr()).

    Args:
      obj: The object whose attributes we want to modify.
      attr_name: The name of the attribute to modify.
      new_attr: The new value for the attribute.

    Raises:
      AttributeError: If the attribute cannot be found.
    """
        _, obj = tf_decorator.unwrap(obj)
        if (tf_inspect.ismodule(obj) or
            (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)):
            orig_obj = obj
            orig_attr = getattr(obj, attr_name)
        else:
            if not tf_inspect.isclass(obj):
                mro = list(tf_inspect.getmro(obj.__class__))
            else:
                mro = list(tf_inspect.getmro(obj))

            mro.reverse()

            orig_attr = None
            found_attr = False

            for cls in mro:
                try:
                    orig_obj = cls
                    orig_attr = getattr(obj, attr_name)
                    found_attr = True
                except AttributeError:
                    continue

            if not found_attr:
                raise AttributeError('Attribute not found.')

        # Calling getattr() on a staticmethod transforms it to a 'normal' function.
        # We need to ensure that we put it back as a staticmethod.
        old_attribute = obj.__dict__.get(attr_name)
        if old_attribute is not None and isinstance(old_attribute,
                                                    staticmethod):
            orig_attr = staticmethod(orig_attr)

        self.stubs.append((orig_obj, attr_name, orig_attr))
        setattr(orig_obj, attr_name, new_attr)
Пример #16
0
  def SmartSet(self, obj, attr_name, new_attr):
    """Replace obj.attr_name with new_attr.

    This method is smart and works at the module, class, and instance level
    while preserving proper inheritance. It will not stub out C types however
    unless that has been explicitly allowed by the type.

    This method supports the case where attr_name is a staticmethod or a
    classmethod of obj.

    Notes:
      - If obj is an instance, then it is its class that will actually be
        stubbed. Note that the method Set() does not do that: if obj is
        an instance, it (and not its class) will be stubbed.
      - The stubbing is using the builtin getattr and setattr. So, the __get__
        and __set__ will be called when stubbing (TODO: A better idea would
        probably be to manipulate obj.__dict__ instead of getattr() and
        setattr()).

    Args:
      obj: The object whose attributes we want to modify.
      attr_name: The name of the attribute to modify.
      new_attr: The new value for the attribute.

    Raises:
      AttributeError: If the attribute cannot be found.
    """
    _, obj = tf_decorator.unwrap(obj)
    if (tf_inspect.ismodule(obj) or
        (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)):
      orig_obj = obj
      orig_attr = getattr(obj, attr_name)
    else:
      if not tf_inspect.isclass(obj):
        mro = list(tf_inspect.getmro(obj.__class__))
      else:
        mro = list(tf_inspect.getmro(obj))

      mro.reverse()

      orig_attr = None
      found_attr = False

      for cls in mro:
        try:
          orig_obj = cls
          orig_attr = getattr(obj, attr_name)
          found_attr = True
        except AttributeError:
          continue

      if not found_attr:
        raise AttributeError('Attribute not found.')

    # Calling getattr() on a staticmethod transforms it to a 'normal' function.
    # We need to ensure that we put it back as a staticmethod.
    old_attribute = obj.__dict__.get(attr_name)
    if old_attribute is not None and isinstance(old_attribute, staticmethod):
      orig_attr = staticmethod(orig_attr)

    self.stubs.append((orig_obj, attr_name, orig_attr))
    setattr(orig_obj, attr_name, new_attr)
Пример #17
0
def _get_defining_class(py_class, name):
  for cls in tf_inspect.getmro(py_class):
    if name in cls.__dict__:
      return cls
  return None