Example #1
0
def run_kalman_smoother_for_marginals(lgssm_scenario: LinearGaussian,
                                      y: jnp.ndarray,
                                      t: jnp.ndarray,
                                      filter_output: Tuple[jnp.ndarray, jnp.ndarray] = None)\
        -> Tuple[jnp.ndarray, jnp.ndarray]:

    if filter_output is None:
        filter_output = run_kalman_filter_for_marginals(lgssm_scenario, y, t)

    f_mus, f_covs = filter_output

    def body_fun(carry: Tuple[jnp.ndarray, jnp.ndarray],
                 i: int) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:
        mu_tplus1, cov_tplus1 = carry

        t_mat = lgssm_scenario.get_transition_matrix(t[i], t[i + 1])
        t_cov_sqrt = lgssm_scenario.get_transition_covariance_sqrt(t[i], t[i + 1])
        t_cov = t_cov_sqrt @ t_cov_sqrt.T

        f_mu_t = f_mus[i]
        f_cov_t = f_covs[i]
        back_kal_gain = f_cov_t @ t_mat.T @ jnp.linalg.inv(t_mat @ f_cov_t @ t_mat.T + t_cov)
        mu_t = f_mu_t + back_kal_gain @ (mu_tplus1 - t_mat @ f_mu_t)
        cov_t = f_cov_t + back_kal_gain @ (cov_tplus1 - t_mat @ f_cov_t @ t_mat.T - t_cov) @ back_kal_gain.T

        return (mu_t, cov_t), (mu_t, cov_t)

    _, (mus, covs) = scan(body_fun,
                          (f_mus[-1], f_covs[-1]),
                          jnp.arange(len(t) - 2, -1, -1))

    return jnp.append(mus[::-1], f_mus[-1, jnp.newaxis], 0), jnp.append(covs[::-1], f_covs[-1, jnp.newaxis], 0)
Example #2
0
 def __get__(self, obj, objtype=None):
     value = self.cache.get(id(obj), None)
     if value is not None:
         return value
     _fr = jnp.append(obj.cor, jnp.expand_dims(obj.delta, 1), axis=1)
     _sr = jnp.expand_dims(jnp.append(obj.delta, 1.), 0)
     value = jnp.append(_fr, _sr, axis=0)
     self.cache[id(obj)] = value
     return value
def stream_vel_taud(h, n, dx, rhoi, g):
    h_minus1 = jnp.roll(h, 1)
    h_plus1 = jnp.roll(h, -1)
    f = jnp.append(
        rhoi * g * h[0] * (h[1] - h[0]) / dx, rhoi * g * h[1:n - 1] *
        (h_plus1[1:n - 1] - h_minus1[1:n - 1]) / 2. / dx)
    f = jnp.append(f, rhoi * g * h[n - 1] * (h[n - 1] - h[n - 2]) / dx)
    fend = .5 * (rhoi * g * (h[n - 1])**2 - rhow * g * R_bed**2) * .5
    return f, fend
