def tv(x: Tensor, norm: str = 'L1') -> Tensor: r"""Returns the TV of :math:`x`. With `'L1'`, .. math:: \text{TV}(x) = \sum_{i, j} \left| x_{i+1, j} - x_{i, j} \right| + \left| x_{i, j+1} - x_{i, j} \right| Alternatively, with `'L2'`, .. math:: \text{TV}(x) = \left( \sum_{c, i, j} (x_{c, i+1, j} - x_{c, i, j})^2 + (x_{c, i, j+1} - x_{c, i, j})^2 \right)^{\frac{1}{2}} Args: x: An input tensor, :math:`(*, C, H, W)`. norm: Specifies the norm funcion to apply: `'L1'` | `'L2'` | `'L2_squared'`. Returns: The TV tensor, :math:`(*,)`. Example: >>> x = torch.rand(5, 3, 256, 256) >>> l = tv(x) >>> l.size() torch.Size([5]) """ w_var = torch.diff(x, dim=-1) h_var = torch.diff(x, dim=-2) if norm == 'L1': w_var = w_var.abs() h_var = h_var.abs() else: # norm in ['L2', 'L2_squared'] w_var = w_var**2 h_var = h_var**2 var = w_var.sum(dim=(-1, -2, -3)) + h_var.sum(dim=(-1, -2, -3)) if norm == 'L2': var = torch.sqrt(var) return var
def implicit_1Dx(phi, xx, nu, gamma, h, beta, dt, use_delj_trick): dx = torch.diff(xx) L = xx.shape[0] dfactor = torch.zeros(L) xInt = torch.zeros(L - 1) V = torch.zeros(L) MInt = torch.zeros(L - 1) VInt = torch.zeros(L - 1) delj = torch.zeros(L - 1) a = torch.zeros(L) c = torch.zeros(L) b = torch.full(L, 1. / dt) Integration_shared.compute_dfactor(dx, L, dfactor) Integration_shared.compute_xInt(xx, L, xInt) Mfirst = Integration_shared.Mfunc1D(xx[0], gamma, h) Mlast = Integration_shared.Mfunc1D(xx[L - 1], gamma, h) V = Integration_shared.Vfunc_beta(xx, nu, beta) MInt = Integration_shared.Mfunc1D(xInt, gamma, h) VInt = Integration_shared.Vfunc_beta(xInt, nu, beta) Integration_shared.compute_delj(dx, MInt, VInt, L, delj, use_delj_trick) Integration_shared.compute_abc_nobc(dx, dfactor, delj, MInt, V, L, a, b, c) r = phi / dt # Boundary conditions if Mfirst <= 0: b[0] += (0.5 / nu - Mfirst) * 2. / dx[0] if (Mlast >= 0): b[L - 1] += -(-0.5 / nu - Mlast) * 2. / dx[L - 2] Integration.tridiag(a, b, c, r, phi, L)
def mask_center(x, mask, output=None): """ Center a batch of segments up to order 1 in the last dimension. If output=True, return a (centered, means, slopes) triplet. """ N = x.shape[-1] # means means = (x * mask).sum([1]) / mask.sum([1]) # boundary values dmask = torch.diff(mask) bndry = (dmask == -1).long() x0, x1 = x[:, 0], (x[:, :-1] * bndry).sum([1]) # slopes t = mask.sum([1]) - 1 v = (x1 - x0) / t vt = v[:, None] * torch.arange(N)[None, :] vm = (x1 - x0) / 2 out = ((x - vt) * mask + x0[:, None] * (1 - mask)) out += vm[:, None] - means[:, None] if not isinstance(output, str): return out if output == 'means': return out, means if output == 'slopes': return out, means, v
def diff(x: torch.Tensor, n: int = 1, dim=0) -> torch.Tensor: """ Find the differences in x. This will return a torch.Tensor of the same shape as `x` with the requisite number of zeros added to the end (`n`). Parameters ---------- x: torch.Tensor input tensor n: int, optional the number of rows between each row for which to calculate the diff dim: int, optional the dimension to find the diff on Returns ------- diffs : torch.Tensor the differences. This result is the same size as `x`. """ y = torch.zeros_like(x) ret = x for _ in range(n): ret = torch.diff(ret, dim=dim) y[: ret.shape[0]] = ret return y
def diff(t, step=1): """ Centered difference operator on (batched) signals. """ dim = t.dim() - 1 dt = torch.diff(t) dt0 = dt.select(dim, 0).unsqueeze(dim) dt1 = dt.select(dim, -1).unsqueeze(dim) return (1 / (2 * step))\ * (torch.cat([dt0, dt], dim) + torch.cat([dt, dt1], dim))
def __call__(self, t): dj_t = t means = [] N = t.shape[0] for j in range(self.order + 1): means += [dj_t.mean()] dj_t = torch.diff(t) * N q = self.basis(self.order, N) delta = sum(mj * q[j] for j, mj in enumerate(means)) return t - delta
def check_xx(xx): """ Check whether xx is monotonically increasing from 0 to 1. """ if not xx[0] == 0 and xx[-1] == 1: raise ValueError('Input xx argument does not run from 0 to 1.' 'Have you passed in an incorrect argument?') if not torch.all(torch.diff(xx) >= 0): raise ValueError('Input xx argument is not monotonically increasing. ' 'Have you passed in an incorrect argument?')
def forward(self, past_traj, past_lens): # Convert to relative dynamics sequence. rel_past_traj = torch.diff(past_traj, dim=0, prepend=past_traj[:1]) # Trajectory Encoding past_traj_enc = self.spatial_emb(rel_past_traj) obs_traj_embedding = nn.utils.rnn.pack_padded_sequence(past_traj_enc, past_lens, enforce_sorted=False) output, states = super(AgentEncoderLSTM, self).forward(obs_traj_embedding) return output, states
def implicit_2Dx(phi, xx, yy, nu1, m12, gamma1, h1, dt, use_delj_trick): L = phi.shape[0] M = phi.shape[1] dx = torch.diff(xx) dfactor = torch.zeros(L) xInt = torch.zeros(L - 1) Integration_shared.compute_dfactor(dx, L, dfactor) Integration_shared.compute_xInt(xx, L, xInt) MInt = torch.zeros(L - 1) V = torch.zeros(L) VInt = torch.zeros(L - 1) delj = torch.zeros(L - 1) for ii in range(0, L): V[ii] = Integration_shared.Vfunc(xx[ii], nu1) for ii in range(0, L-1): VInt[ii] = Integration_shared.Vfunc(xInt[ii], nu1) a = torch.zeros(L) b = torch.zeros(L) c = torch.zeros(L) r = torch.zeros(L) temp = torch.zeros(L) for jj in range(0, M): y = yy[jj] Mfirst = Integration_shared.Mfunc2D(xx[0], y, m12, gamma1, h1) Mlast = Integration_shared.Mfunc2D(xx[L - 1], y, m12, gamma1, h1) for ii in range(0, L-1): MInt[ii] = Integration_shared.Mfunc2D(xInt[ii], y, m12, gamma1, h1) Integration_shared.compute_delj(dx, MInt, VInt, L, delj, use_delj_trick) Integration_shared.compute_abc_nobc(dx, dfactor, delj, MInt, V, dt, L, a, b, c) for ii in range(0, L): r[ii] = phi[ii][jj] / dt if jj == 0 and Mfirst <= 0: b[0] += (0.5 / nu1 - Mfirst) * 2. / dx[0] if jj == M - 1 and Mlast >= 0: b[L - 1] += -(-0.5 / nu1 - Mlast) * 2. / dx[L - 2] Integration.tridiag(a, b, c, r, temp, L) for ii in range(0, L): phi[ii][jj] = temp[ii]
def implicit_2Dy(phi, xx, yy, nu2, m21, gamma2, h2, dt, use_delj_trick): L = phi.shape[0] M = phi.shape[1] dy = torch.diff(yy) dfactor = torch.zeros(M) yInt = torch.zeros(M - 1) Integration_shared.compute_dfactor(dy, M, dfactor) Integration_shared.compute_xInt(yy, M, yInt) MInt = torch.zeros(M - 1) V = torch.zeros(M) VInt = torch.zeros(M - 1) delj = torch.zeros(M - 1) for jj in range(0, M): V[jj] = Integration_shared.Vfunc(yy[jj], nu2) for jj in range(0, M - 1): VInt[jj] = Integration_shared.Vfunc(yInt[jj], nu2) a = torch.zeros(L) b = torch.zeros(L) c = torch.zeros(L) r = torch.zeros(L) for ii in range(0, L): x = xx[ii] Mfirst = Integration_shared.Mfunc2D(yy[0], x, m21, gamma2, h2) Mlast = Integration_shared.Mfunc2D(yy[M-1], x, m21, gamma2, h2) for jj in range(0, M-1): MInt[jj] = Integration_shared.Mfunc2D(yInt[jj], x, m21, gamma2, h2) Integration_shared.compute_delj(dy, MInt, VInt, M, delj, use_delj_trick) Integration_shared.compute_abc_nobc(dy, dfactor, delj, MInt, V, dt, M, a, b, c) for jj in range(0, M): r[jj] = phi[ii][jj] / dt if ii == 0 and Mfirst <= 0: b[0] += (0.5 / nu2 - Mfirst) * 2. / dy[0] if ii == L-1 and Mlast >= 0: b[M-1] += -(-0.5 / nu2 - Mlast) * 2. / dy[M-2] Integration.tridiag(a, b, c, r, phi[ii], M)
def vertical_difference_transpose(x: torch.Tensor): """ Find the column differences in x for a vector, with extra flavor. Parameters ---------- x : torch.Tensor Returns ------- diff: torch.Tensor """ # TODO: examples and characterize output of this function, also more specific up top u0 = (-1.0 * x[0]).unsqueeze(0) u1 = (-1.0 * torch.diff(x, dim=0))[:-1] u2 = (x[-2]).unsqueeze(0) ret = torch.cat([u0, u1, u2], dim=0) return ret
def update(self, y_pred: torch.Tensor, s_c: torch.Tensor): assert y_pred.shape == s_c.shape if self.k is None: order = torch.argsort(input=y_pred, descending=True) else: sequence_length = y_pred.shape[0] if sequence_length < self.k: k = sequence_length else: k = self.k _, order = torch.topk(input=y_pred, k=k, largest=True) senti_score = torch.take(s_c, order) senti_score = torch.combinations(senti_score) senti_score = torch.abs(torch.diff(senti_score, dim=-1)) / 2 senti_score = torch.sum(senti_score) / senti_score.size(0) self.ils_senti += senti_score self.count += 1.0
def diff_dim0_replace_last_row(x: torch.Tensor) -> torch.Tensor: """ Find the single row differences in x and then put the second to last row as the last row in the result Parameters ---------- x : torch.Tensor Returns ------- diff_0_last_row: torch.Tensor the single row differences with the second to last row and the last row """ u0 = (-1.0 * x[0]).unsqueeze(0) u1 = (-1.0 * torch.diff(x, dim=0))[:-1] u2 = (x[-2]).unsqueeze(0) ret = torch.cat([u0, u1, u2], dim=0) return ret
def find_spikes(x, bounds=[-50, 150]): """ Return NaN intervals and a mask vanishing outside NaNs. Returns: -------- spikes: (N, 2) torch.LongTensor mask: torch.BoolTensor """ # spike indices mask = (x < bounds[0]) + (x > bounds[1]) idx = mask.nonzero().flatten() if not len(idx): return torch.tensor([]), mask # get boundaries d_idx = torch.diff(idx) > 1 true = torch.tensor([True]) d1 = torch.cat([true, d_idx]) d0 = torch.cat([d_idx, true]) # spike intervals spikes = torch.stack([idx[d1], idx[d0]]).T return spikes, mask
def note_segments(lf0_score_denorm): """Compute note segments (start and end indices) from log-F0 Note that unvoiced frames must be set to 0 in advance. Args: lf0_score_denorm (Tensor): (B, T) Returns: list: list of note (start, end) indices """ segments = [] for s, e in nonzero_segments(lf0_score_denorm): out = torch.sign(torch.abs(torch.diff(lf0_score_denorm[s:e + 1]))) transitions = torch.where(out > 0)[0] note_start, note_end = s, -1 for pos in transitions: note_end = int(s + pos) segments.append((note_start, note_end)) note_start = note_end return segments
def vertical_difference(x: torch.Tensor, n: int = 1): """ Find the row differences in x. This will return a torch.Tensor of the same shape as `x` with the requisite number of zeros added to the end (`n`). Parameters ---------- x: torch.Tensor input tensor n: int the number of rows between each row for which to calculate the diff Returns ------- torch.Tensor with the column differences. This result is the same size as `x`. """ y = torch.zeros_like(x) ret = x for _ in range(n): ret = torch.diff(ret, dim=0) y[: ret.shape[0]] = ret return y
def cumulative_hazard(self, params, t): M = torch.minimum(self.breakpoints, t) M = torch.diff(M) return (M * params).sum()
def from_init_params( cls, depth: int, w_0: int, w_a: float, w_m: float, group_width: int, bottleneck_multiplier: float = 1.0, se_ratio: Optional[float] = None, **kwargs: Any, ) -> "BlockParams": """ Programatically compute all the per-block settings, given the RegNet parameters. The first step is to compute the quantized linear block parameters, in log space. Key parameters are: - `w_a` is the width progression slope - `w_0` is the initial width - `w_m` is the width stepping in the log space In other terms `log(block_width) = log(w_0) + w_m * block_capacity`, with `bock_capacity` ramping up following the w_0 and w_a params. This block width is finally quantized to multiples of 8. The second step is to compute the parameters per stage, taking into account the skip connection and the final 1x1 convolutions. We use the fact that the output width is constant within a stage. """ QUANT = 8 STRIDE = 2 if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0: raise ValueError("Invalid RegNet settings") # Compute the block widths. Each stage has one unique block width widths_cont = torch.arange(depth) * w_a + w_0 block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m)) block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist() num_stages = len(set(block_widths)) # Convert to per stage parameters split_helper = zip( block_widths + [0], [0] + block_widths, block_widths + [0], [0] + block_widths, ) splits = [w != wp or r != rp for w, wp, r, rp in split_helper] stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t] stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist() strides = [STRIDE] * num_stages bottleneck_multipliers = [bottleneck_multiplier] * num_stages group_widths = [group_width] * num_stages # Adjust the compatibility of stage widths and group widths stage_widths, group_widths = cls._adjust_widths_groups_compatibilty( stage_widths, bottleneck_multipliers, group_widths ) return cls( depths=stage_depths, widths=stage_widths, group_widths=group_widths, bottleneck_multipliers=bottleneck_multipliers, strides=strides, se_ratio=se_ratio, )
def other_ops(self): a = torch.randn(4) b = torch.randn(4) c = torch.randint(0, 8, (5, ), dtype=torch.int64) e = torch.randn(4, 3) f = torch.randn(4, 4, 4) size = [0, 1] dims = [0, 1] return ( torch.atleast_1d(a), torch.atleast_2d(a), torch.atleast_3d(a), torch.bincount(c), torch.block_diag(a), torch.broadcast_tensors(a), torch.broadcast_to(a, (4)), # torch.broadcast_shapes(a), torch.bucketize(a, b), torch.cartesian_prod(a), torch.cdist(e, e), torch.clone(a), torch.combinations(a), torch.corrcoef(a), # torch.cov(a), torch.cross(e, e), torch.cummax(a, 0), torch.cummin(a, 0), torch.cumprod(a, 0), torch.cumsum(a, 0), torch.diag(a), torch.diag_embed(a), torch.diagflat(a), torch.diagonal(e), torch.diff(a), torch.einsum("iii", f), torch.flatten(a), torch.flip(e, dims), torch.fliplr(e), torch.flipud(e), torch.kron(a, b), torch.rot90(e), torch.gcd(c, c), torch.histc(a), torch.histogram(a), torch.meshgrid(a), torch.lcm(c, c), torch.logcumsumexp(a, 0), torch.ravel(a), torch.renorm(e, 1, 0, 5), torch.repeat_interleave(c), torch.roll(a, 1, 0), torch.searchsorted(a, b), torch.tensordot(e, e), torch.trace(e), torch.tril(e), torch.tril_indices(3, 3), torch.triu(e), torch.triu_indices(3, 3), torch.vander(a), torch.view_as_real(torch.randn(4, dtype=torch.cfloat)), torch.view_as_complex(torch.randn(4, 2)), torch.resolve_conj(a), torch.resolve_neg(a), )
def get_timestep(t: Tensor): assert torch.isclose(torch.diff(t), t[1] - t[0]).all(), 'time values are not evenly spaced' return t[1] - t[0]
def diff(cls, t): return t.shape[0] * torch.cat( [torch.diff(t), torch.tensor([t[0] - t[-1]])])
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce): val_dtype, length_dtype = dtypes # zero-length segments are filled with reduction inits contrary to pytorch_scatter. tests = [ { 'src': [1, 2, 3, 4, 5, 6], 'index': [0, 0, 1, 1, 1, 3], 'indptr': [0, 2, 5, 5, 6], 'sum': [3, 12, 0, 6], 'prod': [2, 60, 1, 6], 'mean': [1.5, 4, float('nan'), 6], 'min': [1, 3, float('inf'), 6], 'max': [2, 5, -float('inf'), 6], }, { 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], 'index': [0, 0, 1, 1, 1, 3], 'indptr': [0, 2, 5, 5, 6], 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], 'prod': [[3, 8], [315, 480], [1, 1], [11, 12]], 'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]], 'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]], 'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]], }, { 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], 'prod': [[3, 315, 1, 11], [48, 80, 12, 1]], 'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]], 'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]], 'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]], }, { 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], 'index': [[0, 0, 1], [0, 2, 2]], 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]], 'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]], [[7, 9], [float('nan'), float('nan')], [11, 12]]], 'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]], [[7, 9], [float('inf'), float('inf')], [10, 11]]], 'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]], [[7, 9], [-float('inf'), -float('inf')], [12, 13]]], }, { 'src': [[1, 3], [2, 4]], 'index': [[0, 0], [0, 0]], 'indptr': [[0, 2], [0, 2]], 'sum': [[4], [6]], 'prod': [[3], [8]], 'mean': [[2], [3]], 'min': [[1], [2]], 'max': [[3], [4]], }, { 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], 'index': [[0, 0], [0, 0]], 'indptr': [[0, 2], [0, 2]], 'sum': [[[4, 4]], [[6, 6]]], 'prod': [[[3, 3]], [[8, 8]]], 'mean': [[[2, 2]], [[3, 3]]], 'min': [[[1, 1]], [[2, 2]]], 'max': [[[3, 3]], [[4, 4]]], }, ] for test in tests: data = torch.tensor(test['src'], dtype=val_dtype, device=device, requires_grad=True) indptr = torch.tensor(test['indptr'], dtype=length_dtype, device=device) dim = indptr.ndim - 1 # calculate lengths from indptr lengths = torch.diff(indptr, dim=dim) expected = torch.tensor(test[reduce], dtype=val_dtype, device=device) actual_result = torch.segment_reduce( data=data, reduce=reduce, lengths=lengths, axis=dim, unsafe=True, ) self.assertEqual(actual_result, expected) # test offsets actual_result = torch.segment_reduce( data=data, reduce=reduce, offsets=indptr, axis=dim, unsafe=True, ) self.assertEqual(actual_result, expected) if val_dtype == torch.float64: def fn(x, mode='lengths'): initial = 1 # supply initial values to prevent gradcheck from failing for 0 length segments # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian if reduce == 'min': initial = 1000 elif reduce == 'max': initial = -1000 segment_reduce_args = {x, reduce} segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial) if mode == 'lengths': segment_reduce_kwargs[mode] = lengths elif mode == 'offsets': segment_reduce_kwargs[mode] = indptr return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs) self.assertTrue( gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True)))) self.assertTrue( gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
def quantization_y(x): """ Return Y-quantization, should be below .1 """ dx = torch.diff(x).abs() nz = dx.nonzero().flatten() return dx[nz].min()
def parametrize_slopes(slopes: torch.Tensor) -> torch.Tensor: slopes_parametrized = torch.diff(slopes, dim=-1) slopes_parametrized = torch.cat( [slopes[..., 0:1], slopes_parametrized], dim=-1) return slopes_parametrized
def find_lambda(cvInd, lambda_range): cv_folds = cvInd.unique() K = cv_folds.shape[0] assert K > 1, 'Must have at least two CV folds' W_cv, est_noise_cv, val_noise_cv = [], [], [] def visit(intern_lamb): loss = 0 for cv_iter2 in range(K): estC = estimate_cov(est_noise_cv[cv_iter2], intern_lamb[0], intern_lamb[1], W_cv[cv_iter2]) vncv = val_noise_cv[cv_iter2] valC = torch.mm(vncv.t(), vncv) / vncv.shape[ 0] #sample covariance of validation data loss += fun_norm_loss(estC, valC) if loss.is_complex(): loss = torch.Tensor(inf) return loss # Pre-compute tuning weights and noise values to use in each cross-validation split for cv_iter in range(K): val_trials = cvInd == cv_folds[cv_iter] est_trials = val_trials.logical_not() est_samples = train_samples[est_trials, :] val_samples = train_samples[val_trials, :] this_W_cv, this_est_noise_cv, this_val_noise_cv = estimate_W( est_samples, Ctrain[est_trials, :, 0], False, val_samples, Ctrain[val_trials, :, 0]) W_cv.append(this_W_cv) est_noise_cv.append(this_est_noise_cv) val_noise_cv.append(this_val_noise_cv) # Grid search s = [x.numel() for x in lambda_range] Ngrid = [ min(max(2, ceil(sqrt(x))), x) for x in s ] #Number of values to visit in each dimension (has to be at least 2, except if there is only 1 value for that dimension) grid_vec = [ torch.linspace(0, y - 1, x).int() for x, y in zip(Ngrid, s) ] grid_x, grid_y = torch.meshgrid(grid_vec[0], grid_vec[1]) grid_l1, grid_l2 = torch.meshgrid(lambda_range[0], lambda_range[1]) grid_l1, grid_l2 = grid_l1.flatten(), grid_l2.flatten() sz = s.copy() sz.reverse() print('--GRID SEARCH--') losses = torch.empty(grid_x.numel(), 1) for grid_iter in range(grid_x.numel()): this_lambda = torch.Tensor( (lambda_range[0][grid_x.flatten()[grid_iter]], lambda_range[1][grid_y.flatten()[grid_iter]])) losses[grid_iter] = visit(this_lambda) print( "{:02d}/{:02d} -- lambda_var: {:3.2f}, lambda: {:3.2f}, loss: {:g}" .format(grid_iter, grid_x.numel(), *this_lambda, losses[grid_iter].item())) visited = sub2ind(sz, grid_y.flatten().tolist(), grid_x.flatten().tolist()) best_loss, best_idx = losses.min(0) best_idx = visited[best_idx] best_lambda_gridsearch = (grid_l1[best_idx], grid_l2[best_idx]) print( 'Best lambda setting from grid search: lambda_var = {:3.2f}, lambda = {:3.2f}, loss = {:g}' .format(*best_lambda_gridsearch, best_loss.item())) # Pattern search print('--PATTERN SEARCH--') step_size = int( 2**floor(log2(torch.diff(grid_y[0][0:2]) / 2)) ) #Round down to the nearest power of 2 (so we can keep dividing the step size in half while True: best_y, best_x = ind2sub(sz, best_idx) new_x = best_x + torch.Tensor((-1, 1, -1, 1)).int() * step_size new_y = best_y + torch.Tensor((-1, -1, 1, 1)).int() * step_size del_idx = torch.logical_or( torch.logical_or(new_x < 0, new_x >= lambda_range[0].numel()), torch.logical_or(new_y < 0, new_y >= lambda_range[1].numel())) new_x = new_x[del_idx.logical_not()] new_y = new_y[del_idx.logical_not()] new_idx = sub2ind(sz, new_y.tolist(), new_x.tolist()) not_visited = [x not in visited for x in new_idx] new_idx = [i for (i, v) in zip(new_idx, not_visited) if v] if len(new_idx) > 0: this_losses = torch.empty(len(new_idx)) for ii in range(len(new_idx)): this_lambda = torch.Tensor( (grid_l1[new_idx[ii]], grid_l2[new_idx[ii]])) this_losses[ii] = visit(this_lambda) print( "Step size: {:d}, lambda_var: {:3.2f}, lambda: {:3.2f}, loss: {:g}" .format(step_size, *this_lambda, this_losses[ii].item())) visited.extend(new_idx) # visited = torch.cat((visited, torch.tensor(new_idx)),0) losses = torch.cat((losses, this_losses.unsqueeze(-1)), 0) if (this_losses < best_loss).any(): best_loss, best_idx = losses.min(0) best_idx = visited[best_idx] elif step_size > 1: step_size = int(step_size / 2) else: break best_lambda = torch.Tensor((grid_l1[best_idx], grid_l2[best_idx])) print( "Best setting found: lambda_var = {:3.2f}, lambda = {:3.2f}, loss: {:g}" .format(*best_lambda, best_loss.item())) return best_lambda
def cwt( data: torch.Tensor, scales: Union[np.ndarray, torch.Tensor], # type: ignore wavelet: Union[ContinuousWavelet, str], sampling_period: float = 1.0, ) -> Tuple[torch.Tensor, np.ndarray]: # type: ignore """Compute the single dimensional continuous wavelet transform. This function is a PyTorch port of pywt.cwt as found at: https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py Args: data (torch.Tensor): The input tensor of shape [batch_size, time]. scales (torch.Tensor or np.array): The wavelet scales to use. One can use ``f = pywt.scale2frequency(wavelet, scale)/sampling_period`` to determine what physical frequency, ``f``. Here, ``f`` is in hertz when the ``sampling_period`` is given in seconds. wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with. wavelet (ContinuousWavelet or str): The continuous wavelet to work with. sampling_period (float): Sampling period for the frequencies output (optional). The values computed for ``coefs`` are independent of the choice of ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling period). Raises: ValueError: If a scale is too small for the input signal. Returns: Tuple[torch.Tensor, np.ndarray]: A tuple with the transformation matrix and frequencies in this order. """ # accept array_like input; make a copy to ensure a contiguous array if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) if type(scales) is torch.Tensor: scales = scales.numpy() elif np.isscalar(scales): scales = np.array([scales]) # if not np.isscalar(axis): # raise np.AxisError("axis must be a scalar.") precision = 10 int_psi, x = integrate_wavelet(wavelet, precision=precision) if type(wavelet) is ContinuousWavelet: int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi int_psi = torch.tensor(int_psi, device=data.device) # convert int_psi, x to the same precision as the data x = np.asarray(x, dtype=data.cpu().numpy().real.dtype) size_scale0 = -1 fft_data = None out = [] for scale in scales: step = x[1] - x[0] j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) j = j.astype(int) # floor if j[-1] >= len(int_psi): j = np.extract(j < len(int_psi), j) int_psi_scale = int_psi[j].flip(0) # The padding is selected for: # - optimal FFT complexity # - to be larger than the two signals length to avoid circular # convolution size_scale = _next_fast_len(data.shape[-1] + len(int_psi_scale) - 1) if size_scale != size_scale0: # Must recompute fft_data when the padding size changes. fft_data = fft(data, size_scale, dim=-1) size_scale0 = size_scale fft_wav = fft(int_psi_scale, size_scale, dim=-1) conv = ifft(fft_wav * fft_data, dim=-1) conv = conv[..., :data.shape[-1] + len(int_psi_scale) - 1] coef = -np.sqrt(scale) * torch.diff(conv, dim=-1) # transform axis is always -1 d = (coef.shape[-1] - data.shape[-1]) / 2.0 if d > 0: coef = coef[..., int(np.floor(d)):-int(np.ceil(d))] elif d < 0: raise ValueError("Selected scale of {} too small.".format(scale)) out.append(coef) out_tensor = torch.stack(out) if type(wavelet) is Wavelet: out_tensor = out_tensor.real else: out_tensor = out_tensor if wavelet.complex_cwt else out_tensor.real frequencies = scale2frequency(wavelet, scales, precision) if np.isscalar(frequencies): frequencies = np.array([frequencies]) frequencies /= sampling_period return out_tensor, frequencies