Example #1
0
def model(X: DeviceArray) -> DeviceArray:
    """Gamma-Poisson hierarchical model for daily sales forecasting

    Args:
        X: input data

    Returns:
        output data
    """
    n_stores, n_days, n_features = X.shape
    n_features -= 1  # remove one dim for target
    eps = 1e-12  # epsilon

    plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
    plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
    plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

    disp_param_mu = numpyro.sample(Site.disp_param_mu,
                                   dist.Normal(loc=4., scale=1.))
    disp_param_sigma = numpyro.sample(Site.disp_param_sigma,
                                      dist.HalfNormal(scale=1.))

    with plate_stores:
        disp_param_offsets = numpyro.sample(
            Site.disp_param_offsets,
            dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1))
        disp_params = disp_param_mu + disp_param_offsets * disp_param_sigma
        disp_params = numpyro.sample(Site.disp_params,
                                     dist.Delta(disp_params),
                                     obs=disp_params)

    with plate_features:
        coef_mus = numpyro.sample(
            Site.coef_mus,
            dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)))
        coef_sigmas = numpyro.sample(
            Site.coef_sigmas, dist.HalfNormal(scale=2. * jnp.ones(n_features)))

        with plate_stores:
            coef_offsets = numpyro.sample(
                Site.coef_offsets,
                dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.))
            coefs = coef_mus + coef_offsets * coef_sigmas
            coefs = numpyro.sample(Site.coefs, dist.Delta(coefs), obs=coefs)

    with plate_days, plate_stores:
        targets = X[..., -1]
        features = jnp.nan_to_num(X[..., :-1])  # padded features to 0
        is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets),
                                jnp.ones_like(targets))
        not_observed = 1 - is_observed
        means = (is_observed * jnp.exp(
            jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) +
                 not_observed * eps)

        betas = is_observed * jnp.exp(-disp_params) + not_observed
        alphas = means * betas
        return numpyro.sample(Site.days,
                              dist.GammaPoisson(alphas, betas),
                              obs=jnp.nan_to_num(targets))
