예제 #1
0
    def sum_from_unique(
            cls,
            input: np.array,
            mean: bool = True) -> Tuple[np.array, np.array, "SparseReduce"]:
        un, cts = np.unique(input, return_counts=True)
        un_idx = [
            np.argwhere(input == un[i]).flatten() for i in range(un.size)
        ]
        l_arr = np.array([i.size for i in un_idx])
        argsort = np.argsort(l_arr)
        un_sorted = un[argsort]
        cts_sorted = cts[argsort]
        un_idx_sorted = [un_idx[i] for i in argsort]

        change = list(
            np.argwhere(
                l_arr[argsort][:-1] - l_arr[argsort][1:] != 0).flatten() + 1)
        change.insert(0, 0)
        change.append(len(l_arr))
        change = np.array(change)

        el = []
        for i in range(len(change) - 1):
            el.append(
                np.array([
                    un_idx_sorted[j] for j in range(change[i], change[i + 1])
                ]))

        #assert False
        return un_sorted, cts_sorted, SparseReduce(el, mean)
예제 #2
0
 def scatter(self,
             indices,
             values,
             shape,
             duplicates_handling='undefined',
             outside_handling='undefined'):
     assert duplicates_handling in ('undefined', 'add', 'mean', 'any')
     assert outside_handling in ('discard', 'clamp', 'undefined')
     shape = jnp.array(shape, jnp.int32)
     if outside_handling == 'clamp':
         indices = jnp.maximum(0, jnp.minimum(indices, shape - 1))
     elif outside_handling == 'discard':
         indices_inside = (indices >= 0) & (indices < shape)
         indices_inside = jnp.min(indices_inside, axis=-1)
         filter_indices = jnp.argwhere(indices_inside)
         indices = indices[filter_indices][..., 0, :]
         if values.shape[0] > 1:
             values = values[filter_indices.reshape(-1)]
     array = jnp.zeros(
         tuple(shape) + values.shape[indices.ndim - 1:],
         to_numpy_dtype(self.float_type))
     indices = self.unstack(indices, axis=-1)
     if duplicates_handling == 'add':
         jnp.add.at(array, tuple(indices), values)
     elif duplicates_handling == 'mean':
         count = jnp.zeros(shape, jnp.int32)
         jnp.add.at(array, tuple(indices), values)
         jnp.add.at(count, tuple(indices), 1)
         count = jnp.maximum(1, count)
         return array / count
     else:  # last, any, undefined
         array[indices] = values
     return array
예제 #3
0
파일: hsgp.py 프로젝트: pyro-ppl/numpyro
def plot_1988(data, samples, ax=None):
    indicators = get_floating_days_indicators(data["date"])
    memorial_beta = samples["memorial/beta"][:, None]
    labour_beta = samples["labour/beta"][:, None]
    thanks_beta = samples["thanksgiving/beta"][:, None]

    memorials = indicators["memorial_days_indicator"] * memorial_beta
    labour = indicators["labour_days_indicator"] * labour_beta
    thanksgiving = indicators["thanksgiving_days_indicator"] * thanks_beta
    floating_days = memorials + labour + thanksgiving

    is_1988 = data["date"].dt.year == 1988
    days_in_1988 = data["day_of_year"][is_1988] - 1
    days_effect = samples["day/beta"][:, days_in_1988.values]
    floating_effect = floating_days[:, jnp.argwhere(is_1988.values).ravel()]

    y = data["births_relative"]
    f = (days_effect + floating_effect) * y.std() + y.mean()
    f_median = jnp.median(f, axis=0)

    special_days = {
        "Valentine's": "1988-02-14",
        "Leap day": "1988-02-29",
        "Halloween": "1988-10-31",
        "Christmas eve": "1988-12-24",
        "Christmas day": "1988-12-25",
        "New year": "1988-01-01",
        "New year's eve": "1988-12-31",
        "April 1st": "1988-04-01",
        "Independence day": "1988-07-04",
        "Labour day": "1988-09-05",
        "Memorial day": "1988-05-30",
        "Thanksgiving": "1988-11-24",
    }

    if ax is None:
        ax = plt.gca()

    ax.plot(days_in_1988, f_median, color="k", lw=2)

    for name, date in special_days.items():
        xs = pd.to_datetime(date).day_of_year - 1
        ys = f_median[xs]
        text = ax.text(xs - 3, ys, name, horizontalalignment="right")
        text.set_bbox(dict(facecolor="white", alpha=0.5, edgecolor="none"))

    is_day_13 = data["date"].dt.day == 13
    bad_luck_days = data.loc[is_1988 & is_day_13, "day_of_year"] - 1
    ax.plot(
        bad_luck_days,
        f_median[bad_luck_days.values],
        marker="o",
        mec="gray",
        c="none",
        ms=10,
        lw=0,
    )

    return ax
