Ejemplo n.º 1
0
def nearest_sampler(imgs, coords, mask_value):
    """Construct a new image by nearest sampling from the input image.
    Points falling outside the source image boundary have value of mask_value.
    Args:
        imgs: source image to be sampled from [b, h, w, c]
        coords: coordinates of source pixels to sample from [b, h, w, 2].
            height_t/width_t correspond to the dimensions of the output
            image (don't need to be the same as height_s/width_s).
            The two channels correspond to x and y coordinates respectively.
        mask_value: value of points outside of image. -1 for edge sampling.
        Returns:
            A new sampled image [height_t, width_t, channels]
    """
    coords_x, coords_y = jnp.split(coords, 2, axis=2)
    inp_size = imgs.shape
    out_size = list(coords.shape)
    out_size[2] = imgs.shape[2]

    coords_x = jnp.array(coords_x, dtype='float32')
    coords_y = jnp.array(coords_y, dtype='float32')

    y_max = jnp.array(jnp.shape(imgs)[0] - 1, dtype='float32')
    x_max = jnp.array(jnp.shape(imgs)[1] - 1, dtype='float32')
    zero = jnp.zeros([1], dtype='float32')
    eps = jnp.array([0.5], dtype='float32')

    coords_x_clipped = jnp.clip(coords_x, zero - eps, x_max + eps)
    coords_y_clipped = jnp.clip(coords_y, zero - eps, y_max + eps)

    x0 = jnp.round(coords_x_clipped)
    y0 = jnp.round(coords_y_clipped)

    x0_safe = jnp.clip(x0, zero, x_max)
    y0_safe = jnp.clip(y0, zero, y_max)

    # indices in the flat image to sample from
    dim2 = jnp.array(inp_size[1], dtype='float32')

    base_y0 = y0_safe * dim2
    idx00 = jnp.reshape(x0_safe + base_y0, [-1])

    # sample from imgs
    imgs_flat = jnp.reshape(imgs, [-1, inp_size[2]])
    imgs_flat = imgs_flat.astype('float32')
    output = jnp.reshape(
        jnp.take(imgs_flat, idx00.astype('int32'), axis=0),
        out_size
    )

    return jnp.where(
        jnp.any(mask_value > 0),
        jnp.where(
            compute_mask(coords_x, coords_y, x_max, y_max),
            output,
            jnp.ones_like(output) *
            jnp.reshape(jnp.array(mask_value), [1, 1, -1])
        ),
        output)