Example #4
0
    def run(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = self.card

        print(self.id, 'Variable initialization is finished')

        event_num = 100
        all_data_phif0, all_data_phi, all_data_f = self.mcnpz(
            0, 500000, event_num)
        all_mc_phif0, all_mc_phi, all_mc_f = self.mcnpz(500000, 700000, 1)
        self.mc_phif0 = np.squeeze(all_mc_phif0[0], axis=None)
        self.mc_phi = np.squeeze(all_mc_phi[0], axis=None)
        self.mc_f = np.squeeze(all_mc_f[0], axis=None)
        t_ = 7
        m = onp.random.rand(t_)
        w = onp.random.rand(t_)
        c = onp.random.rand(t_)
        t = onp.random.rand(t_)
        wtarg = np.append(np.append(np.append(m, w), c), t)

        i = 0
        self.data_phif0 = np.squeeze(all_data_phif0[i], axis=None)
        self.data_phi = np.squeeze(all_data_phi[i], axis=None)
        self.data_f = np.squeeze(all_data_f[i], axis=None)

        self.wt = self.Weight(wtarg)
        # print(self.wt.size)
        if self.part == 1:
            self.res = jit(hessian(self.likelihood, argnums=[0, 1, 2]))
            # self.pipeout.send(self.wt)
        else:
            self.res = jit(hessian(self.likelihood, argnums=[3]))

        while (True):
            # print(self.pipe)
            var = self.pipein.recv()
            # print(var.shape)
            if var.shape[0] == t_ * 4:

                start = time.time()
                var_ = var.reshape(4, -1)
                result = self.res(var_[0], var_[1], var_[2], var_[3])
                # print('shape:',result.shape)
                # print('process ID -',self.id,result)
                # self.qout.put(result)
                print('process ID -', self.id + ' part' + str(self.part),
                      '(time):', float(time.time() - start))
                self.pipeout.send(result)

            else:
                self.pipeout.send(0)
                break
Example #5
0
    def predict(self, x):
        """
        Description: Takes input observation and returns next prediction value
        Args:
            x (float/numpy.ndarray): value at current time-step
        Returns:
            Predicted value for the next time-step
        """
        '''
        if self.X.size == 0:
            # self.X = np.asarray([x]).T
            self.X = x.reshape(-1,1)
        else:
            # self.X = np.hstack((self.X, np.asarray([x]).T))
            self.X = np.hstack((self.X, x.reshape(-1,1)))
        '''

        # print("-----------------------------")
        # print("x:")
        # print(x)
        # print("type(x) : " + str(type(x)))
        # print("self.X")
        # print(self.X)
        # print("self.X.shape: " + str(self.X.shape))
        self.X = self._update_x(self.X, x)
        X_sim_pre = self.X.dot(self.k_vectors).dot(self.eigen_diag)
        '''
        if (self.t == 0): # t = 0 results in an excessively complicated corner case otherwise
            self.X_sim = np.append(np.zeros(self.n * self.k + self.n), np.append(self.X[:,0], np.zeros(self.m)))
        else:
            eigen_diag = np.diag(self.k_values**0.25)
            if (self.t <= self.T):
                X_sim_pre = self.X[:,0:self.t-1].dot(np.flipud(self.k_vectors[0:self.t-1,:])).dot(eigen_diag)
            else:
                X_sim_pre = self.X[:,self.t-self.T-1:self.t-1].dot(np.flipud(self.k_vectors)).dot(eigen_diag)
        '''

        # x_y_cols = np.append(np.append(self.X[:,self.t-1], self.X[:,self.t]), self.Y[:,self.t-1])
        x_y_cols = np.append(np.append(self.X[:, 1], self.X[:, 0]), self.Y[:,
                                                                           1])
        '''print("x_y_cols.shape : " + str(x_y_cols.shape))
        print("self.X[:,1].shape : " + str(self.X[:,1].shape))
        print(self.X[:,1])
        print("self.X[:,0].shape : " + str(self.X[:,0].shape))
        print(self.X[:,0])
        print("self.Y[:,1].shape : " + str(self.Y[:,1].shape))
        print("X_sim_pre.shape : " + str(X_sim_pre.shape))'''
        self.X_sim = np.append(X_sim_pre.T.flatten(), x_y_cols)
        # print("self.X_sim.shape : " + str(self.X_sim.shape))
        self.y_hat = self.M.dot(self.X_sim)
        return self.y_hat
Example #6
0
    def run(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = self.card

        print(self.id, 'Variable initialization is finished')

        event_num = 100
        all_data_phif0, all_data_phi, all_data_f = self.mcnpz(
            0, 500000, event_num)
        all_mc_phif0, all_mc_phi, all_mc_f = self.mcnpz(500000, 700000, 1)
        self.mc_phif0 = np.squeeze(all_mc_phif0[0], axis=None)
        self.mc_phi = np.squeeze(all_mc_phi[0], axis=None)
        self.mc_f = np.squeeze(all_mc_f[0], axis=None)
        t_ = 7
        m = onp.random.rand(t_)
        w = onp.random.rand(t_)
        c = onp.random.rand(t_)
        t = onp.random.rand(t_)
        wtarg = np.append(np.append(np.append(m, w), c), t)

        i = 0
        self.data_phif0 = np.squeeze(all_data_phif0[i], axis=None)
        self.data_phi = np.squeeze(all_data_phi[i], axis=None)
        self.data_f = np.squeeze(all_data_f[i], axis=None)

        self.wt = self.Weight(wtarg)
        # if self.part == 1:
        #     self.res = jit(jacfwd(self.part1))
        # else:
        #     self.res = jit(jacfwd(self.part2))
        self.res = jit(hessian(self.likelihood))

        etime = []

        for i in range(10):
            # var = self.qdata.get()
            # if var.shape[0] == 3:

            start = time.time()
            result = self.res(wtarg)
            print('process ID -', self.id, 'grad:', result.shape)
            # self.qout.put(result)
            etime_ = float(time.time() - start)
            print('process ID -', self.id, '(time):', etime_)
            etime.append(etime_)
            # self.qout.put(result)

            # else:
            # self.qout.put(0)
            # break
        print('average cal time:', onp.average(etime[1:]))
Example #7
0
    def run(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = self.card

        print(self.id, 'Variable initialization is finished')

        event_num = 100
        all_data_phif0, all_data_phi, all_data_f = self.mcnpz(
            0, 500000, event_num)
        all_mc_phif0, all_mc_phi, all_mc_f = self.mcnpz(500000, 700000, 1)
        self.mc_phif0 = np.squeeze(all_mc_phif0[0], axis=None)
        self.mc_phi = np.squeeze(all_mc_phi[0], axis=None)
        self.mc_f = np.squeeze(all_mc_f[0], axis=None)

        t_ = 278
        m = onp.random.rand(t_)
        w = onp.random.rand(t_)
        c = onp.random.rand(t_)
        t = onp.random.rand(t_)
        wtarg = np.append(np.append(np.append(m, w), c), t)

        i = 0
        self.data_phif0 = np.squeeze(all_data_phif0[i], axis=None)
        self.data_phi = np.squeeze(all_data_phi[i], axis=None)
        self.data_f = np.squeeze(all_data_f[i], axis=None)

        self.wt = self.Weight(wtarg)

        if self.part == 1:
            self.res = jit(grad(self.part1))
            self.qout.put(self.wt)
        else:
            self.res = jit(grad(self.part2))

        while (True):
            var = self.qdata.get()
            print(var.shape)
            if var.shape[0] == t_ * 4:

                start = time.time()
                result = self.res(var)
                # print('process ID -',self.id,result)
                # self.qout.put(result)
                print('process ID -', self.id + ' part' + str(self.part),
                      '(time):', float(time.time() - start))
                self.qout.put(result)

            else:
                self.qout.put(0)
                break
Example #8
0
def backward_simulation_full(ssm_scenario: StateSpaceModel,
                             marginal_particles: cdict,
                             n_samps: int,
                             random_key: jnp.ndarray) -> cdict:
    marg_particles_vals = marginal_particles.value
    times = marginal_particles.t
    marginal_log_weight = marginal_particles.log_weight

    T, n_pf, d = marg_particles_vals.shape

    t_keys = random.split(random_key, T)
    final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1],
                                                                     marginal_log_weight[-1],
                                                                     shape=(n_samps,))]

    def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int):
        x_t_all = full_resampling(ssm_scenario, marg_particles_vals[ind], times[ind],
                                  x_tplus1_all, times[ind + 1], marginal_log_weight[ind], t_keys[ind])
        return x_t_all, x_t_all

    _, back_sim_particles = scan(back_sim_body,
                                 final_particle_vals,
                                 jnp.arange(T - 2, -1, -1))

    out_samps = marginal_particles.copy()
    out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]])
    out_samps.num_transition_evals = jnp.append(0, jnp.ones(T - 1) * n_pf * n_samps)
    del out_samps.log_weight
    return out_samps
