Exemple #1
0
def benchmark_suite(prepare: Callable[..., Callable],
                    params_list: List[Dict],
                    name: str,
                    target_total_secs: int = None):
    """Benchmarks a function for several combinations of parameters.

  Prints the summarized results in a table..

  Args:
    prepare: given kwargs returns a benchmark function specialized to the kwargs.
    params_list: a list of kwargs on which to run the benchmark.
    name: the name of this benchmark suite
    target_total_secs: the ``target_total_secs`` to pass to ``benchmark``.
 """
    # Sort parameters alphabetically so benchmark results print consistently.
    params_list = [OrderedDict(sorted(p.items())) for p in params_list]
    assert all(p.keys() == params_list[0].keys() for p in params_list)

    times = []
    for params in params_list:
        f = prepare(**params)
        subname = name + "".join("_%s=%s" % (n, p) for n, p in params.items())
        times.append(
            benchmark(f, name=subname, target_total_secs=target_total_secs))

    print("---------Benchmark summary for %s---------" % name)
    param_names = list(params_list[0].keys())
    print(
        tabulate([
            tuple(params.values()) +
            (t.mean(), _pstd(t), t.mean() / times[0].mean())
            for params, t in safe_zip(params_list, times)
        ], param_names + ["mean", "%std", "relative"]))
    print()
 def wrapped(*args, **kwargs):
     bound_expression, names, (_, out_tree) = make_bound_expression(f)(
         *args, **kwargs)
     rewritten = rule(bound_expression)
     flat_args = tree_util.tree_leaves(args)
     bindings = dict(jax_util.safe_zip(names, flat_args))
     return tree_util.tree_unflatten(out_tree,
                                     evaluate(rewritten, bindings))
Exemple #3
0
def _tfval_add_unit(
        vals: Sequence[TfVal],
        avals: Sequence[core.AbstractValue]) -> Sequence[TfValOrUnit]:
    """Turn regular TfVals into TfValOrUnit, based on expected abstract values."""
    return [
        core.unit if aval is core.abstract_unit else v
        for v, aval in util.safe_zip(vals, avals)
    ]
Exemple #4
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    argspecs: Sequence[ArgSpec],  # mix of sparse and dense pointers into spenv
    spenv: SparseEnv,
) -> Sequence[ArgSpec]:
  env : Dict[core.Var, ArgSpec] = {}

  def read(var: core.Var) -> Union[Array, ArgSpec]:
    # all literals are dense
    if isinstance(var, core.Literal):
      return ArgSpec(np.shape(var.val), spenv.push(var.val), None)
    else:
      return env[var]

  def write_buffer(var: core.Var, a: Array) -> None:
    if isinstance(var, core.DropVar):
      return
    env[var] = ArgSpec(a.shape, spenv.push(a), None)

  def write(var: core.Var, a: ArgSpec) -> None:
    if isinstance(var, core.DropVar):
      return
    assert a is not None
    env[var] = a

  # TODO: handle unitvar at all?
  #write_buffer(core.unitvar, core.unit)
  safe_map(write_buffer, jaxpr.constvars, consts)
  safe_map(write, jaxpr.invars, argspecs)

  for eqn in jaxpr.eqns:
    prim = eqn.primitive
    invals = safe_map(read, eqn.invars)

    if any(val.is_sparse() for val in invals):
      if prim not in sparse_rules:
        raise NotImplementedError(f"sparse rule for {prim}")
      out = sparse_rules[prim](spenv, *invals, **eqn.params)
    else:
      if prim is xla.xla_call_p:
        # TODO(vanderplas,frostig): workaround for binding call primitives
        # within a jaxpr interpreter
        params = eqn.params.copy()
        fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
        out_bufs = prim.bind(fun, *(val.data(spenv) for val in invals), **params)
      else:
        out_bufs = prim.bind(*(val.data(spenv) for val in invals), **eqn.params)
      out_bufs = out_bufs if prim.multiple_results else [out_bufs]
      out = []
      for buf, outvar in safe_zip(out_bufs, eqn.outvars):
        if isinstance(outvar, core.DropVar):
          out.append(None)
        else:
          out.append(ArgSpec(buf.shape, spenv.push(buf), None))
    safe_map(write, eqn.outvars, out)

  return safe_map(read, jaxpr.outvars)
