def test_mixturegausscdf_bijector_shape(n_samples, n_features, n_components):

    x = objax.random.normal((
        n_samples,
        n_features,
    ), generator=generator)

    # create layer
    model = MixtureGaussianCDF(n_features, n_components)

    # forward transformation
    z, log_abs_det = model(x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples, ))

    # forward transformation
    x_approx = model.inverse(z)

    # checks
    chex.assert_equal_shape([x, x_approx])
예제 #2
0
    def testNonPolynomialFunctionConsistencyWithPathwise(
            self, effective_mean, effective_log_scale, function, coupling):
        num_samples = 10**5
        rng = jax.random.PRNGKey(1)
        measure_rng, pathwise_rng = jax.random.split(rng)

        mean = jnp.array(effective_mean, dtype=jnp.float32)
        log_scale = jnp.array(effective_log_scale, dtype=jnp.float32)
        data_dims = len(effective_mean)

        measure_valued_jacobians = _measure_valued_variant(
            self.variant)(function, [mean, log_scale], utils.multi_normal,
                          measure_rng, num_samples, coupling)

        measure_valued_mean_jacobians = measure_valued_jacobians[0]
        chex.assert_shape(measure_valued_mean_jacobians,
                          (num_samples, data_dims))
        measure_valued_mean_grads = np.mean(measure_valued_mean_jacobians,
                                            axis=0)

        measure_valued_log_scale_jacobians = measure_valued_jacobians[1]
        chex.assert_shape(measure_valued_log_scale_jacobians,
                          (num_samples, data_dims))
        measure_valued_log_scale_grads = np.mean(
            measure_valued_log_scale_jacobians, axis=0)

        pathwise_jacobians = _estimator_variant(
            self.variant, sge.pathwise_jacobians)(function, [mean, log_scale],
                                                  utils.multi_normal,
                                                  pathwise_rng, num_samples)

        pathwise_mean_jacobians = pathwise_jacobians[0]
        chex.assert_shape(pathwise_mean_jacobians, (num_samples, data_dims))
        pathwise_mean_grads = np.mean(pathwise_mean_jacobians, axis=0)

        pathwise_log_scale_jacobians = pathwise_jacobians[1]
        chex.assert_shape(pathwise_log_scale_jacobians,
                          (num_samples, data_dims))
        pathwise_log_scale_grads = np.mean(pathwise_log_scale_jacobians,
                                           axis=0)

        _assert_equal(pathwise_mean_grads,
                      measure_valued_mean_grads,
                      rtol=5e-1,
                      atol=1e-1)
        _assert_equal(pathwise_log_scale_grads,
                      measure_valued_log_scale_grads,
                      rtol=5e-1,
                      atol=1e-1)
예제 #3
0
def test_householder_bijector_shape(n_samples, n_features, n_reflections):

    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.normal(data_rng, shape=(n_samples, n_features))

    init_func = HouseHolder(n_reflections=n_reflections)

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_features)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples,))

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
예제 #4
0
 def loss_fn(online_params, target_params, transitions, rng_key):
   """Calculates loss given network parameters and transitions."""
   _, *apply_keys = jax.random.split(rng_key, 4)
   q_tm1 = network.apply(online_params, apply_keys[0],
                         transitions.s_tm1).q_values
   q_t = network.apply(online_params, apply_keys[1],
                       transitions.s_t).q_values
   q_target_t = network.apply(target_params, apply_keys[2],
                              transitions.s_t).q_values
   td_errors = _batch_double_q_learning(
       q_tm1,
       transitions.a_tm1,
       transitions.r_t,
       transitions.discount_t,
       q_target_t,
       q_t,
   )
   td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                  grad_error_bound)
   losses = rlax.l2_loss(td_errors)
   chex.assert_shape(losses, (self._batch_size,))
   loss = jnp.mean(losses)
   return loss
예제 #5
0
 def loss_fn(online_params, target_params, transitions, weights,
             rng_key):
     """Calculates loss given network parameters and transitions."""
     _, *apply_keys = jax.random.split(rng_key, 4)
     logits_q_tm1 = network.apply(online_params, apply_keys[0],
                                  transitions.s_tm1).q_logits
     q_t = network.apply(online_params, apply_keys[1],
                         transitions.s_t).q_values
     logits_q_target_t = network.apply(target_params, apply_keys[2],
                                       transitions.s_t).q_logits
     losses = _batch_categorical_double_q_learning(
         support,
         logits_q_tm1,
         transitions.a_tm1,
         transitions.r_t,
         transitions.discount_t,
         support,
         logits_q_target_t,
         q_t,
     )
     loss = jnp.mean(losses * weights)
     chex.assert_shape((losses, weights), (self._batch_size, ))
     return loss, losses
예제 #6
0
    def _entropy_scalar(
            total_count: int, probs: Array,
            log_of_probs: Array) -> Union[jnp.float32, jnp.float64]:
        """Calculates the entropy for a Multinomial with integer `total_count`."""
        # Constant factors in the entropy.
        xi = jnp.arange(total_count + 1, dtype=probs.dtype)
        log_xi_factorial = lax.lgamma(xi + 1)
        log_n_minus_xi_factorial = jnp.flip(log_xi_factorial, axis=-1)
        log_n_factorial = log_xi_factorial[..., -1]
        log_comb_n_xi = (log_n_factorial[..., None] - log_xi_factorial -
                         log_n_minus_xi_factorial)
        comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi))
        chex.assert_shape(comb_n_xi, (total_count + 1, ))

        likelihood1 = math.power_no_nan(probs[..., None], xi)
        likelihood2 = math.power_no_nan(1. - probs[..., None],
                                        total_count - xi)
        chex.assert_shape(likelihood1, (
            probs.shape[-1],
            total_count + 1,
        ))
        chex.assert_shape(likelihood2, (
            probs.shape[-1],
            total_count + 1,
        ))
        likelihood = jnp.sum(likelihood1 * likelihood2, axis=-2)
        chex.assert_shape(likelihood, (total_count + 1, ))
        comb_term = jnp.sum(comb_n_xi * log_xi_factorial * likelihood, axis=-1)
        chex.assert_shape(comb_term, ())

        # Probs factors in the entropy.
        n_probs_factor = jnp.sum(total_count *
                                 math.multiply_no_nan(log_of_probs, probs),
                                 axis=-1)

        return -log_n_factorial - n_probs_factor + comb_term
  def testLinearFunction(self, effective_mean, effective_log_scale, estimator):
    data_dims = 3
    num_samples = _estimator_to_num_samples[estimator]
    rng = jax.random.PRNGKey(1)

    mean = effective_mean * _ones(data_dims)
    log_scale = effective_log_scale * _ones(data_dims)

    jacobians = _estimator_variant(self.variant, estimator)(
        np.sum, [mean, log_scale],
        utils.multi_normal, rng, num_samples)

    mean_jacobians = jacobians[0]
    chex.assert_shape(mean_jacobians, (num_samples, data_dims))
    mean_grads = np.mean(mean_jacobians, axis=0)
    expected_mean_grads = np.ones(data_dims, dtype=np.float32)

    log_scale_jacobians = jacobians[1]
    chex.assert_shape(log_scale_jacobians, (num_samples, data_dims))
    log_scale_grads = np.mean(log_scale_jacobians, axis=0)
    expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32)

    _assert_equal(mean_grads, expected_mean_grads)
    _assert_equal(log_scale_grads, expected_log_scale_grads)
예제 #8
0
def test_logit_shape(n_samples, n_features):

    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.uniform(data_rng, shape=(n_samples, n_features))

    # create layer
    init_func = InverseGaussCDF(eps=1e-5)

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_features,)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples,))

    # inverse transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x_approx, x])
예제 #9
0
 def test_adam(self):
   init_fn, update_fn = optimizers.get_optimizer(
       ConfigDict({
           'optimizer': 'adam',
           'l2_decay_factor': None,
           'batch_size': 50,
           'total_accumulated_batch_size': 100,  # Use gradient accumulation.
           'opt_hparams': {
               'beta1': 0.9,
               'beta2': 0.999,
               'epsilon': 1e-7,
               'weight_decay': 0.0,
           }
       }))
   del update_fn
   optimizer_state = init_fn({'foo': jnp.ones(10)})
   # Test that we can extract 'count'.
   chex.assert_type(extract_field(optimizer_state, 'count'), int)
   # Test that we can extract 'nu'.
   chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,))
   # Test that we can extract 'mu'.
   chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,))
   # Test that attemptping to extract a nonexistent field "abc" returns None.
   chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
