def test_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped _, y = hcb.id_print((x, y), output_stream=testing_stream) return x + y vmap_func = api.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) assertMultiLineStrippedEqual( self, """ { lambda ; a. let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*]) func=_print transforms=(('batch', (None, 0)),) ] 3.00 a d = add c 3.00 in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs))) with hcb.outfeed_receiver(): _ = vmap_func(vargs) assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) [ 3.00 [4.00 5.00] ]""", testing_stream.output) testing_stream.reset()
def test_vmap(self): vmap_fun1 = api.vmap(fun1) vargs = np.array([np.float32(4.), np.float32(5.)]) assertMultiLineStrippedEqual( self, """ { lambda ; a. let b = mul a 2.00 c = id_tap[ arg_treedef=* batch_dims=(0,) func=_print transforms=('batch',) what=a * 2 ] b d = mul c 3.00 e f = id_tap[ arg_treedef=* batch_dims=(0, 0) func=_print nr_untapped=1 transforms=('batch',) what=y * 3 ] d c g = pow f 2.00 in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs))) with hcb.outfeed_receiver(): res_vmap = vmap_fun1(vargs) assertMultiLineStrippedEqual( self, """ batch_dims: (0,) transforms: ('batch',) what: a * 2 [ 8.00 10.00] batch_dims: (0, 0) transforms: ('batch',) what: y * 3 [24.00 30.00]""", testing_stream.output) testing_stream.reset()
def update(state): data, p_, e_, C_, mu, alpha, iters, _ = state x, y = data mu = np.float32(mu) alpha_ = np.float32(alpha) # J = jacobian(p_, x, y) H = J.T @ J Je = J.T @ e_ + alpha_ * p_ I = np.diag_indices_from(H) # dp = solve(H.at[I].add(alpha_ + mu), Je, sym_pos=True) p = p_ - dp e = error(p, x, y) C = (sum_squares(e) + alpha * sum_squares(p)) / 2 rho = (C_ - C) / (dp.T @ (mu * dp + Je)) # mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu) # bad_step = (rho < rho_min) | np.any(np.isnan(p)) mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu) p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p)) e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e)) # sse = sum_squares(e) ssp = sum_squares(p) C = np.where(bad_step, C_, C) improved = (C_ > C) | bad_step # bundle = (alpha, H, I, sse, ssp, x.size) alpha, *_ = cond(bad_step, lambda t: t, update_hyperparams, bundle) C = (sse + alpha * ssp) / 2 # return LevenbergMarquardtBRState(data, p, e, C, mu, alpha, iters + ~bad_step, improved)
def _bl_update(H, C, R, state): G, (α, _), μ, τ = state tr_inv_H = np.trace(solve(H, I, sym_pos="sym")) γ = n - α * tr_inv_H α = np.float32(n / (2 * R + tr_inv_H)) β = np.float32((x.shape[0] - γ) / (2 * C)) return G, (α, β), μ, τ
def test_vmap(self): vmap_fun1 = api.vmap(fun1) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) assertMultiLineStrippedEqual( self, """ { lambda ; a. let b = mul a 2.00 c = id_tap[ arg_treedef=* func=_print transforms=(('batch', (0,)),) what=a * 2 ] b d = mul c 3.00 e f = id_tap[ arg_treedef=* func=_print nr_untapped=1 transforms=(('batch', (0, 0)),) what=y * 3 ] d c g = integer_pow[ y=2 ] f in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs))) with hcb.outfeed_receiver(): _ = vmap_fun1(vargs) assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2 [ 8.00 10.00] transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3 [24.00 30.00]""", testing_stream.output) testing_stream.reset()
def get_datasets(name): """Load train and test datasets into memory.""" ds_builder = tfds.builder(name) ds_builder.download_and_prepare() train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) train_ds['image'] = jnp.float32(train_ds['image']) / 255. test_ds['image'] = jnp.float32(test_ds['image']) / 255. return train_ds, test_ds
class LevenbergMaquardtBayes( namedtuple( "LevenbergMaquardtBayes", ("μi", "μs", "μmin", "μmax"), defaults=( np.float32(0.005), # μi np.float32(10), # μs np.float32(5e-16), # μmin np.float32(1e10) # μmax ))): pass
def test_fmatrix(): sed = FMatrix(['dustmbb', 'syncpl', 'cmb']) parameters = { 'nu': np.array([27., 39., 93., 145., 225., 280.]), 'nu_ref_d': np.float32(353), 'nu_ref_s': np.float32(23.), 'beta_d': np.float32(1.5), 'beta_s': np.float32(-3.), 'T_d': np.float32(20) } evalu = sed(**parameters) return
def test_vmap(self): vmap_fun1 = api.vmap(fun1) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs))) with hcb.outfeed_receiver(): _ = vmap_fun1(vargs) assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2 [ 8.00 10.00] transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3 [24.00 30.00]""", testing_stream.output) testing_stream.reset()
def initialize_prior( self, ) -> Callable[[Tuple[Any, Any]], List[jnp.ndarray]]: f32_prior_mean, f32_prior_cov = ( jnp.float32(self.prior_mean), jnp.float32(self.prior_cov), ) def prior_fn(shape): prior_mean = jnp.ones(shape) * f32_prior_mean prior_cov = jnp.ones(shape) * f32_prior_cov return [prior_mean, prior_cov] return prior_fn
def MSAWeight_PB(msa): gap_idx = msa.abc.charmap['-'] q = msa.abc.q ax = msa.ax (N, L) = ax.shape ## step 1: get counts: c = np.sum(msa.ax_1hot, axis=0) # set gap counts to 0 c = index_update(c, index[:, gap_idx], 0) # get N x L array with count value for corresponding residue in alignment # first, get N x L "column id" array (convenient for vmap) # col_id[n,i] = i col_id = np.int16(np.tensordot(np.ones(N), np.arange(L), axes=0)) # ax_c[n, i] = c[i, ax[n,i]] ax_c = Get_Henikoff_Counts_Residue(col_id, ax, c) ## step 2: get number of unique characters in each column r = np.float32(np.sum(np.array(c > 0), axis=1)) # transform r from Lx1 array to NxL array, where r2[n,i] = r[i]) # will allow for easy elementwise operations with ax_c r2 = np.tensordot(np.ones(N), r, axes=0) ## step 3: get ungapped seq lengths nongap = np.array(ax != gap_idx) l = np.float32(np.sum(nongap, axis=1)) ## step 4: calculate unnormalized weights # get array of main terms in Henikoff sum #wgt_un[n,i] = 1 / (r_[i] * c[i, ax[n,i] ]) wgt_un = np.reciprocal(np.multiply(ax_c, r2)) # set all terms involving gap to zero wgt_un = np.nan_to_num(np.multiply(wgt_un, nongap)) # sum accoss all positions to get prelim unnormalized weight for each sequence wgt_un = np.sum(wgt_un, axis=1) # divide by gapless sequence length wgt_un = np.divide(wgt_un, l) # step 4: Normalize sequence wieghts wgt = (wgt_un * np.float32(N)) / np.sum(wgt_un) msa.wgt = wgt return
def test_jvp(self): jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, )) assertMultiLineStrippedEqual( self, """ { lambda ; a b. let c = mul a 2.00 d = id_tap[ arg_treedef=* func=_print nr_untapped=0 what=a * 2 ] c e = mul d 3.00 f g = id_tap[ arg_treedef=* func=_print nr_untapped=1 what=y * 3 ] e d h = mul g g i = mul b 2.00 j k = id_tap[ arg_treedef=* func=_print nr_untapped=1 transforms=('jvp',) what=a * 2 ] i d l = mul j 3.00 m n o = id_tap[ arg_treedef=* func=_print nr_untapped=2 transforms=('jvp',) what=y * 3 ] l j f p = mul n g q = mul g n r = add_any p q in (h, r) }""", str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1)))) with hcb.outfeed_receiver(): res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) self.assertAllClose(100., res_primals, check_dtypes=False) self.assertAllClose(4., res_tangents, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: a * 2 10.00 transforms: ('jvp',) what: a * 2 0.20 what: y * 3 30.00 transforms: ('jvp',) what: y * 3 0.60""", testing_stream.output) testing_stream.reset()
def Prior_Laplace(f1, f2, N, C): (L, q) = f1.shape qf = np.float32(q) Nf = np.float32(N) # new normalization: 1 / (eff. seq. number) nrm = 1. / (Nf + C) # binary L x q x L x q term: keeps us from adding pseudocounts for f_ii no_diag = np.reshape(1 - np.eye(L), (L, 1, L, 1)) f1_prior = nrm * ((C / qf) + Nf * f1) f2_prior = nrm * (((C / (qf * qf)) * no_diag) + Nf * f2) return f1_prior, f2_prior
def test_grad_double(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * (y * 3.) grad_func = api.grad(api.grad(func)) # Just making the Jaxpr invokes the id_print twice _ = api.make_jaxpr(grad_func)(5.) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00""", testing_stream.output) testing_stream.reset() res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(12., res_grad, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00""", testing_stream.output) testing_stream.reset()
def test_grad_simple(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * hcb.id_print( y * 3., what="y * 3", output_stream=testing_stream) grad_func = api.grad(func) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.))) res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 what: y * 3 30.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3 5.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00""", testing_stream.output) testing_stream.reset()
def align_examples(rng, x, x_index, y): """Random alignment based on labels. Randomly aligns items in x with items in y (and it's not a one-to-one map). The only costraint is that it tries to align items that have the same value. In the LTC Task, we pass labels to this function so that the alignment is based on labels. Here x and y are matrices where the rows are features to be compared elementwise for the alignment, and x_indice is the index of the batch position of x, needed for vmap. Args: rng: an array of jax PRNG keys. x: jnp.array; Matrix of shape `[N, M]`. x_index: jnp.array; Vector of shape `[N,]`. y: jnp.array; Matrix of shape `[N, M]`. Returns: indices of aligned pairs. """ x = jnp.array(x) x_index = jnp.array(x_index) y = jnp.array(y) y_indices = jnp.arange(len(y)) shuffled_y_idx = jax.random.permutation(rng, y_indices) equalities = jnp.float32(x == y[shuffled_y_idx]) aligned_idx = jnp.argmax(equalities) return x_index, shuffled_y_idx[aligned_idx]
def target_m_dqn(model, target_network, states, next_states, actions, rewards, terminals, cumulative_gamma, tau, alpha, clip_value_min): """Compute the target Q-value. Munchausen DQN""" #---------------------------------------- q_state_values = jax.vmap(target_network, in_axes=(0))(states).q_values q_state_values = jnp.squeeze(q_state_values) next_q_values = jax.vmap(target_network, in_axes=(0))(next_states).q_values next_q_values = jnp.squeeze(next_q_values) #---------------------------------------- tau_log_pi_next = stable_scaled_log_softmax(next_q_values, tau, axis=1) pi_target = stable_softmax(next_q_values, tau, axis=1) replay_log_policy = stable_scaled_log_softmax(q_state_values, tau, axis=1) #---------------------------------------- replay_next_qt_softmax = jnp.sum( (next_q_values - tau_log_pi_next) * pi_target, axis=1) replay_action_one_hot = nn.one_hot(actions, q_state_values.shape[-1]) tau_log_pi_a = jnp.sum(replay_log_policy * replay_action_one_hot, axis=1) #a_max=1 tau_log_pi_a = jnp.clip(tau_log_pi_a, a_min=clip_value_min, a_max=1) munchausen_term = alpha * tau_log_pi_a modified_bellman = (rewards + munchausen_term + cumulative_gamma * replay_next_qt_softmax * (1. - jnp.float32(terminals))) return jax.lax.stop_gradient(modified_bellman)
def update(state): data, p_, e_, C_, mu, iters, _ = state x, y = data mu = np.float32(mu) # J = jacobian(p_, x, y) H = damped_hessian(J, mu) Je = jac_err_prod(J, e_, p_) # dp = solve(H, Je, sym_pos=True) p = p_ - dp e = error(p, x, y) C = cost(e, p) rho = (C_ - C) / (dp.T @ (mu * dp + Je)) # mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu) # bad_step = (rho < rho_min) | np.any(np.isnan(p)) mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu) p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p)) e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e)) C = np.where(bad_step, C_, C) improved = (C_ > C) | bad_step # return LevenbergMarquardtState(data, p, e, C, mu, iters + ~bad_step, improved)
def test_grad_primal_unused(self): # The output of id_print is not needed for backwards pass def func(x): return 2. * hcb.id_print( x * 3., what="x * 3", output_stream=testing_stream) grad_func = api.grad(func) with hcb.outfeed_receiver(): assertMultiLineStrippedEqual( self, """ { lambda ; a. let in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.))) # Just making the Jaxpr invokes the id_print once assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset() with hcb.outfeed_receiver(): res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(6., res_grad, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: x * 3 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset()
def multinomial_mode( distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray] ) -> jnp.DeviceArray: """Calculates the (one-hot) mode of a multinomial distribution. Args: distribution_or_probs: `tfp.distributions.Distribution` | List[tensors]. If the former, it is assumed that it has a `probs` property, and represents a distribution over categories. If the latter, these are taken to be the probabilities of categories directly. In either case, it is assumed that `probs` will be shape (batch_size, dim). Returns: `DeviceArray`, float32, (batch_size, dim). The mode of the distribution - this will be in one-hot form, but contain multiple non-zero entries in the event that more than one probability is joint-highest. """ if isinstance(distribution_or_probs, tfd.Distribution): probs = distribution_or_probs.probs_parameter() else: probs = distribution_or_probs max_prob = jnp.max(probs, axis=1, keepdims=True) mode = jnp.int32(jnp.equal(probs, max_prob)) return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
def testCondTypeErrors(self): """Test typing error messages for cond.""" with self.assertRaisesRegex(TypeError, re.escape("Pred type must be either boolean or number, got <function")): lax.cond(lambda x: True, 1., lambda top: 1., 2., lambda fop: 2.) with self.assertRaisesRegex(TypeError, re.escape("Pred type must be either boolean or number, got foo.")): lax.cond("foo", 1., lambda top: 1., 2., lambda fop: 2.) with self.assertRaisesRegex(TypeError, re.escape("Pred must be a scalar, got (1.0, 1.0) of shape (2,).")): lax.cond((1., 1.), 1., lambda top: 1., 2., lambda fop: 2.) with self.assertRaisesRegex(TypeError, re.escape("true_fun and false_fun output must have same type structure, got * and PyTreeDef(tuple, [*,*]).")): lax.cond(True, 1., lambda top: 1., 2., lambda fop: (2., 2.)) with self.assertRaisesWithLiteralMatch( TypeError, "true_fun and false_fun output must have identical types, got\n" "ShapedArray(float32[1])\n" "and\n" "ShapedArray(float32[])."): lax.cond(True, 1., lambda top: np.array([1.], np.float32), 2., lambda fop: np.float32(1.))
def test_grad_primal_unused(self): raise SkipTest("broken by omnistaging") # TODO(mattjj,gnecula): update # The output of id_print is not needed for backwards pass def func(x): return 2. * hcb.id_print( x * 3., what="x * 3", output_stream=testing_stream) grad_func = api.grad(func) jaxpr = str(api.make_jaxpr(grad_func)(5.)) # Just making the Jaxpr invokes the id_print once hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ { lambda ; a. let in (6.00,) }""", jaxpr) assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset() res_grad = grad_func(jnp.float32(5.)) hcb.barrier_wait() self.assertAllClose(6., res_grad, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: x * 3 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset()
def mnist_images(): import tensorflow_datasets as tfds prep = lambda d: np.reshape( np.float32(next(tfds.as_numpy(d))['image']) / 256, (-1, 784)) dataset = tfds.load("mnist:1.0.0") return (prep(dataset['train'].shuffle(50000).batch(50000)), prep(dataset['test'].batch(10000)))
def get_logit_snip_masks(params, nn_density_level, predict, x_batch, batch_input_shape, GlOBAL_PRUNE_BOOL = True): def norm_square_logits(params, f, x): return np.sum(f(params, x) **2) init_grads = grad(norm_square_logits)(params, predict, x_batch.reshape(batch_input_shape) ) thres_list = [None] * len(params) if GlOBAL_PRUNE_BOOL == True: # global pruning cs = [abs( init_grads[idx][0] * params[idx][0]).flatten() for idx in range(len(params)) if len(params[idx]) == 2 ] pooled_cs = np.hstack(cs) idx = int( (1 - nn_density_level) * len(pooled_cs) ) # threshold: entries which below the thredhold will be removed thres = np.sort(pooled_cs)[idx] thres_list = [thres] * len(params) else: # layerwise pruning for layer_index in range( len(params)): if len(params[layer_index]) == 2: cs = abs( init_grads[layer_index][0] * params[layer_index][0]).flatten() idx = int( (1 - nn_density_level) * len(cs) ) # threshold: entries which below the thredhold will be removed thres = np.sort(cs)[idx] thres_list[layer_index] = thres masks = [] for layer_index in range( len(params)): if len(params[layer_index]) < 2: # In this the case, the layer does not contain weight and bias parameters. masks.append( [] ) elif len(params[layer_index]) == 2: # In this case, the layer contains a tuple of parameters for weights and biases weights = params[layer_index][0] weights_grad = init_grads[layer_index][0] layer_cs = np.abs(weights * weights_grad) # 0 selected for weight parameters with magnitudes smaller than the threshold, 1 otherwise this_mask = np.float32(layer_cs >= thres_list[layer_index]) masks.append(this_mask ) else: raise NotImplementedError return masks
def test_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped _, y = hcb.id_print((x, y), output_stream=testing_stream) return x + y vmap_func = api.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs))) with hcb.outfeed_receiver(): _ = vmap_func(vargs) assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) [ 3.00 [4.00 5.00] ]""", testing_stream.output) testing_stream.reset()
def _lm_update(θ, H, Je, y, Λ, state): α, β = Λ p = θ - solve(H + state.μ * I, Je, sym_pos="sym").T e = errors(p, x, y) C = obj.cost(e) R = obj.regularizer(θ) G = np.float32(β * C + α * R) return LMState(p, e, G, C, R, state.μ * μs)
def _setup_toy_data(self, n=32768): x = jnp.float32(jnp.arange(n)) rng = random.PRNGKey(0) rng, key = random.split(rng) values = random.normal(key, shape=[n]) rng, key = random.split(rng) tangents = random.normal(key, shape=[n]) return x, values, tangents
def sample_random_powerlaw(key, N, power): coords = np.float32( np.fft.ifftshift(1 + N // 2 - np.abs(np.fft.fftshift(np.arange(N)) - N // 2))) decay_vec = coords**-power decay_vec = onp.array(decay_vec) decay_vec[N // 4:] = 0 return sample_random_signal(key, decay_vec)
def test_mnist_data_load(): def mean_pixels(i, mean_pix): batch, _ = fetch(i, idx) return mean_pix + jnp.sum(batch) / batch.size init, fetch = load_dataset(MNIST, batch_size=128, split='train') num_batches, idx = init() assert fori_loop(0, num_batches, mean_pixels, jnp.float32(0.)) / num_batches < 0.15
def mnist(): import tensorflow_datasets as tfds dataset = tfds.load("mnist:1.0.0") images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 784)) labels = lambda d: _one_hot(d['label'], 10) train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000))) test = next(tfds.as_numpy(dataset['test'].batch(10000))) return images(train), labels(train), images(test), labels(test)