コード例 #1
0
def sample_ellipsoid(key, center, radii, rotation, unit_cube_constraint=False):
    """
    Sample uniformly inside an ellipsoid.
    When unit_cube_constraint=True then during the sampling when a random radius is chosen, the radius is constrained.

    u(t) = R @ (t * n) + c
    u(t) == 1
    1-c = t * R@n
    t = (1 - c)/R@n take minimum t satisfying this
    likewise for zero intersection
    Args:
        key:
        center: [D]
        radii: [D]
        rotation: [D,D]

    Returns: [D]

    """
    direction_key, radii_key = random.split(key, 2)
    direction = random.normal(direction_key, shape=radii.shape)
    if unit_cube_constraint:
        direction = direction / jnp.linalg.norm(direction)
        R = rotation * radii
        D = R @ direction
        t0 = -center / D
        t1 = jnp.reciprocal(D) + t0
        t0 = jnp.where(t0 < 0., jnp.inf, t0)
        t1 = jnp.where(t1 < 0., jnp.inf, t1)
        t = jnp.minimum(jnp.min(t0), jnp.min(t1))
        t = jnp.minimum(t, 1.)
        return jnp.exp(
            jnp.log(random.uniform(radii_key, minval=0., maxval=t)) /
            radii.size) * D + center
    log_norm = jnp.log(jnp.linalg.norm(direction))
    log_radius = jnp.log(random.uniform(radii_key)) / radii.size
    # x = direction * (radius/norm)
    x = direction * jnp.exp(log_radius - log_norm)
    return circle_to_ellipsoid(x, center, radii, rotation)
コード例 #2
0
ファイル: densityratio.py プロジェクト: cdsnlab/AIoTService
    def __init__(self, x, y, alpha=0., sigma=None, lamb=None, kernel_num=100):
        """[summary]

        Args:
            x (array-like of float): 
                Numerator samples array. x is generated from p(x).
            y (array-like of float): 
                Denumerator samples array. y is generated from q(x).
            alpha (float or array-like, optional): 
                The alpha is a parameter that can adjust the mixing ratio r(x) = p(x)/(alpha*p(x)+(1-alpha)q(x))
                , and is set in the range of 0-1. 
                Defaults to 0.
            sigma (float or array-like, optional): 
                Bandwidth of kernel. If a value is set for sigma, that value is used for kernel bandwidth
                , and if a numerical array is set for sigma, Densratio selects the optimum value by using CV.
                Defaults to array of 10e-4 to 10e+9 divided into 14 on the log scale.
            lamb (float or array-like, optional): 
                Regularization parameter. If a value is set for lamb, that value is used for hyperparameter
                , and if a numerical array is set for lamb, Densratio selects the optimum value by using CV.
                Defaults to array of 10e-4 to 10e+9 divided into 14 on the log scale.
            kernel_num (int, optional): The number of kernels in the linear model. Defaults to 100.

        Raises:
            ValueError: [description]
        """        

        self.__x = transform_data(x)
        self.__y = transform_data(y)

        if self.__x.shape[1] != self.__y.shape[1]:
            raise ValueError("x and y must be same dimentions.")

        if sigma is None:
            sigma = np.logspace(-3,1,9)

        if lamb is None:
            lamb = np.logspace(-3,1,9)

        self.__x_num_row = self.__x.shape[0]
        self.__y_num_row = self.__y.shape[0]
        self.__kernel_num = np.min(np.array([kernel_num, self.__x_num_row])).item() #kernel number is the minimum number of x's lines and the number of kernel.
        self.__centers = np.array(rand.sample(list(self.__x),k=self.__kernel_num)) #randomly choose candidates of rbf kernel centroid.
        self.__n_minimum = min(self.__x_num_row, self.__y_num_row)
        # self.__kernel  = jit(partial(gauss_kernel,centers=self.__centers))

        self._RuLSIF(x = self.__x,
                     y = self.__y,
                     alpha = alpha,
                     s_sigma = np.atleast_1d(sigma),
                     s_lambda = np.atleast_1d(lamb),
                    )
