예제 #1
0
    def apply_fun(params, inputs):
        conv_params, pair_params, conv_block_params, serial_params = params

        # Apply the primary convolutional layer.
        conv_out = conv_apply(conv_params, inputs)
        conv_out = relu(conv_out)

        # Group all possible pairs.
        W, b = pair_params
        pair_1 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,1), dim_nums) + b
        pair_2 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,2), dim_nums) + b
        pair_3 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,3), dim_nums) + b
        pair_4 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,4), dim_nums) + b
        pair_5 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,5), dim_nums) + b
        pair_out = jnp.dstack([pair_1, pair_2, pair_3, pair_4, pair_5])
        pair_out = relu(pair_out)

        # Convolutional block.
        conv_block_out = conv_block_apply(conv_block_params, pair_out)

        # Residual connection.
        res_out = conv_block_out + pair_out
        res_out = relu(res_out)

        # Forward pass.
        out = serial_apply(serial_params, res_out)
        return out
예제 #2
0
def rgb_to_hsv(rgb_image):
  # Adapted from the numpy implementation here: https://gist.github.com/PolarNick239/691387158ff1c41ad73c#file-rgb_to_hsv_np-py
  input_shape = rgb_image.shape
  rgb_image = rgb_image.reshape(-1, 3)
  r, g, b = rgb_image[:, 0], rgb_image[:, 1], rgb_image[:, 2]

  maxc = jnp.maximum(jnp.maximum(r, g), b)
  minc = jnp.minimum(jnp.minimum(r, g), b)
  v = maxc

  deltac = maxc - minc
  # s = deltac / maxc
  s = deltac / (maxc + 1e-9)

  deltac = jnp.where(deltac==0, 1, deltac)
  print(deltac)
  # rc = (maxc - r) / deltac
  # gc = (maxc - g) / deltac
  # bc = (maxc - b) / deltac
  rc = (maxc - r) / (deltac + 1e-9)  # NOT SURE WHY EXACTLY THIS IS NEEDED TO PREVENT NANS! OTHERWISE NANS CAN OCCUR!
  gc = (maxc - g) / (deltac + 1e-9)
  bc = (maxc - b) / (deltac + 1e-9)

  h = 4.0 + gc - rc
  h = jnp.where(g==maxc, 2.0 + jnp.where(g == maxc, rc, 0) - jnp.where(g==maxc, bc, 0), h)
  h = jnp.where(r==maxc, jnp.where(r==maxc, bc, 0) - jnp.where(r==maxc, gc, 0), h)
  h = jnp.where(minc==maxc, 0.0, h)

  h = (h / 6.0) % 1.0
  res = jnp.dstack([h, s, v])
  return res.reshape(input_shape)
예제 #3
0
 def sample(rng, params, num_samples=1):
     cluster_samples = []
     for mean, cov in zip(means, covariances):
         rng, temp_rng = random.split(rng)
         cluster_sample = random.multivariate_normal(
             temp_rng, mean, cov, (num_samples, ))
         cluster_samples.append(cluster_sample)
     samples = np.dstack(cluster_samples)
     idx = random.categorical(rng, weights, shape=(num_samples, 1, 1))
     return np.squeeze(np.take_along_axis(samples, idx, -1))
예제 #4
0
  def torsionVecs_(self, P):
      p0 = P[0]
      p1 = P[1]
      p2 = P[2]
      p3 = P[3]

      r1 = p0 - p1
      r2 = p1 - p2
      r3 = p3 - p2
      cp_12 = np.cross(r1, r2)
      cp_32 = np.cross(r3, r2)
      return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \
        .squeeze() \
        .transpose([1, 0])
예제 #5
0
def contour_grid(xmin, xmax, ymin, ymax, n_x, n_y, n_importance_samples=None):
    x_range, y_range = jnp.linspace(xmin, xmax,
                                    100), jnp.linspace(ymin, ymax, 100)
    X, Y = jnp.meshgrid(x_range, y_range)
    XY = jnp.dstack([X, Y]).reshape((-1, 2))

    if n_importance_samples is not None:
        XY = jnp.broadcast_to(XY[None, ...],
                              (n_importance_samples, ) + XY.shape)

    def reshape_to_grid(Z):
        return Z.reshape(X.shape)

    return X, Y, XY, reshape_to_grid
예제 #6
0
  def torsionVecs(self, P):
      p0 = P[...,[0],[0,1,2]]
      p1 = P[...,[1],[0,1,2]]
      p2 = P[...,[2],[0,1,2]]
      p3 = P[...,[3],[0,1,2]]

      r1 = p0 - p1
      r2 = p1 - p2
      r3 = p3 - p2
      cp_12 = np.cross(r1, r2)
      cp_32 = np.cross(r3, r2)
      return np.dstack((cp_12, np.zeros(cp_12.shape), cp_32)) \
        .squeeze() \
        .transpose([0, 2, 1])