def partial_trace(A, A_label):
    """ Partial trace on tensor A over repeated labels in A_label """

    num_cont = len(A_label) - len(np.unique(A_label))
    if num_cont > 0:
        dup_list = []
        for ele in np.unique(A_label):
            if sum(A_label == ele) > 1:
                dup_list.append([np.where(A_label == ele)[0]])

        cont_ind = np.array(dup_list).reshape(2*num_cont,order='F')
        free_ind = onp.delete(np.arange(len(A_label)),cont_ind)

        cont_dim = np.prod(np.array(A.shape)[cont_ind[:num_cont]])
        free_dim = np.array(A.shape)[free_ind]

        B_label = onp.delete(A_label, cont_ind)
        cont_label = np.unique(A_label[cont_ind])
        B = np.zeros(np.prod(free_dim))
        A = A.transpose(np.append(free_ind, cont_ind)).reshape(np.prod(free_dim),cont_dim,cont_dim)
        for ip in range(cont_dim):
            B = B + A[:,ip,ip]

        return B.reshape(free_dim), B_label, cont_label

    else:
        return A, A_label, []
Example #10
0
def _samplewise_log_loss(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
    """Based on: https://github.com/scikit-learn/scikit-learn/blob/ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn
    /metrics/_classification.py#L2123"""
    if y_true.ndim == 0:  # If no dimension binary classification problem
        y_true = y_true.reshape(1)[:, jnp.newaxis]
        y_pred = y_pred.reshape(1)[:, jnp.newaxis]
    if y_true.shape[0] == 1:  # Reshuffle data to compute log loss correctly
        y_true = jnp.append(1 - y_true, y_true)
        y_pred = jnp.append(1 - y_pred, y_pred)

    # Clipping
    eps = 1e-15
    y_pred = y_pred.astype(jnp.float32).clip(eps, 1 - eps)

    loss = (y_true * -jnp.log(y_pred)).sum()
    return loss
Example #11
0
def initialise_W(M, training_set, renormalisation):
    assert M[0] == training_set.shape[1]
    assert M[-1] == 1
    x = training_set.T
    W = []
    for i in range(1, len(M)):
        W_T = np.array(numpy.random.uniform(size=(M[i], M[i - 1])))
        z_T = W_T @ x
        mean = np.mean(z_T, axis=1)
        std = np.std(z_T, axis=1)
        W += [
            np.append(W_T, np.expand_dims(-mean, axis=1), axis=1) /
            (np.expand_dims(std, axis=1) * renormalisation)
        ]
        x = sigma(W[-1] @ np.append(x, np.ones((1, x.shape[1])), axis=0))
    return W
Example #12
0
    def __init__(
        self,
        xs,
        densities,
        scale: Scale,
        normalized=False,
        traceable=False,
        cumulative_normed_ps=None,
    ):
        if scale is None:
            raise ValueError

        self.scale = scale

        if normalized:
            self.normed_xs = xs
            self.normed_densities = densities

        else:
            self.normed_xs = scale.normalize_points(xs)
            self.normed_densities = scale.normalize_densities(
                self.normed_xs, densities)

        self.cumulative_normed_ps = cumulative_normed_ps
        if cumulative_normed_ps is None:
            self.cumulative_normed_ps = np.append(np.array([0]),
                                                  np.cumsum(self.bin_probs))
Example #13
0
    def update_metrics_dict(metrics_epoch, metrics_step, sample_num):
        for metric in metrics_step.keys():
            if len(metrics_epoch[metric][0]) == sample_num:
                metrics_epoch[metric][0].append(metrics_step[metric])
            else:
                if metrics_epoch[metric][1]:
                    metrics_epoch[metric][0][sample_num] = np.append(
                        metrics_epoch[metric][0][sample_num],
                        metrics_step[metric])
                else:
                    metrics_epoch[metric][0][sample_num] = np.append(
                        metrics_epoch[metric][0][sample_num],
                        metrics_step[metric],
                        axis=0)

        return metrics_epoch
Example #14
0
def eg5(p, testdata, *args):

    out = jnp.zeros((1))

    n = len(p)

    testdata2 = jnp.zeros((n))

    halfway = n // 2

    #note you can potentially get some out of bounds error here p should be ~len 10 -20 something in that range. or just make testdata longer

    #testdata2[0:halfway] = testdata[10:10+halfway] #set data using slices...does this break things?

    testdata2 = index_update(testdata2, index[0:halfway],
                             testdata[10:10 + halfway])

    #testdata2[halfway:] = testdata[35:35+halfway]

    testdata2 = index_update(testdata2, index[halfway:],
                             testdata[35:35 + halfway])

    #out[0] = p[0] #in place setting of array

    out = index_update(out, index[0], p[0])

    out = jnp.append(out, testdata2)  #concatenation of nd arrays

    #out[1:] = out[1:] + p #more slices and also array addition

    out = index_update(out, index[1:], out[1:] + p)

    return sum(out)  #here we use default sum instead of np.sum
Example #15
0
    def __call__(self, state, action):
        augmented_state = jnp.append(state, action)

        new_state = rk4(self._dsdt, augmented_state, [0, self.dt])
        # only care about final timestep of integration returned by integrator
        new_state = new_state[-1]
        new_state = new_state[:4]  # omit action
        # ODEINT IS TOO SLOW!
        # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt])
        # self.s_continuous = ns_continuous[-1] # We only care about the state
        # at the ''final timestep'', self.dt

        new_state = new_state.at[0].set(wrap(new_state[0], -jnp.pi, jnp.pi))
        new_state = new_state.at[1].set(wrap(new_state[1], -jnp.pi, jnp.pi))
        new_state = new_state.at[2].set(
            bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1))
        new_state = new_state.at[3].set(
            bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2))

        return (
            new_state,
            jnp.array([
                jnp.cos(new_state[0]),
                jnp.sin(new_state[0]),
                jnp.cos(new_state[1]),
                jnp.sin(new_state[1]),
                new_state[2],
                new_state[3],
            ]),
        )