Exemple #5
0
def _dedupe_bcoo(data, indices):
  f = _dedupe_bcoo_one
  n_batch = indices.ndim - 2
  for s1, s2 in safe_zip(indices.shape[:n_batch], data.shape[:n_batch]):
    if s1 != s2:
      # TODO: handle broadcasted dimensions.
      raise NotImplementedError("dedupe_bcoo for broadcasted dimensions.")
    f = vmap(f)
  return f(data, indices)
Exemple #6
0
def eval_sparse(
    jaxpr: core.Jaxpr,
    consts: Sequence[Array],  # all consts are dense
    spvalues: Sequence[SparsifyValue],  # mix of sparse and dense pointers into spenv
    spenv: SparsifyEnv,
) -> Sequence[SparsifyValue]:
  env : Dict[core.Var, SparsifyValue] = {}

  def read(var: core.Var) -> Union[Array, SparsifyValue]:
    # all literals are dense
    if isinstance(var, core.Literal):
      return spenv.dense(var.val)
    else:
      return env[var]

  def write_buffer(var: core.Var, a: Array) -> None:
    if isinstance(var, core.DropVar):
      return
    env[var] = spenv.dense(a)

  def write(var: core.Var, a: SparsifyValue) -> None:
    if isinstance(var, core.DropVar):
      return
    assert a is not None
    env[var] = a

  safe_map(write_buffer, jaxpr.constvars, consts)
  safe_map(write, jaxpr.invars, spvalues)

  for eqn in jaxpr.eqns:
    prim = eqn.primitive
    invals = safe_map(read, eqn.invars)

    if any(val.is_sparse() for val in invals):
      if prim not in sparse_rules:
        _raise_unimplemented_primitive(prim)
      out = sparse_rules[prim](spenv, *invals, **eqn.params)
    else:
      if prim is xla.xla_call_p:
        # TODO(vanderplas,frostig): workaround for binding call primitives
        # within a jaxpr interpreter
        params = eqn.params.copy()
        fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
        out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params)
      else:
        out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
      out_bufs = out_bufs if prim.multiple_results else [out_bufs]
      out = []
      for buf, outvar in safe_zip(out_bufs, eqn.outvars):
        if isinstance(outvar, core.DropVar):
          out.append(None)
        else:
          out.append(spenv.dense(buf))
    safe_map(write, eqn.outvars, out)

  return safe_map(read, jaxpr.outvars)
Exemple #7
0
def _pad(operand, padding_value, padding_config):
  low, high, interior = util.unzip3(padding_config)
  if all(lo >= 0 and hi >= 0 and i == 0 for lo, hi, i in padding_config):
    return tf.pad(operand, util.safe_zip(low, high),
                  mode="CONSTANT", constant_values=padding_value)
  # TODO(necula): implement shape inference for XlaPad
  out_shape = _pad_shape(operand, padding_value, padding_config)
  out = tfxla.pad(operand, padding_value, low, high, interior)
  out.set_shape(out_shape)
  return out