예제 #7
0
def generate_image(
    height,
    width,
    scene_camera,
    world,
    config,
):
    """Generates an image of dimensions (height x width x 3) from the given camera."""
    def process_pixel(
        position,
        num_samples,
        rng,
    ):
        j, i = position

        def get_color_at_sample(u, v, sample_rng):
            ray = scene_camera.get_ray(u, v)
            return compute_color_fn(ray, rng=sample_rng).array()

        pixel_rng = jax.random.fold_in(rng, width * i + j)
        pixel_rng, i_rng, j_rng = jax.random.split(pixel_rng, num=3)

        # Random samples for anti-aliasing.
        random_is = jax.random.uniform(i_rng, shape=(num_samples, ))
        random_js = jax.random.uniform(j_rng, shape=(num_samples, ))

        us = (j + random_js) / width
        vs = (i + random_is) / height
        sample_rngs = jax.random.split(pixel_rng, num=num_samples)

        colors = jax.vmap(get_color_at_sample)(us, vs, sample_rngs)
        colors = jnp.mean(colors, axis=0)
        return colors

    num_samples = config.num_antialiasing_samples
    rng = jax.random.PRNGKey(config.rng_seed)

    compute_color_fn = functools.partial(compute_color,
                                         world=world,
                                         config=config)
    process_pixel_fn = functools.partial(process_pixel,
                                         num_samples=num_samples,
                                         rng=rng)
    process_pixel_fn = jax.vmap(jax.vmap(process_pixel_fn))

    grid = jnp.dstack(jnp.meshgrid(jnp.arange(width), jnp.arange(height)))
    image = process_pixel_fn(grid)
    return image
def _dstack_product(x, y):
    """Returns the cartesian product of the elements of x and y vectors.

  Args:
    x: 1d array
    y: 1d array of the same dtype as x.

  Returns:
    a 2D array containing the elements of [x]x[y].
  Example:
    x = jnp.array([1, 2, 3])
    y = jnp.array([4, 5]

    _dstack_product(x,y)
    >>> [[1, 4], [2, 4], [3, 4], [1, 5], [2, 5], [3, 5]]
  """
    return jnp.dstack(jnp.meshgrid(x, y, indexing="ij")).reshape(-1, 2)
예제 #9
0
def generate_grid(key, n_samples, min_val, max_val, n_clusters_per_axis):
    x, y = jnp.linspace(min_val, max_val, n_clusters_per_axis), jnp.linspace(
        min_val, max_val, n_clusters_per_axis)
    X, Y = jnp.meshgrid(x, y)
    xy = jnp.dstack([X, Y]).reshape((-1, 2))

    # Repeat the data so that we can add noise to different copies
    n_repeats = n_samples // (n_clusters_per_axis**2)
    data = jnp.repeat(xy, repeats=n_repeats, axis=0)

    # Add just enough noise so that we see each cluster without overlapping
    std = (max_val - min_val) / n_clusters_per_axis * 0.25

    noise = random.normal(key, data.shape) * std
    data += noise

    data = random.permutation(key, data)
    return data
예제 #10
0
    def predict(self, X, y=None, p=None):
        """

        Parameters
        ==========

        X : array_like, shape (n_samples, n_features)
            Stimulus design matrix.

        y : None or array_like, shape (n_samples, )
            Recorded response. Needed when post-spike filter is fitted.

        p : None or dict
            Model parameters. Only needed if model performance is monitored
            during training.

        """

        if self.n_c > 1:
            XS = jnp.dstack([X[:, :, i] @ self.S for i in range(self.n_c)
                             ]).reshape(X.shape[0], -1)
        else:
            XS = X @ self.S

        extra = {'X': X, 'XS': XS, 'y': y}

        if self.h_spl is not None:

            if y is None:
                raise ValueError(
                    '`y` is needed for calculating response history.')

            yh = jnp.array(
                build_design_matrix(extra['y'][:, jnp.newaxis],
                                    self.Sh.shape[0],
                                    shift=self.shift_h))
            yS = yh @ self.Sh
            extra.update({'yS': yS})

        params = self.p_opt if p is None else p
        y_pred = self.forwardpass(params, extra=extra)

        return y_pred
예제 #11
0
def make_gradient_field(function,
                        xrange=(-1, 2),
                        yrange=(-1, 2),
                        n_points=30,
                        shape=(2, 1)):
    W = jnp.linspace(*xrange, n_points)
    B = jnp.linspace(*yrange, n_points)
    U, V = jnp.meshgrid(W, B)
    pairs = jnp.dstack([U, V]).reshape(-1, *shape)

    vectorized_fun = jit(vmap(function))
    Z = vectorized_fun(pairs).reshape(n_points, n_points)

    grad_fun = jit(vmap(grad(function)))
    gradvals = grad_fun(pairs)

    gradx = gradvals[:, 0].reshape(n_points, n_points)
    grady = gradvals[:, 1].reshape(n_points, n_points)

    gradnorm = jnp.sqrt(gradx**2 + grady**2)

    return U, V, Z, pairs, gradvals, gradx, grady, gradnorm