Example #16
0
def fixed_lag_stitching(ssm_scenario: StateSpaceModel,
                        early_block: jnp.ndarray,
                        t: float,
                        recent_block: jnp.ndarray,
                        recent_block_log_weight: jnp.ndarray,
                        tplus1: float,
                        random_key: jnp.ndarray,
                        maximum_rejections: int,
                        init_bound_param: float,
                        bound_inflation: float) -> Tuple[jnp.ndarray, int]:
    x0_fixed_all = early_block[-1]

    x0_vary_all = recent_block[0]
    x1_vary_all = recent_block[1]

    non_interacting_log_weight = recent_block_log_weight \
                                 + vmap(ssm_scenario.transition_potential, (0, None, 0, None))(x0_vary_all, t,
                                                                                               x1_vary_all, tplus1)

    recent_stitched_inds, num_transition_evals \
        = cond(maximum_rejections > 0,
               lambda tup: rejection_stitching(ssm_scenario, *tup,
                                               maximum_rejections=maximum_rejections,
                                               init_bound_param=init_bound_param,
                                               bound_inflation=bound_inflation),
               lambda tup: (
                   full_stitch(ssm_scenario, *tup), len(x0_fixed_all) ** 2),
               (x0_fixed_all, t, x1_vary_all, tplus1,
                non_interacting_log_weight, random_key))

    return jnp.append(early_block, recent_block[1:, recent_stitched_inds], axis=0), num_transition_evals
