Exemple #1
0
def _make_harness(group_name: str,
                  name: str,
                  func: Callable,
                  args: primitive_harness.ArgDescriptor,
                  *,
                  poly_axes: Sequence[Optional[int]],
                  check_result=True,
                  tol=None,
                  **params) -> Harness:
    """The `poly_axes` must correspond to the non-static arguments, and for each
  one it must specify which axes are: None, or an int.

  The name of the harness within the group will include `poly_axes`.
  You can add an additional `name`.

  `check_result` specifies if we want to check that the result of the shape
  polymorphic conversion produces the same result and the JAX function.
  """
    poly_axes_name = "poly_axes=" + "_".join(map(str, poly_axes))
    if name:
        name = f"{name}_{poly_axes_name}"
    else:
        name = poly_axes_name
    return Harness(group_name,
                   name,
                   func,
                   args,
                   dtype=np.float32,
                   poly_axes=poly_axes,
                   check_result=check_result,
                   tol=tol,
                   **params)
def _make_harness(group_name: str, name: str,
                  func: Callable,
                  args: primitive_harness.ArgDescriptor,
                  *,
                  poly_axes: Sequence[Optional[int]],
                  check_result=True,
                  **params) -> Harness:
  """The `poly_axes` must correspond to the non-static arguments, and for each
  one it must specify which axes are: None, or an int.

  `check_result` specifies if we want to check that the result of the shape
  polymorphic conversion produces the same result and the JAX function.
  """
  return Harness(group_name,
                 name,
                 func, args,
                 dtype=np.float32,
                 poly_axes=poly_axes,
                 check_result=check_result,
                 **params)
def _make_harness(group_name: str, name: str,
                  func: Callable,
                  args: primitive_harness.ArgDescriptor,
                  *,
                  poly_axes: Sequence[Optional[Union[int, Sequence[int]]]],
                  check_result=True,
                  tol=None,
                  **params) -> Harness:
  """The `poly_axes` must correspond to the non-static arguments, and for each
  one it must specify which axes are: None, or an int (for the index of the
  polymorphic axis), or a tuple of ints (for multiple polymorphic axes).

  For each argument, we use its `poly_axes` entry to generate the polymorphic_shapes
  specification, creating shape variables `b0`, `b1, ..., for each of its
  polymorphic axes. This means that separate arguments will share the same
  dimension variable names, in the order in which the axes are listed in
  poly_axes.

  The name of the harness within the group will include `poly_axes`.
  You can add an additional `name`.

  `check_result` specifies if we want to check that the result of the shape
  polymorphic conversion produces the same result and the JAX function.
  """
  poly_axes_name = f"poly_axes={repr(poly_axes)}"
  assert isinstance(poly_axes, Sequence)
  # Make poly_axes: Sequence[Sequence[int]]
  poly_axes = tuple(map(lambda pa: pa if isinstance(pa, Sequence) or pa is None else (pa,),
                        poly_axes))
  if name:
    name = f"{name}_{poly_axes_name}"
  else:
    name = poly_axes_name
  return Harness(group_name,
                 name,
                 func, args,
                 dtype=np.float32,
                 poly_axes=poly_axes,
                 check_result=check_result,
                 tol=tol,
                 **params)