コード例 #3
0
ファイル: partition_test.py プロジェクト: rochusschmid/jax-md
    def test_neighbor_list_build_time_dependent(self, dtype, dim):
        key = random.PRNGKey(1)

        if dim == 2:
            box_fn = lambda t: np.array([[9.0, t], [0.0, 3.75]], f32)
        elif dim == 3:
            box_fn = lambda t: np.array([[9.0, 0.0, t], [0.0, 4.0, 0.0],
                                         [0.0, 0.0, 7.25]])
        min_length = np.min(np.diag(box_fn(0.)))
        cutoff = f32(1.23)
        # TODO(schsam): Get cell-list working with anisotropic cell sizes.
        cell_size = cutoff / min_length

        displacement, _ = space.periodic_general(box_fn)
        metric = space.metric(displacement)

        R = random.uniform(key, (PARTICLE_COUNT, dim), dtype=dtype)
        N = R.shape[0]
        neighbor_list_fn = partition.neighbor_list(metric,
                                                   1.,
                                                   cutoff,
                                                   0.0,
                                                   1.1,
                                                   cell_size=cell_size,
                                                   t=np.array(0.))

        idx = neighbor_list_fn(R, t=np.array(0.25)).idx
        R_neigh = R[idx]
        mask = idx < N

        metric = partial(metric, t=f32(0.25))
        d = vmap(vmap(metric, (None, 0)))
        dR = d(R, R_neigh)

        d_exact = space.map_product(metric)
        dR_exact = d_exact(R, R)

        dR = np.where(dR < cutoff, dR, 0) * mask
        dR_exact = np.where(dR_exact < cutoff, dR_exact, 0)

        dR = np.sort(dR, axis=1)
        dR_exact = np.sort(dR_exact, axis=1)

        for i in range(dR.shape[0]):
            dR_row = dR[i]
            dR_row = dR_row[dR_row > 0.]

            dR_exact_row = dR_exact[i]
            dR_exact_row = dR_exact_row[dR_exact_row > 0.]

            self.assertAllClose(dR_row, dR_exact_row)
コード例 #4
0
        def loss_func(params, target_params, state, target_state, rng,
                      transition_batch):
            rngs = hk.PRNGSequence(rng)
            S = self.q.observation_preprocessor(next(rngs), transition_batch.S)
            A = self.q.action_preprocessor(next(rngs), transition_batch.A)
            W = jnp.clip(transition_batch.W, 0.1,
                         10.)  # clip importance weights to reduce variance

            # regularization term
            if self.policy_regularizer is None:
                regularizer = 0.
            else:
                # flip sign (typical example: regularizer = -beta * entropy)
                regularizer = -self.policy_regularizer.batch_eval(
                    target_params['reg'], target_params['reg_hparams'],
                    target_state['reg'], next(rngs), transition_batch)

            Q, state_new = self.q.function_type1(params, state, next(rngs), S,
                                                 A, True)
            G = self.target_func(target_params, target_state, next(rngs),
                                 transition_batch)
            G += regularizer
            loss = self.loss_function(G, Q, W)

            dLoss_dQ = jax.grad(self.loss_function, argnums=1)
            td_error = -Q.shape[0] * dLoss_dQ(
                G, Q)  # e.g. (G - Q) if loss function is MSE

            # target-network estimate (is this worth computing?)
            Q_targ_list = []
            qs = list(
                zip(self.q_targ_list, target_params['q_targ'],
                    target_state['q_targ']))
            for q, pm, st in qs:
                Q_targ, _ = q.function_type1(pm, st, next(rngs), S, A, False)
                assert Q_targ.ndim == 1, f"bad shape: {Q_targ.shape}"
                Q_targ_list.append(Q_targ)
            Q_targ_list = jnp.stack(Q_targ_list, axis=-1)
            assert Q_targ_list.ndim == 2, f"bad shape: {Q_targ_list.shape}"
            Q_targ = jnp.min(Q_targ_list, axis=-1)

            chex.assert_equal_shape([td_error, W, Q_targ])
            metrics = {
                f'{self.__class__.__name__}/loss':
                loss,
                f'{self.__class__.__name__}/td_error':
                jnp.mean(W * td_error),
                f'{self.__class__.__name__}/td_error_targ':
                jnp.mean(-dLoss_dQ(Q, Q_targ, W)),
            }
            return loss, (td_error, state_new, metrics)