Example #2
0
def main(cfg_path: Path, log_level: int):
    logging.basicConfig(
        stream=sys.stdout,
        level=log_level,
        datefmt='%Y-%m-%d %H:%M',
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    with open(cfg_path) as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    nside = cfg['nside']
    outpath = cfg['hdf5_path']

    mask_file = 'data/n0512.fits'
    mask_data = old_np.nan_to_num(
        hp.read_map(mask_file, verbose=False, dtype=np.float32))
    mask_apodized = nmt.mask_apodization(mask_data, 5., apotype="C2")
    mask_binary = old_np.where(mask_data > 0.0, 1.0, 0.0)

    with h5py.File(outpath, 'w') as f:
        ns0512 = f.require_group('ns0512')
        data_dset = ns0512.require_dataset('binary',
                                           shape=(hp.nside2npix(nside), ),
                                           dtype=np.float32)
        data_dset[...] = np.nan_to_num(mask_binary)

        apodized = ns0512.require_group('apodized')
        data_dset2 = apodized.require_dataset('basic',
                                              shape=(hp.nside2npix(nside), ),
                                              dtype=np.float32)
        data_dset2[...] = np.nan_to_num(mask_apodized)
Example #3
0
def uniform_stochastic_quantize(v: jnp.ndarray,
                                num_levels: int,
                                rng: PRNGKey,
                                v_min: Optional[float] = None,
                                v_max: Optional[float] = None) -> jnp.ndarray:
  """Uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf.

  Args:
    v: vector to be quantized.
    num_levels: Number of levels of quantization.
    rng: jax random key.
    v_min: minimum threshold for quantization. If None, sets it to jnp.amin(v).
    v_max: maximum threshold for quantization. If None, sets it to jnp.amax(v).

  Returns:
    Quantized array.
  """
  # Rescale the vector to be between zero to one.
  if v_min is None:
    v_min = jnp.amin(v)
  if v_max is None:
    v_max = jnp.amax(v)
  v = jnp.nan_to_num((v - v_min) / (v_max - v_min))
  v = jnp.maximum(0., jnp.minimum(v, 1.))
  # Compute the upper and lower boundary of each value.
  v_ceil = jnp.ceil(v * (num_levels - 1)) / (num_levels - 1)
  v_floor = jnp.floor(v * (num_levels - 1)) / (num_levels - 1)
  # uniformly quantize between v_ceil and v_floor.
  rand = jax.random.uniform(key=rng, shape=v.shape)
  threshold = jnp.nan_to_num((v - v_floor) / (v_ceil - v_floor))
  quantized = jnp.where(rand > threshold, v_floor, v_ceil)
  # Rescale the values and return it.
  return v_min + quantized * (v_max - v_min)
Example #4
0
def _giou(boxes1: jnp.ndarray, boxes2: jnp.ndarray) -> jnp.ndarray:
    b1_ymin, b1_xmin, b1_ymax, b1_xmax = jnp.hsplit(boxes1, 4)
    b2_ymin, b2_xmin, b2_ymax, b2_xmax = jnp.hsplit(boxes2, 4)

    b1_width = jnp.maximum(0, b1_xmax - b1_xmin)
    b1_height = jnp.maximum(0, b1_ymax - b1_ymin)
    b2_width = jnp.maximum(0, b2_xmax - b2_xmin)
    b2_height = jnp.maximum(0, b2_ymax - b2_ymin)

    b1_area = b1_width * b1_height
    b2_area = b2_width * b2_height

    intersect_ymin = jnp.maximum(b1_ymin, b2_ymin)
    intersect_xmin = jnp.maximum(b1_xmin, b2_xmin)
    intersect_ymax = jnp.minimum(b1_ymax, b2_ymax)
    intersect_xmax = jnp.minimum(b1_xmax, b2_xmax)

    intersect_width = jnp.maximum(0, intersect_xmax - intersect_xmin)
    intersect_height = jnp.maximum(0, intersect_ymax - intersect_ymin)
    intersect_area = intersect_width * intersect_height

    union_area = b1_area + b2_area - intersect_area
    iou = jnp.nan_to_num(intersect_area / union_area)

    enclose_ymin = jnp.minimum(b1_ymin, b2_ymin)
    enclose_xmin = jnp.minimum(b1_xmin, b2_xmin)
    enclose_ymax = jnp.maximum(b1_ymax, b2_ymax)
    enclose_xmax = jnp.maximum(b1_xmax, b2_xmax)
    enclose_width = jnp.maximum(0, enclose_xmax - enclose_xmin)
    enclose_height = jnp.maximum(0, enclose_ymax - enclose_ymin)
    enclose_area = enclose_width * enclose_height
    giou = iou - jnp.nan_to_num((enclose_area - union_area) / enclose_area)
    return giou.squeeze()
Example #5
0
 def recursive_check(tuple_object, dict_object):
     if isinstance(tuple_object, (List, Tuple)):
         for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
             recursive_check(tuple_iterable_value, dict_iterable_value)
     elif tuple_object is None:
         return
     else:
         self.assert_almost_equals(jnp.nan_to_num(tuple_object), jnp.nan_to_num(dict_object), 1e-5)
def mutual_information(net, samples):

    probs = net.prob_factors(samples)

    avgP = jnp.mean(probs, axis=2)
    S = (avgP * jnp.nan_to_num(jnp.log(avgP)) +
         (1. - avgP) * jnp.nan_to_num(jnp.log(1. - avgP))) / jnp.log(2.)
    condS = jnp.mean(
        (probs * jnp.nan_to_num(jnp.log(probs)) +
         (1. - probs) * jnp.nan_to_num(jnp.log(1. - probs))) / jnp.log(2.),
        axis=2)
    return condS - S
Example #7
0
 def calc_delta_correlated_error_jax(areas_a, ematrix_a, areas_b, errors_b):
     diffs_a, diffs_b = delta_differential_jax(areas_a, areas_b)
     # need to do nan_to_num since the differential can return nan...
     # not sure how to fix "properly" though
     diffs_a = jnp.nan_to_num(diffs_a)
     diffs_b = jnp.nan_to_num(diffs_b)
     # for the total, we need to do
     # diffs_a * ematrix_a * diffs_a + diffs_b*errors_b*diffs_b,
     # since the errors on a and b have no connections, we can get away with this.
     # err_a_sq = diffs_a.T @ ematrix_a @ diffs_a
     err_a_sq = jnp.matmul(diffs_a.T, jnp.matmul(ematrix_a, diffs_a))
     err_b_sq = jnp.sum(jnp.square((diffs_b * errors_b)))
     return jnp.sqrt(err_a_sq + err_b_sq)
Example #8
0
def update(delta_x, delta_gx, Us, VTs, n_step):
    # Add column/row to Us/VTs with updated approximation
    # Calculate J_i
    vT = rmatvec(Us, VTs, delta_x)
    u = (delta_x - matvec(Us, VTs, delta_gx)) / _einsum(
        'bij, bij -> b', vT, delta_gx)[:, None, None]

    vT = jnp.nan_to_num(vT)
    u = jnp.nan_to_num(u)

    # Store in UTs and VTs for calculating J
    VTs = jax.ops.index_update(VTs, jax.ops.index[:, n_step - 1], vT)
    Us = jax.ops.index_update(Us, jax.ops.index[:, :, :, n_step - 1], u)

    return Us, VTs
Example #9
0
def ps_ir(*, draws, proposal_densities, target_densities, num_resampled, key):
    logits = target_densities - proposal_densities
    logits = jnp.nan_to_num(logits, nan=-jnp.inf, neginf=-jnp.inf)
    idxs = jax.random.categorical(key=key,
                                  logits=logits,
                                  shape=(num_resampled, ))
    return draws[idxs]
Example #10
0
def reweighted_stddev(f_n: Array, target_logpdf_n: Array,
                      source_logpdf_n: Array) -> Float:
    """Compute reweighted estimate of
    stddev(f(x)) under x ~ p_target
    based on samples   x ~ p_source

    where
        p_target(x) = exp(target_logpdf(x)) / Z_target

    using samples from a different source
        x_n ~ p_source
        where
        p_source(x) = exp(source_logpdf(x)) / Z_source

    The inputs are arrays "{fxn_name}_n" containing the result of
    calling each fxn on a fixed array of samples:

    * f_n = [f(x_n) for x_n in samples]
    * target_logpdf_n = [target_logpdf(x_n) for x_n in samples]
    * source_logpdf_n = [source_logpdf(x_n) for x_n in samples]
    """

    log_weights_n = target_logpdf_n - source_logpdf_n
    weights = np.exp(log_weights_n - logsumexp(log_weights_n)).flatten()

    f_mean = np.sum(weights * f_n)
    squared_deviations = (f_n - f_mean)**2

    # sanitize 0 * inf -> 0 (instead of nan)
    weighted_squared_deviations = weights * squared_deviations
    sanitized = np.nan_to_num(weighted_squared_deviations, nan=0)
    stddev = np.sqrt(np.sum(sanitized))

    return stddev
Example #11
0
def precision(
    y_true: jnp.ndarray,
    y_pred: jnp.ndarray,
    threshold: jnp.ndarray,
    class_id: jnp.ndarray,
    sample_weight: jnp.ndarray,
    true_positives: ReduceConfusionMatrix,
    false_positives: ReduceConfusionMatrix,
) -> jnp.ndarray:

    # TODO: class_id behavior
    y_pred = (y_pred > threshold).astype(jnp.float32)

    if y_true.dtype != y_pred.dtype:
        y_pred = y_pred.astype(y_true.dtype)

    true_positives = true_positives(y_true=y_true,
                                    y_pred=y_pred,
                                    sample_weight=sample_weight)
    false_positives = false_positives(y_true=y_true,
                                      y_pred=y_pred,
                                      sample_weight=sample_weight)

    return jnp.nan_to_num(
        jnp.divide(true_positives, true_positives + false_positives))
Example #12
0
 def rnn_cell(carry, x):
     newCarry, logits = jax.vmap(eval_cell)(carry[0], carry[1])
     sampleOut = jax.random.categorical(x, logits)
     sample = jax.nn.one_hot(sampleOut, inputDim)
     logProb = jnp.sum(nn.log_softmax(logits) * sample, axis=1)
     return (newCarry, sample), (jnp.nan_to_num(logProb,
                                                nan=-35), sampleOut)
 def loss(self, params, batch):
     """Cross-entropy loss"""
     inputs, targets = batch
     preds = self.predict(params, inputs)
     return jnp.mean(
         jnp.nan_to_num(-targets * jnp.log(preds) -
                        (1 - targets) * jnp.log(1 - preds)))
Example #14
0
def syn2post_softmax(syn_values,
                     post_ids,
                     post_num: int,
                     indices_are_sorted=True):
    """The syn-to-post softmax computation.

  Parameters
  ----------
  syn_values: jax.numpy.ndarray, JaxArray, Variable
    The synaptic values.
  post_ids: jax.numpy.ndarray, JaxArray
    The post-synaptic neuron ids. If ``post_ids`` is generated by
    ``brainpy.conn.TwoEndConnector``, then it has sorted indices.
    Otherwise, this function cannot guarantee indices are sorted.
    You's better set ``indices_are_sorted=False``.
  post_num: int
    The number of the post-synaptic neurons.
  indices_are_sorted: whether ``post_ids`` is known to be sorted.

  Returns
  -------
  post_val: jax.numpy.ndarray, JaxArray
    The post-synaptic value.
  """
    post_ids = as_device_array(post_ids)
    syn_values = as_device_array(syn_values)
    if syn_values.dtype == jnp.bool_:
        syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
    syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted)
    syn_values = syn_values - syn_maxs[post_ids]
    syn_values = jnp.exp(syn_values)
    normalizers = _jit_seg_sum(syn_values, post_ids, post_num,
                               indices_are_sorted)
    softmax = syn_values / normalizers[post_ids]
    return jnp.nan_to_num(softmax)