예제 #12
0
    def _sample_next(sampler, machine, parameters: PyTree,
                     state: MetropolisPtSamplerState):
        new_rng, rng = jax.random.split(state.rng)
        # def cbr(data):
        #    new_rng, rng = data
        #    print("sample_next newrng:\n", new_rng,  "\nand rng:\n", rng)
        #    return new_rng
        # new_rng = hcb.call(
        #   cbr,
        #   (new_rng, rng),
        #   result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype),
        # )

        with loops.Scope() as s:
            s.key = rng
            s.σ = state.σ
            s.log_prob = sampler.machine_pow * machine(parameters,
                                                       state.σ).real
            s.beta = state.beta

            # for logging
            s.beta_0_index = state.beta_0_index
            s.n_accepted_per_beta = state.n_accepted_per_beta
            s.beta_position = state.beta_position
            s.beta_diffusion = state.beta_diffusion

            for i in s.range(sampler.n_sweeps):
                # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
                s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5)

                # def cbi(data):
                #    i, beta = data
                #    print("sweep #", i, " for beta=\n", beta)
                #    return beta
                # beta = hcb.call(
                #   cbi,
                #   (i, s.beta),
                #   result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype),
                # )
                beta = s.beta

                σp, log_prob_correction = sampler.rule.transition(
                    sampler, machine, parameters, state, key1, s.σ)
                proposal_log_prob = sampler.machine_pow * machine(
                    parameters, σp).real

                uniform = jax.random.uniform(key2, shape=(sampler.n_batches, ))
                if log_prob_correction is not None:
                    do_accept = uniform < jnp.exp(
                        beta.reshape((-1, )) *
                        (proposal_log_prob - s.log_prob + log_prob_correction))
                else:
                    do_accept = uniform < jnp.exp(
                        beta.reshape(
                            (-1, )) * (proposal_log_prob - s.log_prob))

                # do_accept must match ndim of proposal and state (which is 2)
                s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ)
                n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                s.log_prob = jax.numpy.where(do_accept.reshape(-1),
                                             proposal_log_prob, s.log_prob)

                # exchange betas

                # randomly decide if every set of replicas should be swapped in even or odd order
                swap_order = jax.random.randint(
                    key3,
                    minval=0,
                    maxval=2,
                    shape=(sampler.n_chains, ),
                )  # 0 or 1
                iswap_order = jnp.mod(swap_order + 1, 2)  #  1 or 0

                # indices of even swapped elements (per-row)
                idxs = jnp.arange(0, sampler.n_replicas, 2).reshape(
                    (1, -1)) + swap_order.reshape((-1, 1))
                # indices off odd swapped elements (per-row)
                inn = (idxs + 1) % sampler.n_replicas

                # for every rows of the input, swap elements at idxs with elements at inn
                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def swap_rows(beta_row, idxs, inn):
                    proposed_beta = jax.ops.index_update(
                        beta_row,
                        idxs,
                        beta_row[inn],
                        unique_indices=True,
                        indices_are_sorted=True,
                    )
                    proposed_beta = jax.ops.index_update(
                        proposed_beta,
                        inn,
                        beta_row[idxs],
                        unique_indices=True,
                        indices_are_sorted=False,
                    )
                    return proposed_beta

                proposed_beta = swap_rows(beta, idxs, inn)

                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def compute_proposed_prob(prob, idxs, inn):
                    prob_rescaled = prob[idxs] + prob[inn]
                    return prob_rescaled

                # compute the probability of the swaps
                log_prob = (proposed_beta - state.beta) * s.log_prob.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                uniform = jax.random.uniform(key4,
                                             shape=(sampler.n_chains,
                                                    sampler.n_replicas // 2))

                do_swap = uniform < prob_rescaled

                do_swap = jnp.dstack((do_swap, do_swap)).reshape(
                    (-1, sampler.n_replicas))  #  concat along last dimension
                # roll if swap_ordeer is odd
                @partial(jax.vmap, in_axes=(0, 0), out_axes=0)
                def fix_swap(do_swap, swap_order):
                    return jax.lax.cond(swap_order == 0, lambda x: x,
                                        lambda x: jnp.roll(x, 1), do_swap)

                do_swap = fix_swap(do_swap, swap_order)
                # jax.experimental.host_callback.id_print(state.beta)
                # jax.experimental.host_callback.id_print(proposed_beta)

                new_beta = jax.numpy.where(do_swap, proposed_beta, beta)

                def cb(data):
                    _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data
                    print("--------.---------.---------.--------")
                    print("     cur beta:\n", _bt)
                    print("proposed beta:\n", _pbt)
                    print("     new beta:\n", new_beta)
                    print("swaporder :", so)
                    print("do_swap :\n", do_swap)
                    print("log_prob;\n", log_prob)
                    print("prob_rescaled;\n", prob)
                    return new_beta

                # new_beta = hcb.call(
                #    cb,
                #    (
                #        beta,
                #        proposed_beta,
                #        new_beta,
                #        swap_order,
                #        do_swap,
                #        log_prob,
                #        prob_rescaled,
                #    ),
                #    result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype),
                # )
                # s.beta = new_beta

                swap_order = swap_order.reshape(-1)

                beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i],
                                        in_axes=(0, 0),
                                        out_axes=0)(do_swap,
                                                    state.beta_0_index)
                proposed_beta_0_index = jnp.mod(
                    state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) *
                    (-jnp.mod(state.beta_0_index, 2) * 2 + 1),
                    sampler.n_replicas,
                )

                s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index,
                                           s.beta_0_index)

                # swap acceptances
                swapped_n_accepted_per_beta = swap_rows(
                    n_accepted_per_beta, idxs, inn)
                s.n_accepted_per_beta = jax.numpy.where(
                    do_swap,
                    swapped_n_accepted_per_beta,
                    n_accepted_per_beta,
                )

                # Update statistics to compute diffusion coefficient of replicas
                # Total exchange steps performed
                delta = s.beta_0_index - s.beta_position
                s.beta_position = s.beta_position + delta / (
                    state.exchange_steps + i)
                delta2 = s.beta_0_index - s.beta_position
                s.beta_diffusion = s.beta_diffusion + delta * delta2

            new_state = state.replace(
                rng=new_rng,
                σ=s.σ,
                # n_accepted=s.accepted,
                n_samples=state.n_samples +
                sampler.n_sweeps * sampler.n_chains,
                beta=s.beta,
                beta_0_index=s.beta_0_index,
                beta_position=s.beta_position,
                beta_diffusion=s.beta_diffusion,
                exchange_steps=state.exchange_steps + sampler.n_sweeps,
                n_accepted_per_beta=s.n_accepted_per_beta,
            )

        offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas,
                             sampler.n_replicas)

        return new_state, new_state.σ[new_state.beta_0_index + offsets, :]
