Esempio n. 1
0
    def test_vmap_not_batched(self):
        x = 3.

        def func(y):
            # x is not mapped, y is mapped
            _, y = hcb.id_print((x, y), output_stream=testing_stream)
            return x + y

        vmap_func = api.vmap(func)
        vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
                    func=_print
                    transforms=(('batch', (None, 0)),) ] 3.00 a
      d = add c 3.00
  in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
        with hcb.outfeed_receiver():
            _ = vmap_func(vargs)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
  [4.00 5.00] ]""", testing_stream.output)
        testing_stream.reset()
Esempio n. 2
0
    def test_vmap(self):
        vmap_fun1 = api.vmap(fun1)
        vargs = np.array([np.float32(4.), np.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  batch_dims=(0,)
                  func=_print
                  transforms=('batch',)
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    batch_dims=(0, 0)
                    func=_print
                    nr_untapped=1
                    transforms=('batch',)
                    what=y * 3 ] d c
      g = pow f 2.00
  in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
        with hcb.outfeed_receiver():
            res_vmap = vmap_fun1(vargs)
        assertMultiLineStrippedEqual(
            self, """
batch_dims: (0,) transforms: ('batch',) what: a * 2
[ 8.00 10.00]
batch_dims: (0, 0) transforms: ('batch',) what: y * 3
[24.00 30.00]""", testing_stream.output)
        testing_stream.reset()
Esempio n. 3
0
 def update(state):
     data, p_, e_, C_, mu, alpha, iters, _ = state
     x, y = data
     mu = np.float32(mu)
     alpha_ = np.float32(alpha)
     #
     J = jacobian(p_, x, y)
     H = J.T @ J
     Je = J.T @ e_ + alpha_ * p_
     I = np.diag_indices_from(H)
     #
     dp = solve(H.at[I].add(alpha_ + mu), Je, sym_pos=True)
     p = p_ - dp
     e = error(p, x, y)
     C = (sum_squares(e) + alpha * sum_squares(p)) / 2
     rho = (C_ - C) / (dp.T @ (mu * dp + Je))
     #
     mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu)
     #
     bad_step = (rho < rho_min) | np.any(np.isnan(p))
     mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu)
     p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p))
     e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e))
     #
     sse = sum_squares(e)
     ssp = sum_squares(p)
     C = np.where(bad_step, C_, C)
     improved = (C_ > C) | bad_step
     #
     bundle = (alpha, H, I, sse, ssp, x.size)
     alpha, *_ = cond(bad_step, lambda t: t, update_hyperparams, bundle)
     C = (sse + alpha * ssp) / 2
     #
     return LevenbergMarquardtBRState(data, p, e, C, mu, alpha,
                                      iters + ~bad_step, improved)
Esempio n. 4
0
 def _bl_update(H, C, R, state):
     G, (α, _), μ, τ = state
     tr_inv_H = np.trace(solve(H, I, sym_pos="sym"))
     γ = n - α * tr_inv_H
     α = np.float32(n / (2 * R + tr_inv_H))
     β = np.float32((x.shape[0] - γ) / (2 * C))
     return G, (α, β), μ, τ
Esempio n. 5
0
    def test_vmap(self):
        vmap_fun1 = api.vmap(fun1)
        vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  func=_print
                  transforms=(('batch', (0,)),)
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=(('batch', (0, 0)),)
                    what=y * 3 ] d c
      g = integer_pow[ y=2 ] f
  in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
        with hcb.outfeed_receiver():
            _ = vmap_fun1(vargs)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
[ 8.00 10.00]
transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
[24.00 30.00]""", testing_stream.output)
        testing_stream.reset()
Esempio n. 6
0
def get_datasets(name):
  """Load train and test datasets into memory."""
  ds_builder = tfds.builder(name)
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds
Esempio n. 7
0
class LevenbergMaquardtBayes(
        namedtuple(
            "LevenbergMaquardtBayes",
            ("μi", "μs", "μmin", "μmax"),
            defaults=(
                np.float32(0.005),  # μi
                np.float32(10),  # μs
                np.float32(5e-16),  # μmin
                np.float32(1e10)  # μmax
            ))):
    pass
Esempio n. 8
0
def test_fmatrix():
    sed = FMatrix(['dustmbb', 'syncpl', 'cmb'])
    parameters = {
        'nu': np.array([27., 39., 93., 145., 225., 280.]),
        'nu_ref_d': np.float32(353),
        'nu_ref_s': np.float32(23.),
        'beta_d': np.float32(1.5),
        'beta_s': np.float32(-3.),
        'T_d': np.float32(20)
    }
    evalu = sed(**parameters)
    return
