def model(X: DeviceArray) -> DeviceArray: """Gamma-Poisson hierarchical model for daily sales forecasting Args: X: input data Returns: output data """ n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target eps = 1e-12 # epsilon plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) plate_days = numpyro.plate(Plate.days, n_days, dim=-1) disp_param_mu = numpyro.sample(Site.disp_param_mu, dist.Normal(loc=4., scale=1.)) disp_param_sigma = numpyro.sample(Site.disp_param_sigma, dist.HalfNormal(scale=1.)) with plate_stores: disp_param_offsets = numpyro.sample( Site.disp_param_offsets, dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1)) disp_params = disp_param_mu + disp_param_offsets * disp_param_sigma disp_params = numpyro.sample(Site.disp_params, dist.Delta(disp_params), obs=disp_params) with plate_features: coef_mus = numpyro.sample( Site.coef_mus, dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features))) coef_sigmas = numpyro.sample( Site.coef_sigmas, dist.HalfNormal(scale=2. * jnp.ones(n_features))) with plate_stores: coef_offsets = numpyro.sample( Site.coef_offsets, dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.)) coefs = coef_mus + coef_offsets * coef_sigmas coefs = numpyro.sample(Site.coefs, dist.Delta(coefs), obs=coefs) with plate_days, plate_stores: targets = X[..., -1] features = jnp.nan_to_num(X[..., :-1]) # padded features to 0 is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets), jnp.ones_like(targets)) not_observed = 1 - is_observed means = (is_observed * jnp.exp( jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) + not_observed * eps) betas = is_observed * jnp.exp(-disp_params) + not_observed alphas = means * betas return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas), obs=jnp.nan_to_num(targets))
def main(cfg_path: Path, log_level: int): logging.basicConfig( stream=sys.stdout, level=log_level, datefmt='%Y-%m-%d %H:%M', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') with open(cfg_path) as f: cfg = yaml.load(f, Loader=yaml.FullLoader) nside = cfg['nside'] outpath = cfg['hdf5_path'] mask_file = 'data/n0512.fits' mask_data = old_np.nan_to_num( hp.read_map(mask_file, verbose=False, dtype=np.float32)) mask_apodized = nmt.mask_apodization(mask_data, 5., apotype="C2") mask_binary = old_np.where(mask_data > 0.0, 1.0, 0.0) with h5py.File(outpath, 'w') as f: ns0512 = f.require_group('ns0512') data_dset = ns0512.require_dataset('binary', shape=(hp.nside2npix(nside), ), dtype=np.float32) data_dset[...] = np.nan_to_num(mask_binary) apodized = ns0512.require_group('apodized') data_dset2 = apodized.require_dataset('basic', shape=(hp.nside2npix(nside), ), dtype=np.float32) data_dset2[...] = np.nan_to_num(mask_apodized)
def uniform_stochastic_quantize(v: jnp.ndarray, num_levels: int, rng: PRNGKey, v_min: Optional[float] = None, v_max: Optional[float] = None) -> jnp.ndarray: """Uniform stochastic algorithm in https://arxiv.org/pdf/1611.00429.pdf. Args: v: vector to be quantized. num_levels: Number of levels of quantization. rng: jax random key. v_min: minimum threshold for quantization. If None, sets it to jnp.amin(v). v_max: maximum threshold for quantization. If None, sets it to jnp.amax(v). Returns: Quantized array. """ # Rescale the vector to be between zero to one. if v_min is None: v_min = jnp.amin(v) if v_max is None: v_max = jnp.amax(v) v = jnp.nan_to_num((v - v_min) / (v_max - v_min)) v = jnp.maximum(0., jnp.minimum(v, 1.)) # Compute the upper and lower boundary of each value. v_ceil = jnp.ceil(v * (num_levels - 1)) / (num_levels - 1) v_floor = jnp.floor(v * (num_levels - 1)) / (num_levels - 1) # uniformly quantize between v_ceil and v_floor. rand = jax.random.uniform(key=rng, shape=v.shape) threshold = jnp.nan_to_num((v - v_floor) / (v_ceil - v_floor)) quantized = jnp.where(rand > threshold, v_floor, v_ceil) # Rescale the values and return it. return v_min + quantized * (v_max - v_min)
def _giou(boxes1: jnp.ndarray, boxes2: jnp.ndarray) -> jnp.ndarray: b1_ymin, b1_xmin, b1_ymax, b1_xmax = jnp.hsplit(boxes1, 4) b2_ymin, b2_xmin, b2_ymax, b2_xmax = jnp.hsplit(boxes2, 4) b1_width = jnp.maximum(0, b1_xmax - b1_xmin) b1_height = jnp.maximum(0, b1_ymax - b1_ymin) b2_width = jnp.maximum(0, b2_xmax - b2_xmin) b2_height = jnp.maximum(0, b2_ymax - b2_ymin) b1_area = b1_width * b1_height b2_area = b2_width * b2_height intersect_ymin = jnp.maximum(b1_ymin, b2_ymin) intersect_xmin = jnp.maximum(b1_xmin, b2_xmin) intersect_ymax = jnp.minimum(b1_ymax, b2_ymax) intersect_xmax = jnp.minimum(b1_xmax, b2_xmax) intersect_width = jnp.maximum(0, intersect_xmax - intersect_xmin) intersect_height = jnp.maximum(0, intersect_ymax - intersect_ymin) intersect_area = intersect_width * intersect_height union_area = b1_area + b2_area - intersect_area iou = jnp.nan_to_num(intersect_area / union_area) enclose_ymin = jnp.minimum(b1_ymin, b2_ymin) enclose_xmin = jnp.minimum(b1_xmin, b2_xmin) enclose_ymax = jnp.maximum(b1_ymax, b2_ymax) enclose_xmax = jnp.maximum(b1_xmax, b2_xmax) enclose_width = jnp.maximum(0, enclose_xmax - enclose_xmin) enclose_height = jnp.maximum(0, enclose_ymax - enclose_ymin) enclose_area = enclose_width * enclose_height giou = iou - jnp.nan_to_num((enclose_area - union_area) / enclose_area) return giou.squeeze()
def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, (List, Tuple)): for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): recursive_check(tuple_iterable_value, dict_iterable_value) elif tuple_object is None: return else: self.assert_almost_equals(jnp.nan_to_num(tuple_object), jnp.nan_to_num(dict_object), 1e-5)
def mutual_information(net, samples): probs = net.prob_factors(samples) avgP = jnp.mean(probs, axis=2) S = (avgP * jnp.nan_to_num(jnp.log(avgP)) + (1. - avgP) * jnp.nan_to_num(jnp.log(1. - avgP))) / jnp.log(2.) condS = jnp.mean( (probs * jnp.nan_to_num(jnp.log(probs)) + (1. - probs) * jnp.nan_to_num(jnp.log(1. - probs))) / jnp.log(2.), axis=2) return condS - S
def calc_delta_correlated_error_jax(areas_a, ematrix_a, areas_b, errors_b): diffs_a, diffs_b = delta_differential_jax(areas_a, areas_b) # need to do nan_to_num since the differential can return nan... # not sure how to fix "properly" though diffs_a = jnp.nan_to_num(diffs_a) diffs_b = jnp.nan_to_num(diffs_b) # for the total, we need to do # diffs_a * ematrix_a * diffs_a + diffs_b*errors_b*diffs_b, # since the errors on a and b have no connections, we can get away with this. # err_a_sq = diffs_a.T @ ematrix_a @ diffs_a err_a_sq = jnp.matmul(diffs_a.T, jnp.matmul(ematrix_a, diffs_a)) err_b_sq = jnp.sum(jnp.square((diffs_b * errors_b))) return jnp.sqrt(err_a_sq + err_b_sq)
def update(delta_x, delta_gx, Us, VTs, n_step): # Add column/row to Us/VTs with updated approximation # Calculate J_i vT = rmatvec(Us, VTs, delta_x) u = (delta_x - matvec(Us, VTs, delta_gx)) / _einsum( 'bij, bij -> b', vT, delta_gx)[:, None, None] vT = jnp.nan_to_num(vT) u = jnp.nan_to_num(u) # Store in UTs and VTs for calculating J VTs = jax.ops.index_update(VTs, jax.ops.index[:, n_step - 1], vT) Us = jax.ops.index_update(Us, jax.ops.index[:, :, :, n_step - 1], u) return Us, VTs
def ps_ir(*, draws, proposal_densities, target_densities, num_resampled, key): logits = target_densities - proposal_densities logits = jnp.nan_to_num(logits, nan=-jnp.inf, neginf=-jnp.inf) idxs = jax.random.categorical(key=key, logits=logits, shape=(num_resampled, )) return draws[idxs]
def reweighted_stddev(f_n: Array, target_logpdf_n: Array, source_logpdf_n: Array) -> Float: """Compute reweighted estimate of stddev(f(x)) under x ~ p_target based on samples x ~ p_source where p_target(x) = exp(target_logpdf(x)) / Z_target using samples from a different source x_n ~ p_source where p_source(x) = exp(source_logpdf(x)) / Z_source The inputs are arrays "{fxn_name}_n" containing the result of calling each fxn on a fixed array of samples: * f_n = [f(x_n) for x_n in samples] * target_logpdf_n = [target_logpdf(x_n) for x_n in samples] * source_logpdf_n = [source_logpdf(x_n) for x_n in samples] """ log_weights_n = target_logpdf_n - source_logpdf_n weights = np.exp(log_weights_n - logsumexp(log_weights_n)).flatten() f_mean = np.sum(weights * f_n) squared_deviations = (f_n - f_mean)**2 # sanitize 0 * inf -> 0 (instead of nan) weighted_squared_deviations = weights * squared_deviations sanitized = np.nan_to_num(weighted_squared_deviations, nan=0) stddev = np.sqrt(np.sum(sanitized)) return stddev
def precision( y_true: jnp.ndarray, y_pred: jnp.ndarray, threshold: jnp.ndarray, class_id: jnp.ndarray, sample_weight: jnp.ndarray, true_positives: ReduceConfusionMatrix, false_positives: ReduceConfusionMatrix, ) -> jnp.ndarray: # TODO: class_id behavior y_pred = (y_pred > threshold).astype(jnp.float32) if y_true.dtype != y_pred.dtype: y_pred = y_pred.astype(y_true.dtype) true_positives = true_positives(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) false_positives = false_positives(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) return jnp.nan_to_num( jnp.divide(true_positives, true_positives + false_positives))
def rnn_cell(carry, x): newCarry, logits = jax.vmap(eval_cell)(carry[0], carry[1]) sampleOut = jax.random.categorical(x, logits) sample = jax.nn.one_hot(sampleOut, inputDim) logProb = jnp.sum(nn.log_softmax(logits) * sample, axis=1) return (newCarry, sample), (jnp.nan_to_num(logProb, nan=-35), sampleOut)
def loss(self, params, batch): """Cross-entropy loss""" inputs, targets = batch preds = self.predict(params, inputs) return jnp.mean( jnp.nan_to_num(-targets * jnp.log(preds) - (1 - targets) * jnp.log(1 - preds)))
def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True): """The syn-to-post softmax computation. Parameters ---------- syn_values: jax.numpy.ndarray, JaxArray, Variable The synaptic values. post_ids: jax.numpy.ndarray, JaxArray The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num: int The number of the post-synaptic neurons. indices_are_sorted: whether ``post_ids`` is known to be sorted. Returns ------- post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ post_ids = as_device_array(post_ids) syn_values = as_device_array(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) syn_values = syn_values - syn_maxs[post_ids] syn_values = jnp.exp(syn_values) normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) softmax = syn_values / normalizers[post_ids] return jnp.nan_to_num(softmax)
def top_k_error_rate_metric(logits: jnp.ndarray, one_hot_labels: jnp.ndarray, k: int = 5, mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: """Returns the top-K error rate between some predictions and some labels. Args: logits: Output of the model. one_hot_labels: One-hot encoded labels. Dimensions should match the logits. k: Number of class the model is allowed to predict for each example. mask: Mask to apply to the loss to ignore some samples (usually, the padding of the batch). Array of ones and zeros. Returns: The error rate (1 - accuracy), averaged over the first dimension (samples). """ if mask is None: mask = jnp.ones([logits.shape[0]]) mask = mask.reshape([logits.shape[0]]) true_labels = jnp.argmax(one_hot_labels, -1).reshape([-1, 1]) top_k_preds = jnp.argsort(logits, axis=-1)[:, -k:] hit = jax.vmap(jnp.isin)(true_labels, top_k_preds) error_rate = 1 - ((hit * mask).sum() / mask.sum()) # Set to zero if there is no non-masked samples. return jnp.nan_to_num(error_rate)
def biot_savart_oncoil(r_eval, dl, ll, I_arr): """ Calculate the Biot-Savart integral over the coils (also ON) a segment of the coil. specified by l and dl. Arguments: *r_eval*: (lenght 3 array) the point wherer the field is to be evaluated in cartesian coordinates. Has to be on a coil. *dl*: ( n_coils, nsegments, 3)-array of the distance vector to every other coil line segment *l* ( n_coils, nsegments, 3)-array of the position of each coil segment Note on algoritnm: the None allows one to add new axes to in-line cast the array into the proper shape. The biot-savart integral is calculated as a sum over all segments. returns: *B*: magnetic field at position r_eval """ top = np.cross(dl, r_eval[None, None, :] - ll) * I_arr[:, None, None] #unchecked bottom = np.linalg.norm(r_eval[None, None, :] - ll, axis=-1)**3 # sum over all infinitesimal line segments, replacing the NaN with zero B = np.sum(np.nan_to_num(top / bottom[:, :, None]), axis=(0, 1)) return B
def visualize_depth(depth, acc=None, near=None, far=None, curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps), modulus=0, colormap=None): """Visualize a depth map. Args: depth: A depth map. acc: An accumulation map, in [0, 1]. near: The depth of the near plane, if None then just use the min(). far: The depth of the far plane, if None then just use the max(). curve_fn: A curve function that gets applied to `depth`, `near`, and `far` before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). modulus: If > 0, mod the normalized depth by `modulus`. Use (0, 1]. colormap: A colormap function. If None (default), will be set to matplotlib's viridis if modulus==0, sinebow otherwise. Returns: An RGB visualization of `depth`. """ # If `near` or `far` are None, identify the min/max non-NaN values. eps = jnp.finfo(jnp.float32).eps near = near or jnp.min(jnp.nan_to_num(depth, jnp.inf)) - eps far = far or jnp.max(jnp.nan_to_num(depth, -jnp.inf)) + eps # Curve all values. depth, near, far = [curve_fn(x) for x in [depth, near, far]] # Wrap the values around if requested. if modulus > 0: value = jnp.mod(depth, modulus) / modulus colormap = colormap or sinebow else: # Scale to [0, 1]. value = jnp.nan_to_num(jnp.clip((depth - near) / (far - near), 0, 1)) colormap = colormap or cm.get_cmap('viridis') vis = colormap(value)[:, :, :3] # Set non-accumulated pixels to white. if acc is not None: vis = vis * acc[:, :, None] + (1 - acc)[:, :, None] return vis
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized): """Piecewise-Constant PDF sampling. Args: key: jnp.ndarray(float32), [2,], random number generator. bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. weights: jnp.ndarray(float32), [batch_size, num_bins]. num_samples: int, the number of samples. randomized: bool, use randomized samples. Returns: z_samples: jnp.ndarray(float32), [batch_size, num_samples]. """ # Pad each weight vector (only if necessary) to bring its sum to `eps`. This # avoids NaNs when the input is zeros or small, but has no effect otherwise. eps = 1e-5 weight_sum = jnp.sum(weights, axis=-1, keepdims=True) padding = jnp.maximum(0, eps - weight_sum) weights += padding / weights.shape[-1] weight_sum += padding # Compute the PDF and CDF for each weight vector, while ensuring that the CDF # starts with exactly 0 and ends with exactly 1. pdf = weights / weight_sum cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1)) cdf = jnp.concatenate([ jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf, jnp.ones(list(cdf.shape[:-1]) + [1]) ], axis=-1) # Draw uniform samples. if randomized: # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1. u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples]) else: # Match the behavior of random.uniform() by spanning [0, 1-eps]. u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples) u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples]) # Identify the location in `cdf` that corresponds to a random sample. # The final `True` index in `mask` will be the start of the sampled interval. mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None] def find_interval(x): # Grab the value where `mask` switches from True to False, and vice versa. # This approach takes advantage of the fact that `x` is sorted. x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2) x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2) return x0, x1 bins_g0, bins_g1 = find_interval(bins) cdf_g0, cdf_g1 = find_interval(cdf) t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) samples = bins_g0 + t * (bins_g1 - bins_g0) # Prevent gradient from backprop-ing through `samples`. return lax.stop_gradient(samples)
def arithmetic_encoding_num_bits(v: jnp.ndarray) -> int: """Computes number of bits needed to store v via arithmetic coding.""" v = jnp.nan_to_num(v) v = v.flatten() uniq = jnp.unique(v) entropy = _entropy(v, uniq) hist_bits = _hist_bits(v, uniq) return hist_bits + (v.size * entropy) + (2 * 32) + 2
def replace_inf_with_zero(x): return jnp.nan_to_num( x, copy=False, nan=0.0, posinf=0.0, neginf=0.0, )
def _log_prob(self, inputs: np.ndarray) -> np.ndarray: """Log prob for arrays.""" # calculate log_prob u, log_det = self._inverse(self._params, inputs) log_prob = self.prior.log_prob(u) + log_det # set NaN's to negative infinity (i.e. zero probability) log_prob = np.nan_to_num(log_prob, nan=np.NINF) return log_prob
def calc_delta_jax(areas_a, areas_b): # do I need bin areas or densities? # I guess since by definition sum(area_a) = 1, areas are needed?! integrand = jnp.true_divide(jnp.square(areas_a - areas_b), areas_a + areas_b) # nan_to_num important as divide gives nans if both 0 delta = 0.5 * jnp.sum(jnp.nan_to_num(integrand)) return delta
def find_minimum_theta_scalar(fc_new, r_fil, theta_i): f = partial(objective_scalar, fc_new, r_fil) f_prime = grad(f) f_primeprime = grad(f_prime) for n in range(n_iter): new_ep = epsilon * np.exp(-n / 15) theta_i = theta_i - alpha * np.nan_to_num(f_prime(theta_i) / (f_primeprime(theta_i) + new_ep)) return theta_i
def apply(self, x, L=10, units=[10], inputDim=2, actFun=nn.elu, initScale=1.0): initFunctionCell = jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="uniform") initFunctionOut = jax.nn.initializers.variance_scaling(scale=initScale, mode="fan_in", distribution="uniform") #initFunction = jax.nn.initializers.lecun_uniform() cellInV = nn.Dense.shared(features=units[0], name='rnn_cell_in_v', bias=False) cellInH = nn.Dense.shared(features=units[0], name='rnn_cell_in_h', bias=False) cellCarryV = nn.Dense.shared(features=units[0], name='rnn_cell_carry_v', bias=False, kernel_init=initFunctionCell) cellCarryH = nn.Dense.shared(features=units[0], name='rnn_cell_carry_h', bias=True, kernel_init=initFunctionCell) outputDense = nn.Dense.shared(features=inputDim, name='rnn_output_dense', kernel_init=initFunctionOut) batchSize = x.shape[0] outputs = jnp.asarray(np.zeros((batchSize,L,L))) states = jnp.asarray(np.zeros((L,batchSize,units[0]))) inputs = jnp.asarray(np.zeros((L+1,L+2,batchSize,inputDim))) # Scan directions for zigzag path direction = np.ones(L,dtype=np.int32) direction[1::2] = -1 direction = jnp.asarray(direction) x = jnp.transpose(x,axes=[1,2,0]) inputs = jax.ops.index_update(inputs,jax.ops.index[1:,1:-1],jax.nn.one_hot(x,inputDim)) def rnn_dim2(carry,x): newCarry = actFun( cellInH(x[0]) + cellInV(x[1]) + cellCarryH(carry) + cellCarryV(x[2]) ) out = jnp.concatenate((newCarry, nn.softmax(outputDense(newCarry))), axis=1) return newCarry, out def rnn_dim1(carry,x): _, out = jax.lax.scan(rnn_dim2,jnp.zeros((batchSize,units[0]),dtype=np.float32), (self.reverse_line(x[0],x[2])[:-2], self.reverse_line(x[1],x[2])[1:-1], self.reverse_line(carry,x[2])) ) carry = jax.ops.index_update(carry,jax.ops.index[:,:],out[:,:,:units[0]]) outputs = jnp.log( jnp.sum( out[:,:,units[0]:] * self.reverse_line(x[0],x[2])[1:-1,:], axis=2 ) ) return self.reverse_line(carry,x[2]), jnp.sum(outputs,axis=0) _, prob = jax.lax.scan(rnn_dim1,states,(inputs[1:],inputs[:-1],direction)) return jnp.nan_to_num(jnp.sum(prob,axis=0))
def maximize_saturation(rgb): """Rescale the maximum saturation in `rgb` to be 1.""" hsv = pix.rgb_to_hsv(rgb) scaling = jnp.maximum(1, jnp.nan_to_num(1 / jnp.max(hsv[Ellipsis, 1]), nan=1)) rgb_scaled = pix.hsv_to_rgb( jnp.stack( [hsv[Ellipsis, 0], scaling * hsv[Ellipsis, 1], hsv[Ellipsis, 2]], axis=-1)) return rgb_scaled
def update(i, g, state): x, s, v = state v = (1 - b2) * np.square(g) + b2 * v # Update 2nd moment. vhat = v / (1 - b2**(i + 1)) # Bias correction. g_norm = np.nan_to_num(g / np.sqrt(vhat)) # Normalise gradient. g_norm = np.clip(g_norm, -g_bound, g_bound) # Bound g. x *= np.exp(-step_size(i) * g_norm * np.sign(x)) # Multiplicative update. x = np.clip(x, -s, s) # Bound parameters. return x, s, v
def visualize_cmap(value, weight, colormap, lo=None, hi=None, percentile=99., curve_fn=lambda x: x, modulus=None, matte_background=True): """Visualize a 1D image and a 1D weighting according to some colormap. Args: value: A 1D image. weight: A weight map, in [0, 1]. colormap: A colormap function. lo: The lower bound to use when rendering, if None then use a percentile. hi: The upper bound to use when rendering, if None then use a percentile. percentile: What percentile of the value map to crop to when automatically generating `lo` and `hi`. Depends on `weight` as well as `value'. curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. matte_background: If True, matte the image over a checkerboard. Returns: A colormap rendering. """ # Identify the values that bound the middle of `value' according to `weight`. lo_auto, hi_auto = math.weighted_percentile( value, weight, [50 - percentile / 2, 50 + percentile / 2]) # If `lo` or `hi` are None, use the automatically-computed bounds above. eps = jnp.finfo(jnp.float32).eps lo = lo or (lo_auto - eps) hi = hi or (hi_auto + eps) # Curve all values. value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] # Wrap the values around if requested. if modulus: value = jnp.mod(value, modulus) / modulus else: # Otherwise, just scale to [0, 1]. value = jnp.nan_to_num( jnp.clip((value - jnp.minimum(lo, hi)) / jnp.abs(hi - lo), 0, 1)) if colormap: colorized = colormap(value)[:, :, :3] else: assert len(value.shape) == 3 and value.shape[-1] == 3 colorized = value return matte(colorized, weight) if matte_background else colorized
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized): """Piecewise-Constant PDF sampling. Args: key: jnp.ndarray(float32), [2,], random number generator. bins: jnp.ndarray(float32), [batch_size, num_bins + 1]. weights: jnp.ndarray(float32), [batch_size, num_bins]. num_samples: int, the number of samples. randomized: bool, use randomized samples. Returns: z_samples: jnp.ndarray(float32), [batch_size, num_samples]. """ # Pad each weight vector (only if necessary) to bring its sum to `eps`. This # avoids NaNs when the input is zeros or small, but has no effect otherwise. eps = 1e-5 weight_sum = jnp.sum(weights, axis=-1, keepdims=True) padding = jnp.maximum(0, eps - weight_sum) weights += padding / weights.shape[-1] weight_sum += padding # Compute the PDF and CDF for each weight vector. pdf = weights / weight_sum cdf = jnp.cumsum(pdf, axis=-1) cdf = jnp.concatenate([jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf], axis=-1) # Take uniform samples if randomized: u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples]) else: u = jnp.linspace(0., 1., num_samples) u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples]) # Invert CDF. This takes advantage of the fact that `bins` is sorted. mask = (u[Ellipsis, None, :] >= cdf[Ellipsis, :, None]) def minmax(x): x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2) x1 = jnp.min( jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2) x0 = jnp.minimum(x0, x[Ellipsis, -2:-1]) x1 = jnp.maximum(x1, x[Ellipsis, 1:2]) return x0, x1 bins_g0, bins_g1 = minmax(bins) cdf_g0, cdf_g1 = minmax(cdf) t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) samples = bins_g0 + t * (bins_g1 - bins_g0) # Prevent gradient from backprop-ing through `samples`. return lax.stop_gradient(samples)
def excute(fn, grad_in=None): if fn is not None: if isinstance(fn, AccumulateGrad): if fn.variable.requires_grad and grad_in is not None: if fn.variable.grad is None: fn.variable.grad = jnp.zeros(fn.variable.data.shape) grad_in = jnp.where(grad_in == jnp.inf, 0, grad_in) grad_in = jnp.nan_to_num(grad_in, copy=False) if len(grad_in.shape) != 4: if len(fn.variable.grad.shape) == 4: gamma = gammapops(grad_in, fn.variable.data.shape[2], fn.variable.data.shape[3]) grad_in = jnp.matmul(gamma, grad_in.T) grad_in = grad_in.reshape(fn.variable.grad.shape) else: if grad_in.shape != fn.variable.grad.shape: gamma = linearpops(fn.variable.grad.shape[1]) grad_in = jnp.transpose( jnp.matmul(gamma.T, grad_in)) grad_in = jnp.where(grad_in == jnp.inf, 0, grad_in) grad_in = jnp.nan_to_num(grad_in, copy=False) fn.variable.grad = index_add(fn.variable.grad, index[:], grad_in) return grad_outs, gamma = fn.apply(grad_in) if gamma is not None: gamma_stack.append(gamma) if type(grad_outs) is not tuple: grad_outs = (grad_outs, ) for i, next_func in enumerate(fn.next_functions): excute(next_func, grad_outs[i])
def segment_mean(data, segment_ids, num_segments): """Returns mean for each segment. Args: data: the values which are averaged segment-wise. segment_ids: indices for the segments. num_segments: total number of segments. """ nominator = jax.ops.segment_sum(data, segment_ids, num_segments) denominator = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments) return jnp.nan_to_num(nominator / denominator)