Exemplo n.º 1
0
def _sort(values, axis=-1, direction='ASCENDING', stable=False, name=None):  # pylint: disable=unused-argument
    """Numpy implementation of `tf.sort`."""
    if direction == 'ASCENDING':
        pass
    elif direction == 'DESCENDING':
        values = np.negative(values)
    else:
        raise ValueError('Unrecognized direction: {}.'.format(direction))
    result = np.sort(values, axis, kind='stable' if stable else 'quicksort')
    if direction == 'DESCENDING':
        return np.negative(result)
    return result
Exemplo n.º 2
0
def eval_for(op):
  if op.op_name in ("IAdd", "IMul", "FAdd", "FMul", "FDiv"):
    x, y = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val)
    y_bc = broadcast_dims(op.all_idxs, y.idxs, y.atom.val)
    if op.op_name in ("IAdd", "FAdd"):
      return jnp.add(x_bc, y_bc)
    elif op.op_name in ("IMul", "FMul"):
      return jnp.multiply(x_bc, y_bc)
    if op.op_name in ("FDiv",):
      return jnp.divide(x_bc, y_bc)
    else:
      raise Exception("Not implemented: " + str(op.op_name))
  elif op.op_name == "Iota":
    n, = op.size_args
    val = jnp.arange(n)
    val_bc = broadcast_dims(op.all_idxs, [], val)
    return val_bc
  elif op.op_name == "Id":
    x, = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, x.atom.val)
    return x_bc
  elif op.op_name == "Get":
    x, idx = op.args
    out_shape = [i.size for i in op.all_idxs]
    x_idxs_used = get_stack_idxs_used(op.all_idxs, x.idxs)
    leading_idx_arrays = []
    for i, idx_used in enumerate(x_idxs_used):
      if idx_used:
        leading_idx_arrays.append(nth_iota(out_shape, i))
      else:
        pass
    payload_idx_array = broadcast_dims(op.all_idxs, idx.idxs, idx.atom.val)
    out = x.atom.val[tuple(leading_idx_arrays) + (payload_idx_array,)]
    return out
  elif op.op_name == "IntToReal":
    x, = op.args
    real_val = jnp.array(x.atom.val, dtype="float32")
    x_bc = broadcast_dims(op.all_idxs, x.idxs, real_val)
    return x_bc
  elif op.op_name in ("FNeg", "INeg"):
    x, = op.args
    x_bc = broadcast_dims(op.all_idxs, x.idxs, jnp.negative(x.atom.val))
    return x_bc
  elif op.op_name == "ThreeFry2x32":
    convert_64_to_32s = lambda x: np.array([x]).view(np.uint32)
    convert_32s_to_64 = lambda x: np.int64(np.array(x).view(np.int64).item())
    x, y = op.args
    key, count = convert_64_to_32s(x.atom.val), convert_64_to_32s(y.atom.val)
    result = convert_32s_to_64(random.threefry_2x32(key, count))
    x_bc = broadcast_dims(op.all_idxs, x.idxs, result)
    return x_bc
  else:
    raise Exception("Unrecognized op: {}".format(op.op_name))
Exemplo n.º 3
0
def objective(t,
              enc_params,
              dec_params,
              log_px_estimator,
              maximize=False,
              num_samples=32):
    rng = random.PRNGKey(t)
    reverse_kl_batch = batch_reverse_kl(funnel_log_density, log_px_estimator,
                                        rng, enc_params, dec_params,
                                        num_samples)
    if maximize is True:
        return jnp.negative(reverse_kl_batch)
    return reverse_kl_batch
Exemplo n.º 4
0
  def __init__(self,
               shift: Numeric,
               scale: Optional[Numeric] = None,
               log_scale: Optional[Numeric] = None):
    """Initializes a ScalarAffine bijector.

    Args:
      shift: the bijector's shift parameter. Can also be batched.
      scale: the bijector's scale parameter. Can also be batched. NOTE: `scale`
        must be non-zero, otherwise the bijector is not invertible. It is the
        user's responsibility to make sure `scale` is non-zero; the class will
        make no attempt to verify this.
      log_scale: the log of the scale parameter. Can also be batched. If
        specified, the bijector's scale is set equal to `exp(log_scale)`. Unlike
        `scale`, `log_scale` is an unconstrained parameter. NOTE: either `scale`
        or `log_scale` can be specified, but not both. If neither is specified,
        the bijector's scale will default to 1.

    Raises:
      ValueError: if both `scale` and `log_scale` are not None.
    """
    super().__init__(event_ndims_in=0, is_constant_jacobian=True)
    self._shift = shift
    if scale is None and log_scale is None:
      self._scale = 1.
      self._inv_scale = 1.
      self._log_scale = 0.
    elif log_scale is None:
      self._scale = scale
      self._inv_scale = 1. / scale
      self._log_scale = jnp.log(jnp.abs(scale))
    elif scale is None:
      self._scale = jnp.exp(log_scale)
      self._inv_scale = jnp.exp(jnp.negative(log_scale))
      self._log_scale = log_scale
    else:
      raise ValueError(
          'Only one of `scale` and `log_scale` can be specified, not both.')
    self._batch_shape = jax.lax.broadcast_shapes(
        jnp.shape(self._shift), jnp.shape(self._scale))