Example #15
0
def top_k_error_rate_metric(logits: jnp.ndarray,
                            one_hot_labels: jnp.ndarray,
                            k: int = 5,
                            mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
  """Returns the top-K error rate between some predictions and some labels.

  Args:
    logits: Output of the model.
    one_hot_labels: One-hot encoded labels. Dimensions should match the logits.
    k: Number of class the model is allowed to predict for each example.
    mask: Mask to apply to the loss to ignore some samples (usually, the padding
      of the batch). Array of ones and zeros.

  Returns:
    The error rate (1 - accuracy), averaged over the first dimension (samples).
  """
  if mask is None:
    mask = jnp.ones([logits.shape[0]])
  mask = mask.reshape([logits.shape[0]])
  true_labels = jnp.argmax(one_hot_labels, -1).reshape([-1, 1])
  top_k_preds = jnp.argsort(logits, axis=-1)[:, -k:]
  hit = jax.vmap(jnp.isin)(true_labels, top_k_preds)
  error_rate = 1 - ((hit * mask).sum() / mask.sum())
  # Set to zero if there is no non-masked samples.
  return jnp.nan_to_num(error_rate)
Example #16
0
def biot_savart_oncoil(r_eval, dl, ll, I_arr):
    """
    Calculate the Biot-Savart integral over the coils (also ON) a segment of the
    coil.
    specified by l and dl.
    Arguments:
    *r_eval*: (lenght 3 array) the point wherer the field is to be evaluated in cartesian
    coordinates. Has to be on a coil.
    *dl*: ( n_coils, nsegments, 3)-array of the distance vector to every
    other coil line segment
    *l* ( n_coils, nsegments, 3)-array of the position of each coil segment

    Note on algoritnm: the None allows one to add new axes to in-line
    cast the array into the proper shape.
    The biot-savart integral is calculated as a sum over all segments.

    returns:
    *B*: magnetic field at position r_eval
    """
    top = np.cross(dl, r_eval[None, None, :] - ll) * I_arr[:, None,
                                                           None]  #unchecked
    bottom = np.linalg.norm(r_eval[None, None, :] - ll, axis=-1)**3
    # sum over all infinitesimal line segments, replacing the NaN with zero
    B = np.sum(np.nan_to_num(top / bottom[:, :, None]), axis=(0, 1))
    return B
Example #17
0
File: vis.py Project: wx-b/mipnerf
def visualize_depth(depth,
                    acc=None,
                    near=None,
                    far=None,
                    curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps),
                    modulus=0,
                    colormap=None):
    """Visualize a depth map.

  Args:
    depth: A depth map.
    acc: An accumulation map, in [0, 1].
    near: The depth of the near plane, if None then just use the min().
    far: The depth of the far plane, if None then just use the max().
    curve_fn: A curve function that gets applied to `depth`, `near`, and `far`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If > 0, mod the normalized depth by `modulus`. Use (0, 1].
    colormap: A colormap function. If None (default), will be set to
      matplotlib's viridis if modulus==0, sinebow otherwise.

  Returns:
    An RGB visualization of `depth`.
  """
    # If `near` or `far` are None, identify the min/max non-NaN values.
    eps = jnp.finfo(jnp.float32).eps
    near = near or jnp.min(jnp.nan_to_num(depth, jnp.inf)) - eps
    far = far or jnp.max(jnp.nan_to_num(depth, -jnp.inf)) + eps

    # Curve all values.
    depth, near, far = [curve_fn(x) for x in [depth, near, far]]

    # Wrap the values around if requested.
    if modulus > 0:
        value = jnp.mod(depth, modulus) / modulus
        colormap = colormap or sinebow
    else:
        # Scale to [0, 1].
        value = jnp.nan_to_num(jnp.clip((depth - near) / (far - near), 0, 1))
        colormap = colormap or cm.get_cmap('viridis')

    vis = colormap(value)[:, :, :3]

    # Set non-accumulated pixels to white.
    if acc is not None:
        vis = vis * acc[:, :, None] + (1 - acc)[:, :, None]

    return vis
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
  """Piecewise-Constant PDF sampling.

  Args:
    key: jnp.ndarray(float32), [2,], random number generator.
    bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
    weights: jnp.ndarray(float32), [batch_size, num_bins].
    num_samples: int, the number of samples.
    randomized: bool, use randomized samples.

  Returns:
    z_samples: jnp.ndarray(float32), [batch_size, num_samples].
  """
  # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
  # avoids NaNs when the input is zeros or small, but has no effect otherwise.
  eps = 1e-5
  weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
  padding = jnp.maximum(0, eps - weight_sum)
  weights += padding / weights.shape[-1]
  weight_sum += padding

  # Compute the PDF and CDF for each weight vector, while ensuring that the CDF
  # starts with exactly 0 and ends with exactly 1.
  pdf = weights / weight_sum
  cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1))
  cdf = jnp.concatenate([
      jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf,
      jnp.ones(list(cdf.shape[:-1]) + [1])
  ],
                        axis=-1)

  # Draw uniform samples.
  if randomized:
    # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
    u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
  else:
    # Match the behavior of random.uniform() by spanning [0, 1-eps].
    u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples)
    u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])

  # Identify the location in `cdf` that corresponds to a random sample.
  # The final `True` index in `mask` will be the start of the sampled interval.
  mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None]

  def find_interval(x):
    # Grab the value where `mask` switches from True to False, and vice versa.
    # This approach takes advantage of the fact that `x` is sorted.
    x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2)
    x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
    return x0, x1

  bins_g0, bins_g1 = find_interval(bins)
  cdf_g0, cdf_g1 = find_interval(cdf)

  t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
  samples = bins_g0 + t * (bins_g1 - bins_g0)

  # Prevent gradient from backprop-ing through `samples`.
  return lax.stop_gradient(samples)