예제 #13
0
        def loop_body(i, s):
            # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
            s["key"], key1, key2, key3, key4 = jax.random.split(s["key"], 5)

            # def cbi(data):
            #    i, beta = data
            #    print("sweep #", i, " for beta=\n", beta)
            #    return beta
            #
            # beta = hcb.call(
            #   cbi,
            #   (i, s["beta"]),
            #   result_shape=jax.ShapeDtypeStruct(s["beta"].shape, s["beta"].dtype),
            # )

            beta = s["beta"]

            σp, log_prob_correction = sampler.rule.transition(
                sampler, machine, parameters, state, key1, s["σ"])
            proposal_log_prob = sampler.machine_pow * machine.apply(
                parameters, σp).real

            uniform = jax.random.uniform(key2, shape=(sampler.n_batches, ))
            if log_prob_correction is not None:
                do_accept = uniform < jnp.exp(
                    beta.reshape((-1, )) *
                    (proposal_log_prob - s["log_prob"] + log_prob_correction))
            else:
                do_accept = uniform < jnp.exp(
                    beta.reshape((-1, )) * (proposal_log_prob - s["log_prob"]))

            # do_accept must match ndim of proposal and state (which is 2)
            s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"])
            n_accepted_per_beta = s["n_accepted_per_beta"] + do_accept.reshape(
                (sampler.n_chains, sampler.n_replicas))

            s["log_prob"] = jax.numpy.where(do_accept.reshape(-1),
                                            proposal_log_prob, s["log_prob"])

            # exchange betas

            # randomly decide if every set of replicas should be swapped in even or odd order
            swap_order = jax.random.randint(
                key3,
                minval=0,
                maxval=2,
                shape=(sampler.n_chains, ),
            )  # 0 or 1
            # iswap_order = jnp.mod(swap_order + 1, 2)  #  1 or 0

            # indices of even swapped elements (per-row)
            idxs = jnp.arange(0, sampler.n_replicas, 2).reshape(
                (1, -1)) + swap_order.reshape((-1, 1))
            # indices off odd swapped elements (per-row)
            inn = (idxs + 1) % sampler.n_replicas

            # for every rows of the input, swap elements at idxs with elements at inn
            @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
            def swap_rows(beta_row, idxs, inn):
                proposed_beta = beta_row.at[idxs].set(beta_row[inn],
                                                      unique_indices=True,
                                                      indices_are_sorted=True)
                proposed_beta = proposed_beta.at[inn].set(
                    beta_row[idxs],
                    unique_indices=True,
                    indices_are_sorted=False)
                return proposed_beta

            proposed_beta = swap_rows(beta, idxs, inn)

            @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
            def compute_proposed_prob(prob, idxs, inn):
                prob_rescaled = prob[idxs] + prob[inn]
                return prob_rescaled

            # compute the probability of the swaps
            log_prob = (proposed_beta - state.beta) * s["log_prob"].reshape(
                (sampler.n_chains, sampler.n_replicas))

            prob_rescaled = jnp.exp(compute_proposed_prob(log_prob, idxs, inn))

            uniform = jax.random.uniform(key4,
                                         shape=(sampler.n_chains,
                                                sampler.n_replicas // 2))

            do_swap = uniform < prob_rescaled

            do_swap = jnp.dstack((do_swap, do_swap)).reshape(
                (-1, sampler.n_replicas))  # concat along last dimension

            # roll if swap_ordeer is odd
            @partial(jax.vmap, in_axes=(0, 0), out_axes=0)
            def fix_swap(do_swap, swap_order):
                return jax.lax.cond(swap_order == 0, lambda x: x,
                                    lambda x: jnp.roll(x, 1), do_swap)

            do_swap = fix_swap(do_swap, swap_order)
            # jax.experimental.host_callback.id_print(state.beta)
            # jax.experimental.host_callback.id_print(proposed_beta)

            # new_beta = jax.numpy.where(do_swap, proposed_beta, beta)

            # def cb(data):
            #     _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data
            #     print("--------.---------.---------.--------")
            #     print("     cur beta:\n", _bt)
            #     print("proposed beta:\n", _pbt)
            #     print("     new beta:\n", new_beta)
            #     print("swaporder :", so)
            #     print("do_swap :\n", do_swap)
            #     print("log_prob;\n", log_prob)
            #     print("prob_rescaled;\n", prob)
            #     return new_beta
            #
            # new_beta = hcb.call(
            #    cb,
            #    (
            #        beta,
            #        proposed_beta,
            #        new_beta,
            #        swap_order,
            #        do_swap,
            #        log_prob,
            #        prob_rescaled,
            #    ),
            #    result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype),
            # )
            # s["beta"] = new_beta

            swap_order = swap_order.reshape(-1)

            beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i],
                                    in_axes=(0, 0),
                                    out_axes=0)(do_swap, state.beta_0_index)
            proposed_beta_0_index = jnp.mod(
                state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) *
                (-jnp.mod(state.beta_0_index, 2) * 2 + 1),
                sampler.n_replicas,
            )

            s["beta_0_index"] = jnp.where(beta_0_moved, proposed_beta_0_index,
                                          s["beta_0_index"])

            # swap acceptances
            swapped_n_accepted_per_beta = swap_rows(n_accepted_per_beta, idxs,
                                                    inn)
            s["n_accepted_per_beta"] = jax.numpy.where(
                do_swap,
                swapped_n_accepted_per_beta,
                n_accepted_per_beta,
            )

            # Update statistics to compute diffusion coefficient of replicas
            # Total exchange steps performed
            delta = s["beta_0_index"] - s["beta_position"]
            s["beta_position"] = s["beta_position"] + delta / (
                state.exchange_steps + jnp.asarray(i, dtype=jnp.int64))
            delta2 = s["beta_0_index"] - s["beta_position"]
            s["beta_diffusion"] = s["beta_diffusion"] + delta * delta2

            return s