Esempio n. 9
0
  def test_vmap(self):
    vmap_fun1 = api.vmap(fun1)
    vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
    #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs)))
    with hcb.outfeed_receiver():
      _ = vmap_fun1(vargs)
    assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
[ 8.00 10.00]
transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
[24.00 30.00]""", testing_stream.output)
    testing_stream.reset()
Esempio n. 10
0
    def initialize_prior(
        self, ) -> Callable[[Tuple[Any, Any]], List[jnp.ndarray]]:
        f32_prior_mean, f32_prior_cov = (
            jnp.float32(self.prior_mean),
            jnp.float32(self.prior_cov),
        )

        def prior_fn(shape):
            prior_mean = jnp.ones(shape) * f32_prior_mean
            prior_cov = jnp.ones(shape) * f32_prior_cov
            return [prior_mean, prior_cov]

        return prior_fn
Esempio n. 11
0
def MSAWeight_PB(msa):
    gap_idx = msa.abc.charmap['-']
    q = msa.abc.q
    ax = msa.ax
    (N, L) = ax.shape

    ## step 1: get counts:

    c = np.sum(msa.ax_1hot, axis=0)

    # set gap counts to 0
    c = index_update(c, index[:, gap_idx], 0)

    # get N x L array with count value for corresponding residue in alignment
    # first, get  N x L "column id" array (convenient for vmap)
    # col_id[n,i] = i
    col_id = np.int16(np.tensordot(np.ones(N), np.arange(L), axes=0))
    # ax_c[n, i] = c[i, ax[n,i]]
    ax_c = Get_Henikoff_Counts_Residue(col_id, ax, c)

    ## step 2: get number of unique characters in each column
    r = np.float32(np.sum(np.array(c > 0), axis=1))

    # transform r from Lx1 array to NxL array, where r2[n,i] = r[i])
    # will allow for easy elementwise operations with ax_c
    r2 = np.tensordot(np.ones(N), r, axes=0)

    ## step 3: get ungapped seq lengths
    nongap = np.array(ax != gap_idx)
    l = np.float32(np.sum(nongap, axis=1))

    ## step 4: calculate unnormalized weights
    # get array of main terms in Henikoff sum
    #wgt_un[n,i] = 1 / (r_[i] * c[i, ax[n,i] ])
    wgt_un = np.reciprocal(np.multiply(ax_c, r2))

    # set all terms involving  gap to zero
    wgt_un = np.nan_to_num(np.multiply(wgt_un, nongap))

    # sum accoss all positions to get prelim unnormalized weight for each sequence
    wgt_un = np.sum(wgt_un, axis=1)

    # divide by gapless sequence length
    wgt_un = np.divide(wgt_un, l)

    # step 4: Normalize sequence wieghts
    wgt = (wgt_un * np.float32(N)) / np.sum(wgt_un)
    msa.wgt = wgt

    return
Esempio n. 12
0
    def test_jvp(self):
        jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, ))
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a b.
  let c = mul a 2.00
      d = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=a * 2 ] c
      e = mul d 3.00
      f g = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    what=y * 3 ] e d
      h = mul g g
      i = mul b 2.00
      j k = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=('jvp',)
                    what=a * 2 ] i d
      l = mul j 3.00
      m n o = id_tap[ arg_treedef=*
                      func=_print
                      nr_untapped=2
                      transforms=('jvp',)
                      what=y * 3 ] l j f
      p = mul n g
      q = mul g n
      r = add_any p q
  in (h, r) }""",
            str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1))))
        with hcb.outfeed_receiver():
            res_primals, res_tangents = jvp_fun1(jnp.float32(5.),
                                                 jnp.float32(0.1))
        self.assertAllClose(100., res_primals, check_dtypes=False)
        self.assertAllClose(4., res_tangents, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
transforms: ('jvp',) what: a * 2
0.20
what: y * 3
30.00
transforms: ('jvp',) what: y * 3
0.60""", testing_stream.output)
        testing_stream.reset()
Esempio n. 13
0
def Prior_Laplace(f1, f2, N, C):

    (L, q) = f1.shape
    qf = np.float32(q)
    Nf = np.float32(N)

    # new normalization: 1 / (eff. seq. number)
    nrm = 1. / (Nf + C)

    # binary L x q x L x q term: keeps us from adding pseudocounts for f_ii
    no_diag = np.reshape(1 - np.eye(L), (L, 1, L, 1))

    f1_prior = nrm * ((C / qf) + Nf * f1)
    f2_prior = nrm * (((C / (qf * qf)) * no_diag) + Nf * f2)

    return f1_prior, f2_prior
Esempio n. 14
0
    def test_grad_double(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * (y * 3.)

        grad_func = api.grad(api.grad(func))
        # Just making the Jaxpr invokes the id_print twice
        _ = api.make_jaxpr(grad_func)(5.)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00""", testing_stream.output)
        testing_stream.reset()
        res_grad = grad_func(jnp.float32(5.))

        self.assertAllClose(12., res_grad, check_dtypes=False)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00""", testing_stream.output)
        testing_stream.reset()
Esempio n. 15
0
    def test_grad_simple(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * hcb.id_print(
                y * 3., what="y * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.)))

        res_grad = grad_func(jnp.float32(5.))
        self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3
5.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00""", testing_stream.output)
        testing_stream.reset()