コード例 #5
0
def svd_truncated(A,
                  chi_max=None,
                  cutoff=DEFAULT_CUTOFF,
                  epsilon=DEFAULT_EPS,
                  return_norm_change=False):
    """
    Like `svd`, but keeps at most `chi_max` many singular values and ignores singular values below `cutoff`
    """
    if return_norm_change:
        U, S, Vh, norm_change = svd_reduced(A,
                                            tolerance=cutoff,
                                            epsilon=epsilon,
                                            return_norm_change=True)

        if chi_max is not None:
            k = np.min([chi_max, len(S)])
            old_S_norm = np.linalg.norm(S)
            U = dynamic_slice_in_dim(U, 0, k, 1)
            S = dynamic_slice_in_dim(S, 0, k, 0)
            Vh = dynamic_slice_in_dim(Vh, 0, k, 0)
            norm_change *= np.linalg.norm(
                S) / old_S_norm  # FIXME is this correct?

        return U, S, Vh, norm_change

    else:
        U, S, Vh = svd_reduced(A,
                               tolerance=cutoff,
                               epsilon=epsilon,
                               return_norm_change=False)

        if chi_max is not None:
            k = np.min([chi_max, len(S)])
            U = dynamic_slice_in_dim(U, 0, k, 1)
            S = dynamic_slice_in_dim(S, 0, k, 0)
            Vh = dynamic_slice_in_dim(Vh, 0, k, 0)

        return U, S, Vh
コード例 #6
0
 def actor_loss(policy_params: networks_lib.Params,
                critic_params: networks_lib.Params,
                sac_alpha: jnp.ndarray, transitions: types.Transition,
                key: jnp.ndarray) -> jnp.ndarray:
     """Computes the loss for the policy."""
     dist_params = networks.policy_network.apply(
         policy_params, transitions.observation)
     action = networks.sample(dist_params, key)
     log_prob = networks.log_prob(dist_params, action)
     q_action = networks.critic_network.apply(critic_params,
                                              transitions.observation,
                                              action)
     min_q = jnp.min(q_action, axis=-1)
     return jnp.mean(sac_alpha * log_prob - min_q)
コード例 #7
0
ファイル: banana_model.py プロジェクト: oraisa/masters-thesis
 def plot_posterior(self, X, ax, T):
     mu1, mu2, _, sigma1_p, sigma2_p, __ = self.compute_posterior_params(
         X, T)
     samples = self.generate_posterior_samples(40, X, T)
     spread_samples = samples.mean(
         axis=0) + (samples - samples.mean(axis=0)) * 2.5
     mi = np.min(spread_samples, axis=0)
     ma = np.max(spread_samples, axis=0)
     xs = np.linspace(mi[0], ma[0], 1000)
     ys = np.linspace(mi[1], ma[1], 1000)
     X, Y = np.meshgrid(xs, ys)
     Z = self.banana_density(X, Y, mu1, mu2, sigma1_p, sigma2_p, self.a,
                             self.b, self.m)
     ax.contour(X, Y, Z)
コード例 #8
0
 def single_interp(xstar):
     # N
     dx = great_circle_sep(ra1=x[:, 0] * jnp.pi / 180.,
                           dec1=x[:, 1] * jnp.pi / 180.,
                           ra2=xstar[0] * jnp.pi / 180.,
                           dec2=xstar[1] * jnp.pi / 180.)
     # dx = jnp.linalg.norm(xstar - x, axis=-1)
     nn_dist = jnp.min(jnp.where(dx == 0., jnp.inf, dx))
     dx = dx / nn_dist
     weight = jnp.exp(-0.5 * dx**2)
     if outliers is not None:
         weight = jnp.where(outliers, 0., weight)
     weight /= jnp.sum(weight)
     return jnp.sum(y * weight)
コード例 #9
0
    def improvement(self, y_batch):
        """Return how much a batch of y values can improve over the incumbent.

    Args:
      y_batch: (q, t) shaped array of t y values corresponding to a batch of
      size q of x-locations. Each column of this array corresponds to one
        realization of y values at q x-locations evaluated/predicted t times.
    Returns:
      (t,) shaped array of non-negative float values indicating the
      improvement the best of q y value within each of t realizations achieves
      over the incumbent.
    """
        difference = self.incumbent - jnp.min(y_batch, axis=0)
        return jnp.maximum(0.0, difference)
コード例 #10
0
 def actor_loss(policy_params: networks_lib.Params,
                q_params: networks_lib.Params, alpha: jnp.ndarray,
                transitions: types.Transition,
                key: networks_lib.PRNGKey) -> jnp.ndarray:
     dist_params = networks.policy_network.apply(
         policy_params, transitions.observation)
     action = networks.sample(dist_params, key)
     log_prob = networks.log_prob(dist_params, action)
     q_action = networks.q_network.apply(q_params,
                                         transitions.observation,
                                         action)
     min_q = jnp.min(q_action, axis=-1)
     actor_loss = alpha * log_prob - min_q
     return jnp.mean(actor_loss)