예제 #14
0
def dstack(arrays):
  arrays = [a.value if isinstance(a, JaxArray) else a for a in arrays]
  return JaxArray(jnp.dstack(arrays))
예제 #15
0
파일: evaluate.py 프로젝트: C-J-Cundy/NuX
def evaluate_2d_model(create_model, args, classification=False):
    assert args.save_path.endswith(".pickle") == False

    init_key = random.PRNGKey(args.init_key_seed)
    train_key = random.PRNGKey(args.train_key_seed)
    eval_key = random.PRNGKey(args.eval_key_seed)

    train_ds, get_test_ds = get_dataset(args.dataset,
                                        args.batch_size,
                                        args.n_batches,
                                        args.test_batch_size,
                                        args.test_n_batches,
                                        quantize_bits=args.quantize_bits,
                                        classification=classification,
                                        label_keep_percent=1.0,
                                        random_label_percent=0.0)

    doubly_batched_inputs = next(train_ds)
    inputs = {"x": doubly_batched_inputs["x"][0]}

    if "y" in doubly_batched_inputs:
        inputs["y"] = doubly_batched_inputs["y"][0]

    flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, ))

    outputs = flow.apply(init_key, inputs)

    print("n_params", flow.n_params)

    trainer = initialize_trainer(flow,
                                 clip=args.clip,
                                 lr=args.lr,
                                 warmup=args.warmup,
                                 cosine_decay_steps=args.cosine_decay_steps,
                                 save_path=args.save_path,
                                 retrain=args.retrain,
                                 train_args=args.train_args,
                                 classification=classification)

    test_losses = sorted(trainer.test_losses.items(), key=lambda x: x[0])
    test_losses = jnp.array(test_losses)

    test_ds = get_test_ds()
    res = trainer.evaluate_test(eval_key, test_ds)
    print("test", trainer.summarize_losses_and_aux(res))

    # Plot samples
    samples = flow.sample(eval_key, n_samples=5000, manifold_sample=True)

    # Find the spread of the data
    data = doubly_batched_inputs["x"].reshape((-1, 2))
    (xmin, ymin), (xmax, ymax) = data.min(axis=0), data.max(axis=0)
    xspread, yspread = xmax - xmin, ymax - ymin
    xmin -= 0.25 * xspread
    xmax += 0.25 * xspread
    ymin -= 0.25 * yspread
    ymax += 0.25 * yspread

    # Plot the samples against the true samples and also a dentisy plot
    if "prediction" in samples:
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(28, 7))
        ax1.scatter(*data.T)
        ax1.set_title("True Samples")
        ax2.scatter(*samples["x"].T, alpha=0.2, s=3, c=samples["prediction"])
        ax2.set_title("Learned Samples")
        ax1.set_xlim(xmin, xmax)
        ax1.set_ylim(ymin, ymax)
        ax2.set_xlim(xmin, xmax)
        ax2.set_ylim(ymin, ymax)

        n_importance_samples = 100
        x_range, y_range = jnp.linspace(xmin, xmax,
                                        100), jnp.linspace(ymin, ymax, 100)
        X, Y = jnp.meshgrid(x_range, y_range)
        XY = jnp.dstack([X, Y]).reshape((-1, 2))
        XY = jnp.broadcast_to(XY[None, ...],
                              (n_importance_samples, ) + XY.shape)
        outputs = flow.scan_apply(eval_key, {"x": XY})
        outputs["log_px"] = jax.scipy.special.logsumexp(
            outputs["log_px"], axis=0) - jnp.log(n_importance_samples)
        outputs["prediction"] = jnp.mean(outputs["prediction"], axis=0)

        Z = jnp.exp(outputs["log_px"])
        ax3.contourf(X, Y, Z.reshape(X.shape))
        ax4.contourf(X, Y, outputs["prediction"].reshape(X.shape))
        plt.show()
    else:
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(21, 7))
        ax1.scatter(*data.T)
        ax1.set_title("True Samples")
        ax2.scatter(*samples["x"].T, alpha=0.2, s=3)
        ax2.set_title("Learned Samples")
        ax1.set_xlim(xmin, xmax)
        ax1.set_ylim(ymin, ymax)
        ax2.set_xlim(xmin, xmax)
        ax2.set_ylim(ymin, ymax)

        n_importance_samples = 100
        x_range, y_range = jnp.linspace(xmin, xmax,
                                        100), jnp.linspace(ymin, ymax, 100)
        X, Y = jnp.meshgrid(x_range, y_range)
        XY = jnp.dstack([X, Y]).reshape((-1, 2))
        XY = jnp.broadcast_to(XY[None, ...],
                              (n_importance_samples, ) + XY.shape)
        outputs = flow.scan_apply(eval_key, {"x": XY})
        outputs["log_px"] = jax.scipy.special.logsumexp(
            outputs["log_px"], axis=0) - jnp.log(n_importance_samples)

        Z = jnp.exp(outputs["log_px"])
        ax3.contourf(X, Y, Z.reshape(X.shape))
        plt.show()

    assert 0
