def impl(data, wavelet, mode, axis): coeff_len = dwt_coeff_length( data.shape[axis], len( wavelet.dec_hi), mode) out_shape = tuple_setitem(data.shape, axis, coeff_len) if axis < 0 or axis >= data.ndim: raise ValueError("0 <= axis < data.ndim failed") ca = np.empty(out_shape, dtype=data.dtype) cd = np.empty(out_shape, dtype=data.dtype) # Iterate over all points except along the slicing axis for idx in np.ndindex(*tuple_setitem(data.shape, axis, 1)): initial_in_row = slice_axis(data, idx, axis, None) initial_ca_row = slice_axis(ca, idx, axis, None) initial_cd_row = slice_axis(cd, idx, axis, None) # The numba array type returned by slice_axis assumes # non-contiguity in the general case. # However, the slice may actually be contiguous in layout # If so, cast the array type to obtain type contiguity # else, copy the slice to obtain contiguity in # both type and layout if initial_in_row.flags.c_contiguous: in_row = force_type_contiguity(initial_in_row) else: in_row = initial_in_row.copy() if initial_ca_row.flags.c_contiguous: ca_row = force_type_contiguity(initial_ca_row) else: ca_row = initial_ca_row.copy() if initial_cd_row.flags.c_contiguous: cd_row = force_type_contiguity(initial_cd_row) else: cd_row = initial_cd_row.copy() # Compute the approximation and detail coefficients downsampling_convolution(in_row, ca_row, wavelet.dec_lo, mode, 2) downsampling_convolution(in_row, cd_row, wavelet.dec_hi, mode, 2) # If necessary, copy back into the output if not initial_ca_row.flags.c_contiguous: initial_ca_row[:] = ca_row[:] if not initial_cd_row.flags.c_contiguous: initial_cd_row[:] = cd_row[:] return ca, cd
def fn(A): for axis in range(A.ndim): for i in np.ndindex(*tuple_setitem(A.shape, axis, 1)): S = slice_axis(A, i, axis, None) if S.flags.c_contiguous != (S.itemsize == S.strides[0]): raise ValueError("contiguity flag doesn't match layout")
def impl(approx_coeffs, detail_coeffs, wavelet, mode, axis): if have_approx and have_detail: coeff_shape = approx_coeffs.shape it = enumerate(zip(approx_coeffs.shape, detail_coeffs.shape)) # NOTE(sjperkins) # Clip the coefficient dimensions to the smallest dimensions # pywt clips in waverecn and fails in idwt and idwt_axis # on heterogenous coefficient shapes. # The actual clipping is performed in slice_axis for i, (asize, dsize) in it: size = asize if asize < dsize else dsize coeff_shape = tuple_setitem(coeff_shape, i, size) elif have_approx: coeff_shape = approx_coeffs.shape elif have_detail: coeff_shape = detail_coeffs.shape else: raise ValueError("Either approximation or detail must be present") if not (0 <= axis < len(coeff_shape)): raise ValueError(("0 <= axis < coeff.ndim does not hold")) idwt_len = idwt_buffer_length(coeff_shape[axis], wavelet.rec_lo.shape[0], mode) out_shape = tuple_setitem(coeff_shape, axis, idwt_len) output = np.empty(out_shape, dtype=out_dtype) # Iterate over all points except along the slicing axis for idx in np.ndindex(*tuple_setitem(output.shape, axis, 1)): initial_out_row = slice_axis(output, idx, axis, None) # Zero if we have a contiguous slice, else allocate if initial_out_row.flags.c_contiguous: out_row = force_type_contiguity(initial_out_row) out_row[:] = 0 else: out_row = np.zeros_like(initial_out_row) # Apply approximation coefficients if they exist if approx_coeffs is not None: initial_ca_row = slice_axis(approx_coeffs, idx, axis, coeff_shape[axis]) if initial_ca_row.flags.c_contiguous: ca_row = force_type_contiguity(initial_ca_row) else: ca_row = initial_ca_row.copy() upsampling_convolution_valid_sf(ca_row, wavelet.rec_lo, out_row, mode) # Apply detail coefficients if they exist if detail_coeffs is not None: initial_cd_row = slice_axis(detail_coeffs, idx, axis, coeff_shape[axis]) if initial_cd_row.flags.c_contiguous: cd_row = force_type_contiguity(initial_cd_row) else: cd_row = initial_cd_row.copy() upsampling_convolution_valid_sf(cd_row, wavelet.rec_hi, out_row, mode) # Copy back output row if the output space was non-contiguous if not initial_out_row.flags.c_contiguous: initial_out_row[:] = out_row return output
def fn(a, index, axis=1, extent=None): return slice_axis(a, index, axis, extent)