예제 #10
0
def test_mixture_logistic_cdf_bijector_shape(n_samples, n_features, n_components):

    params_rng, data_rng = jax.random.split(KEY, 2)

    x = jax.random.normal(data_rng, shape=(n_samples, n_features))

    # create layer
    init_func = MixtureLogisticCDF(n_components=n_components)

    # create layer
    params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_features)

    # forward transformation
    z, log_abs_det = forward_f(params, x)

    # checks
    chex.assert_equal_shape([z, x])
    chex.assert_shape(log_abs_det, (n_samples,))

    # forward transformation
    x_approx, log_abs_det = inverse_f(params, z)

    # checks
    chex.assert_equal_shape([x, x_approx])
예제 #11
0
    def critic_loss(
        critic_params: hk.Params,
        generator_params: hk.Params,
        img_batch: ImgBatch,
        latent_batch: LatentBatch,
    ) -> Tuple[jnp.ndarray, Log]:

        batch_size = img_batch.shape[0]
        img_generated = generator.apply(generator_params, latent_batch)

        def f(x):
            return critic.apply(critic_params, x)

        f_real = f(img_batch)

        # Use the vector-jacobian product to efficiently compute the grad for each f_i.
        f_gen, grad_fn = jax.vjp(f, img_generated)
        grad = grad_fn(jnp.ones(f_gen.shape))[0]

        assert_shape(f_real, (batch_size, 1))
        assert_shape(f_gen, (batch_size, 1))
        assert_shape(grad, img_batch.shape)

        flat_grad = grad.reshape(batch_size, -1)
        gp = jnp.square(1 - jnp.linalg.norm(flat_grad, axis=1))
        assert_shape(gp, (batch_size, ))

        loss = jnp.mean(f_gen) - jnp.mean(f_real)
        gradient_penalty = jnp.mean(gp)

        log = {
            "wasserstein": -loss,
            "gradient_penalty": gradient_penalty,
        }

        return loss + 10 * gradient_penalty, log
예제 #12
0
        def affine_transform(dist_params, scale, shift, value_transform=None):
            """ implements the "Categorical Algorithm" from https://arxiv.org/abs/1707.06887 """

            # check inputs
            chex.assert_rank([dist_params['logits'], scale, shift],
                             [2, {0, 1}, {0, 1}])
            p = jax.nn.softmax(dist_params['logits'])
            batch_size = p.shape[0]

            if isscalar(scale):
                scale = jnp.full(shape=(batch_size, ),
                                 fill_value=jnp.squeeze(scale))
            if isscalar(shift):
                shift = jnp.full(shape=(batch_size, ),
                                 fill_value=jnp.squeeze(shift))

            chex.assert_shape(p, (batch_size, self.num_bins))
            chex.assert_shape([scale, shift], (batch_size, ))

            if value_transform is None:
                f = f_inv = lambda x: x
            else:
                f, f_inv = value_transform

            # variable names correspond to those defined in: https://arxiv.org/abs/1707.06887
            z = self.__atoms
            Vmin, Vmax, Δz = z[0], z[-1], z[1] - z[0]
            Tz = f(jax.vmap(jnp.add)(jnp.outer(scale, f_inv(z)), shift))
            Tz = jnp.clip(Tz, Vmin, Vmax)  # keep values in valid range
            chex.assert_shape(Tz, (batch_size, self.num_bins))

            b = (Tz - Vmin) / Δz  # float in [0, num_bins - 1]
            l = jnp.floor(b).astype(
                'int32')  # noqa: E741   # int in {0, 1, ..., num_bins - 1}
            u = jnp.ceil(b).astype('int32')  # int in {0, 1, ..., num_bins - 1}
            chex.assert_shape([p, b, l, u], (batch_size, self.num_bins))

            m = jnp.zeros_like(p)
            i = jnp.expand_dims(jnp.arange(batch_size), axis=1)  # batch index
            m = jax.ops.index_add(m, (i, l),
                                  p * (u - b),
                                  indices_are_sorted=True)
            m = jax.ops.index_add(m, (i, u),
                                  p * (b - l),
                                  indices_are_sorted=True)
            m = jax.ops.index_add(m, (i, l),
                                  p * (l == u),
                                  indices_are_sorted=True)
            # chex.assert_tree_all_close(jnp.sum(m, axis=1), jnp.ones(batch_size), rtol=1e-6)

            # # The above index trickery is equivalent to:
            # m_alt = onp.zeros((batch_size, self.num_bins))
            # for i in range(batch_size):
            #     for j in range(self.num_bins):
            #         if l[i, j] == u[i, j]:
            #             m_alt[i, l[i, j]] += p[i, j]  # don't split if b[i, j] is an integer
            #         else:
            #             m_alt[i, l[i, j]] += p[i, j] * (u[i, j] - b[i, j])
            #             m_alt[i, u[i, j]] += p[i, j] * (b[i, j] - l[i, j])
            # chex.assert_tree_all_close(m, m_alt, rtol=1e-6)
            return {'logits': jnp.log(jnp.maximum(m, 1e-16))}
예제 #13
0
    def epipolar_projection(self, key, target_rays, ref_worldtocamera,
                            intrinsic_matrix, randomized):
        """** Visually verified in colab.

    Function to map given rays on the epipolar line on reference view at
    mutiple depth values.
    (The depth value are often set to near and far plane of camera).
    Args:
      key: prngkey
      target_rays: The rays that we want to project onto nearby cameras. Often
        these are the rays that we want to render contains origins, directions
        shape (#bs, rays, 3)
      ref_worldtocamera:(#near_cam, 4,4) The worldtocamera matrix of nearby
        cameras that we want to project onto.
      intrinsic_matrix: (1, 3, 4), The intrinsic matrix for the datset
      randomized: if True, use randomized depths for projection.

    Returns:
      pcoords: (#near_cam, batch_size, num_projections, 2)
      valid_proj_mask: (#near_cam, batch_size, num_projections) specifying with
        of the projections are valid i.e. in front of camera and within image
        bound
      wcoords: (batch_size, num_projections, 3)
    """
        # Check shape of intrincs, currently we only support case where all the
        # views are from the same camera
        chex.assert_shape(intrinsic_matrix, (1, 3, 4))
        #intrinsic_matrix = intrinsic_matrix.squeeze(0)

        projection_depths = jnp.linspace(self.min_depth, self.max_depth,
                                         self.num_samples)

        if randomized:
            mids = .5 * (projection_depths[Ellipsis, 1:] +
                         projection_depths[Ellipsis, :-1])
            upper = jnp.concatenate([mids, projection_depths[Ellipsis, -1:]],
                                    -1)
            lower = jnp.concatenate([projection_depths[Ellipsis, :1], mids],
                                    -1)
            batch_size = target_rays.batch_shape[0]
            p_rand = jax.random.uniform(key, [batch_size, self.num_samples])
            projection_depths = lower + (upper - lower) * p_rand
        # Compute the woorld coordinates for each ray for each depth values
        # wcoords has shape (#rays, num_projections, 3)
        wcoords = (target_rays.origins[:, None] +
                   target_rays.directions[:, None] *
                   projection_depths[Ellipsis, None])
        #Convert to homogenous coordinates (#rays, num_samples, 3) -> (#rays, num_samples, 4)
        wcoords = jnp.concatenate(
            [wcoords, jnp.ones_like(wcoords[Ellipsis, 0:1])], axis=-1)
        pcoords, proj_frontof_cam_mask = self.project2camera(
            wcoords, ref_worldtocamera, intrinsic_matrix)

        # Find projections that are inside the image
        within_image_mask = self.inside_image(pcoords)

        # Clip coordinates to be withing the images
        pcoords = jnp.concatenate([
            jnp.clip(pcoords[Ellipsis, 0:1], 0, self.image_height - 1),
            jnp.clip(pcoords[Ellipsis, 1:], 0, self.image_width - 1)
        ],
                                  axis=-1)

        # A projection is valid if it is in front of camera and within the image
        # boundaries
        valid_proj_mask = proj_frontof_cam_mask * within_image_mask

        return pcoords, valid_proj_mask, wcoords[Ellipsis, :3]