Example #17
0
    def simulate(self, t_all: jnp.ndarray, random_key: jnp.ndarray) -> cdict:

        len_t = len(t_all)

        random_keys = random.split(random_key, 2 * len_t)
        latent_keys = random_keys[:len_t]
        obs_keys = random_keys[len_t:]

        x_init = self.initial_sample(t_all[0], latent_keys[0])

        def transition_body(x, i):
            new_x = self.transition_sample(x, t_all[i - 1], t_all[i],
                                           latent_keys[i])
            return new_x, new_x

        _, x_all_but_zero = scan(transition_body, x_init, jnp.arange(1, len_t))

        x_all = jnp.append(x_init[jnp.newaxis], x_all_but_zero, axis=0)

        y = vmap(self.likelihood_sample)(x_all, t_all, obs_keys)

        out_cdict = cdict(x=x_all,
                          y=y,
                          t=t_all,
                          name=f'{self.name} simulation')
        return out_cdict
Example #18
0
    def model(T=10, q=1, r=1, phi=0.0, beta=0.0):
        def transition(state, i):
            x0, mu0 = state
            x1 = numpyro.sample("x", dist.Normal(phi * x0, q))
            mu1 = beta * mu0 + x1
            y1 = numpyro.sample("y", dist.Normal(mu1, r))
            numpyro.deterministic("y2", y1 * 2)
            return (x1, mu1), (x1, y1)

        mu0 = x0 = numpyro.sample("x_0", dist.Normal(0, q))
        y0 = numpyro.sample("y_0", dist.Normal(mu0, r))

        _, xy = scan(transition, (x0, mu0), jnp.arange(T))
        x, y = xy

        return jnp.append(x0, x), jnp.append(y0, y)
    def read_process(self, vector: jnp.array):
        # 2x4 4x4 4x3 -> 2x3
        # tags structure
        # srcNode: sTag iTag cTag
        # desNode: sTag iTag cTag
        # print(vector)
        for i, l in enumerate(vector):
            for j, t in enumerate(l):
                if isinstance(t, Tensor):
                    # print(type(vector))
                    # print(type(jax.ops.index[i, j]))
                    vector[i][j] = float(t.cpu().detach().numpy())
                    # jax.ops.index_update(vector, jax.ops.index[i,j], float(t.cpu().detach().numpy()))
                elif isinstance(t, np.ndarray):
                    vector[i][j] = float(t)
                    # jax.ops.index_update(vector, jax.ops.index[i,j], t.astype(float))
        vector = vector.astype('float64')
        # print("vector", vector.dtype)
        left_matrix = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]])
        right_matrix = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])

        tags = jnp.dot(jnp.dot(left_matrix, vector), right_matrix)
        final_tags = (jax.ops.index_update(tags, jax.ops.index[0, 1:3],
                                           jnp.min(tags[:, 1:3],
                                                   axis=0))).reshape(
                                                       1, length)

        tags.reshape(1, length)
        res_tags = jnp.append(tags, final_tags)
        return res_tags