def loss_fn(param_dict, signal):
    
    params = param_dict["nn"]
    sigma = param_dict["s"]
    hf = 1
    N = int(jnp.round(6*sigma))
    # Adding some more noise during training to prevent classifier from overfitting on irrelevant aspects of the spectra
    signal = signal + 0.2*np.random.randn(signal.shape[0])
    x = diff_stft(signal, s = sigma,hf = hf)

    li = []
    l1 = jnp.array([[1,0]])
    l2 = jnp.array([[0,1]])
    l_c = []
    for i in range(x.shape[1]):
        timi = i*int(hf*N)/fs
        d1 = np.min(np.abs(I1 - timi))
        d2 = np.min(np.abs(I2 - timi))
        if(d1  < d2):
            li.append(1)
            l_c.append(l1)
        else:
            li.append(2)
            l_c.append(l2)

    li = np.array(li)
    l_c = np.concatenate(l_c,axis = 0).T

    xzp = jnp.concatenate([x,jnp.zeros((Nzp - (N//2 + 1),x.shape[1]))],axis = 0)
    logits = net.apply(params,xzp.T)

    # Regularized loss (Cross entropy + regularizer to avoid small windows)
    cel = -jnp.mean(logits*l_c.T) + (0.1/sigma)
    
    return cel
Ejemplo n.º 3
0
    def testLossAndGradientsAreFinite(self):
        # Test that the loss and its approximation both give finite losses and
        # derivatives everywhere that they should for a wide range of values.
        num_samples = 100000
        rng = random.PRNGKey(0)

        # Normally distributed inputs.
        rng, key = random.split(rng)
        x = random.normal(key, shape=[num_samples])

        # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1
        # to ensure that we hit the special cases at 0, 2.
        rng, key = random.split(rng)
        alpha = jnp.round(
            random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *
            10) / 10.

        # Random log-normally distributed values in approx (1e-5, 100000):
        rng, key = random.split(rng)
        scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5

        fn = self.variant(general.lossfun)
        loss = fn(x, alpha, scale)
        d_x, d_alpha, d_scale = (jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)),
                                          [0, 1, 2])(x, alpha, scale))

        for v in [loss, d_x, d_alpha, d_scale]:
            chex.assert_tree_all_finite(v)
Ejemplo n.º 4
0
def create_stepped_learning_rate_fn(base_learning_rate,
                                    steps_per_epoch,
                                    lr_sched_steps,
                                    warmup_length=0.0):
    """Create a stepped learning rate function.
    Args:
    base_learning_rate: base learning rate
    steps_per_epoch: number of steps per epoch
    lr_sched_steps: learning rate schedule as a list of pairs where each
      pair is `[step, lr_factor]`
    warmup_length: linear LR warmup length; 0 for no warmup
    Returns:
    function of the form f(step) -> learning_rate
    """
    boundaries = [step[0] for step in lr_sched_steps]
    decays = [step[1] for step in lr_sched_steps]
    boundaries = jnp.array(boundaries) * steps_per_epoch
    boundaries = jnp.round(boundaries).astype(jnp.int32)
    values = jnp.array([1.0] + decays) * base_learning_rate

    def step_fn(step):
        lr = piecewise_constant(boundaries, values, step)
        if warmup_length > 0.0:
            lr = lr * jnp.minimum(
                1., step / float(warmup_length) / steps_per_epoch)
        return lr

    return step_fn
Ejemplo n.º 5
0
    def input_wavefront(self, wavelength=1e-6):
        """Create a Wavefront object suitable for sending through a given optical system.

        Uses self.source_offset to assign an off-axis tilt, if requested.
        (FIXME does not work for Fresnel yet)

        Parameters
        ----------
        wavelength : float
            Wavelength in meters

        Returns
        -------
        wavefront : morphine.fresnel.FresnelWavefront instance
            A wavefront appropriate for passing through this optical system.

        """
        oversample = int(np.round(1 / self.beam_ratio))
        inwave = FresnelWavefront(self.pupil_diameter / 2,
                                  wavelength=wavelength,
                                  npix=self.npix,
                                  oversample=oversample)
        # _log.debug(
        # "Creating input wavefront with wavelength={0} microns,"
        # "npix={1}, diam={3}, pixel scale={2}".format(
        #     wavelength * 1e6, self.npix, self.pupil_diameter / (self.npix), self.pupil_diameter
        # ))
        inwave._display_hint_expected_nplanes = len(
            self)  # For displaying a multi-step calculation nicely
        return inwave
Ejemplo n.º 6
0
 def topk_mask_internal(value):
     assert value.ndim == 1
     indices = jnp.argsort(value)
     k = jnp.round(density_fraction * jnp.size(value)).astype(jnp.int32)
     mask = jnp.greater_equal(np.arange(value.size), value.size - k)
     mask = jnp.zeros_like(mask).at[indices].set(mask)
     return mask.astype(np.int32)
Ejemplo n.º 7
0
def Epot(pos, *args):
    M, L = args
    pos = pos.reshape((3, M))
    energy = 0
    for i in range(M - 1):
        for j in list(range(i + 1, M)):
            deltaX = pos[0, i] - pos[0, j]
            deltaXmi = deltaX - L * np.round(deltaX / L)
            deltaY = pos[1, i] - pos[1, j]
            deltaYmi = deltaY - L * np.round(deltaY / L)
            deltaZ = pos[2, i] - pos[2, j]
            deltaZmi = deltaZ - L * np.round(deltaZ / L)
            r = np.linalg.norm([deltaXmi, deltaYmi, deltaZmi])
            energy += Vlj(r)

    return energy
Ejemplo n.º 8
0
 def get_attn():
     return stax.GlobalSelfAttention(
         n_chan_out=width,
         n_chan_key=width,
         n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
         n_heads=int(np.sqrt(width)),
     ) if proj == 'avg' else stax.Identity()
Ejemplo n.º 9
0
    def test_logistic_regression(self):
        key = random.PRNGKey(0)

        N, n = 5, 2

        key, k1, k2, k3 = random.split(key, num=4)
        X_np = random.normal(k1, shape=(N, n))
        a_true = random.normal(k2, shape=(n, 1))
        y_np = jnp.round(
            sigmoid(X_np @ a_true + random.normal(k3, shape=(N, 1)) * 0.5))

        X_jax = jnp.array(X_np)
        lam_jax = 0.1 * jnp.ones(1)

        a = cp.Variable((n, 1))
        X = cp.Parameter((N, n))
        lam = cp.Parameter(1, nonneg=True)
        y = y_np

        log_likelihood = cp.sum(
            cp.multiply(y, X @ a) - cp.log_sum_exp(
                cp.hstack([np.zeros((N,
                                     1)), X @ a]).T, axis=0, keepdims=True).T)
        prob = cp.Problem(
            cp.Minimize(-log_likelihood + lam * cp.sum_squares(a)))

        fit_logreg = CvxpyLayer(prob, [X, lam], [a])

        check_grads(fit_logreg, (X_jax, lam_jax), order=1, modes=['rev'])
Ejemplo n.º 10
0
def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
  """ The Hermitian matrix `H` is split into two matrices `Hm`
  `Hp`, respectively sharing its eigenspaces beneath and above
  its `split_point`th eigenvalue.

  Returns, in addition, `Vm` and `Vp`, isometries such that
  `Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
  returned instead; this allows the overall isometries mapping from
  an initial input matrix to progressively smaller blocks to be formed.

  Args:
    H: The Hermitian matrix to split.
    split_point: The eigenvalue to split along.
    V0: Matrix of isometries to be updated.
    precision: TPU matmul precision.
  Returns:
    Hm: A Hermitian matrix sharing the eigenvalues of `H` beneath
      `split_point`.
    Vm: An isometry from the input space of `V0` to `Hm`.
    Hp: A Hermitian matrix sharing the eigenvalues of `H` above
      `split_point`.
    Vp: An isometry from the input space of `V0` to `Hp`.
  """
  def _fill_diagonal(X, vals):
    return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals)

  H_shift = _fill_diagonal(H, H.diagonal() - split_point)
  U, _ = jsp.linalg.polar_unitary(H_shift)
  P = -0.5 * _fill_diagonal(U, U.diagonal() - 1.)
  rank = jnp.round(jnp.trace(P)).astype(jnp.int32)
  rank = int(rank)
  return _split_spectrum_jittable(P, H, V0, rank, precision)
Ejemplo n.º 11
0
def grassman_distance(y1, y2):
  """Grassman distance between subspaces spanned by Y1 and Y2."""
  q1, _ = jnp.linalg.qr(y1)
  q2, _ = jnp.linalg.qr(y2)

  _, sigma, _ = jnp.linalg.svd(q1.T @ q2)
  sigma = jnp.round(sigma, decimals=6)
  return jnp.linalg.norm(jnp.arccos(sigma))
Ejemplo n.º 12
0
def compute_grassman_distance(Y1, Y2):
    """Grassman distance between subspaces spanned by Y1 and Y2."""
    Q1, _ = jnp.linalg.qr(Y1)
    Q2, _ = jnp.linalg.qr(Y2)

    _, sigma, _ = jnp.linalg.svd(Q1.T @ Q2)
    sigma = jnp.round(sigma, decimals=6)
    return jnp.linalg.norm(jnp.arccos(sigma))
Ejemplo n.º 13
0
 def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
   if onp.issubdtype(dtype, onp.integer) and decimals < 0:
     self.skipTest("Integer rounding with decimals < 0 not implemented")
   onp_fun = lambda x: onp.round(x, decimals=decimals)
   lnp_fun = lambda x: lnp.round(x, decimals=decimals)
   args_maker = lambda: [rng(shape, dtype)]
   self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
Ejemplo n.º 14
0
 def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
     onp_fun = lambda x: onp.round(x, decimals=decimals)
     lnp_fun = lambda x: lnp.round(x, decimals=decimals)
     args_maker = lambda: [rng(shape, dtype)]
     self._CheckAgainstNumpy(onp_fun,
                             lnp_fun,
                             args_maker,
                             check_dtypes=True)
     self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
Ejemplo n.º 15
0
def grassman_distance(Y1, Y2):  # pylint: disable=invalid-name
  """Grassman distance between subspaces spanned by Y1 and Y2."""
  Q1, _ = jnp.linalg.qr(Y1)  # pylint: disable=invalid-name
  Q2, _ = jnp.linalg.qr(Y2)  # pylint: disable=invalid-name

  _, sigma, _ = jnp.linalg.svd(Q1.T @ Q2)
  # sigma = jnp.clip(sigma, -1., 1.)
  sigma = jnp.round(sigma, decimals=6)
  return jnp.linalg.norm(jnp.arccos(sigma))
Ejemplo n.º 16
0
def jax_invech(v):
    '''
    Inverse half vectorization operator
    '''
    rows = int(jnp.round(.5 * (-1 + jnp.sqrt(1 + 8 * len(v)))))
    res = jnp.zeros((rows, rows))
    res = jax.ops.index_update(res, jnp.triu_indices(rows), v)
    res = res + res.T - jnp.diag(jnp.diag(res))
    return res
Ejemplo n.º 17
0
def save_images(batch, fname):
    n_rows = batch.shape[0] // 16
    batch = onp.uint8(jnp.round((batch + 1) * 127.5))
    out = onp.full((1 + 33 * n_rows, 1 + 33 * 16, 3), 255, 'uint8')
    for i, im in enumerate(batch):
        top = 1 + 33 * (i // 16)
        left = 1 + 33 * (i % 16)
        out[top:top + 32, left:left + 32] = im
    Image.fromarray(out).save(fname)
Ejemplo n.º 18
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std,
                   parameterization=parameterization)
    affine = conv(width) if is_conv else fc(width)
    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')
    ave_pool = stax.AvgPool((2, 3), None,
                            'SAME' if padding == 'SAME' else 'CIRCULAR')
    ave_pool_or_identity = ave_pool if use_pooling else stax.Identity()
    dropout_or_identity = dropout if use_dropout else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm))
    res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity,
                           affine)
    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Ejemplo n.º 19