예제 #14
0
def test_logmarglike_lineargaussianmodel_threetransfers_basics():

    n_components = 2
    theta_truth = jax.random.normal(key, (n_components,))
    mu = theta_truth * 1.1
    muinvvar = (theta_truth) ** -2
    mucov = np.eye(n_components) / muinvvar
    logmuinvvar = np.where(muinvvar == 0, 0, np.log(muinvvar))

    M_T = design_matrix_polynomials(n_components, n_pix_y)  # (n_components, n_pix_y)
    y_truth = np.matmul(theta_truth, M_T)  # (nobj, n_pix_y)
    y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth)  # (nobj, n_pix_y)

    ell = 1

    R_T = design_matrix_polynomials(n_components, n_pix_z)
    z_truth = np.matmul(theta_truth, R_T)  # (nobj, n_pix_z)
    z, zinvvar, logzinvvar = make_masked_noisy_data(z_truth)  # (nobj, n_pix_z)

    # first run
    logfml, theta_map, theta_cov = logmarglike_lineargaussianmodel_threetransfers(
        ell, M_T, R_T, y, yinvvar, z, zinvvar, mu, muinvvar
    )
    # check result is finite and shapes are correct
    assert np.isfinite(logfml)
    assert_shape(theta_map, (n_components,))
    assert_shape(theta_cov, (n_components, n_components))

    # now trying jit
    (
        logfml2,
        theta_map2,
        theta_cov2,
    ) = logmarglike_lineargaussianmodel_threetransfers_jit(
        ell,
        M_T,
        R_T,
        y,
        yinvvar,
        logyinvvar,
        z,
        zinvvar,
        logzinvvar,
        mu,
        muinvvar,
        logmuinvvar,
    )
    assert_fml_thetamap_thetacov(
        logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, relative_accuracy
    )

    # check that posterior distribution is equal to product of gaussians too
    def log_posterior(theta):
        y_mod = np.matmul(theta, M_T)  # (n_samples, n_pix_y)
        like_y = batch_gaussian_loglikelihood(y_mod - y, yinvvar)
        z_mod = ell * np.matmul(theta, R_T)  # (n_samples, n_pix_y)
        like_z = batch_gaussian_loglikelihood(z_mod - z, zinvvar)
        prior = batch_gaussian_loglikelihood(theta - mu, muinvvar)
        return like_y + like_z + prior

    def log_posterior2(theta):
        dt = theta - theta_map
        s, logdet = np.linalg.slogdet(theta_cov * 2 * np.pi)
        chi2 = np.dot(dt.T, np.linalg.solve(theta_cov, dt))
        return logfml - 0.5 * (s * logdet + chi2)

    logpostv = log_posterior(theta_truth)
    logpostv2 = log_posterior2(theta_truth)
    assert abs(logpostv2 / logpostv - 1) < 0.01

    def loss_fn(theta):
        return -log_posterior(theta)

    params = [1 * theta_map]
    learning_rate = 1e-5
    for n in range(10):
        grads = grad(loss_fn)(*params)
        params = [param - learning_rate * grad for param, grad in zip(params, grads)]
        # print(n, loss_fn(*params), params[0] - theta_map)
    assert np.allclose(theta_map, params[0], rtol=1e-6)

    # Testing analytic covariance is correct
    theta_cov2 = np.linalg.inv(np.reshape(hessian(loss_fn)(theta_map), theta_cov.shape))
    assert np.allclose(theta_cov, theta_cov2, rtol=1e-6)

    loss_fn_vmap = jit(vmap(loss_fn))

    n = 15
    theta_std = np.diag(theta_cov) ** 0.5
    theta_samples, vol_element = generate_sample_grid(theta_map, theta_std, n)
    logpost = -loss_fn_vmap(theta_samples)
    logfml_numerical = logsumexp(np.log(vol_element) + logpost)
    # print("logfml, logfml_numerical", logfml, logfml_numerical)
    assert abs(logfml_numerical / logfml - 1) < 0.01
예제 #15
0
def test_logmarglike_lineargaussianmodel_threetransfers_vmap():

    theta_truth = jax.random.normal(key, (nobj, n_components))
    mu = theta_truth * 1.1
    muinvvar = (theta_truth) ** -2
    mucov = np.eye(n_components)[None, :, :] / muinvvar[:, :, None]
    logmuinvvar = np.where(muinvvar == 0, 0, np.log(muinvvar))

    M_T = design_matrix_polynomials(n_components, n_pix_y)  # (n_components, n_pix_y)
    M_T_y = M_T[None, :, :] * np.ones((nobj, 1, 1))  # (nobj, n_components, n_pix_y)
    y_truth = np.matmul(theta_truth, M_T)  # (nobj, n_pix_y)
    y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth)  # (nobj, n_pix_y)

    ells = 10 ** jax.random.normal(key, (nobj,))
    R_T = design_matrix_polynomials(n_components, n_pix_z)
    R_T_z = R_T[None, :, :] * np.ones((nobj, 1, 1))  # (nobj, n_components, n_pix_y)
    z_truth = ells * np.matmul(theta_truth, R_T)  # (nobj, n_pix_z)
    z, zinvvar, logzinvvar = make_masked_noisy_data(z_truth)  # (nobj, n_pix_z)

    # first run
    (
        logfml,
        theta_map,
        theta_cov,
    ) = logmarglike_lineargaussianmodel_threetransfers_jitvmap(
        ells,
        M_T_y,
        R_T_z,
        y,
        yinvvar,
        logyinvvar,
        z,
        zinvvar,
        logzinvvar,
        mu,
        muinvvar,
        logmuinvvar,
    )
    # check result is finite and shapes are correct
    assert np.all(np.isfinite(logfml))
    assert_shape(logfml, (nobj,))
    assert_shape(theta_map, (nobj, n_components))
    assert_shape(theta_cov, (nobj, n_components, n_components))

    theta_std = np.diagonal(theta_cov, axis1=1, axis2=2) ** 0.5
    assert_shape(theta_std, (nobj, n_components))

    def loss(param_list, data_list):
        (ells, M_T_y, R_T_z) = param_list
        (
            y,
            yinvvar,
            logyinvvar,
            z,
            zinvvar,
            logzinvvar,
            mu,
            muinvvar,
            logmuinvvar,
        ) = data_list
        (
            logfml,
            theta_map,
            theta_cov,
        ) = logmarglike_lineargaussianmodel_threetransfers_jitvmap(
            ells,
            M_T_y,
            R_T_z,
            y,
            yinvvar,
            logyinvvar,
            z,
            zinvvar,
            logzinvvar,
            mu,
            muinvvar,
            logmuinvvar,
        )
        return -np.sum(logfml)

    params = [ells, M_T_y, R_T_z]

    data = [
        y,
        yinvvar,
        logyinvvar,
        z,
        zinvvar,
        logzinvvar,
        mu,
        muinvvar,
        logmuinvvar,
    ]

    assert np.isfinite(loss(params, data))

    learning_rate = 1e-3
    num_steps = 10
    opt_init, opt_update, get_params = optimizers.adam(learning_rate)
    opt_state = opt_init(params)

    @jit
    def update(step, opt_state, data):
        params = get_params(opt_state)
        value, grads = jax.value_and_grad(loss)(params, data)
        opt_state = opt_update(step, grads, opt_state)
        return value, opt_state

    for step in range(num_steps):
        value, opt_state = update(step, opt_state, data)
예제 #16
0
def test_logmarglike_lineargaussianmodel_onetransfer_batched():

    # shapes of arrays are in parentheses.

    # generate some fake data (noisy and masked)
    nobj = 14
    n_components = 2

    # true linear parameters, per object
    theta_truth = jax.random.normal(key, (nobj, n_components))

    n_pix_y = 100  # number of pixels for each object
    M_T = design_matrix_polynomials(n_components, n_pix_y)  # (n_components, n_pix_y)
    M_T_y = M_T[None, :, :] * np.ones((nobj, 1, 1))  # (nobj, n_components, n_pix_y)

    # data array, truth
    y_truth = np.matmul(theta_truth, M_T)  # (nobj, n_pix_y)

    # data array, with noise, and array containing the inverse noise variance and its logarithm
    y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth)  # (nobj, n_pix_y)
    assert_equal_shape([y_truth, y, yinvvar, logyinvvar])
    # importantly, yinvvar and logyinvvar has zeros, symbolising ignored/masked pixels

    # now given the data and model matrix M_T,
    # compute the evidences, best fit thetas, and their covariances, for all objects at once.
    (
        logfml,
        theta_map,
        theta_cov,
    ) = logmarglike_lineargaussianmodel_onetransfer_jitvmap(
        M_T_y, y, yinvvar, logyinvvar
    )

    # checking shapes of output arrays
    assert_shape(logfml, (nobj,))
    assert_shape(theta_map, (nobj, n_components))
    assert_shape(theta_cov, (nobj, n_components, n_components))

    # add more tets
    # run optimisation of design matrix giv
    @partial(jit, static_argnums=())
    def loss_fn(params, data):
        M_T_new = params[0]  # params is a list
        y, yinvvar, logyinvvar = data
        (
            logfml,
            theta_map,
            theta_cov,
        ) = logmarglike_lineargaussianmodel_onetransfer_jitvmap(
            M_T_new, y, yinvvar, logyinvvar
        )
        return -np.sum(logfml)

    M_T_new_initial = jax.random.normal(key, (nobj, n_components, n_pix_y))
    param_list = [1 * M_T_new_initial]

    learning_rate = 1e-5
    opt_init, opt_update, get_params = jax.experimental.optimizers.adam(learning_rate)
    opt_state = opt_init(param_list)

    @partial(jit, static_argnums=())
    def update(step, opt_state, data):
        params = get_params(opt_state)
        value, grads = jax.value_and_grad(loss_fn)(params, data)
        opt_state = opt_update(step, grads, opt_state)
        return value, opt_state

    num_iterations = 10
    # TODO: use better optimizer
    data = (y, yinvvar, logyinvvar)
    for step in range(num_iterations):
        # Could potentially also iterage over batches of data
        loss_value, opt_state = update(step, opt_state, data)

    # optimised matrix:
    M_T_new_optimised = get_params(opt_state)