예제 #16
0
    def fit(self,
            p0=None,
            extra=None,
            num_subunits=2,
            num_epochs=1,
            num_iters=3000,
            initialize='random',
            metric=None,
            alpha=1,
            beta=0.05,
            fit_linear_filter=True,
            fit_intercept=True,
            fit_R=True,
            fit_history_filter=False,
            fit_nonlinearity=False,
            step_size=1e-2,
            tolerance=10,
            verbose=100,
            random_seed=2046,
            return_model=None):

        self.metric = metric

        self.alpha = alpha  # elastic net parameter (1=L1, 0=L2)
        self.beta = beta  # elastic net parameter - global penalty weight

        self.n_s = num_subunits
        self.num_iters = num_iters

        self.fit_linear_filter = fit_linear_filter
        self.fit_history_filter = fit_history_filter
        self.fit_nonlinearity = fit_nonlinearity
        self.fit_intercept = fit_intercept
        self.fit_R = fit_R

        # initialize parameters
        if p0 is None:
            p0 = {}

        dict_keys = p0.keys()
        if 'b' not in dict_keys:
            if initialize == 'random':  # not necessary, but for consistency with others.
                key = random.PRNGKey(random_seed)
                b0 = 0.01 * random.normal(
                    key, shape=(self.n_b * self.n_c * self.n_s, )).flatten()
                p0.update({'b': b0})

        if 'intercept' not in dict_keys:
            p0.update({'intercept': jnp.zeros(1)})

        if 'R' not in dict_keys:
            p0.update({'R': jnp.array([1.])})

        if 'bh' not in dict_keys:
            try:
                p0.update({'bh': self.bh_spl})
            except:
                p0.update({'bh': None})

        if 'nl_params' not in dict_keys:
            if self.nl_params is not None:
                p0.update({
                    'nl_params': [self.nl_params for _ in range(self.n_s + 1)]
                })
            else:
                p0.update({'nl_params': [None for _ in range(self.n_s + 1)]})

        if extra is not None:

            if self.n_c > 1:
                XS_ext = jnp.dstack([
                    extra['X'][:, :, i] @ self.S for i in range(self.n_c)
                ]).reshape(extra['X'].shape[0], -1)
                extra.update({'XS': XS_ext})
            else:
                extra.update({'XS': extra['X'] @ self.S})

            if self.h_spl is not None:
                yh = jnp.array(
                    build_design_matrix(extra['y'][:, jnp.newaxis],
                                        self.Sh.shape[0],
                                        shift=1))
                yS = yh @ self.Sh
                extra.update({'yS': yS})

            extra = {key: jnp.array(extra[key]) for key in extra.keys()}

        self.p0 = p0
        self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters,
                                          metric, step_size, tolerance,
                                          verbose, return_model)

        self.R = self.p_opt['R'] if fit_R else jnp.array([1.])

        if fit_linear_filter:
            self.b_opt = self.p_opt['b']

            if self.n_c > 1:
                self.w_opt = jnp.stack([(self.S @ self.b_opt.reshape(
                    self.n_b, self.n_c, self.n_s)[:, :, i])
                                        for i in range(self.n_s)],
                                       axis=-1)
            else:
                self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_s)

        if fit_history_filter:
            self.bh_opt = self.p_opt['bh']
            self.h_opt = self.Sh @ self.bh_opt

        if fit_intercept:
            self.intercept = self.p_opt['intercept']

        if fit_nonlinearity:
            self.nl_params_opt = self.p_opt['nl_params']