0
def predict(params,
            state,
            action_field=None,
            action_size=2,
            action_layer=[2, 3]):
    """
    Predict the next state give the args:
        params: network parameters
        state:  current state

    Returns:
        predicted_state: next state prediction
        action: action to take
    """

    # Date is a mutable variable that will hold the intermediatery states between layers
    data = state

    i = 0
    action_data = np.array([])
    for w, b in params[:-1]:
        data = np.add(np.dot(w, data), b)
        i += 1
        if action_field and i in action_layer:
            action_data.append(data)

    try:
        sin_cut = wandb.config.sin_cut
    except KeyError:
        wandb.config.sin_cut = 0.001

    if action_field:
        assert len(action_field) == 2
        for i in range(0, len(action_field or [])):
            # All action_fields arrays must equal action_data size
            assert len(action_field[i]) == len(action_data)

        # action_field has two sets of parameters

        # Fast GPU noise
        # http://people.compute.dtu.dk/jerf/papers/abstracts/noise_abstract.pdf
        action = np.round(
            np.dot(
                1 / np.dot(len(action_field)),
                np.sin(np.dot(action_data, action_field[1])),
            ))
    else:
        action = None

    final_w, final_b = params[-1]
    predicted_state = np.tanh(np.dot(final_w, data))

    # TODO: Make this come out of a noise function
    action = 0.0

    return predicted_state, action
