Example #1
0
File: xla.py Project: romanngg/jax
 def __missing__(self, key):
     translation_tables = [
         _backend_specific_translations[p]
         for p in xb.expand_platform_alias(key)
     ]
     ret = self[key] = _TranslationRuleAdapter(translation_tables,
                                               _wrap_old_translation)
     return ret
Example #2
0
def check_backend_matches(inner_backend, outer_backend):
  # For nested calls, the outermost call sets the backend for all inner calls;
  # it's an error if the inner call has a conflicting explicit backend spec.
  if inner_backend is None:
    return
  if (inner_backend != outer_backend and
      outer_backend not in xb.expand_platform_alias(inner_backend)):
    raise ValueError(
        f"Outer-jit backend specification {outer_backend} must match explicit "
        f"inner-jit backend specification {inner_backend}.")
Example #3
0
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
                         platform: Optional[str] = None) -> None:
  if platform is None:
    _translations[prim] = rule
  else:
    # For backward compatibility reasons, we allow rules to be registered
    # under "gpu" even though the platforms are now called "cuda" and "rocm".
    # TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
    # this expansion.
    for p in xb.expand_platform_alias(platform):
      _backend_specific_translations[p][prim] = rule