예제 #17
0
    def __init__(self,
                 X,
                 y,
                 dims,
                 df,
                 smooth='cr',
                 compute_mle=False,
                 **kwargs):
        """

        Parameters
        ==========
        X : array_like, shape (n_samples, n_features)
            Stimulus design matrix.

        y : array_like, shape (n_samples, )
            Recorded response.

        dims : list or array_like, shape (ndims, )
            Dimensions or shape of the RF to estimate. Assumed order [t, sx, sy].

        df : list or array_like, shape (ndims, )
            Degree of freedom, or the number of basis used for each RF dimension. 

        smooth : str
            Type of basis. 
            * cr: natural cubic spline (default)
            * cc: cyclic cubic spline
            * bs: B-spline
            * tp: thin plate spine

        compute_mle : bool
            Compute sta and maximum likelihood optionally.

        """

        super().__init__(X, y, dims, compute_mle, **kwargs)

        # Optimization
        self.bh_opt = None
        self.b_opt = None
        self.extra = None
        self.h_spl = None
        self.bh_spl = None
        self.yS = None
        self.Sh = None

        # Parameters
        self.df = df  # number basis / degree of freedom
        self.smooth = smooth  # type of basis

        S = jnp.array(build_spline_matrix(self.dims, df, smooth))  # for w

        if self.n_c > 1:
            XS = jnp.dstack([self.X[:, :, i] @ S for i in range(self.n_c)
                             ]).reshape(self.n_samples, -1)
        else:
            XS = self.X @ S

        self.S = S  # spline matrix
        self.XS = XS

        self.n_b = S.shape[1]  # num:ber of spline coefficients

        # compute spline-based maximum likelihood
        self.b_spl = jnp.linalg.lstsq(XS.T @ XS, XS.T @ y, rcond=None)[0]

        if self.n_c > 1:
            self.w_spl = S @ self.b_spl.reshape(self.n_b, self.n_c)
        else:
            self.w_spl = S @ self.b_spl
예제 #18
0
    def fit(self,
            p0=None,
            extra=None,
            initialize='random',
            num_epochs=1,
            num_iters=3000,
            metric=None,
            alpha=1,
            beta=0.05,
            fit_linear_filter=True,
            fit_intercept=True,
            fit_R=True,
            fit_history_filter=False,
            fit_nonlinearity=False,
            step_size=1e-2,
            tolerance=10,
            verbose=100,
            random_seed=2046,
            return_model=None):
        """

        Parameters
        ==========

        p0 : dict
            * 'b': Initial spline coefficients.
            * 'bh': Initial response history filter coefficients

        initialize : None or str
            Parametric initialization.
            * if `initialize=None`, `b` will be initialized by b_spl.
            * if `initialize='random'`, `b` will be randomly initialized.

        num_iters : int
            Max number of optimization iterations.

        metric : None or str
            Extra cross-validation metric. Default is `None`. Or
            * 'mse': mean squared error
            * 'r2': R2 score
            * 'corrcoef': Correlation coefficient

        alpha : float, from 0 to 1.
            Elastic net parameter, balance between L1 and L2 regularization.
            * 0.0 -> only L2
            * 1.0 -> only L1

        beta : float
            Elastic net parameter, overall weight of regularization for receptive field.

        step_size : float
            Initial step size for JAX optimizer.

        tolerance : int
            Set early stop tolerance. Optimization stops when cost monotonically
            increases or stop increases for tolerance=n steps.

        verbose: int
            When `verbose=0`, progress is not printed. When `verbose=n`,
            progress will be printed in every n steps.

        """

        self.metric = metric

        self.alpha = alpha
        self.beta = beta  # elastic net parameter - global penalty weight for linear filter
        self.num_iters = num_iters

        self.fit_R = fit_R
        self.fit_linear_filter = fit_linear_filter
        self.fit_history_filter = fit_history_filter
        self.fit_nonlinearity = fit_nonlinearity
        self.fit_intercept = fit_intercept

        # initial parameters

        if p0 is None:
            p0 = {}

        dict_keys = p0.keys()
        if 'b' not in dict_keys:
            if initialize is None:
                p0.update({'b': self.b_spl})
            else:
                if initialize == 'random':
                    key = random.PRNGKey(random_seed)
                    b0 = 0.01 * random.normal(
                        key, shape=(self.n_b * self.n_c, )).flatten()
                    p0.update({'b': b0})

        if 'intercept' not in dict_keys:
            p0.update({'intercept': jnp.array([0.])})

        if 'R' not in dict_keys:
            p0.update({'R': jnp.array([1.])})

        if 'bh' not in dict_keys:
            if initialize is None and self.bh_spl is not None:
                p0.update({'bh': self.bh_spl})
            elif initialize == 'random' and self.bh_spl is not None:
                key = random.PRNGKey(random_seed)
                bh0 = 0.01 * random.normal(key, shape=(len(
                    self.bh_spl), )).flatten()
                p0.update({'bh': bh0})
            else:
                p0.update({'bh': None})

        if 'nl_params' not in dict_keys:
            if self.nl_params is not None:
                p0.update({'nl_params': self.nl_params})
            else:
                p0.update({'nl_params': None})

        if extra is not None:

            if self.n_c > 1:
                XS_ext = jnp.dstack([
                    extra['X'][:, :, i] @ self.S for i in range(self.n_c)
                ]).reshape(extra['X'].shape[0], -1)
                extra.update({'XS': XS_ext})
            else:
                extra.update({'XS': extra['X'] @ self.S})

            if self.h_spl is not None:
                yh_ext = jnp.array(
                    build_design_matrix(extra['y'][:, jnp.newaxis],
                                        self.Sh.shape[0],
                                        shift=1))
                yS_ext = yh_ext @ self.Sh
                extra.update({'yS': yS_ext})

            extra = {key: jnp.array(extra[key]) for key in extra.keys()}

            self.extra = extra  # store for cross-validation

        # store optimized parameters
        self.p0 = p0
        self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters,
                                          metric, step_size, tolerance,
                                          verbose, return_model)
        self.R = self.p_opt['R'] if fit_R else jnp.array([1.])

        if fit_linear_filter:
            self.b_opt = self.p_opt['b']  # optimized RF basis coefficients
            if self.n_c > 1:
                self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_c)
            else:
                self.w_opt = self.S @ self.b_opt  # optimized RF

        if fit_history_filter:
            self.bh_opt = self.p_opt['bh']
            self.h_opt = self.Sh @ self.bh_opt

        if fit_nonlinearity:
            self.nl_params_opt = self.p_opt['nl_params']

        if fit_intercept:
            self.intercept = self.p_opt['intercept']
