def test_mixturegausscdf_bijector_shape(n_samples, n_features, n_components): x = objax.random.normal(( n_samples, n_features, ), generator=generator) # create layer model = MixtureGaussianCDF(n_features, n_components) # forward transformation z, log_abs_det = model(x) # checks chex.assert_equal_shape([z, x]) chex.assert_shape(log_abs_det, (n_samples, )) # forward transformation x_approx = model.inverse(z) # checks chex.assert_equal_shape([x, x_approx])
def testNonPolynomialFunctionConsistencyWithPathwise( self, effective_mean, effective_log_scale, function, coupling): num_samples = 10**5 rng = jax.random.PRNGKey(1) measure_rng, pathwise_rng = jax.random.split(rng) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) data_dims = len(effective_mean) measure_valued_jacobians = _measure_valued_variant( self.variant)(function, [mean, log_scale], utils.multi_normal, measure_rng, num_samples, coupling) measure_valued_mean_jacobians = measure_valued_jacobians[0] chex.assert_shape(measure_valued_mean_jacobians, (num_samples, data_dims)) measure_valued_mean_grads = np.mean(measure_valued_mean_jacobians, axis=0) measure_valued_log_scale_jacobians = measure_valued_jacobians[1] chex.assert_shape(measure_valued_log_scale_jacobians, (num_samples, data_dims)) measure_valued_log_scale_grads = np.mean( measure_valued_log_scale_jacobians, axis=0) pathwise_jacobians = _estimator_variant( self.variant, sge.pathwise_jacobians)(function, [mean, log_scale], utils.multi_normal, pathwise_rng, num_samples) pathwise_mean_jacobians = pathwise_jacobians[0] chex.assert_shape(pathwise_mean_jacobians, (num_samples, data_dims)) pathwise_mean_grads = np.mean(pathwise_mean_jacobians, axis=0) pathwise_log_scale_jacobians = pathwise_jacobians[1] chex.assert_shape(pathwise_log_scale_jacobians, (num_samples, data_dims)) pathwise_log_scale_grads = np.mean(pathwise_log_scale_jacobians, axis=0) _assert_equal(pathwise_mean_grads, measure_valued_mean_grads, rtol=5e-1, atol=1e-1) _assert_equal(pathwise_log_scale_grads, measure_valued_log_scale_grads, rtol=5e-1, atol=1e-1)
def test_householder_bijector_shape(n_samples, n_features, n_reflections): params_rng, data_rng = jax.random.split(KEY, 2) x = jax.random.normal(data_rng, shape=(n_samples, n_features)) init_func = HouseHolder(n_reflections=n_reflections) # create layer params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_features) # forward transformation z, log_abs_det = forward_f(params, x) # checks chex.assert_equal_shape([z, x]) chex.assert_shape(log_abs_det, (n_samples,)) # inverse transformation x_approx, log_abs_det = inverse_f(params, z) # checks chex.assert_equal_shape([x_approx, x])
def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, *apply_keys = jax.random.split(rng_key, 4) q_tm1 = network.apply(online_params, apply_keys[0], transitions.s_tm1).q_values q_t = network.apply(online_params, apply_keys[1], transitions.s_t).q_values q_target_t = network.apply(target_params, apply_keys[2], transitions.s_t).q_values td_errors = _batch_double_q_learning( q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, q_target_t, q_t, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) chex.assert_shape(losses, (self._batch_size,)) loss = jnp.mean(losses) return loss
def loss_fn(online_params, target_params, transitions, weights, rng_key): """Calculates loss given network parameters and transitions.""" _, *apply_keys = jax.random.split(rng_key, 4) logits_q_tm1 = network.apply(online_params, apply_keys[0], transitions.s_tm1).q_logits q_t = network.apply(online_params, apply_keys[1], transitions.s_t).q_values logits_q_target_t = network.apply(target_params, apply_keys[2], transitions.s_t).q_logits losses = _batch_categorical_double_q_learning( support, logits_q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, support, logits_q_target_t, q_t, ) loss = jnp.mean(losses * weights) chex.assert_shape((losses, weights), (self._batch_size, )) return loss, losses
def _entropy_scalar( total_count: int, probs: Array, log_of_probs: Array) -> Union[jnp.float32, jnp.float64]: """Calculates the entropy for a Multinomial with integer `total_count`.""" # Constant factors in the entropy. xi = jnp.arange(total_count + 1, dtype=probs.dtype) log_xi_factorial = lax.lgamma(xi + 1) log_n_minus_xi_factorial = jnp.flip(log_xi_factorial, axis=-1) log_n_factorial = log_xi_factorial[..., -1] log_comb_n_xi = (log_n_factorial[..., None] - log_xi_factorial - log_n_minus_xi_factorial) comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi)) chex.assert_shape(comb_n_xi, (total_count + 1, )) likelihood1 = math.power_no_nan(probs[..., None], xi) likelihood2 = math.power_no_nan(1. - probs[..., None], total_count - xi) chex.assert_shape(likelihood1, ( probs.shape[-1], total_count + 1, )) chex.assert_shape(likelihood2, ( probs.shape[-1], total_count + 1, )) likelihood = jnp.sum(likelihood1 * likelihood2, axis=-2) chex.assert_shape(likelihood, (total_count + 1, )) comb_term = jnp.sum(comb_n_xi * log_xi_factorial * likelihood, axis=-1) chex.assert_shape(comb_term, ()) # Probs factors in the entropy. n_probs_factor = jnp.sum(total_count * math.multiply_no_nan(log_of_probs, probs), axis=-1) return -log_n_factorial - n_probs_factor + comb_term
def testLinearFunction(self, effective_mean, effective_log_scale, estimator): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] rng = jax.random.PRNGKey(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) jacobians = _estimator_variant(self.variant, estimator)( np.sum, [mean, log_scale], utils.multi_normal, rng, num_samples) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = np.mean(mean_jacobians, axis=0) expected_mean_grads = np.ones(data_dims, dtype=np.float32) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = np.mean(log_scale_jacobians, axis=0) expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) _assert_equal(mean_grads, expected_mean_grads) _assert_equal(log_scale_grads, expected_log_scale_grads)
def test_logit_shape(n_samples, n_features): params_rng, data_rng = jax.random.split(KEY, 2) x = jax.random.uniform(data_rng, shape=(n_samples, n_features)) # create layer init_func = InverseGaussCDF(eps=1e-5) # create layer params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_features,) # forward transformation z, log_abs_det = forward_f(params, x) # checks chex.assert_equal_shape([z, x]) chex.assert_shape(log_abs_det, (n_samples,)) # inverse transformation x_approx, log_abs_det = inverse_f(params, z) # checks chex.assert_equal_shape([x_approx, x])
def test_adam(self): init_fn, update_fn = optimizers.get_optimizer( ConfigDict({ 'optimizer': 'adam', 'l2_decay_factor': None, 'batch_size': 50, 'total_accumulated_batch_size': 100, # Use gradient accumulation. 'opt_hparams': { 'beta1': 0.9, 'beta2': 0.999, 'epsilon': 1e-7, 'weight_decay': 0.0, } })) del update_fn optimizer_state = init_fn({'foo': jnp.ones(10)}) # Test that we can extract 'count'. chex.assert_type(extract_field(optimizer_state, 'count'), int) # Test that we can extract 'nu'. chex.assert_shape(extract_field(optimizer_state, 'nu')['foo'], (10,)) # Test that we can extract 'mu'. chex.assert_shape(extract_field(optimizer_state, 'mu')['foo'], (10,)) # Test that attemptping to extract a nonexistent field "abc" returns None. chex.assert_equal(extract_field(optimizer_state, 'abc'), None)
def test_mixture_logistic_cdf_bijector_shape(n_samples, n_features, n_components): params_rng, data_rng = jax.random.split(KEY, 2) x = jax.random.normal(data_rng, shape=(n_samples, n_features)) # create layer init_func = MixtureLogisticCDF(n_components=n_components) # create layer params, forward_f, inverse_f = init_func(rng=params_rng, n_features=n_features) # forward transformation z, log_abs_det = forward_f(params, x) # checks chex.assert_equal_shape([z, x]) chex.assert_shape(log_abs_det, (n_samples,)) # forward transformation x_approx, log_abs_det = inverse_f(params, z) # checks chex.assert_equal_shape([x, x_approx])
def critic_loss( critic_params: hk.Params, generator_params: hk.Params, img_batch: ImgBatch, latent_batch: LatentBatch, ) -> Tuple[jnp.ndarray, Log]: batch_size = img_batch.shape[0] img_generated = generator.apply(generator_params, latent_batch) def f(x): return critic.apply(critic_params, x) f_real = f(img_batch) # Use the vector-jacobian product to efficiently compute the grad for each f_i. f_gen, grad_fn = jax.vjp(f, img_generated) grad = grad_fn(jnp.ones(f_gen.shape))[0] assert_shape(f_real, (batch_size, 1)) assert_shape(f_gen, (batch_size, 1)) assert_shape(grad, img_batch.shape) flat_grad = grad.reshape(batch_size, -1) gp = jnp.square(1 - jnp.linalg.norm(flat_grad, axis=1)) assert_shape(gp, (batch_size, )) loss = jnp.mean(f_gen) - jnp.mean(f_real) gradient_penalty = jnp.mean(gp) log = { "wasserstein": -loss, "gradient_penalty": gradient_penalty, } return loss + 10 * gradient_penalty, log
def affine_transform(dist_params, scale, shift, value_transform=None): """ implements the "Categorical Algorithm" from https://arxiv.org/abs/1707.06887 """ # check inputs chex.assert_rank([dist_params['logits'], scale, shift], [2, {0, 1}, {0, 1}]) p = jax.nn.softmax(dist_params['logits']) batch_size = p.shape[0] if isscalar(scale): scale = jnp.full(shape=(batch_size, ), fill_value=jnp.squeeze(scale)) if isscalar(shift): shift = jnp.full(shape=(batch_size, ), fill_value=jnp.squeeze(shift)) chex.assert_shape(p, (batch_size, self.num_bins)) chex.assert_shape([scale, shift], (batch_size, )) if value_transform is None: f = f_inv = lambda x: x else: f, f_inv = value_transform # variable names correspond to those defined in: https://arxiv.org/abs/1707.06887 z = self.__atoms Vmin, Vmax, Δz = z[0], z[-1], z[1] - z[0] Tz = f(jax.vmap(jnp.add)(jnp.outer(scale, f_inv(z)), shift)) Tz = jnp.clip(Tz, Vmin, Vmax) # keep values in valid range chex.assert_shape(Tz, (batch_size, self.num_bins)) b = (Tz - Vmin) / Δz # float in [0, num_bins - 1] l = jnp.floor(b).astype( 'int32') # noqa: E741 # int in {0, 1, ..., num_bins - 1} u = jnp.ceil(b).astype('int32') # int in {0, 1, ..., num_bins - 1} chex.assert_shape([p, b, l, u], (batch_size, self.num_bins)) m = jnp.zeros_like(p) i = jnp.expand_dims(jnp.arange(batch_size), axis=1) # batch index m = jax.ops.index_add(m, (i, l), p * (u - b), indices_are_sorted=True) m = jax.ops.index_add(m, (i, u), p * (b - l), indices_are_sorted=True) m = jax.ops.index_add(m, (i, l), p * (l == u), indices_are_sorted=True) # chex.assert_tree_all_close(jnp.sum(m, axis=1), jnp.ones(batch_size), rtol=1e-6) # # The above index trickery is equivalent to: # m_alt = onp.zeros((batch_size, self.num_bins)) # for i in range(batch_size): # for j in range(self.num_bins): # if l[i, j] == u[i, j]: # m_alt[i, l[i, j]] += p[i, j] # don't split if b[i, j] is an integer # else: # m_alt[i, l[i, j]] += p[i, j] * (u[i, j] - b[i, j]) # m_alt[i, u[i, j]] += p[i, j] * (b[i, j] - l[i, j]) # chex.assert_tree_all_close(m, m_alt, rtol=1e-6) return {'logits': jnp.log(jnp.maximum(m, 1e-16))}
def epipolar_projection(self, key, target_rays, ref_worldtocamera, intrinsic_matrix, randomized): """** Visually verified in colab. Function to map given rays on the epipolar line on reference view at mutiple depth values. (The depth value are often set to near and far plane of camera). Args: key: prngkey target_rays: The rays that we want to project onto nearby cameras. Often these are the rays that we want to render contains origins, directions shape (#bs, rays, 3) ref_worldtocamera:(#near_cam, 4,4) The worldtocamera matrix of nearby cameras that we want to project onto. intrinsic_matrix: (1, 3, 4), The intrinsic matrix for the datset randomized: if True, use randomized depths for projection. Returns: pcoords: (#near_cam, batch_size, num_projections, 2) valid_proj_mask: (#near_cam, batch_size, num_projections) specifying with of the projections are valid i.e. in front of camera and within image bound wcoords: (batch_size, num_projections, 3) """ # Check shape of intrincs, currently we only support case where all the # views are from the same camera chex.assert_shape(intrinsic_matrix, (1, 3, 4)) #intrinsic_matrix = intrinsic_matrix.squeeze(0) projection_depths = jnp.linspace(self.min_depth, self.max_depth, self.num_samples) if randomized: mids = .5 * (projection_depths[Ellipsis, 1:] + projection_depths[Ellipsis, :-1]) upper = jnp.concatenate([mids, projection_depths[Ellipsis, -1:]], -1) lower = jnp.concatenate([projection_depths[Ellipsis, :1], mids], -1) batch_size = target_rays.batch_shape[0] p_rand = jax.random.uniform(key, [batch_size, self.num_samples]) projection_depths = lower + (upper - lower) * p_rand # Compute the woorld coordinates for each ray for each depth values # wcoords has shape (#rays, num_projections, 3) wcoords = (target_rays.origins[:, None] + target_rays.directions[:, None] * projection_depths[Ellipsis, None]) #Convert to homogenous coordinates (#rays, num_samples, 3) -> (#rays, num_samples, 4) wcoords = jnp.concatenate( [wcoords, jnp.ones_like(wcoords[Ellipsis, 0:1])], axis=-1) pcoords, proj_frontof_cam_mask = self.project2camera( wcoords, ref_worldtocamera, intrinsic_matrix) # Find projections that are inside the image within_image_mask = self.inside_image(pcoords) # Clip coordinates to be withing the images pcoords = jnp.concatenate([ jnp.clip(pcoords[Ellipsis, 0:1], 0, self.image_height - 1), jnp.clip(pcoords[Ellipsis, 1:], 0, self.image_width - 1) ], axis=-1) # A projection is valid if it is in front of camera and within the image # boundaries valid_proj_mask = proj_frontof_cam_mask * within_image_mask return pcoords, valid_proj_mask, wcoords[Ellipsis, :3]
def test_logmarglike_lineargaussianmodel_threetransfers_basics(): n_components = 2 theta_truth = jax.random.normal(key, (n_components,)) mu = theta_truth * 1.1 muinvvar = (theta_truth) ** -2 mucov = np.eye(n_components) / muinvvar logmuinvvar = np.where(muinvvar == 0, 0, np.log(muinvvar)) M_T = design_matrix_polynomials(n_components, n_pix_y) # (n_components, n_pix_y) y_truth = np.matmul(theta_truth, M_T) # (nobj, n_pix_y) y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth) # (nobj, n_pix_y) ell = 1 R_T = design_matrix_polynomials(n_components, n_pix_z) z_truth = np.matmul(theta_truth, R_T) # (nobj, n_pix_z) z, zinvvar, logzinvvar = make_masked_noisy_data(z_truth) # (nobj, n_pix_z) # first run logfml, theta_map, theta_cov = logmarglike_lineargaussianmodel_threetransfers( ell, M_T, R_T, y, yinvvar, z, zinvvar, mu, muinvvar ) # check result is finite and shapes are correct assert np.isfinite(logfml) assert_shape(theta_map, (n_components,)) assert_shape(theta_cov, (n_components, n_components)) # now trying jit ( logfml2, theta_map2, theta_cov2, ) = logmarglike_lineargaussianmodel_threetransfers_jit( ell, M_T, R_T, y, yinvvar, logyinvvar, z, zinvvar, logzinvvar, mu, muinvvar, logmuinvvar, ) assert_fml_thetamap_thetacov( logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, relative_accuracy ) # check that posterior distribution is equal to product of gaussians too def log_posterior(theta): y_mod = np.matmul(theta, M_T) # (n_samples, n_pix_y) like_y = batch_gaussian_loglikelihood(y_mod - y, yinvvar) z_mod = ell * np.matmul(theta, R_T) # (n_samples, n_pix_y) like_z = batch_gaussian_loglikelihood(z_mod - z, zinvvar) prior = batch_gaussian_loglikelihood(theta - mu, muinvvar) return like_y + like_z + prior def log_posterior2(theta): dt = theta - theta_map s, logdet = np.linalg.slogdet(theta_cov * 2 * np.pi) chi2 = np.dot(dt.T, np.linalg.solve(theta_cov, dt)) return logfml - 0.5 * (s * logdet + chi2) logpostv = log_posterior(theta_truth) logpostv2 = log_posterior2(theta_truth) assert abs(logpostv2 / logpostv - 1) < 0.01 def loss_fn(theta): return -log_posterior(theta) params = [1 * theta_map] learning_rate = 1e-5 for n in range(10): grads = grad(loss_fn)(*params) params = [param - learning_rate * grad for param, grad in zip(params, grads)] # print(n, loss_fn(*params), params[0] - theta_map) assert np.allclose(theta_map, params[0], rtol=1e-6) # Testing analytic covariance is correct theta_cov2 = np.linalg.inv(np.reshape(hessian(loss_fn)(theta_map), theta_cov.shape)) assert np.allclose(theta_cov, theta_cov2, rtol=1e-6) loss_fn_vmap = jit(vmap(loss_fn)) n = 15 theta_std = np.diag(theta_cov) ** 0.5 theta_samples, vol_element = generate_sample_grid(theta_map, theta_std, n) logpost = -loss_fn_vmap(theta_samples) logfml_numerical = logsumexp(np.log(vol_element) + logpost) # print("logfml, logfml_numerical", logfml, logfml_numerical) assert abs(logfml_numerical / logfml - 1) < 0.01
def test_logmarglike_lineargaussianmodel_threetransfers_vmap(): theta_truth = jax.random.normal(key, (nobj, n_components)) mu = theta_truth * 1.1 muinvvar = (theta_truth) ** -2 mucov = np.eye(n_components)[None, :, :] / muinvvar[:, :, None] logmuinvvar = np.where(muinvvar == 0, 0, np.log(muinvvar)) M_T = design_matrix_polynomials(n_components, n_pix_y) # (n_components, n_pix_y) M_T_y = M_T[None, :, :] * np.ones((nobj, 1, 1)) # (nobj, n_components, n_pix_y) y_truth = np.matmul(theta_truth, M_T) # (nobj, n_pix_y) y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth) # (nobj, n_pix_y) ells = 10 ** jax.random.normal(key, (nobj,)) R_T = design_matrix_polynomials(n_components, n_pix_z) R_T_z = R_T[None, :, :] * np.ones((nobj, 1, 1)) # (nobj, n_components, n_pix_y) z_truth = ells * np.matmul(theta_truth, R_T) # (nobj, n_pix_z) z, zinvvar, logzinvvar = make_masked_noisy_data(z_truth) # (nobj, n_pix_z) # first run ( logfml, theta_map, theta_cov, ) = logmarglike_lineargaussianmodel_threetransfers_jitvmap( ells, M_T_y, R_T_z, y, yinvvar, logyinvvar, z, zinvvar, logzinvvar, mu, muinvvar, logmuinvvar, ) # check result is finite and shapes are correct assert np.all(np.isfinite(logfml)) assert_shape(logfml, (nobj,)) assert_shape(theta_map, (nobj, n_components)) assert_shape(theta_cov, (nobj, n_components, n_components)) theta_std = np.diagonal(theta_cov, axis1=1, axis2=2) ** 0.5 assert_shape(theta_std, (nobj, n_components)) def loss(param_list, data_list): (ells, M_T_y, R_T_z) = param_list ( y, yinvvar, logyinvvar, z, zinvvar, logzinvvar, mu, muinvvar, logmuinvvar, ) = data_list ( logfml, theta_map, theta_cov, ) = logmarglike_lineargaussianmodel_threetransfers_jitvmap( ells, M_T_y, R_T_z, y, yinvvar, logyinvvar, z, zinvvar, logzinvvar, mu, muinvvar, logmuinvvar, ) return -np.sum(logfml) params = [ells, M_T_y, R_T_z] data = [ y, yinvvar, logyinvvar, z, zinvvar, logzinvvar, mu, muinvvar, logmuinvvar, ] assert np.isfinite(loss(params, data)) learning_rate = 1e-3 num_steps = 10 opt_init, opt_update, get_params = optimizers.adam(learning_rate) opt_state = opt_init(params) @jit def update(step, opt_state, data): params = get_params(opt_state) value, grads = jax.value_and_grad(loss)(params, data) opt_state = opt_update(step, grads, opt_state) return value, opt_state for step in range(num_steps): value, opt_state = update(step, opt_state, data)
def test_logmarglike_lineargaussianmodel_onetransfer_batched(): # shapes of arrays are in parentheses. # generate some fake data (noisy and masked) nobj = 14 n_components = 2 # true linear parameters, per object theta_truth = jax.random.normal(key, (nobj, n_components)) n_pix_y = 100 # number of pixels for each object M_T = design_matrix_polynomials(n_components, n_pix_y) # (n_components, n_pix_y) M_T_y = M_T[None, :, :] * np.ones((nobj, 1, 1)) # (nobj, n_components, n_pix_y) # data array, truth y_truth = np.matmul(theta_truth, M_T) # (nobj, n_pix_y) # data array, with noise, and array containing the inverse noise variance and its logarithm y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth) # (nobj, n_pix_y) assert_equal_shape([y_truth, y, yinvvar, logyinvvar]) # importantly, yinvvar and logyinvvar has zeros, symbolising ignored/masked pixels # now given the data and model matrix M_T, # compute the evidences, best fit thetas, and their covariances, for all objects at once. ( logfml, theta_map, theta_cov, ) = logmarglike_lineargaussianmodel_onetransfer_jitvmap( M_T_y, y, yinvvar, logyinvvar ) # checking shapes of output arrays assert_shape(logfml, (nobj,)) assert_shape(theta_map, (nobj, n_components)) assert_shape(theta_cov, (nobj, n_components, n_components)) # add more tets # run optimisation of design matrix giv @partial(jit, static_argnums=()) def loss_fn(params, data): M_T_new = params[0] # params is a list y, yinvvar, logyinvvar = data ( logfml, theta_map, theta_cov, ) = logmarglike_lineargaussianmodel_onetransfer_jitvmap( M_T_new, y, yinvvar, logyinvvar ) return -np.sum(logfml) M_T_new_initial = jax.random.normal(key, (nobj, n_components, n_pix_y)) param_list = [1 * M_T_new_initial] learning_rate = 1e-5 opt_init, opt_update, get_params = jax.experimental.optimizers.adam(learning_rate) opt_state = opt_init(param_list) @partial(jit, static_argnums=()) def update(step, opt_state, data): params = get_params(opt_state) value, grads = jax.value_and_grad(loss_fn)(params, data) opt_state = opt_update(step, grads, opt_state) return value, opt_state num_iterations = 10 # TODO: use better optimizer data = (y, yinvvar, logyinvvar) for step in range(num_iterations): # Could potentially also iterage over batches of data loss_value, opt_state = update(step, opt_state, data) # optimised matrix: M_T_new_optimised = get_params(opt_state)
def test_logmarglike_lineargaussianmodel_onetransfer_basics(): theta_truth = jax.random.normal(key, (n_components,)) M_T = design_matrix_polynomials(n_components, n_pix_y) # (n_components, n_pix_y) y_truth = np.matmul(theta_truth, M_T) # (nobj, n_pix_y) y, yinvvar, logyinvvar = make_masked_noisy_data(y_truth) # (nobj, n_pix_y) assert_equal_shape([y_truth, y, yinvvar, logyinvvar]) logfml, theta_map, theta_cov = logmarglike_lineargaussianmodel_onetransfer( M_T, y, yinvvar ) # check result is finite and shapes are correct assert_shape(theta_map, (n_components,)) assert_shape(theta_cov, (n_components, n_components)) assert np.isfinite(logfml) assert np.all(np.isfinite(theta_map)) assert np.all(np.isfinite(theta_cov)) # check that result isn't too far off the truth, in chi2 sense dt = theta_map - theta_truth chi2 = 0.5 * np.ravel(np.matmul(dt.T, np.linalg.solve(theta_cov, dt))) assert chi2 < 100 # check that normalised posterior distribution factorises into product of gaussians def log_posterior(theta): y_mod = np.matmul(theta, M_T) # (n_samples, n_pix_y) return batch_gaussian_loglikelihood(y_mod - y, yinvvar) def log_posterior2(theta): dt = theta - theta_map s, logdet = np.linalg.slogdet(theta_cov * 2 * np.pi) chi2 = np.dot(dt.T, np.linalg.solve(theta_cov, dt)) return logfml - 0.5 * (s * logdet + chi2) logpostv = log_posterior(theta_truth) logpostv2 = log_posterior2(theta_truth) assert abs(logpostv2 / logpostv - 1) < 0.01 # now trying jit version of function logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_onetransfer_jit( M_T, y, yinvvar, logyinvvar ) # check that outputs match original implementation assert_fml_thetamap_thetacov( logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, relative_accuracy ) # now running simple optimiser to check that result is indeed optimum def loss_fn(theta): return -log_posterior(theta) params = [1 * theta_map] learning_rate = 1e-5 for n in range(10): grads = grad(loss_fn)(*params) params = [param - learning_rate * grad for param, grad in zip(params, grads)] # print(n, loss_fn(*params), params[0] - theta_map) assert np.allclose(theta_map, params[0], rtol=1e-6) # Testing analytic covariance is correct and equals inverse of hessian theta_cov2 = np.linalg.inv(np.reshape(hessian(loss_fn)(theta_map), theta_cov.shape)) assert np.allclose(theta_cov, theta_cov2, rtol=1e-6) # create vectorised loss loss_fn_vmap = jit(vmap(loss_fn)) # now computes the evidence numerically n = 15 theta_std = np.diag(theta_cov) ** 0.5 theta_samples, vol_element = generate_sample_grid(theta_map, theta_std, n) loglikelihoods = -loss_fn_vmap(theta_samples) logfml_numerical = logsumexp(np.log(vol_element) + loglikelihoods) # print("logfml, logfml_numerical", logfml, logfml_numerical) assert abs(logfml_numerical / logfml - 1) < 0.01 # Compare with case including gaussian prior with large variance mu = theta_map * 0 muinvvar = 1 / (1e4 * np.diag(theta_cov) ** 0.5) logfml2, theta_map2, theta_cov2 = logmarglike_lineargaussianmodel_twotransfers( M_T, y, yinvvar, mu, muinvvar ) assert_fml_thetamap_thetacov( logfml, theta_map, theta_cov, logfml2, theta_map2, theta_cov2, 0.2 )
def test_sample_returns_batch(self): replay = make_replay() add(replay, [1, 2, 3]) sample_size = 2 samples, unused_indices, unused_weights = replay.sample(sample_size) chex.assert_shape(samples.value, (sample_size, ))
def test_sample(self): num_samples = 2 samples = self.replay.sample(num_samples) chex.assert_shape(samples.a, (num_samples, ))
def extract_patches_from_indicators(x, indicators, patch_size, patch_dropout, grid_shape, train, iterative=False): """Extract patches from a batch of images. Args: x: The batch of images of shape (batch, height, width, channels). indicators: The one hot indicators of shape (batch, num_patches, k). patch_size: The size of the (squared) patches to extract. patch_dropout: Probability to replace a patch by 0 values. grid_shape: Pair of height, width of the disposition of the num_patches patches. train: If the model is being trained. Disable dropout if not. iterative: If True, etracts the patches with a for loop rather than instanciating the "all patches" tensor and extracting by dotproduct with indicators. `iterative` is more memory efficient. Returns: The patches extracted from x with shape (batch, k, patch_size, patch_size, channels). """ batch_size, height, width, channels = x.shape scores_h, scores_w = grid_shape k = indicators.shape[-1] indicators = einops.rearrange(indicators, "b (h w) k -> b k h w", h=scores_h, w=scores_w) scale_height = height // scores_h scale_width = width // scores_w padded_height = scale_height * scores_h + patch_size - 1 padded_width = scale_width * scores_w + patch_size - 1 top_pad = (patch_size - scale_height) // 2 left_pad = (patch_size - scale_width) // 2 bottom_pad = padded_height - top_pad - height right_pad = padded_width - left_pad - width # TODO(jbcdnr): assert padding is positive. padded_x = jnp.pad(x, [(0, 0), (top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]) # Extract the patches. Iterative fits better in memory as it does not # instanciate the "all patches" tensor but iterate over them to compute the # weighted sum with the indicator variables from topk. if not iterative: assert patch_dropout == 0., "Patch dropout not implemented." patches = utils.extract_images_patches(padded_x, window_size=(patch_size, patch_size), stride=(scale_height, scale_width)) shape = (batch_size, scores_h, scores_w, patch_size, patch_size, channels) chex.assert_shape(patches, shape) patches = jnp.einsum("b k h w, b h w i j c -> b k i j c", indicators, patches) else: mask = jnp.ones((batch_size, scores_h, scores_w)) mask = nn.dropout(mask, patch_dropout, deterministic=not train) def accumulate_patches(acc, index_i_j): i, j = index_i_j patch = jax.lax.dynamic_slice( padded_x, (0, i * scale_height, j * scale_width, 0), (batch_size, patch_size, patch_size, channels)) weights = indicators[:, :, i, j] is_masked = mask[:, i, j] weighted_patch = jnp.einsum("b, bk, bijc -> bkijc", is_masked, weights, patch) chex.assert_equal_shape([acc, weighted_patch]) acc += weighted_patch return acc, None indices = jnp.stack(jnp.meshgrid(jnp.arange(scores_h), jnp.arange(scores_w), indexing="ij"), axis=-1) indices = indices.reshape((-1, 2)) init_patches = jnp.zeros( (batch_size, k, patch_size, patch_size, channels)) patches, _ = jax.lax.scan(accumulate_patches, init_patches, indices) return patches
def apply(self, x, *, patch_size, k, downscale, scorer_has_se, normalization_str="identity", selection_method, selection_method_kwargs=None, selection_method_inference=None, patch_dropout=0., hard_topk_probability=0., random_patch_probability=0., use_iterative_extraction, append_position_to_input, feature_network, aggregation_method, aggregation_method_kwargs=None, train): """Process a high resolution image by selecting a subset of useful patches. This model processes the input as follow: 1. Compute scores per patch on a downscaled version of the input. 2. Select "important" patches using sampling or top-k methods. 3. Extract the patches from the high-resolution image. 4. Compute representation vector for each patch with a feature network. 5. Aggregate the patch representation to obtain an image representation. Args: x: Input tensor of shape (batch, height, witdh, channels). patch_size: Size of the (squared) patches to extract. k: Number of patches to extract per image. downscale: Downscale multiplier for the input of the scorer network. scorer_has_se: Whether scorer network has Squeeze-excite layers. normalization_str: String specifying the normalization of the scores. selection_method: Method that selects which patches should be extracted, based on their scores. Either returns indices (hard selection) or indicators vectors (which could yield interpolated patches). selection_method_kwargs: Keyword args for the selection_method. selection_method_inference: Selection method used at inference. patch_dropout: Probability to replace a patch by 0 values. hard_topk_probability: Probability to use the true topk on the scores to select the patches. This operation has no gradient so scorer's weights won't be trained. random_patch_probability: Probability to replace each patch by a random patch in the image during training. use_iterative_extraction: If True, uses a for loop instead of patch indexing for memory efficiency. append_position_to_input: Append normalized (height, width) position to the channels of the input. feature_network: Network to be applied on each patch individually to obtain patch representation vectors. aggregation_method: Method to aggregate the representations of the k patches of each image to obtain the image representation. aggregation_method_kwargs: Keywords arguments for aggregation_method. train: If the model is being trained. Disable dropout otherwise. Returns: A representation vector for each image in the batch. """ selection_method = SelectionMethod(selection_method) aggregation_method = AggregationMethod(aggregation_method) if selection_method_inference: selection_method_inference = SelectionMethod( selection_method_inference) selection_method_kwargs = selection_method_kwargs or {} aggregation_method_kwargs = aggregation_method_kwargs or {} stats = {} # Compute new dimension of the scoring image. b, h, w, c = x.shape scoring_shape = (b, h // downscale, w // downscale, c) # === Compute the scores with a small CNN. if selection_method == SelectionMethod.RANDOM: scores_h, scores_w = Scorer.compute_output_size( h // downscale, w // downscale) num_patches = scores_h * scores_w else: # Downscale input to run scorer on. scoring_x = jax.image.resize(x, scoring_shape, method="bilinear") scores = Scorer(scoring_x, use_squeeze_excite=scorer_has_se, name="scorer") flatten_scores = einops.rearrange(scores, "b h w -> b (h w)") num_patches = flatten_scores.shape[-1] scores_h, scores_w = scores.shape[1:3] # Compute entropy before normalization prob_scores = jax.nn.softmax(flatten_scores) stats["entropy_before_normalization"] = jax.scipy.special.entr( prob_scores).sum(axis=1).mean(axis=0) # Normalize the flatten scores normalization_fn = create_normalization_fn(normalization_str) flatten_scores = normalization_fn(flatten_scores) scores = flatten_scores.reshape(scores.shape) stats["scores"] = scores[Ellipsis, None] # Concatenate height and width position to the input channels. if append_position_to_input: coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) c += 2 # Overwrite the selection method at inference if selection_method_inference and not train: selection_method = selection_method_inference # === Patch selection # Select the patches by sampling or top-k. Some methods returns the indices # of the selected patches, other methods return indicator vectors. extract_by_indices = selection_method in [ SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM ] if selection_method is SelectionMethod.SINKHORN_TOPK: indicators = select_patches_sinkhorn_topk( flatten_scores, k=k, **selection_method_kwargs) elif selection_method is SelectionMethod.PERTURBED_TOPK: sigma = selection_method_kwargs["sigma"] num_samples = selection_method_kwargs["num_samples"] sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats["sigma"] = sigma indicators = select_patches_perturbed_topk(flatten_scores, k=k, sigma=sigma, num_samples=num_samples) elif selection_method is SelectionMethod.HARD_TOPK: indices = select_patches_hard_topk(flatten_scores, k=k) elif selection_method is SelectionMethod.RANDOM: batch_random_indices_fn = jax.vmap( functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False)) indices = batch_random_indices_fn( jax.random.split(nn.make_rng(), b)) # Compute scores entropy for regularization if selection_method not in [SelectionMethod.RANDOM]: prob_scores = flatten_scores # Normalize the scores if it is not already done. if "softmax" not in normalization_str: prob_scores = jax.nn.softmax(prob_scores) stats["entropy"] = jax.scipy.special.entr(prob_scores).sum( axis=1).mean(axis=0) # Randomly use hard topk at training. if (train and hard_topk_probability > 0 and selection_method not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]): true_indices = select_patches_hard_topk(flatten_scores, k=k) random_values = jax.random.uniform(nn.make_rng(), (b, )) use_hard = random_values < hard_topk_probability if extract_by_indices: indices = jnp.where(use_hard[:, None], true_indices, indices) else: true_indicators = make_indicators(true_indices, num_patches) indicators = jnp.where(use_hard[:, None, None], true_indicators, indicators) # Sample some random patches during training with random_patch_probability. if (train and random_patch_probability > 0 and selection_method is not SelectionMethod.RANDOM): single_random_patches = functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False) random_indices = jax.vmap(single_random_patches)(jax.random.split( nn.make_rng(), b)) random_values = jax.random.uniform(nn.make_rng(), (b, k)) use_random = random_values < random_patch_probability if extract_by_indices: indices = jnp.where(use_random, random_indices, indices) else: random_indicators = make_indicators(random_indices, num_patches) indicators = jnp.where(use_random[:, None, :], random_indicators, indicators) # === Patch extraction if extract_by_indices: patches = extract_patches_from_indices(x, indices, patch_size=patch_size, grid_shape=(scores_h, scores_w)) indicators = make_indicators(indices, num_patches) else: patches = extract_patches_from_indicators( x, indicators, patch_size, grid_shape=(scores_h, scores_w), iterative=use_iterative_extraction, patch_dropout=patch_dropout, train=train) chex.assert_shape(patches, (b, k, patch_size, patch_size, c)) stats["extracted_patches"] = einops.rearrange( patches, "b k i j c -> b i (k j) c") # Remove position channels for plotting. if append_position_to_input: stats["extracted_patches"] = ( stats["extracted_patches"][Ellipsis, :-2]) # === Compute patch features flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c") representations = feature_network(flatten_patches, train=train) if representations.ndim > 2: collapse_axis = tuple(range(1, representations.ndim - 1)) representations = representations.mean(axis=collapse_axis) representations = einops.rearrange(representations, "(b k) d -> b k d", k=k) stats["patch_representations"] = representations # === Aggregate the k patches # - for sampling we are forced to take an expectation # - for topk we have multiple options: mean, max, transformer. if aggregation_method is AggregationMethod.TRANSFORMER: patch_pos_encoding = nn.Dense(einops.rearrange( indicators, "b d k -> b k d"), features=representations.shape[-1]) chex.assert_equal_shape([representations, patch_pos_encoding]) representations += patch_pos_encoding representations = transformer.Transformer( representations, **aggregation_method_kwargs, is_training=train) elif aggregation_method is AggregationMethod.MEANPOOLING: representations = representations.mean(axis=1) elif aggregation_method is AggregationMethod.MAXPOOLING: representations = representations.max(axis=1) elif aggregation_method is AggregationMethod.SUM_LAYERNORM: representations = representations.sum(axis=1) representations = nn.LayerNorm(representations) representations = nn.Dense(representations, features=representations.shape[-1], name="classification_dense1") representations = nn.swish(representations) return representations, stats
def __call__(self, queries: jnp.ndarray, hm_memory: HierarchicalMemory, hm_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: """Do hierarchical attention over the stored memories. Args: queries: Tensor [B, Q, E] Query(ies) in, for batch size B, query length Q, and embedding dimension E. hm_memory: Hierarchical Memory. hm_mask: Optional boolean mask tensor of shape [B, Q, M]. Where false, the corresponding query timepoints cannot attend to the corresponding memory chunks. This can be used for enforcing causal attention on the learner, not attending to memories from prior episodes, etc. Returns: Value updates for each query slot: [B, Q, D] """ # some shape checks batch_size, query_length, _ = queries.shape (memory_batch_size, num_memories, memory_chunk_size, mem_embbedding_size) = hm_memory.contents.shape assert batch_size == memory_batch_size chex.assert_shape(hm_memory.keys, (batch_size, num_memories, mem_embbedding_size)) chex.assert_shape( hm_memory.accumulator, (memory_batch_size, memory_chunk_size, mem_embbedding_size)) chex.assert_shape(hm_memory.steps_since_last_write, (memory_batch_size, )) if hm_mask is not None: chex.assert_type(hm_mask, bool) chex.assert_shape(hm_mask, (batch_size, query_length, num_memories)) query_head = self._singlehead_linear(queries, self._size, "query") key_head = self._singlehead_linear( jax.lax.stop_gradient(hm_memory.keys), self._size, "key") # What times in the input [t] attend to what times in the memories [T]. logits = jnp.einsum("btd,bTd->btT", query_head, key_head) scaled_logits = logits / np.sqrt(self._size) # Mask last dimension, replacing invalid logits with large negative values. # This allows e.g. enforcing causal attention on learner, or blocking # attention across episodes if hm_mask is not None: masked_logits = jnp.where(hm_mask, scaled_logits, -1e6) else: masked_logits = scaled_logits # identify the top-k memories and their relevance weights top_k_logits, top_k_indices = jax.lax.top_k(masked_logits, self._k) weights = jax.nn.softmax(top_k_logits) # set up the within-memory attention assert self._size % self._num_heads == 0 mha_key_size = self._size // self._num_heads attention_layer = hk.MultiHeadAttention(key_size=mha_key_size, model_size=self._size, num_heads=self._num_heads, w_init_scale=self._init_scale, name="within_mem_attn") # position encodings augmented_contents = hm_memory.contents if self._memory_position_encoding: position_embs = sinusoid_position_encoding(memory_chunk_size, mem_embbedding_size) augmented_contents += position_embs[None, None, :, :] def _within_memory_attention(sub_inputs, sub_memory_contents, sub_weights, sub_top_k_indices): top_k_contents = sub_memory_contents[sub_top_k_indices, :, :] # Now we go deeper, with another vmap over **tokens**, because each token # can each attend to different memories. def do_attention(sub_sub_inputs, sub_sub_top_k_contents): tiled_inputs = jnp.tile(sub_sub_inputs[None, None, :], reps=(self._k, 1, 1)) sub_attention_results = attention_layer( query=tiled_inputs, key=sub_sub_top_k_contents, value=sub_sub_top_k_contents) return sub_attention_results do_attention = hk_vmap(do_attention, in_axes=0, split_rng=False) attention_results = do_attention(sub_inputs, top_k_contents) attention_results = jnp.squeeze(attention_results, axis=2) # Now collapse results across k memories attention_results = sub_weights[:, :, None] * attention_results attention_results = jnp.sum(attention_results, axis=1) return attention_results # vmap across batch batch_within_memory_attention = hk_vmap(_within_memory_attention, in_axes=0, split_rng=False) outputs = batch_within_memory_attention( queries, jax.lax.stop_gradient(augmented_contents), weights, top_k_indices) return outputs
def test_bayesianpca_spec_and_specandphot(): n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer = 122, 100, 47, 5, 50 dataPipeline = DataPipeline.save_fake_data( n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer ) dataPipeline = DataPipeline("data/fake/fake_") batchsize = 20 data_batch = dataPipeline.next_batch_specandphot(dataPipeline.indices, batchsize) ( si, bs, batch_index_wave, batch_index_transfer_redshift, spec, spec_invvar, spec_loginvvar, # batch_spec_mask, specphotscaling, phot, phot_invvar, phot_loginvvar, batch_redshifts, batch_transferfunctions, batch_index_wave_ext, batch_interprightindices, batch_interpweights, ) = data_batch assert bs == batchsize assert si == 0 assert_shape(spec, (bs, n_pix_spec)) assert_shape(spec_invvar, (bs, n_pix_spec)) assert_shape(spec_loginvvar, (bs, n_pix_spec)) assert_shape(phot, (bs, n_pix_phot)) assert_shape(phot_invvar, (bs, n_pix_phot)) assert_shape(phot_loginvvar, (bs, n_pix_phot)) assert_shape(batch_redshifts, (bs,)) assert_shape(batch_transferfunctions, (bs, n_pix_sed, n_pix_phot)) n_components = 3 pcacomponents_speconly = jax.random.normal(key, (n_components, n_pix_sed))
def measure_valued_estimation_std(function: Callable[[chex.Array], float], dist: Any, rng: chex.PRNGKey, num_samples: int, coupling: bool = True) -> chex.Array: """Measure valued grads of a Gaussian expectation of `function` wrt the std. Args: function: Function f(x) for which to estimate grads_{std} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value. dist: a distribution on which we can call `sample`. rng: a PRNGKey key. num_samples: Int, the number of samples used to compute the grads. coupling: A boolean. Whether or not to use coupling for the positive and negative samples. Recommended: True, as this reduces variance. Returns: A `num_samples x D` vector containing the estimates of the gradients obtained for each sample. The mean of this vector can be used to update the scale parameter. The entire vector can be used to assess estimator variance. """ mean, log_std = dist.params std = jnp.exp(log_std) dist_samples = dist.sample((num_samples, ), seed=rng) pos_rng, neg_rng = jax.random.split(rng) # The only difference between mean and std gradients is what we sample. pos_sample = jax.random.double_sided_maxwell(pos_rng, loc=0.0, scale=1.0, shape=dist_samples.shape) if coupling: unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape) neg_sample = unif_rvs * pos_sample else: neg_sample = jax.random.normal(neg_rng, dist_samples.shape) # Both need to be positive in the case of the scale. # N x D positive_diag = mean + std * pos_sample # N x D negative_diag = mean + std * neg_sample # NOTE: you can sample base samples here if you use the same rng # Duplicate the D dimension - N x D x D. base_dist_samples = utils.tile_second_to_last_dim(dist_samples) positive = utils.set_diags(base_dist_samples, positive_diag) negative = utils.set_diags(base_dist_samples, negative_diag) # Different C for the scale c = std # D # Apply function. We apply the function to each element of N x D x D. # We apply a function that takes a sample and returns one number, so the # output will be N x D (which is what we want, batch by dimension). # We apply a function in parallel to the batch. # Broadcast the division. vmaped_function = jax.vmap(jax.vmap(function, 1, 0)) grads = (vmaped_function(positive) - vmaped_function(negative)) / c chex.assert_shape(grads, (num_samples, ) + std.shape) return grads
def test_bayesianpca_photonly(): n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer = 122, 100, 47, 5, 50 dataPipeline = DataPipeline.save_fake_data( n_obj, n_pix_sed, n_pix_spec, n_pix_phot, n_pix_transfer ) dataPipeline = DataPipeline("data/fake/fake_", phot=True, spec=False) batchsize = 20 data_batch = dataPipeline.next_batch_photonly(dataPipeline.indices, batchsize) ( si, bs, phot, phot_invvar, phot_loginvvar, batch_redshifts, transferfunctions, batch_interprightindices_transfer, batch_interpweights_transfer, ) = data_batch assert bs == batchsize assert si == 0 assert_shape(phot, (bs, n_pix_phot)) assert_shape(phot_invvar, (bs, n_pix_phot)) assert_shape(phot_loginvvar, (bs, n_pix_phot)) assert_shape(batch_redshifts, (bs,)) assert_shape(transferfunctions, (n_pix_transfer, n_pix_sed, n_pix_phot)) assert_shape(batch_interprightindices_transfer, (bs,)) assert_shape(batch_interpweights_transfer, (bs,))
def test_logits_shape(self, img_resolution, num_classes, stage_sizes): model = BoTNet(num_classes=num_classes, stage_sizes=stage_sizes) rng = dict(params=random.PRNGKey(0)) x = jnp.ones((2, img_resolution, img_resolution, 3)) logits, _ = model.init_with_output(rng, x, is_training=True) chex.assert_shape(logits, (2, num_classes))
def main(argv): """Trains Prioritized DQN agent on Atari.""" del argv logging.info('Prioritized DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.double_dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation chex.assert_shape(sample_network_input, (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames)) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) # Note the t in the replay is not exactly aligned with the agent t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * int(FLAGS.num_train_frames / FLAGS.num_action_repeats)), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.PrioritizedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. checkpoint = parts.NullCheckpoint() state = checkpoint.state state.iteration = 0 state.train_agent = train_agent state.eval_agent = eval_agent state.random_state = random_state state.writer = writer if checkpoint.can_be_restored(): checkpoint.restore() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_trackers = parts.make_default_trackers(train_agent) train_stats = parts.generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration %d.', state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_trackers = parts.make_default_trackers(eval_agent) eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', state.iteration, '%3d'), ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('train_state_value', train_stats['state_value'], '%.3f'), ('importance_sampling_exponent', train_agent.importance_sampling_exponent, '%.3f'), ('max_seen_priority', train_agent.max_seen_priority, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ] log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def apply(self, x, config, num_classes, train=True): """Creates a model definition.""" b, c = x.shape[0], x.shape[3] k = config.k sigma = config.ptopk_sigma num_samples = config.ptopk_num_samples sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats = {"x": x, "sigma": sigma} feature_extractor = models.ResNet50.shared(train=train, name="ResNet_0") rpn_feature = feature_extractor(x) rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature), communication=Communication( config.communication), train=train) stats.update(rpn_stats) # rpn_scores are a list of score images. We keep track of the structure # because it is used in the aggregation step later-on. rpn_scores_shapes = [s.shape for s in rpn_scores] rpn_scores_flat = jnp.concatenate( [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1) top_k_indicators = sample_patches.select_patches_perturbed_topk( rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples) top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1]) offset = 0 weights = [] for sh in rpn_scores_shapes: cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]] cur = jnp.reshape(cur, [b, k, sh[1], sh[2]]) weights.append(cur) offset += sh[1] * sh[2] chex.assert_equal(offset, top_k_indicators.shape[-1]) part_imgs = weighted_anchor_aggregator(x, weights) chex.assert_shape(part_imgs, (b * k, 224, 224, c)) stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c]) part_features = feature_extractor(part_imgs) part_features = jnp.mean(part_features, axis=[1, 2]) # GAP the spatial dims part_features = nn.dropout( # features from parts jnp.reshape(part_features, [b * k, 2048]), 0.5, deterministic=not train, rng=nn.make_rng()) features = nn.dropout( # features from whole image jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]), 0.5, deterministic=not train, rng=nn.make_rng()) # Mean pool all part features, add it to features and predict logits. concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]), axis=1) + features concat_logits = nn.Dense(concat_out, num_classes) raw_logits = nn.Dense(features, num_classes) part_logits = jnp.reshape(nn.Dense(part_features, num_classes), [b, k, -1]) all_logits = { "raw_logits": raw_logits, "concat_logits": concat_logits, "part_logits": part_logits, } # add entropy into it for entropy regularization. stats["rpn_scores_entropy"] = jax.scipy.special.entr( jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0) return all_logits, stats
def load_spectrophotometry( self, input_dir="./", write_subset=False, use_subset=False, subsampling=1, spec=True, phot=True, ): if use_subset: suffix = "2.npy" else: suffix = ".npy" self.input_dir = input_dir self.lamgrid = onp.load(self.input_dir + "lamgrid.npy") self.lam_phot_eff = onp.load(self.input_dir + "lam_phot_eff.npy") self.lam_phot_size_eff = onp.load(self.input_dir + "lam_phot_size_eff.npy") self.redshifts = onp.load(self.input_dir + "redshifts" + suffix) n_obj = self.redshifts.size self.n_obj = n_obj assert_shape(self.redshifts, (n_obj, )) if phot: self.transferfunctions = ( onp.load(self.input_dir + "transferfunctions.npy") * 1e-16) self.transferfunctions_zgrid = onp.load( self.input_dir + "transferfunctions_zgrid.npy") assert self.transferfunctions.shape[ 0] == self.transferfunctions_zgrid.size assert self.transferfunctions.shape[1] == self.lamgrid.size self.index_transfer_redshift = onp.load(self.input_dir + "index_transfer_redshift" + suffix) self.interprightindices_transfer = onp.load( self.input_dir + "interprightindices_transfer" + suffix) self.interpweights_transfer = onp.load(self.input_dir + "interpweights_transfer" + suffix) self.phot = fluxes = onp.load(self.input_dir + "phot" + suffix) self.phot_invvar = flux_ivars = onp.load(self.input_dir + "phot_invvar" + suffix) assert_shape(self.index_transfer_redshift, (n_obj, )) assert_shape(self.interprightindices_transfer, (n_obj, )) assert_shape(self.interpweights_transfer, (n_obj, )) n_pix_phot = self.phot.shape[1] assert_shape(self.phot, (n_obj, n_pix_phot)) assert_shape(self.phot_invvar, (n_obj, n_pix_phot)) self.n_pix_phot = self.phot.shape[1] if spec: self.chi2s_sdss = onp.load(self.input_dir + "chi2s_sdss" + suffix) self.lamspec_waveoffset = int( onp.load(self.input_dir + "lamspec_waveoffset" + suffix)) self.index_wave = onp.load(self.input_dir + "index_wave" + suffix) self.interprightindices = onp.load(self.input_dir + "interprightindices" + suffix) self.interpweights = onp.load(self.input_dir + "interpweights" + suffix) self.specmod_sdss = onp.load(self.input_dir + "spec_mod" + suffix) if True: self.spec = onp.load(self.input_dir + "spec" + suffix) self.spec_invvar = onp.load(self.input_dir + "spec_invvar" + suffix) else: self.spec = onp.load(self.input_dir + "spec_mod" + suffix) self.spec_invvar = ( onp.load(self.input_dir + "spec_invvar" + suffix) * 0 + 1) self.n_pix_spec = self.spec.shape[1] assert_shape(self.chi2s_sdss, (n_obj, )) assert_shape(self.index_wave, (n_obj, )) n_pix_spec = self.spec.shape[1] assert_shape(self.spec, (n_obj, n_pix_spec)) assert_shape(self.specmod_sdss, (n_obj, n_pix_spec)) assert_shape(self.spec_invvar, (n_obj, n_pix_spec)) if write_subset: M = 50000 suffix = "2.npy" self.index_wave = self.index_wave[:M] self.redshifts = self.redshifts[:M] self.chi2s_sdss = self.chi2s_sdss[:M] self.phot_invvar = self.phot_invvar[:M, :] self.index_transfer_redshift = self.index_transfer_redshift[:M] np.save(self.input_dir + "index_wave" + suffix, self.index_wave[:M]) np.save( self.input_dir + "interprightindices_transfer" + suffix, self.interprightindices_transfer[:M, :], ) np.save( self.input_dir + "interpweights_transfer" + suffix, self.interpweights_transfer[:M, :], ) np.save( self.input_dir + "index_transfer_redshift2.npy", self.index_transfer_redshift, ) np.save(self.input_dir + "redshifts" + suffix, self.redshifts) np.save(self.input_dir + "spec" + suffix, self.spec) np.save(self.input_dir + "chi2s_sdss" + suffix, self.chi2s_sdss) np.save(self.input_dir + "spec_invvar" + suffix, self.spec_invvar) np.save(self.input_dir + "phot" + suffix, self.phot) np.save(self.input_dir + "phot_invvar" + suffix, self.phot_invvar) np.save(self.input_dir + "spec_mod" + suffix, self.specmod_sdss) if subsampling > 1: self.lamgrid = self.lamgrid[::subsampling] self.transferfunctions = self.transferfunctions[:, :: subsampling, :][:: subsampling, :, :] self.transferfunctions_zgrid = self.transferfunctions_zgrid[:: subsampling] self.lamspec_waveoffset = self.lamspec_waveoffset // subsampling self.spec = self.spec[:, ::subsampling] self.specmod_sdss = self.specmod_sdss[:, ::subsampling] self.spec_invvar = self.spec_invvar[:, ::subsampling] self.index_wave = self.index_wave // subsampling self.index_transfer_redshift = self.index_transfer_redshift // subsampling self.interprightindices = ( self.interprightindices[:, ::subsampling] // subsampling) self.interpweights = (self.interpweights[:, ::subsampling] / subsampling) # dilution self.interprightindices_transfer = ( self.interprightindices_transfer // subsampling ) # is it correct? self.interpweights_transfer = (self.interpweights_transfer / subsampling) # is it correct?
def epipolar_projection( self, key, # pylint: disable=unused-argument target_rays, ref_worldtocamera, intrinsic_matrix, image_height, image_width, min_depth, max_depth, ): """** Visually verified in colab. Function to map given rays on the epipolar line on reference view at mutiple depth values. (The depth value are often set to near and far plane of camera). Args: key: prngkey target_rays: The rays that we want to project onto nearby cameras. Often these are the rays that we want to render contains origins, directions shape (#bs, rays, 3) ref_worldtocamera: The worldtocamera matrix of nearby cameras that we want to project onto intrinsic_matrix: (1, 3, 4), The intrinsic matrix for the datset image_height: image height image_width: image width min_depth: min depth of projection max_depth: max depth of projection Returns: pcoords: (#near_cam, batch_size, num_projections, 2) valid_proj_mask: (#near_cam, batch_size, num_projections) specifying with of the projections are valid i.e. in front of camera and within image bound wcoords: (batch_size, num_projections, 3) """ # Check shape of intrincs, currently we only support case where all the # views are from the same camera chex.assert_shape(intrinsic_matrix, (1, 3, 4)) projection_depths = jnp.linspace(min_depth, max_depth, self.num_samples) # Compute the world coordinates for each ray for each depth values # wcoords has shape (#rays, num_projections, 3) wcoords = ( target_rays.origins[:, None] + target_rays.directions[:, None] * projection_depths[Ellipsis, None]) # Convert to homogeneous coordinates # (#rays, num_samples, 3) -> (#rays, num_samples, 4) wcoords = jnp.concatenate( [wcoords, jnp.ones_like(wcoords[Ellipsis, 0:1])], axis=-1) pcoords, proj_frontof_cam_mask = self.project2camera( wcoords, ref_worldtocamera, intrinsic_matrix) # Find projections that are inside the image within_image_mask = self.inside_image(pcoords, image_height, image_width) # Clip coordinates to be within the images pcoords = jnp.concatenate([ jnp.clip(pcoords[Ellipsis, 0:1], 0, image_height - 1), jnp.clip(pcoords[Ellipsis, 1:], 0, image_width - 1) ], axis=-1) # A projection is valid if it is in front of camera and within the image # boundaries valid_proj_mask = proj_frontof_cam_mask * within_image_mask return pcoords, valid_proj_mask, wcoords[Ellipsis, :3]