Example #20
0
        def _dynamics(state, action):
            self.nsamples += 1
            # Augment the state with our force action so it can be passed to _dsdt
            augmented_state = jnp.append(state, action)

            new_state = rk4(self._dsdt, augmented_state, [0, self.dt])
            # only care about final timestep of integration returned by integrator
            new_state = new_state[-1]
            new_state = new_state[:4]  # omit action
            # ODEINT IS TOO SLOW!
            # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt])
            # self.s_continuous = ns_continuous[-1] # We only care about the state
            # at the ''final timestep'', self.dt

            new_state = jax.ops.index_update(new_state, 0,
                                             wrap(new_state[0], -pi, pi))
            new_state = jax.ops.index_update(new_state, 1,
                                             wrap(new_state[1], -pi, pi))
            new_state = jax.ops.index_update(
                new_state, 2,
                bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1))
            new_state = jax.ops.index_update(
                new_state, 3,
                bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2))

            return new_state
Example #21
0
def runge_kutta_step(func, y0, f0, t0, dt):
    """Take an arbitrary Runge-Kutta step and estimate error.
    Args:
        func: Function to evaluate like `func(t, y)` to compute the time derivative
            of `y`.
        y0: initial value for the state.
        f0: initial value for the derivative, computed from `func(t0, y0)`.
        t0: initial time.
        dt: time step.
        alpha, beta, c: Butcher tableau describing how to take the Runge-Kutta step.
    Returns:
        y1: estimated function at t1 = t0 + dt
        f1: derivative of the state at t1
        y1_error: estimated error at t1
        k: list of Runge-Kutta coefficients `k` used for calculating these terms.
    """
    k = np.array([f0])
    for alpha_i, beta_i in zip(alpha, beta):
        ti = t0 + dt * alpha_i
        yi = y0 + dt * np.dot(k.T, beta_i)
        ft = func(yi, ti)
        k = np.append(k, np.array([ft]), axis=0)

    y1       = dt * np.dot(c_sol, k) + y0
    y1_error = dt * np.dot(c_error, k)
    f1 = k[-1]
    return y1, f1, y1_error, k