Ejemplo n.º 20
0
 def f(params, pot_ini):
     x = params[:-1]
     v = params[-1]
     vprint = jnp.round(v, 2)
     print("evaluating for V = {:.2f}".format(vprint))
     eff, pot = simulator.eff_at_bias(convr(x),
                                      v,
                                      pot_ini,
                                      verbose=False)
     return -eff, pot
Ejemplo n.º 21
0
def letter_seq(arr: np.array) -> str:
    """
    Convert a 2D one-hot array into a string representation.

    TODO: More docstrings needed.
    """
    sequence = ""
    for letter in arr:
        sequence += arr_to_letter(np.round(letter))
    return sequence.strip("start").strip("stop")
Ejemplo n.º 22
0
  def sample(self, key, params):
    sample_x = distribution_utils.sample_from_discretized_mix_logistic_rgb(
        key, params, self.n_mixtures)  # range [-1., 1.]

    sample_x = (sample_x + 1.) / 2.  # range [0, 1.]
    sample_x = sample_x * (self.n_classes - 1.)  # range [0, n_classes - 1]

    # Better round now, otherwise we get floor division when cast to int32.
    sample_x = jnp.round(sample_x)
    return sample_x
Ejemplo n.º 23
0
def sample_time_jump_with_linear_increase(step, num_train_steps, min_jump,
                                          max_jump, rng):
    """Returns a stochastic jump size, with linearly increasing mean."""
    max_time_jump_for_step = min_jump + (step /
                                         (num_train_steps - 1)) * (max_jump -
                                                                   min_jump)
    max_time_jump_for_step = jnp.round(max_time_jump_for_step)
    jump = jax.random.randint(rng, (), min_jump, max_time_jump_for_step + 1)
    jump = int(jump)
    return jump
    def uoro_grad(self, key, theta, state, s_tilde=None, theta_tilde=None):
        epsilon_perturbation = 1e-7
        epsilon_stability = 1e-7

        total_theta_grad = 0
        total_loss = 0.0

        if s_tilde is None:
            s_tilde = jnp.zeros(state.inner_state.shape)

        if theta_tilde is None:
            theta_tilde = jnp.zeros(theta.shape)

        state_old = state
        # TODO: How do we handle key here? Do we want to split again?
        loss, state_new = self.unroll_fn(key, theta, state_old, self.T, 1)
        total_loss += loss

        dl_dstate_old = self.compute_dL_dstate_old(theta, state_old)
        dl_dtheta_direct = self.compute_dL_dtheta_direct(theta, state_old)

        dl_dstate_old = dl_dstate_old.inner_state

        indirect_grad = (dl_dstate_old * s_tilde).sum() * theta_tilde
        pseudograds = indirect_grad + dl_dtheta_direct

        state_old_perturbed = state_old._replace(
            inner_state=state_old.inner_state + s_tilde * epsilon_perturbation)
        state_new_perturbed = self.f(theta, state_old_perturbed)

        state_deriv_in_direction_s_tilde = (
            (state_new_perturbed - state_new.inner_state) /
            epsilon_perturbation)

        nus = jnp.round(jax.random.uniform(
            key, state_old.inner_state.shape)) * 2 - 1

        custom_f = lambda param_vector: self.f(param_vector, state_old)
        primals, f_vjp = jax.vjp(custom_f, theta)
        direct_theta_tilde_contribution, = f_vjp(nus)

        rho_0 = jnp.sqrt((jnp.linalg.norm(theta_tilde) + epsilon_stability) /
                         (jnp.linalg.norm(state_deriv_in_direction_s_tilde) +
                          epsilon_stability))
        rho_1 = jnp.sqrt(
            (jnp.linalg.norm(direct_theta_tilde_contribution) +
             epsilon_stability) / (jnp.linalg.norm(nus) + epsilon_stability))

        theta_grad = pseudograds
        total_theta_grad += theta_grad

        s_tilde = rho_0 * state_deriv_in_direction_s_tilde + rho_1 * nus
        theta_tilde = theta_tilde / rho_0 + direct_theta_tilde_contribution / rho_1

        return (total_loss, state_new, s_tilde, theta_tilde), total_theta_grad