def align_examples(rng, x, x_index, y):
  """Random alignment based on labels.

  Randomly aligns items in x with items in y (and it's not a one-to-one map).
  The only costraint is that it tries to align items that have the same value.
  In the LTC Task, we pass labels to this function so that the alignment is
  based on labels.

  Here x and y are matrices where the rows are features to be compared
  elementwise for the alignment, and x_indice is the index of the batch
  position of x, needed for vmap.

  Args:
    rng: an array of jax PRNG keys.
    x: jnp.array; Matrix of shape `[N, M]`.
    x_index: jnp.array; Vector of shape `[N,]`.
    y: jnp.array; Matrix of shape `[N, M]`.

  Returns:
    indices of aligned pairs.
  """
  x = jnp.array(x)
  x_index = jnp.array(x_index)
  y = jnp.array(y)

  y_indices = jnp.arange(len(y))
  shuffled_y_idx = jax.random.permutation(rng, y_indices)
  equalities = jnp.float32(x == y[shuffled_y_idx])
  aligned_idx = jnp.argmax(equalities)

  return x_index, shuffled_y_idx[aligned_idx]
Esempio n. 17
0
def target_m_dqn(model, target_network, states, next_states, actions, rewards,
                 terminals, cumulative_gamma, tau, alpha, clip_value_min):
    """Compute the target Q-value. Munchausen DQN"""

    #----------------------------------------
    q_state_values = jax.vmap(target_network, in_axes=(0))(states).q_values
    q_state_values = jnp.squeeze(q_state_values)

    next_q_values = jax.vmap(target_network, in_axes=(0))(next_states).q_values
    next_q_values = jnp.squeeze(next_q_values)
    #----------------------------------------

    tau_log_pi_next = stable_scaled_log_softmax(next_q_values, tau, axis=1)
    pi_target = stable_softmax(next_q_values, tau, axis=1)
    replay_log_policy = stable_scaled_log_softmax(q_state_values, tau, axis=1)

    #----------------------------------------

    replay_next_qt_softmax = jnp.sum(
        (next_q_values - tau_log_pi_next) * pi_target, axis=1)

    replay_action_one_hot = nn.one_hot(actions, q_state_values.shape[-1])
    tau_log_pi_a = jnp.sum(replay_log_policy * replay_action_one_hot, axis=1)

    #a_max=1
    tau_log_pi_a = jnp.clip(tau_log_pi_a, a_min=clip_value_min, a_max=1)

    munchausen_term = alpha * tau_log_pi_a
    modified_bellman = (rewards + munchausen_term +
                        cumulative_gamma * replay_next_qt_softmax *
                        (1. - jnp.float32(terminals)))

    return jax.lax.stop_gradient(modified_bellman)
Esempio n. 18
0
 def update(state):
     data, p_, e_, C_, mu, iters, _ = state
     x, y = data
     mu = np.float32(mu)
     #
     J = jacobian(p_, x, y)
     H = damped_hessian(J, mu)
     Je = jac_err_prod(J, e_, p_)
     #
     dp = solve(H, Je, sym_pos=True)
     p = p_ - dp
     e = error(p, x, y)
     C = cost(e, p)
     rho = (C_ - C) / (dp.T @ (mu * dp + Je))
     #
     mu = np.where(rho > rho_c, np.maximum(mu / c, mu_min), mu)
     #
     bad_step = (rho < rho_min) | np.any(np.isnan(p))
     mu = np.where(bad_step, np.minimum(c * mu, mu_max), mu)
     p = cond(bad_step, lambda t: t[0], lambda t: t[1], (p_, p))
     e = cond(bad_step, lambda t: t[0], lambda t: t[1], (e_, e))
     C = np.where(bad_step, C_, C)
     improved = (C_ > C) | bad_step
     #
     return LevenbergMarquardtState(data, p, e, C, mu, iters + ~bad_step,
                                    improved)