예제 #4
0
    def get_grad_accumulator(self):
        if self.grad_fn is not None:
            raise RuntimeError(
                "get_grad_accumulator() should be only called on leaf Variables"
            )

        if len(jnp.argwhere(self.gamma == 0)) != 0 and self.requires_grad:
            return jnp.zeros(shape=self.gamma)
예제 #5
0
    def sum_from_block(cls,
                       input: np.array,
                       mean: bool = True) -> "BlockReduce":
        change = list(np.argwhere(input[:-1] - input[1:] != 0).flatten() + 1)
        change.insert(0, 0)
        change.append(len(input))
        change = np.array(change)

        return BlockReduce(change, mean)
예제 #6
0
def get_minimum_zeroth_element(x: Array, window_size: int = 10) -> int:

    # window for the convolution
    window = np.ones(window_size) / window_size

    # rolling average
    x_cumsum_window = np.convolve(np.abs(x), window, "valid")

    # get minimum zeroth element
    min_idx = int(np.min(np.argwhere(x_cumsum_window == 0.0)[0]))
    return min_idx
예제 #7
0
 def sum_from_unique(
         cls,
         input: Array,
         mean: bool = True) -> Tuple[np.array, np.array, "LinearReduce"]:
     un, cts = np.unique(input, return_counts=True)
     un_idx = [
         np.argwhere(input == un[i]).flatten() for i in range(un.size)
     ]
     m = np.zeros((len(un_idx), input.shape[0]))
     for i, idx in enumerate(un_idx):
         b = np.ones(int(cts[i].squeeze())).squeeze()
         m = m.at[i, idx.squeeze()].set(b / cts[i].squeeze() if mean else b)
     return un, cts, LinearReduce(m)
예제 #8
0
    def calc_accuracy(
        params: hk.Params, sample: Tuple[np.ndarray, np.ndarray, np.ndarray,
                                         np.ndarray]) -> float:
        doc, context, target, labels = sample

        device_target = jax.device_put(target)
        probs = model.apply(params, doc, context)

        positive_idxs = jnp.squeeze(jnp.argwhere(labels))
        predicted_class = jnp.argmax(probs, axis=-1)

        return jnp.mean(
            predicted_class[positive_idxs] == device_target[positive_idxs])
예제 #9
0
def argwhere(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.argwhere(x))
예제 #10
0
 def nonzero(self, values):
     return jnp.argwhere(values)
예제 #11
0
 def where(self, condition, x=None, y=None):
     if x is None or y is None:
         return jnp.argwhere(condition)
     return jnp.where(condition, x, y)