Example #19
0
def arithmetic_encoding_num_bits(v: jnp.ndarray) -> int:
  """Computes number of bits needed to store v via arithmetic coding."""
  v = jnp.nan_to_num(v)
  v = v.flatten()
  uniq = jnp.unique(v)
  entropy = _entropy(v, uniq)
  hist_bits = _hist_bits(v, uniq)
  return hist_bits + (v.size * entropy) + (2 * 32) + 2
Example #20
0
def replace_inf_with_zero(x):
    return jnp.nan_to_num(
        x,
        copy=False,
        nan=0.0,
        posinf=0.0,
        neginf=0.0,
    )
Example #21
0
 def _log_prob(self, inputs: np.ndarray) -> np.ndarray:
     """Log prob for arrays."""
     # calculate log_prob
     u, log_det = self._inverse(self._params, inputs)
     log_prob = self.prior.log_prob(u) + log_det
     # set NaN's to negative infinity (i.e. zero probability)
     log_prob = np.nan_to_num(log_prob, nan=np.NINF)
     return log_prob
Example #22
0
 def calc_delta_jax(areas_a, areas_b):
     # do I need bin areas or densities?
     # I guess since by definition sum(area_a) = 1, areas are needed?!
     integrand = jnp.true_divide(jnp.square(areas_a - areas_b),
                                 areas_a + areas_b)
     # nan_to_num important as divide gives nans if both 0
     delta = 0.5 * jnp.sum(jnp.nan_to_num(integrand))
     return delta