예제 #17
0
def test_logmarglike_lineargaussianmodel_onetransfer_basics():

    theta_truth = jax.random.normal(key, (n_components,))

    M_T = design_matrix_polynomials(n_components, n_pix_y)  # (n_components, n_pix_y)
    y_truth = np.matmul(theta_truth, M_T)  # (nobj, n_pix_y)
    y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth)  # (nobj, n_pix_y)
    assert_equal_shape([y_truth, y, yinvvar, logyinvvar])

    logfml, theta_map, theta_cov = logmarglike_lineargaussianmodel_onetransfer(
        M_T, y, yinvvar
    )

    # check result is finite and shapes are correct
    assert_shape(theta_map, (n_components,))
    assert_shape(theta_cov, (n_components, n_components))
    assert np.isfinite(logfml)
    assert np.all(np.isfinite(theta_map))
    assert np.all(np.isfinite(theta_cov))

    # check that result isn't too far off the truth, in chi2 sense
    dt = theta_map - theta_truth
    chi2 = 0.5 * np.ravel(np.matmul(dt.T, np.linalg.solve(theta_cov, dt)))
    assert chi2 < 100

    # check that normalised posterior distribution factorises into product of gaussians
    def log_posterior(theta):
        y_mod = np.matmul(theta, M_T)  # (n_samples, n_pix_y)
        return batch_gaussian_loglikelihood(y_mod - y, yinvvar)

    def log_posterior2(theta):
        dt = theta - theta_map
        s, logdet = np.linalg.slogdet(theta_cov * 2 * np.pi)
        chi2 = np.dot(dt.T, np.linalg.solve(theta_cov, dt))
        return logfml - 0.5 * (s * logdet + chi2)

    logpostv = log_posterior(theta_truth)
    logpostv2 = log_posterior2(theta_truth)
    assert abs(logpostv2 / logpostv - 1) < 0.01

    # now trying jit version of function
    logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_onetransfer_jit(
        M_T, y, yinvvar, logyinvvar
    )
    # check that outputs match original implementation
    assert_fml_thetamap_thetacov(
        logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, relative_accuracy
    )

    # now running simple optimiser to check that result is indeed optimum
    def loss_fn(theta):
        return -log_posterior(theta)

    params = [1 * theta_map]
    learning_rate = 1e-5
    for n in range(10):
        grads = grad(loss_fn)(*params)
        params = [param - learning_rate * grad for param, grad in zip(params, grads)]
        # print(n, loss_fn(*params), params[0] - theta_map)
    assert np.allclose(theta_map, params[0], rtol=1e-6)

    # Testing analytic covariance is correct and equals inverse of hessian
    theta_cov2 = np.linalg.inv(np.reshape(hessian(loss_fn)(theta_map), theta_cov.shape))
    assert np.allclose(theta_cov, theta_cov2, rtol=1e-6)

    # create vectorised loss
    loss_fn_vmap = jit(vmap(loss_fn))

    # now computes the evidence numerically
    n = 15
    theta_std = np.diag(theta_cov) ** 0.5
    theta_samples, vol_element = generate_sample_grid(theta_map, theta_std, n)
    loglikelihoods = -loss_fn_vmap(theta_samples)
    logfml_numerical = logsumexp(np.log(vol_element) + loglikelihoods)
    # print("logfml, logfml_numerical", logfml, logfml_numerical)
    assert abs(logfml_numerical / logfml - 1) < 0.01

    # Compare with case including gaussian prior with large variance
    mu = theta_map * 0
    muinvvar = 1 / (1e4 * np.diag(theta_cov) ** 0.5)
    logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_twotransfers(
        M_T, y, yinvvar, mu, muinvvar
    )
    assert_fml_thetamap_thetacov(
        logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, 0.2
    )
예제 #18
0
 def test_sample_returns_batch(self):
     replay = make_replay()
     add(replay, [1, 2, 3])
     sample_size = 2
     samples, unused_indices, unused_weights = replay.sample(sample_size)
     chex.assert_shape(samples.value, (sample_size, ))
예제 #19
0
 def test_sample(self):
     num_samples = 2
     samples = self.replay.sample(num_samples)
     chex.assert_shape(samples.a, (num_samples, ))