コード例 #11
0
 def testShapesAndValues(self):
     agent = self._create_test_agent()
     self.assertEqual(agent._support.shape[0], self._num_atoms)
     self.assertEqual(jnp.min(agent._support), -self._vmax)
     self.assertEqual(jnp.max(agent._support), self._vmax)
     state = onp.ones((1, 28224))
     net_output = agent.online_network(state)
     self.assertEqual(net_output.logits.shape,
                      (1, self._num_actions, self._num_atoms))
     self.assertEqual(net_output.probabilities.shape,
                      net_output.logits.shape)
     self.assertEqual(net_output.logits.shape[1], self._num_actions)
     self.assertEqual(net_output.logits.shape[2], self._num_atoms)
     self.assertEqual(net_output.q_values.shape, (1, self._num_actions))
コード例 #12
0
ファイル: train.py プロジェクト: kokizzu/google-research
def update_preconditioner(config, optimizer, p_update_grad_vars, rng, state,
                          train_iter):
    """Computes preconditioner state using samples from dataloader."""
    # TODO(basv): support multiple hosts.
    values = jax.tree_map(jnp.zeros_like, optimizer.target)

    eps = config.precon_est_eps
    n_batches = config.precon_est_batches
    for _ in range(n_batches):
        rng, est_key = jax.random.split(rng)
        batch = next(train_iter)
        batch = input_pipeline.load_and_shard_tf_batch(config, batch)
        if not config.debug_run:
            # Shard the step PRNG key
            sharded_keys = common_utils.shard_prng_key(est_key)
        else:
            sharded_keys = est_key
        values = p_update_grad_vars(optimizer, state, batch, sharded_keys,
                                    values)
    stds = jax.tree_map(
        lambda v: jnp.sqrt(eps + (1 / n_batches) * jnp.mean(v)), values)
    std_min = jnp.min(jnp.asarray(jax.tree_leaves(stds)))
    # TODO(basv): verify preconditioner estimate.
    new_precon = jax.tree_map(lambda s, x: jnp.ones_like(x) * (s / std_min),
                              stds, optimizer.target)

    def convert_momentum(
        new_precon,
        state,
    ):
        """Converts momenta to new preconditioner."""
        if config.weight_norm == 'learned':
            state = state.direction_state
        old_precon = state.preconditioner
        momentum = state.momentum

        m_c = jnp.power(old_precon, -.5) * momentum
        m = jnp.power(new_precon, .5) * m_c
        return m

    # TODO(basv): verify momentum convert.
    new_momentum = jax.tree_map(convert_momentum, new_precon,
                                optimizer.state.param_states)
    # TODO(basv): verify this is replaced correctly, check replicated.
    optimizer = replace_param_state(config,
                                    optimizer,
                                    preconditioner=new_precon,
                                    momentum=new_momentum)
    return optimizer, rng
コード例 #13
0
def get_peaks_single(history,tvec,int=0,Tint=0):

  """
  calculates the peak prevalence for a single run, with or without an intervention
  history: 2D array of values for each variable at each timepoint
  tvec: 1D vector of timepoints
  int: Optional, 1 or 0 for whether or not there was an intervention. Defaults to 0
  Tint: Optional, timepoint (days) at which intervention was started
  """

  delta_t=tvec[1]-tvec[0]

  if int==0:
    time_int=0
  else:
    time_int=Tint
      
  # Final values
  print('Final recovered: {:3.1f}%'.format(100 * history[-1][6]))
  print('Final deaths: {:3.1f}%'.format(100 * history[-1][5]))
  print('Remaining infections: {:3.1f}%'.format(
      100 * np.sum(history[-1][1:5], axis=-1)))

  # Peak prevalence
  print('Peak I1: {:3.1f}%'.format(
      100 * np.max(history[:, 2])))
  print('Peak I2: {:3.1f}%'.format(
      100 * np.max(history[:, 3])))
  print('Peak I3: {:3.1f}%'.format(
      100 * np.max(history[:, 4])))

  # Time of peaks
  print('Time of peak I1: {:3.1f} days'.format(
      np.argmax(history[:, 2])*delta_t - time_int))
  print('Time of peak I2: {:3.1f} days'.format(
      np.argmax(history[:, 3])*delta_t - time_int))
  print('Time of peak I3: {:3.1f} days'.format(
      np.argmax(history[:, 4])*delta_t - time_int))
  
  # First time when all infections go extinct
  all_cases=history[:, 1]+history[:, 2]+history[:, 3]+history[:, 4]
  extinct=np.where(all_cases == 0)[0]
  if len(extinct) != 0:
    extinction_time=np.min(extinct)*delta_t - time_int
    print('Time of extinction of all infections: {:3.1f} days'.format(extinction_time))
  else:
     print('Infections did not go extinct by end of simulation')
 
  return