Example #23
0
def find_minimum_theta_scalar(fc_new, r_fil, theta_i):
	f = partial(objective_scalar, fc_new, r_fil)
	f_prime = grad(f)
	f_primeprime = grad(f_prime)
	for n in range(n_iter):
		new_ep = epsilon * np.exp(-n / 15)
		theta_i = theta_i - alpha * np.nan_to_num(f_prime(theta_i) / (f_primeprime(theta_i) + new_ep))
	return theta_i
Example #24
0
    def apply(self, x, L=10, units=[10], inputDim=2, actFun=nn.elu, initScale=1.0):

        initFunctionCell = jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="uniform")
        initFunctionOut = jax.nn.initializers.variance_scaling(scale=initScale, mode="fan_in", distribution="uniform")
        #initFunction = jax.nn.initializers.lecun_uniform()

        cellInV = nn.Dense.shared(features=units[0],
                                    name='rnn_cell_in_v',
                                    bias=False)
        cellInH = nn.Dense.shared(features=units[0],
                                    name='rnn_cell_in_h',
                                    bias=False)
        cellCarryV = nn.Dense.shared(features=units[0],
                                    name='rnn_cell_carry_v',
                                    bias=False,
                                    kernel_init=initFunctionCell)
        cellCarryH = nn.Dense.shared(features=units[0],
                                    name='rnn_cell_carry_h',
                                    bias=True,
                                    kernel_init=initFunctionCell)

        outputDense = nn.Dense.shared(features=inputDim,
                                      name='rnn_output_dense',
                                      kernel_init=initFunctionOut)

        batchSize = x.shape[0]

        outputs = jnp.asarray(np.zeros((batchSize,L,L)))

        states = jnp.asarray(np.zeros((L,batchSize,units[0])))
        inputs = jnp.asarray(np.zeros((L+1,L+2,batchSize,inputDim)))

        # Scan directions for zigzag path
        direction = np.ones(L,dtype=np.int32)
        direction[1::2] = -1
        direction = jnp.asarray(direction)

        x = jnp.transpose(x,axes=[1,2,0])
        inputs = jax.ops.index_update(inputs,jax.ops.index[1:,1:-1],jax.nn.one_hot(x,inputDim))
      
        def rnn_dim2(carry,x):
            newCarry = actFun( cellInH(x[0]) + cellInV(x[1]) + cellCarryH(carry) + cellCarryV(x[2]) )
            out = jnp.concatenate((newCarry, nn.softmax(outputDense(newCarry))), axis=1)
            return newCarry, out
        def rnn_dim1(carry,x):
            _, out = jax.lax.scan(rnn_dim2,jnp.zeros((batchSize,units[0]),dtype=np.float32),
                                    (self.reverse_line(x[0],x[2])[:-2],
                                     self.reverse_line(x[1],x[2])[1:-1],
                                     self.reverse_line(carry,x[2]))
                                 )
            carry = jax.ops.index_update(carry,jax.ops.index[:,:],out[:,:,:units[0]])
            outputs = jnp.log( jnp.sum( out[:,:,units[0]:] * self.reverse_line(x[0],x[2])[1:-1,:], axis=2 ) )
            return self.reverse_line(carry,x[2]), jnp.sum(outputs,axis=0)
        
        _, prob = jax.lax.scan(rnn_dim1,states,(inputs[1:],inputs[:-1],direction))
        return jnp.nan_to_num(jnp.sum(prob,axis=0))