Ejemplo n.º 25
0
def apply_bond_charge_corrections(initial_charges, bond_idxs, deltas):
    """For an arbitrary collection of ordered bonds and associated increments `(a, b, delta)`,
    update `charges` by `charges[a] += delta`, `charges[b] -= delta`

    Notes
    -----
    * preserves sum(initial_charges) for arbitrary values of bond_idxs or deltas
    * order within each row of bond_idxs is meaningful
        `(..., bond_idxs, deltas)`
        means the opposite of
        `(..., bond_idxs[:, ::-1], deltas)`
    * order within the first axis of bond_idxs, deltas is not meaningful
        `(..., bond_idxs[perm], deltas[perm])`
        means the same thing for any permutation `perm`
    """

    # apply bond charge corrections
    incremented = ops.index_add(initial_charges, bond_idxs[:, 0], +deltas)
    decremented = ops.index_add(incremented, bond_idxs[:, 1], -deltas)
    final_charges = decremented

    # make some safety assertions
    assert bond_idxs.shape[1] == 2
    assert len(deltas) == len(bond_idxs)

    net_charge = jnp.sum(initial_charges)
    net_charge_is_integral = jnp.isclose(net_charge,
                                         jnp.round(net_charge),
                                         atol=1e-5)

    final_net_charge = jnp.sum(final_charges)
    net_charge_is_unchanged = jnp.isclose(final_net_charge,
                                          net_charge,
                                          atol=1e-5)

    assert net_charge_is_integral
    assert net_charge_is_unchanged

    # print some safety warnings
    directed_bonds = Counter([tuple(b) for b in bond_idxs])
    undirected_bonds = Counter([tuple(sorted(b)) for b in bond_idxs])

    if max(directed_bonds.values()) > 1:
        duplicates = [
            bond for (bond, count) in directed_bonds.items() if count > 1
        ]
        print(UserWarning(f"Duplicate directed bonds! {duplicates}"))
    elif max(undirected_bonds.values()) > 1:
        duplicates = [
            bond for (bond, count) in undirected_bonds.items() if count > 1
        ]
        print(UserWarning(f"Duplicate undirected bonds! {duplicates}"))

    return final_charges
