def probsample(n_samples, prob, return_counts=True): if return_counts: bins = torch.cat([torch.zeros(1), torch.cumsum(prob, dim=0)]) bins[-1] = 1 res = torch.histogram(torch.rand(n_samples), bins) return res.hist bins = torch.cumsum(prob, dim=0) bins[-1] = 1 return torch.bucketize(torch.rand((n_samples, )), bins)
def other_ops(self): a = torch.randn(4) b = torch.randn(4) c = torch.randint(0, 8, (5, ), dtype=torch.int64) e = torch.randn(4, 3) f = torch.randn(4, 4, 4) size = [0, 1] dims = [0, 1] return ( torch.atleast_1d(a), torch.atleast_2d(a), torch.atleast_3d(a), torch.bincount(c), torch.block_diag(a), torch.broadcast_tensors(a), torch.broadcast_to(a, (4)), # torch.broadcast_shapes(a), torch.bucketize(a, b), torch.cartesian_prod(a), torch.cdist(e, e), torch.clone(a), torch.combinations(a), torch.corrcoef(a), # torch.cov(a), torch.cross(e, e), torch.cummax(a, 0), torch.cummin(a, 0), torch.cumprod(a, 0), torch.cumsum(a, 0), torch.diag(a), torch.diag_embed(a), torch.diagflat(a), torch.diagonal(e), torch.diff(a), torch.einsum("iii", f), torch.flatten(a), torch.flip(e, dims), torch.fliplr(e), torch.flipud(e), torch.kron(a, b), torch.rot90(e), torch.gcd(c, c), torch.histc(a), torch.histogram(a), torch.meshgrid(a), torch.lcm(c, c), torch.logcumsumexp(a, 0), torch.ravel(a), torch.renorm(e, 1, 0, 5), torch.repeat_interleave(c), torch.roll(a, 1, 0), torch.searchsorted(a, b), torch.tensordot(e, e), torch.trace(e), torch.tril(e), torch.tril_indices(3, 3), torch.triu(e), torch.triu_indices(3, 3), torch.vander(a), torch.view_as_real(torch.randn(4, dtype=torch.cfloat)), torch.view_as_complex(torch.randn(4, 2)), torch.resolve_conj(a), torch.resolve_neg(a), )
def binSpikes(spikes, bin_edges): bin_counts, bin_edges_output = torch.histogram(input=spikes, bins=bin_edges) bin_centers = (bin_edges_output[:-1]+bin_edges_output[1:])/2.0 return bin_counts, bin_centers