def extract_patches_from_indicators(x,
                                    indicators,
                                    patch_size,
                                    patch_dropout,
                                    grid_shape,
                                    train,
                                    iterative=False):
    """Extract patches from a batch of images.

  Args:
    x: The batch of images of shape (batch, height, width, channels).
    indicators: The one hot indicators of shape (batch, num_patches, k).
    patch_size: The size of the (squared) patches to extract.
    patch_dropout: Probability to replace a patch by 0 values.
    grid_shape: Pair of height, width of the disposition of the num_patches
      patches.
    train: If the model is being trained. Disable dropout if not.
    iterative: If True, etracts the patches with a for loop rather than
      instanciating the "all patches" tensor and extracting by dotproduct with
      indicators. `iterative` is more memory efficient.

  Returns:
    The patches extracted from x with shape
      (batch, k, patch_size, patch_size, channels).

  """
    batch_size, height, width, channels = x.shape
    scores_h, scores_w = grid_shape
    k = indicators.shape[-1]
    indicators = einops.rearrange(indicators,
                                  "b (h w) k -> b k h w",
                                  h=scores_h,
                                  w=scores_w)

    scale_height = height // scores_h
    scale_width = width // scores_w
    padded_height = scale_height * scores_h + patch_size - 1
    padded_width = scale_width * scores_w + patch_size - 1
    top_pad = (patch_size - scale_height) // 2
    left_pad = (patch_size - scale_width) // 2
    bottom_pad = padded_height - top_pad - height
    right_pad = padded_width - left_pad - width

    # TODO(jbcdnr): assert padding is positive.

    padded_x = jnp.pad(x, [(0, 0), (top_pad, bottom_pad),
                           (left_pad, right_pad), (0, 0)])

    # Extract the patches. Iterative fits better in memory as it does not
    # instanciate the "all patches" tensor but iterate over them to compute the
    # weighted sum with the indicator variables from topk.
    if not iterative:
        assert patch_dropout == 0., "Patch dropout not implemented."
        patches = utils.extract_images_patches(padded_x,
                                               window_size=(patch_size,
                                                            patch_size),
                                               stride=(scale_height,
                                                       scale_width))

        shape = (batch_size, scores_h, scores_w, patch_size, patch_size,
                 channels)
        chex.assert_shape(patches, shape)

        patches = jnp.einsum("b k h w, b h w i j c -> b k i j c", indicators,
                             patches)

    else:
        mask = jnp.ones((batch_size, scores_h, scores_w))
        mask = nn.dropout(mask, patch_dropout, deterministic=not train)

        def accumulate_patches(acc, index_i_j):
            i, j = index_i_j
            patch = jax.lax.dynamic_slice(
                padded_x, (0, i * scale_height, j * scale_width, 0),
                (batch_size, patch_size, patch_size, channels))
            weights = indicators[:, :, i, j]

            is_masked = mask[:, i, j]
            weighted_patch = jnp.einsum("b, bk, bijc -> bkijc", is_masked,
                                        weights, patch)
            chex.assert_equal_shape([acc, weighted_patch])
            acc += weighted_patch
            return acc, None

        indices = jnp.stack(jnp.meshgrid(jnp.arange(scores_h),
                                         jnp.arange(scores_w),
                                         indexing="ij"),
                            axis=-1)
        indices = indices.reshape((-1, 2))
        init_patches = jnp.zeros(
            (batch_size, k, patch_size, patch_size, channels))
        patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices)

    return patches
    def apply(self,
              x,
              *,
              patch_size,
              k,
              downscale,
              scorer_has_se,
              normalization_str="identity",
              selection_method,
              selection_method_kwargs=None,
              selection_method_inference=None,
              patch_dropout=0.,
              hard_topk_probability=0.,
              random_patch_probability=0.,
              use_iterative_extraction,
              append_position_to_input,
              feature_network,
              aggregation_method,
              aggregation_method_kwargs=None,
              train):
        """Process a high resolution image by selecting a subset of useful patches.

    This model processes the input as follow:
    1. Compute scores per patch on a downscaled version of the input.
    2. Select "important" patches using sampling or top-k methods.
    3. Extract the patches from the high-resolution image.
    4. Compute representation vector for each patch with a feature network.
    5. Aggregate the patch representation to obtain an image representation.

    Args:
      x: Input tensor of shape (batch, height, witdh, channels).
      patch_size: Size of the (squared) patches to extract.
      k: Number of patches to extract per image.
      downscale: Downscale multiplier for the input of the scorer network.
      scorer_has_se: Whether scorer network has Squeeze-excite layers.
      normalization_str: String specifying the normalization of the scores.
      selection_method: Method that selects which patches should be extracted,
        based on their scores. Either returns indices (hard selection) or
        indicators vectors (which could yield interpolated patches).
      selection_method_kwargs: Keyword args for the selection_method.
      selection_method_inference: Selection method used at inference.
      patch_dropout: Probability to replace a patch by 0 values.
      hard_topk_probability: Probability to use the true topk on the scores to
        select the patches. This operation has no gradient so scorer's weights
        won't be trained.
      random_patch_probability: Probability to replace each patch by a random
        patch in the image during training.
      use_iterative_extraction: If True, uses a for loop instead of patch
        indexing for memory efficiency.
      append_position_to_input: Append normalized (height, width) position to
        the channels of the input.
      feature_network: Network to be applied on each patch individually to
        obtain patch representation vectors.
      aggregation_method: Method to aggregate the representations of the k
        patches of each image to obtain the image representation.
      aggregation_method_kwargs: Keywords arguments for aggregation_method.
      train: If the model is being trained. Disable dropout otherwise.

    Returns:
      A representation vector for each image in the batch.
    """
        selection_method = SelectionMethod(selection_method)
        aggregation_method = AggregationMethod(aggregation_method)
        if selection_method_inference:
            selection_method_inference = SelectionMethod(
                selection_method_inference)

        selection_method_kwargs = selection_method_kwargs or {}
        aggregation_method_kwargs = aggregation_method_kwargs or {}

        stats = {}

        # Compute new dimension of the scoring image.
        b, h, w, c = x.shape
        scoring_shape = (b, h // downscale, w // downscale, c)

        # === Compute the scores with a small CNN.
        if selection_method == SelectionMethod.RANDOM:
            scores_h, scores_w = Scorer.compute_output_size(
                h // downscale, w // downscale)
            num_patches = scores_h * scores_w
        else:
            # Downscale input to run scorer on.
            scoring_x = jax.image.resize(x, scoring_shape, method="bilinear")
            scores = Scorer(scoring_x,
                            use_squeeze_excite=scorer_has_se,
                            name="scorer")
            flatten_scores = einops.rearrange(scores, "b h w -> b (h w)")
            num_patches = flatten_scores.shape[-1]
            scores_h, scores_w = scores.shape[1:3]

            # Compute entropy before normalization
            prob_scores = jax.nn.softmax(flatten_scores)
            stats["entropy_before_normalization"] = jax.scipy.special.entr(
                prob_scores).sum(axis=1).mean(axis=0)

            # Normalize the flatten scores
            normalization_fn = create_normalization_fn(normalization_str)
            flatten_scores = normalization_fn(flatten_scores)
            scores = flatten_scores.reshape(scores.shape)
            stats["scores"] = scores[Ellipsis, None]

        # Concatenate height and width position to the input channels.
        if append_position_to_input:
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
            c += 2

        # Overwrite the selection method at inference
        if selection_method_inference and not train:
            selection_method = selection_method_inference

        # === Patch selection

        # Select the patches by sampling or top-k. Some methods returns the indices
        # of the selected patches, other methods return indicator vectors.
        extract_by_indices = selection_method in [
            SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM
        ]
        if selection_method is SelectionMethod.SINKHORN_TOPK:
            indicators = select_patches_sinkhorn_topk(
                flatten_scores, k=k, **selection_method_kwargs)
        elif selection_method is SelectionMethod.PERTURBED_TOPK:
            sigma = selection_method_kwargs["sigma"]
            num_samples = selection_method_kwargs["num_samples"]
            sigma *= self.state("sigma_mutiplier",
                                shape=(),
                                initializer=nn.initializers.ones).value
            stats["sigma"] = sigma
            indicators = select_patches_perturbed_topk(flatten_scores,
                                                       k=k,
                                                       sigma=sigma,
                                                       num_samples=num_samples)
        elif selection_method is SelectionMethod.HARD_TOPK:
            indices = select_patches_hard_topk(flatten_scores, k=k)
        elif selection_method is SelectionMethod.RANDOM:
            batch_random_indices_fn = jax.vmap(
                functools.partial(jax.random.choice,
                                  a=num_patches,
                                  shape=(k, ),
                                  replace=False))
            indices = batch_random_indices_fn(
                jax.random.split(nn.make_rng(), b))

        # Compute scores entropy for regularization
        if selection_method not in [SelectionMethod.RANDOM]:
            prob_scores = flatten_scores
            # Normalize the scores if it is not already done.
            if "softmax" not in normalization_str:
                prob_scores = jax.nn.softmax(prob_scores)
            stats["entropy"] = jax.scipy.special.entr(prob_scores).sum(
                axis=1).mean(axis=0)

        # Randomly use hard topk at training.
        if (train and hard_topk_probability > 0 and selection_method
                not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]):
            true_indices = select_patches_hard_topk(flatten_scores, k=k)
            random_values = jax.random.uniform(nn.make_rng(), (b, ))
            use_hard = random_values < hard_topk_probability
            if extract_by_indices:
                indices = jnp.where(use_hard[:, None], true_indices, indices)
            else:
                true_indicators = make_indicators(true_indices, num_patches)
                indicators = jnp.where(use_hard[:, None, None],
                                       true_indicators, indicators)

        # Sample some random patches during training with random_patch_probability.
        if (train and random_patch_probability > 0
                and selection_method is not SelectionMethod.RANDOM):
            single_random_patches = functools.partial(jax.random.choice,
                                                      a=num_patches,
                                                      shape=(k, ),
                                                      replace=False)
            random_indices = jax.vmap(single_random_patches)(jax.random.split(
                nn.make_rng(), b))
            random_values = jax.random.uniform(nn.make_rng(), (b, k))
            use_random = random_values < random_patch_probability
            if extract_by_indices:
                indices = jnp.where(use_random, random_indices, indices)
            else:
                random_indicators = make_indicators(random_indices,
                                                    num_patches)
                indicators = jnp.where(use_random[:, None, :],
                                       random_indicators, indicators)

        # === Patch extraction
        if extract_by_indices:
            patches = extract_patches_from_indices(x,
                                                   indices,
                                                   patch_size=patch_size,
                                                   grid_shape=(scores_h,
                                                               scores_w))
            indicators = make_indicators(indices, num_patches)
        else:
            patches = extract_patches_from_indicators(
                x,
                indicators,
                patch_size,
                grid_shape=(scores_h, scores_w),
                iterative=use_iterative_extraction,
                patch_dropout=patch_dropout,
                train=train)

        chex.assert_shape(patches, (b, k, patch_size, patch_size, c))

        stats["extracted_patches"] = einops.rearrange(
            patches, "b k i j c -> b i (k j) c")
        # Remove position channels for plotting.
        if append_position_to_input:
            stats["extracted_patches"] = (
                stats["extracted_patches"][Ellipsis, :-2])

        # === Compute patch features
        flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c")
        representations = feature_network(flatten_patches, train=train)
        if representations.ndim > 2:
            collapse_axis = tuple(range(1, representations.ndim - 1))
            representations = representations.mean(axis=collapse_axis)
        representations = einops.rearrange(representations,
                                           "(b k) d -> b k d",
                                           k=k)

        stats["patch_representations"] = representations

        # === Aggregate the k patches

        # - for sampling we are forced to take an expectation
        # - for topk we have multiple options: mean, max, transformer.
        if aggregation_method is AggregationMethod.TRANSFORMER:
            patch_pos_encoding = nn.Dense(einops.rearrange(
                indicators, "b d k -> b k d"),
                                          features=representations.shape[-1])

            chex.assert_equal_shape([representations, patch_pos_encoding])
            representations += patch_pos_encoding
            representations = transformer.Transformer(
                representations,
                **aggregation_method_kwargs,
                is_training=train)

        elif aggregation_method is AggregationMethod.MEANPOOLING:
            representations = representations.mean(axis=1)
        elif aggregation_method is AggregationMethod.MAXPOOLING:
            representations = representations.max(axis=1)
        elif aggregation_method is AggregationMethod.SUM_LAYERNORM:
            representations = representations.sum(axis=1)
            representations = nn.LayerNorm(representations)

        representations = nn.Dense(representations,
                                   features=representations.shape[-1],
                                   name="classification_dense1")
        representations = nn.swish(representations)

        return representations, stats
예제 #22
0
    def __call__(self,
                 queries: jnp.ndarray,
                 hm_memory: HierarchicalMemory,
                 hm_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """Do hierarchical attention over the stored memories.

    Args:
       queries: Tensor [B, Q, E] Query(ies) in, for batch size B, query length
         Q, and embedding dimension E.
       hm_memory: Hierarchical Memory.
       hm_mask: Optional boolean mask tensor of shape [B, Q, M]. Where false,
         the corresponding query timepoints cannot attend to the corresponding
         memory chunks. This can be used for enforcing causal attention on the
         learner, not attending to memories from prior episodes, etc.

    Returns:
      Value updates for each query slot: [B, Q, D]
    """
        # some shape checks
        batch_size, query_length, _ = queries.shape
        (memory_batch_size, num_memories, memory_chunk_size,
         mem_embbedding_size) = hm_memory.contents.shape
        assert batch_size == memory_batch_size
        chex.assert_shape(hm_memory.keys,
                          (batch_size, num_memories, mem_embbedding_size))
        chex.assert_shape(
            hm_memory.accumulator,
            (memory_batch_size, memory_chunk_size, mem_embbedding_size))
        chex.assert_shape(hm_memory.steps_since_last_write,
                          (memory_batch_size, ))
        if hm_mask is not None:
            chex.assert_type(hm_mask, bool)
            chex.assert_shape(hm_mask,
                              (batch_size, query_length, num_memories))
        query_head = self._singlehead_linear(queries, self._size, "query")
        key_head = self._singlehead_linear(
            jax.lax.stop_gradient(hm_memory.keys), self._size, "key")

        # What times in the input [t] attend to what times in the memories [T].
        logits = jnp.einsum("btd,bTd->btT", query_head, key_head)

        scaled_logits = logits / np.sqrt(self._size)

        # Mask last dimension, replacing invalid logits with large negative values.
        # This allows e.g. enforcing causal attention on learner, or blocking
        # attention across episodes
        if hm_mask is not None:
            masked_logits = jnp.where(hm_mask, scaled_logits, -1e6)
        else:
            masked_logits = scaled_logits

        # identify the top-k memories and their relevance weights
        top_k_logits, top_k_indices = jax.lax.top_k(masked_logits, self._k)
        weights = jax.nn.softmax(top_k_logits)

        # set up the within-memory attention
        assert self._size % self._num_heads == 0
        mha_key_size = self._size // self._num_heads
        attention_layer = hk.MultiHeadAttention(key_size=mha_key_size,
                                                model_size=self._size,
                                                num_heads=self._num_heads,
                                                w_init_scale=self._init_scale,
                                                name="within_mem_attn")

        # position encodings
        augmented_contents = hm_memory.contents
        if self._memory_position_encoding:
            position_embs = sinusoid_position_encoding(memory_chunk_size,
                                                       mem_embbedding_size)
            augmented_contents += position_embs[None, None, :, :]

        def _within_memory_attention(sub_inputs, sub_memory_contents,
                                     sub_weights, sub_top_k_indices):
            top_k_contents = sub_memory_contents[sub_top_k_indices, :, :]

            # Now we go deeper, with another vmap over **tokens**, because each token
            # can each attend to different memories.
            def do_attention(sub_sub_inputs, sub_sub_top_k_contents):
                tiled_inputs = jnp.tile(sub_sub_inputs[None, None, :],
                                        reps=(self._k, 1, 1))
                sub_attention_results = attention_layer(
                    query=tiled_inputs,
                    key=sub_sub_top_k_contents,
                    value=sub_sub_top_k_contents)
                return sub_attention_results

            do_attention = hk_vmap(do_attention, in_axes=0, split_rng=False)
            attention_results = do_attention(sub_inputs, top_k_contents)
            attention_results = jnp.squeeze(attention_results, axis=2)
            # Now collapse results across k memories
            attention_results = sub_weights[:, :, None] * attention_results
            attention_results = jnp.sum(attention_results, axis=1)
            return attention_results

        # vmap across batch
        batch_within_memory_attention = hk_vmap(_within_memory_attention,
                                                in_axes=0,
                                                split_rng=False)
        outputs = batch_within_memory_attention(
            queries, jax.lax.stop_gradient(augmented_contents), weights,
            top_k_indices)

        return outputs
def test_bayesianpca_spec_and_specandphot():

    n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer = 122, 100, 47, 5, 50
    dataPipeline = DataPipeline.save_fake_data(
        n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer
    )
    dataPipeline = DataPipeline("data/fake/fake_")

    batchsize = 20
    data_batch = dataPipeline.next_batch_specandphot(dataPipeline.indices, batchsize)

    (
        si,
        bs,
        batch_index_wave,
        batch_index_transfer_redshift,
        spec,
        spec_invvar,
        spec_loginvvar,
        # batch_spec_mask,
        specphotscaling,
        phot,
        phot_invvar,
        phot_loginvvar,
        batch_redshifts,
        batch_transferfunctions,
        batch_index_wave_ext,
        batch_interprightindices,
        batch_interpweights,
    ) = data_batch

    assert bs == batchsize
    assert si == 0
    assert_shape(spec, (bs, n_pix_spec))
    assert_shape(spec_invvar, (bs, n_pix_spec))
    assert_shape(spec_loginvvar, (bs, n_pix_spec))
    assert_shape(phot, (bs, n_pix_phot))
    assert_shape(phot_invvar, (bs, n_pix_phot))
    assert_shape(phot_loginvvar, (bs, n_pix_phot))
    assert_shape(batch_redshifts, (bs,))
    assert_shape(batch_transferfunctions, (bs, n_pix_sed, n_pix_phot))

    n_components = 3
    pcacomponents_speconly = jax.random.normal(key, (n_components, n_pix_sed))
예제 #24
0
def measure_valued_estimation_std(function: Callable[[chex.Array], float],
                                  dist: Any,
                                  rng: chex.PRNGKey,
                                  num_samples: int,
                                  coupling: bool = True) -> chex.Array:
    """Measure valued grads of a Gaussian expectation of `function` wrt the std.

  Args:
    function: Function f(x) for which to estimate grads_{std} E_dist f(x).
      The function takes in one argument (a sample from the distribution) and
      returns a floating point value.
    dist: a distribution on which we can call `sample`.
    rng: a PRNGKey key.
    num_samples: Int, the number of samples used to compute the grads.
    coupling: A boolean. Whether or not to use coupling for the positive and
      negative samples. Recommended: True, as this reduces variance.

  Returns:
    A `num_samples x D` vector containing the estimates of the gradients
    obtained for each sample. The mean of this vector can be used to update
    the scale parameter. The entire vector can be used to assess estimator
    variance.
  """
    mean, log_std = dist.params
    std = jnp.exp(log_std)

    dist_samples = dist.sample((num_samples, ), seed=rng)

    pos_rng, neg_rng = jax.random.split(rng)

    # The only difference between mean and std gradients is what we sample.
    pos_sample = jax.random.double_sided_maxwell(pos_rng,
                                                 loc=0.0,
                                                 scale=1.0,
                                                 shape=dist_samples.shape)
    if coupling:
        unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape)
        neg_sample = unif_rvs * pos_sample
    else:
        neg_sample = jax.random.normal(neg_rng, dist_samples.shape)

    # Both need to be positive in the case of the scale.
    # N x D
    positive_diag = mean + std * pos_sample
    # N x D
    negative_diag = mean + std * neg_sample

    # NOTE: you can sample base samples here if you use the same rng
    # Duplicate the D dimension - N x D x D.
    base_dist_samples = utils.tile_second_to_last_dim(dist_samples)
    positive = utils.set_diags(base_dist_samples, positive_diag)
    negative = utils.set_diags(base_dist_samples, negative_diag)

    # Different C for the scale
    c = std  # D
    # Apply function. We apply the function to each element of N x D x D.
    # We apply a function that takes a sample and returns one number, so the
    # output will be N x D (which is what we want, batch by dimension).
    # We apply a function in parallel to the batch.
    # Broadcast the division.
    vmaped_function = jax.vmap(jax.vmap(function, 1, 0))
    grads = (vmaped_function(positive) - vmaped_function(negative)) / c

    chex.assert_shape(grads, (num_samples, ) + std.shape)
    return grads
def test_bayesianpca_photonly():

    n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer = 122, 100, 47, 5, 50
    dataPipeline = DataPipeline.save_fake_data(
        n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer
    )
    dataPipeline = DataPipeline("data/fake/fake_", phot=True, spec=False)

    batchsize = 20
    data_batch = dataPipeline.next_batch_photonly(dataPipeline.indices, batchsize)

    (
        si,
        bs,
        phot,
        phot_invvar,
        phot_loginvvar,
        batch_redshifts,
        transferfunctions,
        batch_interprightindices_transfer,
        batch_interpweights_transfer,
    ) = data_batch

    assert bs == batchsize
    assert si == 0
    assert_shape(phot, (bs, n_pix_phot))
    assert_shape(phot_invvar, (bs, n_pix_phot))
    assert_shape(phot_loginvvar, (bs, n_pix_phot))
    assert_shape(batch_redshifts, (bs,))
    assert_shape(transferfunctions, (n_pix_transfer, n_pix_sed, n_pix_phot))
    assert_shape(batch_interprightindices_transfer, (bs,))
    assert_shape(batch_interpweights_transfer, (bs,))
 def test_logits_shape(self, img_resolution, num_classes, stage_sizes):
     model = BoTNet(num_classes=num_classes, stage_sizes=stage_sizes)
     rng = dict(params=random.PRNGKey(0))
     x = jnp.ones((2, img_resolution, img_resolution, 3))
     logits, _ = model.init_with_output(rng, x, is_training=True)
     chex.assert_shape(logits, (2, num_classes))
예제 #27
0
def main(argv):
  """Trains Prioritized DQN agent on Atari."""
  del argv
  logging.info('Prioritized DQN on Atari on %s.',
               jax.lib.xla_bridge.get_backend().platform)
  random_state = np.random.RandomState(FLAGS.seed)
  rng_key = jax.random.PRNGKey(
      random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))

  if FLAGS.results_csv_path:
    writer = parts.CsvWriter(FLAGS.results_csv_path)
  else:
    writer = parts.NullWriter()

  def environment_builder():
    """Creates Atari environment."""
    env = gym_atari.GymAtari(
        FLAGS.environment_name, seed=random_state.randint(1, 2**32))
    return gym_atari.RandomNoopsEnvironmentWrapper(
        env,
        min_noop_steps=1,
        max_noop_steps=30,
        seed=random_state.randint(1, 2**32),
    )

  env = environment_builder()

  logging.info('Environment: %s', FLAGS.environment_name)
  logging.info('Action spec: %s', env.action_spec())
  logging.info('Observation spec: %s', env.observation_spec())
  num_actions = env.action_spec().num_values
  network_fn = networks.double_dqn_atari_network(num_actions)
  network = hk.transform(network_fn)

  def preprocessor_builder():
    return processors.atari(
        additional_discount=FLAGS.additional_discount,
        max_abs_reward=FLAGS.max_abs_reward,
        resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
        num_action_repeats=FLAGS.num_action_repeats,
        num_pooled_frames=2,
        zero_discount_on_life_loss=True,
        num_stacked_frames=FLAGS.num_stacked_frames,
        grayscaling=True,
    )

  # Create sample network input from sample preprocessor output.
  sample_processed_timestep = preprocessor_builder()(env.reset())
  sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                          sample_processed_timestep)
  sample_network_input = sample_processed_timestep.observation
  chex.assert_shape(sample_network_input,
                    (FLAGS.environment_height, FLAGS.environment_width,
                     FLAGS.num_stacked_frames))

  exploration_epsilon_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
                  FLAGS.num_action_repeats),
      decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                      FLAGS.num_iterations * FLAGS.num_train_frames),
      begin_value=FLAGS.exploration_epsilon_begin_value,
      end_value=FLAGS.exploration_epsilon_end_value)

  # Note the t in the replay is not exactly aligned with the agent t.
  importance_sampling_exponent_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
      end_t=(FLAGS.num_iterations *
             int(FLAGS.num_train_frames / FLAGS.num_action_repeats)),
      begin_value=FLAGS.importance_sampling_exponent_begin_value,
      end_value=FLAGS.importance_sampling_exponent_end_value)

  if FLAGS.compress_state:

    def encoder(transition):
      return transition._replace(
          s_tm1=replay_lib.compress_array(transition.s_tm1),
          s_t=replay_lib.compress_array(transition.s_t))

    def decoder(transition):
      return transition._replace(
          s_tm1=replay_lib.uncompress_array(transition.s_tm1),
          s_t=replay_lib.uncompress_array(transition.s_t))
  else:
    encoder = None
    decoder = None

  replay_structure = replay_lib.Transition(
      s_tm1=None,
      a_tm1=None,
      r_t=None,
      discount_t=None,
      s_t=None,
  )

  replay = replay_lib.PrioritizedTransitionReplay(
      FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent,
      importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability,
      FLAGS.normalize_weights, random_state, encoder, decoder)

  optimizer = optax.rmsprop(
      learning_rate=FLAGS.learning_rate,
      decay=0.95,
      eps=FLAGS.optimizer_epsilon,
      centered=True,
  )

  train_rng_key, eval_rng_key = jax.random.split(rng_key)

  train_agent = agent.PrioritizedDqn(
      preprocessor=preprocessor_builder(),
      sample_network_input=sample_network_input,
      network=network,
      optimizer=optimizer,
      transition_accumulator=replay_lib.TransitionAccumulator(),
      replay=replay,
      batch_size=FLAGS.batch_size,
      exploration_epsilon=exploration_epsilon_schedule,
      min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
      learn_period=FLAGS.learn_period,
      target_network_update_period=FLAGS.target_network_update_period,
      grad_error_bound=FLAGS.grad_error_bound,
      rng_key=train_rng_key,
  )
  eval_agent = parts.EpsilonGreedyActor(
      preprocessor=preprocessor_builder(),
      network=network,
      exploration_epsilon=FLAGS.eval_exploration_epsilon,
      rng_key=eval_rng_key,
  )

  # Set up checkpointing.
  checkpoint = parts.NullCheckpoint()

  state = checkpoint.state
  state.iteration = 0
  state.train_agent = train_agent
  state.eval_agent = eval_agent
  state.random_state = random_state
  state.writer = writer
  if checkpoint.can_be_restored():
    checkpoint.restore()

  while state.iteration <= FLAGS.num_iterations:
    # New environment for each iteration to allow for determinism if preempted.
    env = environment_builder()

    logging.info('Training iteration %d.', state.iteration)
    train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode)
    num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
    train_seq_truncated = itertools.islice(train_seq, num_train_frames)
    train_trackers = parts.make_default_trackers(train_agent)
    train_stats = parts.generate_statistics(train_trackers, train_seq_truncated)

    logging.info('Evaluation iteration %d.', state.iteration)
    eval_agent.network_params = train_agent.online_params
    eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode)
    eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
    eval_trackers = parts.make_default_trackers(eval_agent)
    eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated)

    # Logging and checkpointing.
    human_normalized_score = atari_data.get_human_normalized_score(
        FLAGS.environment_name, eval_stats['episode_return'])
    capped_human_normalized_score = np.amin([1., human_normalized_score])
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'),
        ('train_state_value', train_stats['state_value'], '%.3f'),
        ('importance_sampling_exponent',
         train_agent.importance_sampling_exponent, '%.3f'),
        ('max_seen_priority', train_agent.max_seen_priority, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()
예제 #28
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""
        b, c = x.shape[0], x.shape[3]
        k = config.k
        sigma = config.ptopk_sigma
        num_samples = config.ptopk_num_samples

        sigma *= self.state("sigma_mutiplier",
                            shape=(),
                            initializer=nn.initializers.ones).value

        stats = {"x": x, "sigma": sigma}

        feature_extractor = models.ResNet50.shared(train=train,
                                                   name="ResNet_0")

        rpn_feature = feature_extractor(x)
        rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature),
                                            communication=Communication(
                                                config.communication),
                                            train=train)
        stats.update(rpn_stats)

        # rpn_scores are a list of score images. We keep track of the structure
        # because it is used in the aggregation step later-on.
        rpn_scores_shapes = [s.shape for s in rpn_scores]
        rpn_scores_flat = jnp.concatenate(
            [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1)
        top_k_indicators = sample_patches.select_patches_perturbed_topk(
            rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples)
        top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1])
        offset = 0
        weights = []
        for sh in rpn_scores_shapes:
            cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]]
            cur = jnp.reshape(cur, [b, k, sh[1], sh[2]])
            weights.append(cur)
            offset += sh[1] * sh[2]
        chex.assert_equal(offset, top_k_indicators.shape[-1])

        part_imgs = weighted_anchor_aggregator(x, weights)
        chex.assert_shape(part_imgs, (b * k, 224, 224, c))
        stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c])

        part_features = feature_extractor(part_imgs)
        part_features = jnp.mean(part_features,
                                 axis=[1, 2])  # GAP the spatial dims

        part_features = nn.dropout(  # features from parts
            jnp.reshape(part_features, [b * k, 2048]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())
        features = nn.dropout(  # features from whole image
            jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())

        # Mean pool all part features, add it to features and predict logits.
        concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]),
                              axis=1) + features
        concat_logits = nn.Dense(concat_out, num_classes)
        raw_logits = nn.Dense(features, num_classes)
        part_logits = jnp.reshape(nn.Dense(part_features, num_classes),
                                  [b, k, -1])

        all_logits = {
            "raw_logits": raw_logits,
            "concat_logits": concat_logits,
            "part_logits": part_logits,
        }
        # add entropy into it for entropy regularization.
        stats["rpn_scores_entropy"] = jax.scipy.special.entr(
            jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0)
        return all_logits, stats