コード例 #14
0
    def _test_online_smoothing_pf_full(self):
        if not hasattr(self, 'sim_samps'):
            self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
        pf = BootstrapFilter()
        len_t = len(self.t)
        rkeys = random.split(random.PRNGKey(0), len_t)

        particles = initiate_particles(self.ssm_scenario, pf, self.n,
                                       rkeys[0], self.sim_samps.y[0], self.t[0])
        for i in range(1, len_t):
            particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
                                                    self.sim_samps.y[i], self.t[i], rkeys[i], 3,
                                                    False)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0]) < 0).mean(), 0.1)
コード例 #15
0
def get_domain_extension(
    data: np.ndarray, extension: Union[float, int],
) -> Tuple[float, float]:
    """Gets the extension for the support
    
    Parameters
    ----------
    data : np.ndarray
        the input data to get max and minimum

    extension : Union[float, int]
        the extension
    
    Returns
    -------
    lb : float
        the new extended lower bound for the data
    ub : float
        the new extended upper bound for the data
    """

    # case of int, convert to float
    if isinstance(extension, int):
        extension = float(extension / 100)

    # get the domain
    domain = np.abs(np.max(data) - np.min(data))

    # extend the domain
    domain_ext = extension * domain

    # get the extended domain
    lb = np.min(data) - domain_ext
    up = np.max(data) + domain_ext

    return lb, up