Example #22
0
def _forecast(future, sample, Z_exp, n_coefs):
    beta = sample['beta']
    z_exp = Z_exp[-n_coefs:]
    for t in range(future):
        mu = np.dot(beta, z_exp[-n_coefs:])
        yf = numpyro.sample("yf[{}]".format(t), dist.Normal(mu, sample['tau']))
        z_exp = np.append(z_exp, yf)
Example #23
0
def moco_loss(emb_query, emb_key, moco_dictionary, temperature):
    """Compute MoCo loss.

  Args:
    emb_query: embedding predicted by query network
    emb_key: embedding predicted by key network
    moco_dictionary: dictionary of embeddings from prior epochs
    temperature: softmax temperature

  Returns:
    MoCo loss
  """
    # Positive logits
    # pos_logits.shape = (n_samples, 1)
    pos_logits = (emb_query * emb_key).sum(axis=1, keepdims=True) / temperature

    # Negative logits = (n_samples, n_codes)
    neg_logits = jnp.dot(emb_query, moco_dictionary.T) / temperature

    # We now want to:
    # - append pos_logits and neg_logits along axis 1
    # - compute negative log_softmax to get cross-entropy loss
    # - use the cross-entropy of the positive samples (position 0 in axis 1)
    logits = jnp.append(pos_logits, neg_logits, axis=1)
    moco_loss_per_sample = -jax.nn.log_softmax(logits)[:, 0]

    return moco_loss_per_sample
Example #24
0
    def add_nd(self, params):
        """Adds a new hidden node to network.
        
        Inserts a node represented by `params` at the beginning of the
        hidden layer. Regarding the weighted adjacency matrix, arcs are
        assigned in row-major order such that it preserves topological
        ordering.

        Args:
            params: Sequence of weights.

        Raises:
            AttributeError: If insufficient number of weights is given.
        """
        inb, hid, _ = self.shape
        Σ = np.sum(self.shape)

        if Σ != len(params):
            msg = "{} weights required to add a node to a {} network, got {}."
            raise AttributeError(msg.format(Σ, self.shape, len(params)))

        self.shape = jo.index_add(self.shape, 1, 1)
        self.v = np.append(self.v, 0)

        col = np.pad(params[:inb], (0, hid), constant_values=0)
        row = np.pad(params[inb:], (1, 0), constant_values=0)
        self.θ = mo.insert(self.θ, 0, col, axis=1)
        self.θ = mo.insert(self.θ, inb, row, axis=0)
Example #25
0
 def mods(self, m, w, c, t, wt, data_phif0, data_phi, data_f, mc_phif0,
          mc_phi, mc_f):
     args = np.append(np.append(np.append(m, w), c), t)
     array_args = args.reshape(4, -1)
     f0m, f0w, const, theta = np.split(array_args, 4, axis=0)
     f0m = np.squeeze(f0m, axis=0)
     f0w = np.squeeze(f0w, axis=0)
     theta = np.squeeze(theta, axis=0)
     const = np.append(np.squeeze(const, axis=0),
                       np.ones(const.shape)).reshape(2, -1)
     d_phif0 = self.MOD(f0m, f0w, const, theta, data_phif0, data_phi,
                        data_f)
     m_phif0 = self.MOD(f0m, f0w, const, theta, mc_phif0, mc_phi, mc_f)
     d_tmp = np.sum(dplex.dabs(d_phif0), axis=1)
     m_tmp = np.average(np.sum(dplex.dabs(m_phif0), axis=1))
     wt_sum = np.sum(wt)
     return -np.sum(wt * (np.log(d_tmp) - np.log(m_tmp))) / np.log(wt_sum)