예제 #19
0
    def _plot(fig,
              ax1,
              ax2,
              mean,
              sigma,
              array_samples_theta,
              interactive=False):
        colorbar = None
        colorbar_2 = None

        plt.gca()
        # plt.cla()
        # plt.clf()
        fig.clear()
        fig.add_axes(ax1)
        fig.add_axes(ax2)

        plt.cla()

        xlim = (-5., 5.)
        ylim = (-5., 5.)
        xlist = np.linspace(*xlim, 100)
        ylist = np.linspace(*ylim, 100)
        X_, Y_ = np.meshgrid(xlist, ylist)
        Z = np.dstack((X_, Y_))
        Z = Z.reshape(-1, 2)
        predictions = onp.mean(probability_class_1(Z, array_samples_theta),
                               axis=1)
        predictions = predictions.reshape(100, 100)
        # print("finished")
        ax1.clear()
        if np.size(predictions):
            CS = ax1.contourf(X_, Y_, predictions, cmap="cividis")
        ax1.scatter(X_1[:, 0], X_1[:, 1])
        ax1.scatter(X_2[:, 0], X_2[:, 1])
        ax1.set_xlim(*xlim)
        ax1.set_ylim(*ylim)
        ax1.set_title("Predicted probability of belonging to C_1")
        ax3 = fig.add_axes(Bbox([[0.43, 0.11], [0.453, 0.88]]))
        if np.size(predictions):
            colorbar = fig.colorbar(
                CS,
                cax=ax3,
            )
        ax1.set_position(Bbox([[0.125, 0.11], [0.39, 0.88]]))

        x_prior = np.linspace(-3, 3, 100)
        y_prior = np.linspace(-3, 3, 100)
        X_prior, Y_prior = np.meshgrid(x_prior, y_prior)
        Z = np.dstack((X_prior, Y_prior))
        Z = Z.reshape(-1, 2)
        prior_values = multivariate_normal.pdf(Z, np.zeros(2), np.identity(2))
        prior_values = prior_values.reshape(100, 100)

        std_x = onp.sqrt(sigma[0, 0])
        std_y = onp.sqrt(sigma[1, 1])
        x_posterior = np.linspace(mean[0] - 3 * std_x, mean[0] + 3 * std_x,
                                  100)
        y_posterior = np.linspace(mean[1] - 3 * std_y, mean[1] + 3 * std_y,
                                  100)
        X_post, Y_post = np.meshgrid(x_posterior, y_posterior)

        Z_post = np.dstack((X_post, Y_post)).reshape(-1, 2)
        posterior_values = multivariate_normal.pdf(Z_post, mean, sigma)
        posterior_values = posterior_values.reshape(100, 100)

        ax2.contour(X_post, Y_post, posterior_values)
        ax2.contour(X_, Y_, prior_values, cmap="inferno")
        ax2.set_title("Two contour plots respectively showing\n"
                      "The prior and the approximated posterior distributions")

        plt.pause(0.001)
        if interactive:
            if np.size(predictions):
                colorbar.remove()

        return True