Esempio n. 19
0
    def test_grad_primal_unused(self):
        # The output of id_print is not needed for backwards pass
        def func(x):
            return 2. * hcb.id_print(
                x * 3., what="x * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        with hcb.outfeed_receiver():
            assertMultiLineStrippedEqual(
                self, """
{ lambda  ; a.
  let
  in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))

        # Just making the Jaxpr invokes the id_print once
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()

        with hcb.outfeed_receiver():
            res_grad = grad_func(jnp.float32(5.))

        self.assertAllClose(6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 3
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()
Esempio n. 20
0
def multinomial_mode(
    distribution_or_probs: Union[tfd.Distribution, jnp.DeviceArray]
) -> jnp.DeviceArray:
    """Calculates the (one-hot) mode of a multinomial distribution.

  Args:
    distribution_or_probs:
      `tfp.distributions.Distribution` | List[tensors].
      If the former, it is assumed that it has a `probs` property, and
      represents a distribution over categories. If the latter, these are
      taken to be the probabilities of categories directly.
      In either case, it is assumed that `probs` will be shape
      (batch_size, dim).

  Returns:
    `DeviceArray`, float32, (batch_size, dim).
    The mode of the distribution - this will be in one-hot form, but contain
    multiple non-zero entries in the event that more than one probability is
    joint-highest.
  """
    if isinstance(distribution_or_probs, tfd.Distribution):
        probs = distribution_or_probs.probs_parameter()
    else:
        probs = distribution_or_probs
    max_prob = jnp.max(probs, axis=1, keepdims=True)
    mode = jnp.int32(jnp.equal(probs, max_prob))
    return jnp.float32(mode / jnp.sum(mode, axis=1, keepdims=True))
Esempio n. 21
0
 def testCondTypeErrors(self):
   """Test typing error messages for  cond."""
   with self.assertRaisesRegex(TypeError,
       re.escape("Pred type must be either boolean or number, got <function")):
     lax.cond(lambda x: True,
              1., lambda top: 1., 2., lambda fop: 2.)
   with self.assertRaisesRegex(TypeError,
       re.escape("Pred type must be either boolean or number, got foo.")):
     lax.cond("foo",
              1., lambda top: 1., 2., lambda fop: 2.)
   with self.assertRaisesRegex(TypeError,
       re.escape("Pred must be a scalar, got (1.0, 1.0) of shape (2,).")):
     lax.cond((1., 1.),
              1., lambda top: 1., 2., lambda fop: 2.)
   with self.assertRaisesRegex(TypeError,
       re.escape("true_fun and false_fun output must have same type structure, got * and PyTreeDef(tuple, [*,*]).")):
     lax.cond(True,
              1., lambda top: 1., 2., lambda fop: (2., 2.))
   with self.assertRaisesWithLiteralMatch(
       TypeError,
       "true_fun and false_fun output must have identical types, got\n"
       "ShapedArray(float32[1])\n"
       "and\n"
       "ShapedArray(float32[])."):
     lax.cond(True,
              1., lambda top: np.array([1.], np.float32),
              2., lambda fop: np.float32(1.))
Esempio n. 22
0
    def test_grad_primal_unused(self):
        raise SkipTest("broken by omnistaging")  # TODO(mattjj,gnecula): update

        # The output of id_print is not needed for backwards pass
        def func(x):
            return 2. * hcb.id_print(
                x * 3., what="x * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        jaxpr = str(api.make_jaxpr(grad_func)(5.))
        # Just making the Jaxpr invokes the id_print once
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let
  in (6.00,) }""", jaxpr)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()

        res_grad = grad_func(jnp.float32(5.))
        hcb.barrier_wait()

        self.assertAllClose(6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 3
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()
Esempio n. 23
0
def mnist_images():
    import tensorflow_datasets as tfds
    prep = lambda d: np.reshape(
        np.float32(next(tfds.as_numpy(d))['image']) / 256, (-1, 784))
    dataset = tfds.load("mnist:1.0.0")
    return (prep(dataset['train'].shuffle(50000).batch(50000)),
            prep(dataset['test'].batch(10000)))
Esempio n. 24
0
def get_logit_snip_masks(params, nn_density_level, predict, x_batch, batch_input_shape, GlOBAL_PRUNE_BOOL = True):
    
    def norm_square_logits(params, f, x): 
        return np.sum(f(params, x) **2)
    
    init_grads = grad(norm_square_logits)(params, predict, x_batch.reshape(batch_input_shape) ) 

    thres_list = [None] * len(params)

    if GlOBAL_PRUNE_BOOL == True: # global pruning

        cs = [abs( init_grads[idx][0] *  params[idx][0]).flatten() for idx in range(len(params)) if len(params[idx]) == 2 ]

        pooled_cs = np.hstack(cs)

        idx = int( (1 - nn_density_level) * len(pooled_cs) )

        # threshold: entries which below the thredhold will be removed
        thres = np.sort(pooled_cs)[idx]
        thres_list = [thres] * len(params)
        
    else: # layerwise pruning
        for layer_index in range( len(params)):
            if len(params[layer_index]) == 2:

                cs = abs( init_grads[layer_index][0] *  params[layer_index][0]).flatten()
                idx = int( (1 - nn_density_level) * len(cs) )
                # threshold: entries which below the thredhold will be removed
                thres = np.sort(cs)[idx]            
                thres_list[layer_index] = thres

            

    masks = []
    for layer_index in range( len(params)):

        if len(params[layer_index]) < 2:
            # In this the case, the layer does not contain weight and bias parameters.
            masks.append( [] )

        elif len(params[layer_index]) == 2:
            # In this case, the layer contains a tuple of parameters for weights and biases

            weights = params[layer_index][0]

            weights_grad = init_grads[layer_index][0]

            layer_cs = np.abs(weights * weights_grad)

            # 0 selected for weight parameters with magnitudes smaller than the threshold, 1 otherwise
            this_mask = np.float32(layer_cs >= thres_list[layer_index])

            masks.append(this_mask ) 


        else:
            raise NotImplementedError
            
    return masks
Esempio n. 25
0
  def test_vmap_not_batched(self):
    x = 3.
    def func(y):
      # x is not mapped, y is mapped
      _, y = hcb.id_print((x, y), output_stream=testing_stream)
      return x + y

    vmap_func = api.vmap(func)
    vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
    #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs)))
    with hcb.outfeed_receiver():
      _ = vmap_func(vargs)
    assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
  [4.00 5.00] ]""", testing_stream.output)
    testing_stream.reset()
Esempio n. 26
0
 def _lm_update(θ, H, Je, y, Λ, state):
     α, β = Λ
     p = θ - solve(H + state.μ * I, Je, sym_pos="sym").T
     e = errors(p, x, y)
     C = obj.cost(e)
     R = obj.regularizer(θ)
     G = np.float32(β * C + α * R)
     return LMState(p, e, G, C, R, state.μ * μs)
 def _setup_toy_data(self, n=32768):
   x = jnp.float32(jnp.arange(n))
   rng = random.PRNGKey(0)
   rng, key = random.split(rng)
   values = random.normal(key, shape=[n])
   rng, key = random.split(rng)
   tangents = random.normal(key, shape=[n])
   return x, values, tangents
def sample_random_powerlaw(key, N, power):
    coords = np.float32(
        np.fft.ifftshift(1 + N // 2 -
                         np.abs(np.fft.fftshift(np.arange(N)) - N // 2)))
    decay_vec = coords**-power
    decay_vec = onp.array(decay_vec)
    decay_vec[N // 4:] = 0
    return sample_random_signal(key, decay_vec)
Esempio n. 29
0
def test_mnist_data_load():
    def mean_pixels(i, mean_pix):
        batch, _ = fetch(i, idx)
        return mean_pix + jnp.sum(batch) / batch.size

    init, fetch = load_dataset(MNIST, batch_size=128, split='train')
    num_batches, idx = init()
    assert fori_loop(0, num_batches, mean_pixels, jnp.float32(0.)) / num_batches < 0.15
Esempio n. 30
0
def mnist():
    import tensorflow_datasets as tfds
    dataset = tfds.load("mnist:1.0.0")
    images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 784))
    labels = lambda d: _one_hot(d['label'], 10)
    train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000)))
    test = next(tfds.as_numpy(dataset['test'].batch(10000)))
    return images(train), labels(train), images(test), labels(test)