Ejemplo n.º 26
0
def sample_mask_indices(input_dim, hidden_dim):
    """
    Samples the indices assigned to hidden units during the construction of MADE masks

    :param input_dim: the dimensionality of the input variable
    :type input_dim: int
    :param hidden_dim: the dimensionality of the hidden layer
    :type hidden_dim: int
    """
    indices = jnp.linspace(1, input_dim, num=hidden_dim)
    # Simple procedure tries to space fractional indices evenly by rounding to nearest int
    return jnp.round(indices)
Ejemplo n.º 27
0
def diff_stft(xinp, s, hf=0.5):
    """
    Inputs
    ------
    xinp: jnp.array
        Input audio signal in time domain
    s: jnp.float
        The standard deviation of the Gaussian window to be used
    hf: jnp.float
        The fraction of window size that will be overlapped within consecutive frames
    
    Outputs
    -------
    a: jnp.array
        The computed magnitude spectrogram
    """

    # Effective window length of Gaussian is 6\sigma
    sz = s * 6
    hp = hf * sz

    # Truncating to integers for use in jnp functions
    intsz = int(jnp.round(sz))
    inthp = int(jnp.round(hp))

    m = jnp.arange(0, intsz, dtype=jnp.float32)

    # Obtaining the "differentiable" window function by using the real valued \sigma
    window = jnp.exp(-0.5 * jnp.power((m - sz / 2) / (s + 1e-5), 2))
    window_norm = window / jnp.sum(window)

    # Computing the STFT, and taking its magnitude
    stft = jnp.sqrt(1 / (2 * window_norm.shape[0] + 1)) * jnp.stack([
        jnp.fft.rfft(window_norm * xinp[i:i + intsz])
        for i in range(0,
                       len(xinp) - intsz, inthp)
    ], 1)
    a = jnp.abs(stft)

    return a