Exemplo n.º 5
0
def negative(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.negative(x))
Exemplo n.º 6
0
def negative(a: Numeric):
    return jnp.negative(a)
Exemplo n.º 7
0
sparse_rules[sparse.todense_p] = _todense_sparse_rule

#------------------------------------------------------------------------------
# BCOO methods derived from sparsify
# defined here to avoid circular imports


def _sum(self, *args, **kwargs):
    """Sum array along axis."""
    return sparsify(lambda x: x.sum(*args, **kwargs))(self)


_bcoo_methods = {
    'sum': _sum,
    "__neg__": sparsify(jnp.negative),
    "__pos__": sparsify(jnp.positive),
    "__matmul__": sparsify(jnp.matmul),
    "__rmatmul__": sparsify(lambda self, other: jnp.matmul(other, self)),
    "__mul__": sparsify(jnp.multiply),
    "__rmul__": sparsify(lambda self, other: jnp.multiply(other, self)),
    "__add__": sparsify(jnp.add),
    "__radd__": sparsify(lambda self, other: jnp.add(other, self)),
    "__sub__":
    sparsify(lambda self, other: jnp.add(self, jnp.negative(other))),
    "__rsub__":
    sparsify(lambda self, other: jnp.add(other, jnp.negative(self))),
}

for method, impl in _bcoo_methods.items():
    setattr(BCOO, method, impl)
Exemplo n.º 8
0
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
                           mode=None, fill_value=None):
  # mirrors lax_numpy._rewriting_take.
  treedef, static_idx, dynamic_idx = lax_numpy._split_index_for_jit(idx, arr.shape)
  result = sparsify(
      lambda arr, idx: lax_numpy._gather(arr, treedef, static_idx, idx, indices_are_sorted,
                                         unique_indices, mode, fill_value))(arr, dynamic_idx)
  # Account for a corner case in the rewriting_take implementation.
  if not isinstance(result, BCOO) and np.size(result) == 0:
    result = BCOO.fromdense(result)
  return result

_bcoo_methods = {
  'sum': _sum,
  "__neg__": sparsify(jnp.negative),
  "__pos__": sparsify(jnp.positive),
  "__matmul__": sparsify(jnp.matmul),
  "__rmatmul__": sparsify(lambda self, other: jnp.matmul(other, self)),
  "__mul__": sparsify(jnp.multiply),
  "__rmul__": sparsify(lambda self, other: jnp.multiply(other, self)),
  "__add__": sparsify(jnp.add),
  "__radd__": sparsify(lambda self, other: jnp.add(other, self)),
  "__sub__": sparsify(lambda self, other: jnp.add(self, jnp.negative(other))),
  "__rsub__": sparsify(lambda self, other: jnp.add(other, jnp.negative(self))),
  "__getitem__": _sparse_rewriting_take,
}

for method, impl in _bcoo_methods.items():
  setattr(BCOO, method, impl)
Exemplo n.º 9
0
 def inverse_log_det_jacobian(self, y: Array) -> Array:
   """Computes log|det J(f^{-1})(y)|."""
   batch_shape = jax.lax.broadcast_shapes(self._batch_shape, y.shape)
   return jnp.broadcast_to(jnp.negative(self._log_scale), batch_shape)
Exemplo n.º 10
0
minimum = utils.copy_docstring(tf.math.minimum,
                               lambda x, y, name=None: np.minimum(x, y))

multiply = utils.copy_docstring(tf.math.multiply,
                                lambda x, y, name=None: np.multiply(x, y))

multiply_no_nan = utils.copy_docstring(
    tf.math.multiply_no_nan,
    lambda x, y, name=None: np.where(  # pylint: disable=g-long-lambda
        onp.broadcast_to(np.equal(y, 0.),
                         np.array(x).shape), np.zeros_like(np.multiply(x, y)),
        np.multiply(x, y)))

negative = utils.copy_docstring(tf.math.negative,
                                lambda x, name=None: np.negative(x))

# nextafter = utils.copy_docstring(
#     tf.math.nextafter,
#     lambda x1, x2, name=None: np.nextafter)

not_equal = utils.copy_docstring(tf.math.not_equal,
                                 lambda x, y, name=None: np.not_equal(x, y))

polygamma = utils.copy_docstring(
    tf.math.polygamma, lambda a, x, name=None: scipy_special.polygamma(a, x))

polyval = utils.copy_docstring(
    tf.math.polyval, lambda coeffs, x, name=None: np.polyval(coeffs, x))

pow = utils.copy_docstring(  # pylint: disable=redefined-builtin
Exemplo n.º 11
0
 def _neg(x):
     return jnp.negative(x)