Example #26
0
 def nat_to_probability(self) -> RealArray:
     max_q = jnp.maximum(0.0, jnp.amax(self.log_odds, axis=-1))
     q_minus_max_q = self.log_odds - max_q[..., np.newaxis]
     log_scaled_A = jnp.logaddexp(-max_q,
                                  jss.logsumexp(q_minus_max_q, axis=-1))
     p = jnp.exp(q_minus_max_q - log_scaled_A[..., np.newaxis])
     final_p = 1.0 - jnp.sum(p, axis=-1, keepdims=True)
     return jnp.append(p, final_p, axis=-1)
Example #27
0
 def __apply_Q_smaller(params, state, action, target):
     q0_t = SharedNetwork.apply_Q(params, state, action,
                                  0 + 2 * int(target))
     q1_t = SharedNetwork.apply_Q(params, state, action,
                                  1 + 2 * int(target))
     q_t = jnp.append(q0_t, q1_t, axis=1).min(axis=1, keepdims=True)
     assert (q_t.shape == (state.shape[0], 1))
     return q_t
def log_joint(Y, X,  model_params, ca_obj,  n_lats, nlags, Fourier = False, nxcirc = None, wwnrm = None, Bf = None, learn_model_params  = False, learn_per_neuron = False):

	# logjoint here can work in fourier domain or not.
	# If Fourier, need to pass in Fourier args (nxcirc, wwnrm, bf)

	# logjoint can also learn calcium hyperparams (tau, alpha, marginal variance) or not
	# if yes, please append these to they hyperparams argument AFTER the rho and length scale

	# X should be passed in as samples by neurons by latents
	#model params is a single vector of loadings then length scales than CA params
	X = np.reshape(X, [n_lats, -1])
	n_neurons = np.shape(Y)[0]
	loadings_hat = np.reshape(model_params[0:n_neurons*n_lats], [n_neurons,n_lats])
	ls_hat = model_params[n_neurons*n_lats:n_neurons*n_lats+n_lats]


	if Fourier:
		K = gpf.mkcovs.mkcovdiag_ASD_wellcond(ls_hat, np.ones(np.size(ls_hat)), nxcirc, wwnrm = wwnrm,addition = 1e-4).T
		if n_lats == 1:
			K = np.expand_dims(K, axis = 0)
		log_prior = calc_gp_prior(X, K,n_lats, Fourier = True)


		params = np.matmul(X, Bf)
		rates = loadings_hat@params


		if learn_model_params:
			if learn_per_neuron:
				param_butt = np.reshape(model_params[n_neurons*n_lats+n_lats:], [n_neurons,-1])
			else:
				param_butt = np.tile(model_params[n_neurons*n_lats+n_lats:], [n_neurons,1])	


			rates = np.append(rates, np.array(param_butt), axis = 1)


	else:
		K = make_cov(ca_obj.Tps, model_params[0], model_params[1]) + np.eye(ca_obj.Tps)*1e-2 #need heavy regularization here. (might be differnet for different opt params)
		log_prior = calc_gp_prior(X, K)
		if learn_model_params :
			params = np.append(X, model_params[2:])

	ll =ca_obj.log_likelihood(Y, rates, nlags,  learn_model_params  = learn_model_params)

	return log_prior + ll
def cost_func_jvp(bb, u):
    n = bb.size
    directmat = jnp.empty([0])
    for i in range(n):
        seed = jnp.zeros(n)
        seed = jax.ops.index_update(seed, jax.ops.index[i], 1)
        primal, res = jax.jvp(cost_func, (bb, u), (seed, jnp.zeros(n + 1)))
        directmat = jnp.append(directmat, res)
    return directmat
Example #30
0
def arrayize_nn_params(nn_params):
    """
    turn a list of tuples of (W,b) into a 1D array
    """
    shapes = [(w.shape, b.shape) for w, b in nn_params]
    combined_wbs = [
        jnp.append(w, b[..., jnp.newaxis], axis=1) for w, b in nn_params
    ]
    outs = jnp.concatenate(([a.flatten() for a in combined_wbs]))
    return outs