Example #25
0
def maximize_saturation(rgb):
    """Rescale the maximum saturation in `rgb` to be 1."""
    hsv = pix.rgb_to_hsv(rgb)
    scaling = jnp.maximum(1,
                          jnp.nan_to_num(1 / jnp.max(hsv[Ellipsis, 1]), nan=1))
    rgb_scaled = pix.hsv_to_rgb(
        jnp.stack(
            [hsv[Ellipsis, 0], scaling * hsv[Ellipsis, 1], hsv[Ellipsis, 2]],
            axis=-1))
    return rgb_scaled
Example #26
0
 def update(i, g, state):
     x, s, v = state
     v = (1 - b2) * np.square(g) + b2 * v  # Update 2nd moment.
     vhat = v / (1 - b2**(i + 1))  # Bias correction.
     g_norm = np.nan_to_num(g / np.sqrt(vhat))  # Normalise gradient.
     g_norm = np.clip(g_norm, -g_bound, g_bound)  # Bound g.
     x *= np.exp(-step_size(i) * g_norm *
                 np.sign(x))  # Multiplicative update.
     x = np.clip(x, -s, s)  # Bound parameters.
     return x, s, v
Example #27
0
def visualize_cmap(value,
                   weight,
                   colormap,
                   lo=None,
                   hi=None,
                   percentile=99.,
                   curve_fn=lambda x: x,
                   modulus=None,
                   matte_background=True):
    """Visualize a 1D image and a 1D weighting according to some colormap.

  Args:
    value: A 1D image.
    weight: A weight map, in [0, 1].
    colormap: A colormap function.
    lo: The lower bound to use when rendering, if None then use a percentile.
    hi: The upper bound to use when rendering, if None then use a percentile.
    percentile: What percentile of the value map to crop to when automatically
      generating `lo` and `hi`. Depends on `weight` as well as `value'.
    curve_fn: A curve function that gets applied to `value`, `lo`, and `hi`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If
      `modulus` is not None, `lo`, `hi` and `percentile` will have no effect.
    matte_background: If True, matte the image over a checkerboard.

  Returns:
    A colormap rendering.
  """
    # Identify the values that bound the middle of `value' according to `weight`.
    lo_auto, hi_auto = math.weighted_percentile(
        value, weight, [50 - percentile / 2, 50 + percentile / 2])

    # If `lo` or `hi` are None, use the automatically-computed bounds above.
    eps = jnp.finfo(jnp.float32).eps
    lo = lo or (lo_auto - eps)
    hi = hi or (hi_auto + eps)

    # Curve all values.
    value, lo, hi = [curve_fn(x) for x in [value, lo, hi]]

    # Wrap the values around if requested.
    if modulus:
        value = jnp.mod(value, modulus) / modulus
    else:
        # Otherwise, just scale to [0, 1].
        value = jnp.nan_to_num(
            jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1))

    if colormap:
        colorized = colormap(value)[:, :, :3]
    else:
        assert len(value.shape) == 3 and value.shape[-1] == 3
        colorized = value

    return matte(colorized, weight) if matte_background else colorized
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
    """Piecewise-Constant PDF sampling.

  Args:
    key: jnp.ndarray(float32), [2,], random number generator.
    bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
    weights: jnp.ndarray(float32), [batch_size, num_bins].
    num_samples: int, the number of samples.
    randomized: bool, use randomized samples.

  Returns:
    z_samples: jnp.ndarray(float32), [batch_size, num_samples].
  """
    # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
    # avoids NaNs when the input is zeros or small, but has no effect otherwise.
    eps = 1e-5
    weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
    padding = jnp.maximum(0, eps - weight_sum)
    weights += padding / weights.shape[-1]
    weight_sum += padding

    # Compute the PDF and CDF for each weight vector.
    pdf = weights / weight_sum
    cdf = jnp.cumsum(pdf, axis=-1)
    cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf],
                          axis=-1)

    # Take uniform samples
    if randomized:
        u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
    else:
        u = jnp.linspace(0., 1., num_samples)
        u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])

    # Invert CDF. This takes advantage of the fact that `bins` is sorted.
    mask = (u[Ellipsis, None, :] >= cdf[Ellipsis, :, None])

    def minmax(x):
        x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]),
                     -2)
        x1 = jnp.min(
            jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
        x0 = jnp.minimum(x0, x[Ellipsis, -2:-1])
        x1 = jnp.maximum(x1, x[Ellipsis, 1:2])
        return x0, x1

    bins_g0, bins_g1 = minmax(bins)
    cdf_g0, cdf_g1 = minmax(cdf)

    t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
    samples = bins_g0 + t * (bins_g1 - bins_g0)

    # Prevent gradient from backprop-ing through `samples`.
    return lax.stop_gradient(samples)
Example #29
0
def excute(fn, grad_in=None):
    if fn is not None:

        if isinstance(fn, AccumulateGrad):

            if fn.variable.requires_grad and grad_in is not None:

                if fn.variable.grad is None:
                    fn.variable.grad = jnp.zeros(fn.variable.data.shape)

                grad_in = jnp.where(grad_in == jnp.inf, 0, grad_in)
                grad_in = jnp.nan_to_num(grad_in, copy=False)
                if len(grad_in.shape) != 4:
                    if len(fn.variable.grad.shape) == 4:
                        gamma = gammapops(grad_in, fn.variable.data.shape[2],
                                          fn.variable.data.shape[3])

                        grad_in = jnp.matmul(gamma, grad_in.T)
                        grad_in = grad_in.reshape(fn.variable.grad.shape)
                    else:
                        if grad_in.shape != fn.variable.grad.shape:
                            gamma = linearpops(fn.variable.grad.shape[1])
                            grad_in = jnp.transpose(
                                jnp.matmul(gamma.T, grad_in))

                grad_in = jnp.where(grad_in == jnp.inf, 0, grad_in)
                grad_in = jnp.nan_to_num(grad_in, copy=False)
                fn.variable.grad = index_add(fn.variable.grad, index[:],
                                             grad_in)

            return
        grad_outs, gamma = fn.apply(grad_in)
        if gamma is not None:

            gamma_stack.append(gamma)

        if type(grad_outs) is not tuple:
            grad_outs = (grad_outs, )

        for i, next_func in enumerate(fn.next_functions):
            excute(next_func, grad_outs[i])
Example #30
0
def segment_mean(data, segment_ids, num_segments):
    """Returns mean for each segment.

  Args:
    data: the values which are averaged segment-wise.
    segment_ids: indices for the segments.
    num_segments: total number of segments.
  """
    nominator = jax.ops.segment_sum(data, segment_ids, num_segments)
    denominator = jax.ops.segment_sum(jnp.ones_like(data), segment_ids,
                                      num_segments)
    return jnp.nan_to_num(nominator / denominator)