コード例 #16
0
ファイル: sinkhorn_test.py プロジェクト: netw0rkf10w/ott
 def test_euclidean_point_cloud_parallel_weights(self, lse_mode):
   """Two point clouds, parallel execution for batched histograms."""
   self.rng, *rngs = jax.random.split(self.rng, 2)
   batch = 4
   a = jax.random.uniform(rngs[0], (batch, self.n))
   b = jax.random.uniform(rngs[0], (batch, self.m))
   a = a / jnp.sum(a, axis=1)[:, jnp.newaxis]
   b = b / jnp.sum(b, axis=1)[:, jnp.newaxis]
   threshold = 1e-3
   geom = pointcloud.PointCloud(
       self.x, self.y, epsilon=0.1, online=True)
   errors = sinkhorn.sinkhorn(
       geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors
   err = errors[errors > -1][-1]
   self.assertGreater(jnp.min(threshold - err), 0)
コード例 #17
0
    def minmax(self, data, max_val=None, min_val=None):

        if min_val is None:
            self.min = np.min(data, axis=0)
        else:
            self.min = min_val

        if max_val is None:
            self.max = np.max(data, axis=0)
        else:
            self.max = max_val

        minmax_data = (data - self.min) / (self.max - self.min)

        return minmax_data
コード例 #18
0
    def _assert_is_diagonal(self, j, axis1, axis2, constant_diagonal: bool):
        c = j.shape[axis1]
        self.assertEqual(c, j.shape[axis2])
        mask_shape = [c if i in (axis1, axis2) else 1 for i in range(j.ndim)]
        mask = np.eye(c, dtype=np.bool_).reshape(mask_shape)

        # Check that removing the diagonal makes the array all 0.
        j_masked = np.where(mask, np.zeros((), j.dtype), j)
        self.assertAllClose(np.zeros_like(j, j.dtype), j_masked)

        if constant_diagonal:
            # Check that diagonal is constant.
            if j.size != 0:
                j_diagonals = np.diagonal(j, axis1=axis1, axis2=axis2)
                self.assertAllClose(np.min(j_diagonals, -1),
                                    np.max(j_diagonals, -1))
コード例 #19
0
def plot_mean_response(mus: np.ndarray,
                       x: np.ndarray,
                       y: np.ndarray,
                       response) -> plt.Figure:
  fig = plt.figure()
  plt.plot(x, y, '.', alpha=0.3, label='data')
  xx = np.linspace(np.min(x), np.max(x), 100)
  yy = []
  for x in xx:
    yy.append(response(mus[-1, :], x))
  yy = np.array(yy).squeeze()
  plt.plot(xx, yy, label='mean response')
  plt.xlabel("x")
  plt.ylabel("y")
  plt.title("Mean response")
  return fig
コード例 #20
0
    def compute_periodic_snr_debug_info(state):
      debug_info = OrderedDict()
      snr_state = state.snr_state

      if snr_state.snr_matrix is not None:
        w, v = jax_eig(snr_state.snr_matrix)
        D = jnp.linalg.inv(v) @ snr_state.snr_matrix @ v
        D = jnp.diag(D).real
        debug_info['debug_info/max_real_eig'] = jnp.max(D)
        debug_info['debug_info/min_real_eig'] = jnp.min(D)
        debug_info['debug_info/avg_real_eig'] = jnp.mean(D)
        debug_info['debug_info/std_real_eig'] = jnp.std(D)
        debug_info['debug_info/num_real_eig_gt_zero'] = jnp.sum(D > 0.)
        debug_info['debug_info/ratio_real_eig_gt_zero'] = jnp.mean(D > 0.)

      return debug_info
コード例 #21
0
 def get_bmu_distance_squares(self, bmu_loc):
     if self.periodic:
         offsets = jnp.array([[0, 0], [-self.m, -self.n], [self.m, self.n],
                              [-self.m, 0], [self.m, 0], [0, -self.n],
                              [0, self.n], [-self.m, self.n],
                              [self.m, -self.n]])
         # offsets = jax.device_put(offsets)
         offset_locations = jnp.array(
             [self.locations + offset for offset in offsets])
         batch_cdist = jax.vmap(jax_cdist_0, (None, 0))
         distances = batch_cdist(bmu_loc, offset_locations)
         min_dists = jnp.min(distances, axis=0)
         return min_dists
     else:
         distances = jax_cdist_0(bmu_loc, self.locations)
         return distances
コード例 #22
0
ファイル: vis.py プロジェクト: wx-b/mipnerf
def visualize_depth(depth,
                    acc=None,
                    near=None,
                    far=None,
                    curve_fn=lambda x: jnp.log(x + jnp.finfo(jnp.float32).eps),
                    modulus=0,
                    colormap=None):
    """Visualize a depth map.

  Args:
    depth: A depth map.
    acc: An accumulation map, in [0, 1].
    near: The depth of the near plane, if None then just use the min().
    far: The depth of the far plane, if None then just use the max().
    curve_fn: A curve function that gets applied to `depth`, `near`, and `far`
      before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps).
    modulus: If > 0, mod the normalized depth by `modulus`. Use (0, 1].
    colormap: A colormap function. If None (default), will be set to
      matplotlib's viridis if modulus==0, sinebow otherwise.

  Returns:
    An RGB visualization of `depth`.
  """
    # If `near` or `far` are None, identify the min/max non-NaN values.
    eps = jnp.finfo(jnp.float32).eps
    near = near or jnp.min(jnp.nan_to_num(depth, jnp.inf)) - eps
    far = far or jnp.max(jnp.nan_to_num(depth, -jnp.inf)) + eps

    # Curve all values.
    depth, near, far = [curve_fn(x) for x in [depth, near, far]]

    # Wrap the values around if requested.
    if modulus > 0:
        value = jnp.mod(depth, modulus) / modulus
        colormap = colormap or sinebow
    else:
        # Scale to [0, 1].
        value = jnp.nan_to_num(jnp.clip((depth - near) / (far - near), 0, 1))
        colormap = colormap or cm.get_cmap('viridis')

    vis = colormap(value)[:, :, :3]

    # Set non-accumulated pixels to white.
    if acc is not None:
        vis = vis * acc[:, :, None] + (1 - acc)[:, :, None]

    return vis
コード例 #23
0
def resolve_clashes(x0, box0, min_dist=0.1):
    def urt(x, box):
        distance_matrix = distance(x, box)
        i, j = np.triu_indices(len(distance_matrix), k=1)
        return distance_matrix[i, j]

    dij = urt(x0, box0)
    x_shape = x0.shape
    box_shape = box0.shape

    if np.min(dij) < min_dist:
        # print('some distances too small')
        print(
            f"before optimization: min(dij) = {np.min(dij)} < min_dist threshold ({min_dist})"
        )

        # print('smallest few distances', sorted(dij)[:10])

        def unflatten(xbox):
            n = x_shape[0] * x_shape[1]
            x = xbox[:n].reshape(x_shape)
            box = xbox[n:].reshape(box_shape)
            return x, box

        def U_repulse(xbox):
            x, box = unflatten(xbox)
            dij = urt(x, box)
            return np.sum(np.where(dij < min_dist, (dij - min_dist)**2, 0))

        def fun(xbox):
            v, g = value_and_grad(U_repulse)(xbox)
            return float(v), onp.array(g, onp.float64)

        initial_state = np.hstack([x0.flatten(), box0.flatten()])
        # print(f'penalty before: {U_repulse(initial_state)}')
        result = minimize(fun, initial_state, jac=True, method="L-BFGS-B")
        # print(f'penalty after minimization: {U_repulse(result.x)}')

        x, box = unflatten(result.x)
        dij = urt(x, box)

        print(f"after optimization: min(dij) = {np.min(dij)}")

        return x, box

    else:
        return x0, box0
コード例 #24
0
ファイル: physionet_data.py プロジェクト: jirufengyu/ode
def get_data_min_max(records):
    """
    Get min and max for each feature across the dataset.
    """

    cache_path = os.path.join(records.processed_folder,
                              "minmax_" + str(records.quantization) + '.pt')

    if os.path.exists(cache_path):
        with open(cache_path, "rb") as cache_file:
            data = pickle.load(cache_file)
        data_min, data_max = data
        return data_min, data_max

    data_min, data_max = None, None

    for b, (tt, vals, mask) in enumerate(records):
        if b % 100 == 0:
            print(b, len(records))
        n_features = vals.shape[-1]

        batch_min = []
        batch_max = []
        for i in range(n_features):
            non_missing_vals = vals[:, i][mask[:, i] == 1]
            if len(non_missing_vals) == 0:
                batch_min.append(jnp.inf)
                batch_max.append(-jnp.inf)
            else:
                batch_min.append(jnp.min(non_missing_vals))
                batch_max.append(jnp.max(non_missing_vals))

        batch_min = jnp.stack(batch_min)
        batch_max = jnp.stack(batch_max)

        if (data_min is None) and (data_max is None):
            data_min = batch_min
            data_max = batch_max
        else:
            data_min = jnp.minimum(data_min, batch_min)
            data_max = jnp.maximum(data_max, batch_max)

    with open(cache_path, "wb") as cache_file:
        pickle.dump((data_min, data_max), cache_file)

    return data_min, data_max
コード例 #25
0
ファイル: made.py プロジェクト: jxzhangjhu/NuX
  def next_mask(self, prev_sel, size, rng):

    # Choose the degrees of the next layer
    max_connection = self.dim - 1 if self.triangular_jacobian == False else self.dim

    if self.method == "random":
      sel = random.randint(rng, shape=(size,), minval=min(jnp.min(sel), max_connection), maxval=dim)
    elif "sequential" in self.method:
      sel = jnp.arange(size)%max(1, max_connection) + min(1, max_connection)
      if self.method == "shuffled_sequential":
        sel = random.permutation(rng, sel)
    else:
      assert 0, "Invalid mask method"

    # Create the new mask
    mask = (prev_sel[:,None] <= sel).astype(jnp.int32)
    return mask, sel
コード例 #26
0
ファイル: metric_utils.py プロジェクト: tensorflow/lingvo
def top_k_accuracy(top_k: int,
                   logits: JTensor,
                   label_ids: Optional[JTensor] = None,
                   label_probs: Optional[JTensor] = None,
                   weights: Optional[JTensor] = None) -> JTensor:
    """Computes the top-k accuracy given the logits and labels.

  Args:
    top_k: An int scalar, specifying the value of top-k.
    logits: A [..., C] float tensor corresponding to the logits.
    label_ids: A [...] int vector corresponding to the class labels. One of
      label_ids and label_probs should be presented.
    label_probs: A [..., C] float vector corresponding to the class
      probabilites. Must be presented if label_ids is None.
    weights: A [...] float vector corresponding to the weight to assign to each
      example.

  Returns:
    The top-k accuracy represented as a `JTensor`.

  Raises:
    ValueError if neither `label_ids` nor `label_probs` are provided.
  """
    if label_ids is None and label_probs is None:
        raise ValueError("One of label_ids and label_probs should be given.")
    if label_ids is None:
        label_ids = jnp.argmax(label_probs, axis=-1)

    values, _ = jax.lax.top_k(logits, k=top_k)
    threshold = jnp.min(values, axis=-1)

    # Reshape logits to [-1, C].
    logits_reshaped = jnp.reshape(logits, [-1, logits.shape[-1]])

    # Reshape label_ids to [-1, 1].
    label_ids_reshaped = jnp.reshape(label_ids, [-1, 1])
    logits_slice = jnp.take_along_axis(logits_reshaped,
                                       label_ids_reshaped,
                                       axis=-1)[..., 0]

    # Reshape logits_slice back to original shape to be compatible with weights.
    logits_slice = jnp.reshape(logits_slice, label_ids.shape)
    correct = jnp.greater_equal(logits_slice, threshold)
    correct_sum = jnp.sum(correct * weights)
    all_sum = jnp.maximum(1.0, jnp.sum(weights))
    return correct_sum / all_sum
コード例 #27
0
        def beam_search_cond_fn(state):
            """beam search state termination condition fn."""

            # 1. is less than max length?
            not_max_length_yet = state.cur_len < max_length

            # 2. can the new beams still improve?
            best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty)
            worst_finished_score = jnp.where(
                state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
            )
            improvement_still_possible = jnp.all(worst_finished_score < best_running_score)

            # 3. is there still a beam that has not finished?
            still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)

            return not_max_length_yet & still_open_beam & improvement_still_possible