Ejemplo n.º 28
0
    def _precompute_lossfun_inputs(self):
        """Precompute a loss and its derivatives for random inputs and parameters.

    Generates a large number of random inputs to the loss, and random
    shape/scale parameters for the loss function at each sample, and
    computes the loss and its derivative with respect to all inputs and
    parameters, returning everything to be used to assert various properties
    in our unit tests.

    Returns:
      A tuple containing:
       (the number (int) of samples, and the length of all following arrays,
        A tensor of losses for each sample,
        A tensor of residuals of each sample (the loss inputs),
        A tensor of shape parameters of each loss,
        A tensor of scale parameters of each loss,
        A tensor of derivatives of each loss wrt each x,
        A tensor of derivatives of each loss wrt each alpha,
        A tensor of derivatives of each loss wrt each scale)

    Typical usage example:
    (num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale)
        = self._precompute_lossfun_inputs()
    """
        num_samples = 100000
        rng = random.PRNGKey(0)

        # Normally distributed inputs.
        rng, key = random.split(rng)
        x = random.normal(key, shape=[num_samples])

        # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1
        # to ensure that we hit the special cases at 0, 2.
        rng, key = random.split(rng)
        alpha = jnp.round(
            random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *
            10) / 10.
        # Push the sampled alphas at the extents of the range to +/- infinity, so
        # that we probe those cases too.
        alpha = jnp.where(alpha == 3, jnp.inf, alpha)
        alpha = jnp.where(alpha == -16, -jnp.inf, alpha)

        # Random log-normally distributed values in approx (1e-5, 100000):
        rng, key = random.split(rng)
        scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5

        fn = self.variant(general.lossfun)
        loss = fn(x, alpha, scale)
        d_x, d_alpha, d_scale = (jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)),
                                          [0, 1, 2])(x, alpha, scale))

        return (num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale)
Ejemplo n.º 29
0
 def body_func(args):
     xi, accumulated_sum = args
     xi_float = jnp.asarray(xi, dtype=dtype)
     log_xi_factorial = lax.lgamma(xi_float + 1.)
     log_comb_n_xi = (log_n_factorial - log_xi_factorial -
                      lax.lgamma(total_count - xi_float + 1.))
     comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi))
     likelihood1 = math.power_no_nan(probs, xi)
     likelihood2 = math.power_no_nan(1. - probs, total_count - xi)
     likelihood = likelihood1 * likelihood2
     comb_term = comb_n_xi * log_xi_factorial * likelihood  # [K]
     chex.assert_shape(comb_term, (probs.shape[-1], ))
     return xi + 1, accumulated_sum + comb_term
Ejemplo n.º 30
0
def uoro_grad(key, theta, state, s_tilde=None, theta_tilde=None):
    epsilon_perturbation = 1e-7
    epsilon_stability = 1e-7

    total_theta_grad = 0
    total_loss = 0.0

    if s_tilde is None:
        s_tilde = jnp.zeros(state.shape)

    if theta_tilde is None:
        theta_tilde = jnp.zeros(theta.shape)

    state_old = state  # (23,)
    state_new = f(theta, state_old)  # (23,)
    loss = L(theta, state_old)
    total_loss += loss

    dl_dstate_old = compute_dL_dstate_old(theta, state_old)  # (23,)
    dl_dtheta_direct = compute_dL_dtheta_direct(theta, state_old)  # (1,)

    indirect_grad = (dl_dstate_old * s_tilde).sum() * theta_tilde  # (1,)
    pseudograds = indirect_grad + dl_dtheta_direct  # (1,)

    state_old_perturbed = state_old + s_tilde * epsilon_perturbation  # (23,)
    state_new_perturbed = f(theta, state_old_perturbed)  # (23,)

    state_deriv_in_direction_s_tilde = (
        state_new_perturbed - state_new) / epsilon_perturbation  # (23,)

    nus = jnp.round(jax.random.uniform(key, state_old.shape)) * 2 - 1  # (23,)

    # Tricky part is this first line
    custom_f = lambda param_vector: f(param_vector, state_old)
    primals, f_vjp = jax.vjp(custom_f, theta)
    direct_theta_tilde_contribution, = f_vjp(nus)  # (1,)

    rho_0 = jnp.sqrt((jnp.linalg.norm(theta_tilde) + epsilon_stability) /
                     (jnp.linalg.norm(state_deriv_in_direction_s_tilde) +
                      epsilon_stability))
    rho_1 = jnp.sqrt(
        (jnp.linalg.norm(direct_theta_tilde_contribution) + epsilon_stability)
        / (jnp.linalg.norm(nus) + epsilon_stability))

    theta_grad = pseudograds
    total_theta_grad += theta_grad

    s_tilde = rho_0 * state_deriv_in_direction_s_tilde + rho_1 * nus
    theta_tilde = theta_tilde / rho_0 + direct_theta_tilde_contribution / rho_1

    return (total_loss, state_new, s_tilde, theta_tilde), total_theta_grad