Exemple #8
0
def benchmark_suite(prepare: Callable[..., Callable],
                    params_list: List[Dict],
                    name: str,
                    target_total_secs: int = None):
    """Benchmarks a function for several combinations of parameters.

  Prints the summarized results in a table..

  Args:
    prepare: given kwargs returns a benchmark function specialized to the kwargs.
    params_list: a list of kwargs on which to run the benchmark.
    name: the name of this benchmark suite
    target_total_secs: the ``target_total_secs`` to pass to ``benchmark``.
 """
    # Sort parameters alphabetically so benchmark results print consistently.
    params_list = [OrderedDict(sorted(p.items())) for p in params_list]
    assert all(p.keys() == params_list[0].keys() for p in params_list)

    times = []
    for params in params_list:
        f = prepare(**params)
        subname = name + "".join("_%s=%s" % (n, _param_str(p))
                                 for n, p in params.items())
        times.append(
            benchmark(f, name=subname, target_total_secs=target_total_secs))

    param_names = list(params_list[0].keys())
    data_header = param_names + ["mean", "%std", "relative"]
    data = [
        list(map(_param_str, params.values())) +
        [t.mean(), _pstd(t), t.mean() / times[0].mean()]
        for params, t in safe_zip(params_list, times)
    ]

    if FLAGS.baseline_dir:
        mean_idx = len(param_names)
        means = _get_baseline_means(FLAGS.baseline_dir, name)
        assert len(means) == len(data), (means, data)
        data_header.append("mean/baseline")
        for idx, mean in enumerate(means):
            data[idx].append(data[idx][mean_idx] / mean)

    print("---------Benchmark summary for %s---------" % name)
    print(tabulate(data, data_header))
    print()

    if FLAGS.export_dir:
        filename = _export_results(data_header, data, FLAGS.export_dir, name)
        print("Wrote %s results to %s" % (name, filename))
        print()
Exemple #9
0
 def value_and_grad_fun(*args, **kwargs):
   if not allow_int:
     dyn_args = [args[i] for i in _ensure_index_tuple(argnums)]
     dyn_args_flat, _ = tree_util.tree_flatten(dyn_args, is_leaf=lambda arg: isinstance(arg, BCOO))
     for arg in dyn_args_flat:
       dtype = np.dtype(arg)
       if not (np.issubdtype(arg, np.floating) or np.issubdtype(arg, np.complexfloating)):
         raise TypeError("grad requires real- or complex-valued inputs (input dtype that "
                         "is a sub-dtype of np.floating or np.complexfloating), "
                         f"but got {dtype.name}. If you want to use integer-valued "
                         "inputs, set allow_int to True.")
   value, grad = raw_value_and_grad_fun(*args, **kwargs)
   if isinstance(argnums, int):
     grad = maybe_copy_index(args[argnums], grad)
   else:
     grad = tuple(maybe_copy_index(args[argnum], g) for argnum, g in safe_zip(argnums, grad))
   return value, grad
Exemple #10
0
        out = TFCDict(self)
        if isinstance(o, dict) or (type(o) is type(self)):
            for key in self._keys:
                out[key] -= o[key]
        elif isinstance(o, np.ndarray):
            o = o.flatten()
            for k in range(self._nKeys):
                out[self._keys[k]] -= o[self._slices[k]]
        return out


# Register TFCDict as a JAX type
register_pytree_node(
    TFCDict,
    lambda x: (list(x.values()), list(x.keys())),
    lambda keys, values: TFCDict(safe_zip(keys, values)),
)


class TFCDictRobust(OrderedDict):
    """This class is like the :class:`TFCDict <tfc.utils.TFCUtils.TFCDict>` class, but it handles non-flat arrays."""

    def __init__(self, *args):
        """Initialize TFCDictRobust using the OrderedDict method."""

        # Store dictionary and keep a record of the keys. Keys will stay in same
        # order, so that adding and subtracting is repeatable.
        super().__init__(*args)
        self._keys = list(self.keys())
        self._nKeys = len(self._keys)
        self.getSlices()
Exemple #11
0
def _duplicate_for_sparse_spvalues(spvalues, params):
  for spvalue, param in safe_zip(spvalues, params):
      yield from [param, param] if spvalue.is_sparse() else [param]
Exemple #12
0
def _duplicate_for_sparse_argspecs(argspecs, params):
    for argspec, param in safe_zip(argspecs, params):
        yield from [param, param] if argspec.is_sparse() else [param]
Exemple #13
0
 def solve_shape_vars(shape_spec: str,
                      shape: Sequence[int]) -> Dict[str, int]:
     shape_polys = masking.parse_spec(shape_spec)
     return jax2tf.jax2tf._solve_shape_vars(
         util.safe_zip(shape_polys, shape))
 def f(*args):
     sub_env = dict(jax_util.safe_zip(self.variable_names, args))
     return evaluate(self.expression, sub_env)
Exemple #15
0
 def _compatible(shape1, shape2):
   return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))