コード例 #28
0
def spin_wiener_filter(data_q,
                       data_u,
                       ncov_diag_Q,
                       ncov_diag_U,
                       input_ps_map_E,
                       input_ps_map_B,
                       iterations=10):
    """
    Wiener filter Elsner-Wandelt messenger field adapted for spin-2 fields (CMB polarization or galaxy weak lensing(
    Parameters
    ----------
    data_q : Q square image data (e.g. gamma1)
    data_u : U square image data (e.g. gamma2)
    ncov_diag_Q : Q noise variance per pixel (assumed uncorrelated)
    ncov_diag_U : U noise variance per pixel (assumed uncorrelated)
    input_ps_map_E : 1D power P(k) for E-mode signal power spectrum evaluated 2D components k1,k2 as a square image
    input_ps_map_B : 1D power P(k) for B-mode signal power spectrum evaluated 2D components k1,k2 as a square image
    iterations : number of iterations

    Returns
    -------
    s_q,s_u : Wiener filtered q and u signals
    """
    tcov_diag = jnp.min(jnp.array([ncov_diag_Q, ncov_diag_U]))
    scov_ft_E = jnp.fft.fftshift(input_ps_map_E)
    scov_ft_B = jnp.fft.fftshift(input_ps_map_B)
    s_q = jnp.zeros(data_q.shape)
    s_u = jnp.zeros(data_q.shape)

    for i in jnp.arange(iterations):
        # in Q, U representation
        t_Q = (tcov_diag / ncov_diag_Q) * data_q + (
            (ncov_diag_Q - tcov_diag) / ncov_diag_Q) * s_q
        t_U = (tcov_diag / ncov_diag_U) * data_u + (
            (ncov_diag_U - tcov_diag) / ncov_diag_U) * s_u
        # in E, B representation
        t_E, t_B = ks93(t_Q, t_U)
        s_E = (scov_ft_E / (scov_ft_E + tcov_diag)) * jnp.fft.fft2(t_E)
        s_B = (scov_ft_B / (scov_ft_B + tcov_diag)) * jnp.fft.fft2(t_B)
        s_E = jnp.fft.ifft2(s_E)
        s_B = jnp.fft.ifft2(s_B)
        # in Q, U representation
        s_q, s_u = ks93inv(s_E, s_B)

    return s_q, s_u
コード例 #29
0
 def _decode(self, feats):
   b, t, d = feats.shape
   mean_val = jnp.mean(feats)
   max_val = jnp.max(feats)
   min_val = jnp.min(feats)
   frames = jnp.reshape(feats, [b * t, d])
   val = jnp.where(frames > 0.8, size=b * t * d)
   hist = jnp.zeros([d])
   hist = hist.at[val[1]].add(1)
   hist = hist.at[0].set(1)
   nframes = jnp.array(b * t)
   metrics = {
       'mean': (mean_val, nframes),
       'max': (max_val, nframes),
       'min': (min_val, nframes),
       'hist': (hist, nframes)
   }
   return metrics
コード例 #30
0
def initialize_slices(T, b):
    B1s = []
    B2s = []
    B3s = []
    B2B3s = []
    bj0 = 0
    mindim = jnp.min(T.shape)

    for bj in range(0, mindim - b, b):
        bj0 = bj
        B1s.append(index[:bj])
        B2s.append(index[bj:bj + b])
        B3s.append(index[bj + b:])
        B2B3s.append(index[bj:])
    for bj in range(bj0 + b, mindim, b):
        B1s.append(index[:bj])
        B2B3s.append(index[bj:])
    return [B1s, B2s, B3s, B2B3s]