예제 #12
0
def plot_bars(design: PVDesign, filename=None) -> None:

    _, ax1 = plt.subplots()
    ax1.set_zorder(1)
    ax1.patch.set_visible(False)

    Ec = scales.energy * physics.Ec(design)
    Ev = scales.energy * physics.Ev(design)
    EF = scales.energy * physics.EF(design)
    dim_grid = scales.length * design.grid * 1e4
    if design.PhiM0 > 0 and design.PhiML > 0:
        ax1.margins(x=.2, y=.5)
    else:
        ax1.margins(y=.5)
        ax1.set_xlim(0, dim_grid[-1])

    idx = jnp.concatenate(
        [jnp.array([0]),
         jnp.argwhere(Ec[:-1] != Ec[1:]).flatten() + 1])

    uc = Ec[idx]
    uv = Ev[idx]
    startx = dim_grid[idx]
    starty = uv
    height = uc - uv
    width = jnp.diff(jnp.append(startx, dim_grid[-1]))

    for i in range(startx.size):
        x, y, w, h = startx[i], starty[i], width[i], height[i]
        rect = Rectangle((x, y),
                         w,
                         h,
                         color=COLORS[i % len(COLORS)],
                         linewidth=0,
                         alpha=.2)
        ax1.add_patch(rect)
        ax1.text(x + w / 2,
                 y + h + .1,
                 round(y + h, 2),
                 ha="center",
                 va="bottom")
        ax1.text(x + w / 2, y - .1, round(y, 2), ha="center", va="top")

    ax1.plot(dim_grid, EF, linestyle="--", color="black", label="$E_{F}$")

    if design.PhiM0 > 0:
        phim0 = -design.PhiM0 * scales.energy
        xstart, _ = ax1.get_xlim()
        width = -xstart
        height = .2
        ystart = phim0 - height / 2
        rect = Rectangle((xstart, ystart),
                         width,
                         height,
                         color="red",
                         linewidth=0,
                         alpha=.2)
        ax1.add_patch(rect)
        ax1.text(xstart + width / 2,
                 ystart + height + 0.1,
                 round(phim0, 2),
                 ha="center",
                 va="bottom")
        ax1.text(xstart + width / 2,
                 ystart - 0.1,
                 "contact",
                 ha="center",
                 va="top")
        ax1.axhline(y=phim0,
                    xmin=0,
                    xmax=1 / 7,
                    linestyle="--",
                    color="black",
                    linewidth=2)
        ax1.axvline(dim_grid[0], color="white", linewidth=2)
        ax1.axvline(dim_grid[0],
                    color="lightgray",
                    linewidth=2,
                    linestyle="dashed")

    if design.PhiML > 0:
        phiml = -design.PhiML * scales.energy
        xstart = dim_grid[-1]
        _, xend = ax1.get_xlim()
        width = xend - xstart
        height = .2
        ystart = phiml - height / 2
        rect = Rectangle((xstart, ystart),
                         width,
                         height,
                         color="blue",
                         linewidth=0,
                         alpha=.2)
        ax1.add_patch(rect)
        ax1.text(xstart + width / 2,
                 ystart + height + 0.1,
                 round(phiml, 2),
                 ha="center",
                 va="bottom")
        ax1.text(xstart + width / 2,
                 ystart - 0.1,
                 "contact",
                 ha="center",
                 va="top")
        ax1.axhline(y=phiml,
                    xmin=6 / 7,
                    xmax=1,
                    linestyle="--",
                    color="black",
                    linewidth=2)
        ax1.axvline(dim_grid[-1], color="white", linewidth=2)
        ax1.axvline(dim_grid[-1],
                    color="lightgray",
                    linewidth=2,
                    linestyle="dashed")

    posline = jnp.argwhere(design.Ndop[:-1] != design.Ndop[1:]).flatten()

    for idx in posline:
        vpos = (dim_grid[idx] + dim_grid[idx + 1]) / 2
        ax1.axvline(vpos, color="white", linewidth=4)
        ax1.axvline(vpos, color="lightgray", linewidth=2, linestyle="dashed")

    ax1.set_ylim(jnp.min(uv) * 1.5, 0)
    ax1.set_xlabel("position / μm")
    ax1.set_ylabel("energy / eV")
    ax1.legend()

    plt.tight_layout()
    if filename is not None:
        plt.savefig(filename)
    plt.show()
    def set_accepted(self, ϵ, smoothing=None):
        """Sets the accepted and rejected attributes of the containers

        Using a distance (or list of distances for each target) cutoff between
        simulation summaries and summaries from some target the accepted and
        rejected parameter values (and summaries) can be defined. These points
        are used to make an approximate set of marginal distributions for
        plotting based on histogramming the points - where smoothing can be
        performed on the histogram to avoid undersampling artefacts.

        Parameters
        ----------
        ϵ : float or float(n_targets)
            The acceptance distance between summaries from simulations and the
            summary of the target data. A different epsilon can be passed for
            each target.
        smoothing : float or None, default=None
            A Gaussian smoothing for the marginal distributions

        Methods
        -------
        get_accepted:
            Returns a boolean array with whether summaries are within `ϵ`
        """
        def get_accepted(distances, ϵ):
            """ Returns a boolean array with whether summaries are within `ϵ`

            Parameters
            ----------
            distances : float(any)
                The distances between the summary of a target and the summaries
                of the run simulations
            ϵ : float
                The acceptance distance between summaries from simulations and
                the summary of the target data
            """
            return np.less(distances, ϵ)

        if not isinstance(ϵ, list):
            accepted = jax.vmap(lambda distances: get_accepted(distances, ϵ))(
                self.distances.all)
        else:
            accepted = jax.vmap(get_accepted)(self.distances.all, ϵ)

        rejected = ~accepted
        accepted_inds = [np.argwhere(accept)[:, 0] for accept in accepted]
        rejected_inds = [np.argwhere(reject)[:, 0] for reject in rejected]
        self.parameters.ϵ = ϵ
        self.summaries.ϵ = ϵ
        self.distances.ϵ = ϵ
        self.parameters.accepted = [
            self.parameters.all[inds] for inds in accepted_inds
        ]
        self.parameters.n_accepted = np.array([
            self.parameters.accepted[i].shape[0] for i in range(self.n_targets)
        ])
        self.summaries.accepted = [
            self.summaries.all[inds] for inds in accepted_inds
        ]
        self.summaries.n_accepted = np.array([
            self.summaries.accepted[i].shape[0] for i in range(self.n_targets)
        ])
        self.distances.accepted = [
            self.distances.all[i, inds] for i, inds in enumerate(accepted_inds)
        ]
        self.distances.n_accepted = np.array([
            self.distances.accepted[i].shape[0] for i in range(self.n_targets)
        ])
        self.parameters.rejected = [
            self.parameters.all[inds] for inds in rejected_inds
        ]
        self.summaries.rejected = [
            self.summaries.all[inds] for inds in rejected_inds
        ]
        self.distances.rejected = [
            self.distances.all[i, inds] for i, inds in enumerate(rejected_inds)
        ]
        self.marginals = self.get_marginals(smoothing=smoothing)