예제 #29
0
    def load_spectrophotometry(
        self,
        input_dir="./",
        write_subset=False,
        use_subset=False,
        subsampling=1,
        spec=True,
        phot=True,
    ):

        if use_subset:
            suffix = "2.npy"
        else:
            suffix = ".npy"

        self.input_dir = input_dir

        self.lamgrid = onp.load(self.input_dir + "lamgrid.npy")
        self.lam_phot_eff = onp.load(self.input_dir + "lam_phot_eff.npy")
        self.lam_phot_size_eff = onp.load(self.input_dir +
                                          "lam_phot_size_eff.npy")

        self.redshifts = onp.load(self.input_dir + "redshifts" + suffix)
        n_obj = self.redshifts.size
        self.n_obj = n_obj
        assert_shape(self.redshifts, (n_obj, ))

        if phot:
            self.transferfunctions = (
                onp.load(self.input_dir + "transferfunctions.npy") * 1e-16)
            self.transferfunctions_zgrid = onp.load(
                self.input_dir + "transferfunctions_zgrid.npy")

            assert self.transferfunctions.shape[
                0] == self.transferfunctions_zgrid.size
            assert self.transferfunctions.shape[1] == self.lamgrid.size
            self.index_transfer_redshift = onp.load(self.input_dir +
                                                    "index_transfer_redshift" +
                                                    suffix)
            self.interprightindices_transfer = onp.load(
                self.input_dir + "interprightindices_transfer" + suffix)
            self.interpweights_transfer = onp.load(self.input_dir +
                                                   "interpweights_transfer" +
                                                   suffix)
            self.phot = fluxes = onp.load(self.input_dir + "phot" + suffix)
            self.phot_invvar = flux_ivars = onp.load(self.input_dir +
                                                     "phot_invvar" + suffix)
            assert_shape(self.index_transfer_redshift, (n_obj, ))
            assert_shape(self.interprightindices_transfer, (n_obj, ))
            assert_shape(self.interpweights_transfer, (n_obj, ))

            n_pix_phot = self.phot.shape[1]
            assert_shape(self.phot, (n_obj, n_pix_phot))
            assert_shape(self.phot_invvar, (n_obj, n_pix_phot))
            self.n_pix_phot = self.phot.shape[1]

        if spec:
            self.chi2s_sdss = onp.load(self.input_dir + "chi2s_sdss" + suffix)
            self.lamspec_waveoffset = int(
                onp.load(self.input_dir + "lamspec_waveoffset" + suffix))
            self.index_wave = onp.load(self.input_dir + "index_wave" + suffix)
            self.interprightindices = onp.load(self.input_dir +
                                               "interprightindices" + suffix)
            self.interpweights = onp.load(self.input_dir + "interpweights" +
                                          suffix)

            self.specmod_sdss = onp.load(self.input_dir + "spec_mod" + suffix)
            if True:
                self.spec = onp.load(self.input_dir + "spec" + suffix)
                self.spec_invvar = onp.load(self.input_dir + "spec_invvar" +
                                            suffix)
            else:
                self.spec = onp.load(self.input_dir + "spec_mod" + suffix)
                self.spec_invvar = (
                    onp.load(self.input_dir + "spec_invvar" + suffix) * 0 + 1)

            self.n_pix_spec = self.spec.shape[1]

            assert_shape(self.chi2s_sdss, (n_obj, ))
            assert_shape(self.index_wave, (n_obj, ))
            n_pix_spec = self.spec.shape[1]
            assert_shape(self.spec, (n_obj, n_pix_spec))
            assert_shape(self.specmod_sdss, (n_obj, n_pix_spec))
            assert_shape(self.spec_invvar, (n_obj, n_pix_spec))

        if write_subset:

            M = 50000
            suffix = "2.npy"

            self.index_wave = self.index_wave[:M]
            self.redshifts = self.redshifts[:M]
            self.chi2s_sdss = self.chi2s_sdss[:M]
            self.phot_invvar = self.phot_invvar[:M, :]
            self.index_transfer_redshift = self.index_transfer_redshift[:M]

            np.save(self.input_dir + "index_wave" + suffix,
                    self.index_wave[:M])
            np.save(
                self.input_dir + "interprightindices_transfer" + suffix,
                self.interprightindices_transfer[:M, :],
            )
            np.save(
                self.input_dir + "interpweights_transfer" + suffix,
                self.interpweights_transfer[:M, :],
            )
            np.save(
                self.input_dir + "index_transfer_redshift2.npy",
                self.index_transfer_redshift,
            )
            np.save(self.input_dir + "redshifts" + suffix, self.redshifts)
            np.save(self.input_dir + "spec" + suffix, self.spec)
            np.save(self.input_dir + "chi2s_sdss" + suffix, self.chi2s_sdss)
            np.save(self.input_dir + "spec_invvar" + suffix, self.spec_invvar)
            np.save(self.input_dir + "phot" + suffix, self.phot)
            np.save(self.input_dir + "phot_invvar" + suffix, self.phot_invvar)
            np.save(self.input_dir + "spec_mod" + suffix, self.specmod_sdss)

        if subsampling > 1:

            self.lamgrid = self.lamgrid[::subsampling]
            self.transferfunctions = self.transferfunctions[:, ::
                                                            subsampling, :][::
                                                                            subsampling, :, :]
            self.transferfunctions_zgrid = self.transferfunctions_zgrid[::
                                                                        subsampling]
            self.lamspec_waveoffset = self.lamspec_waveoffset // subsampling
            self.spec = self.spec[:, ::subsampling]
            self.specmod_sdss = self.specmod_sdss[:, ::subsampling]
            self.spec_invvar = self.spec_invvar[:, ::subsampling]
            self.index_wave = self.index_wave // subsampling
            self.index_transfer_redshift = self.index_transfer_redshift // subsampling
            self.interprightindices = (
                self.interprightindices[:, ::subsampling] // subsampling)
            self.interpweights = (self.interpweights[:, ::subsampling] /
                                  subsampling)  # dilution
            self.interprightindices_transfer = (
                self.interprightindices_transfer // subsampling
            )  # is it correct?
            self.interpweights_transfer = (self.interpweights_transfer /
                                           subsampling)  # is it correct?
예제 #30
0
  def epipolar_projection(
      self,
      key,  # pylint: disable=unused-argument
      target_rays,
      ref_worldtocamera,
      intrinsic_matrix,
      image_height,
      image_width,
      min_depth,
      max_depth,
  ):
    """** Visually verified in colab.

    Function to map given rays on the epipolar line on reference view at
    mutiple depth values.
    (The depth value are often set to near and far plane of camera).
    Args:
      key: prngkey
      target_rays: The rays that we want to project onto nearby cameras. Often
        these are the rays that we want to render contains origins, directions
        shape (#bs, rays, 3)
      ref_worldtocamera: The worldtocamera matrix of nearby cameras that we want
        to project onto
      intrinsic_matrix: (1, 3, 4), The intrinsic matrix for the datset
      image_height: image height
      image_width: image width
      min_depth: min depth of projection
      max_depth: max depth of projection

    Returns:
      pcoords: (#near_cam, batch_size, num_projections, 2)
      valid_proj_mask: (#near_cam, batch_size, num_projections) specifying with
        of the projections are valid i.e. in front of camera and within image
        bound
      wcoords: (batch_size, num_projections, 3)
    """
    # Check shape of intrincs, currently we only support case where all the
    # views are from the same camera
    chex.assert_shape(intrinsic_matrix, (1, 3, 4))

    projection_depths = jnp.linspace(min_depth, max_depth, self.num_samples)

    # Compute the world coordinates for each ray for each depth values
    # wcoords has shape (#rays, num_projections, 3)
    wcoords = (
        target_rays.origins[:, None] +
        target_rays.directions[:, None] * projection_depths[Ellipsis, None])
    # Convert to homogeneous coordinates
    # (#rays, num_samples, 3) -> (#rays, num_samples, 4)
    wcoords = jnp.concatenate(
        [wcoords, jnp.ones_like(wcoords[Ellipsis, 0:1])], axis=-1)
    pcoords, proj_frontof_cam_mask = self.project2camera(
        wcoords, ref_worldtocamera, intrinsic_matrix)

    # Find projections that are inside the image
    within_image_mask = self.inside_image(pcoords, image_height, image_width)

    # Clip coordinates to be within the images
    pcoords = jnp.concatenate([
        jnp.clip(pcoords[Ellipsis, 0:1], 0, image_height - 1),
        jnp.clip(pcoords[Ellipsis, 1:], 0, image_width - 1)
    ],
                              axis=-1)

    # A projection is valid if it is in front of camera and within the image
    # boundaries
    valid_proj_mask = proj_frontof_cam_mask * within_image_mask

    return pcoords, valid_proj_mask, wcoords[Ellipsis, :3]