def _correlate(in1, in2, mode='full', method='auto', convolution=False): quick_out = _st_core._check_conv_inputs(in1, in2, mode, convolution) if quick_out is not None: return quick_out if method not in ('auto', 'direct', 'fft'): raise ValueError('acceptable methods are "auto", "direct", or "fft"') if method == 'auto': method = choose_conv_method(in1, in2, mode=mode) if method == 'direct': return _st_core._direct_correlate(in1, in2, mode, in1.dtype, convolution) # if method == 'fft': inputs_swapped = _st_core._inputs_swap_needed(mode, in1.shape, in2.shape) if inputs_swapped: in1, in2 = in2, in1 if not convolution: in2 = _st_core._reverse_and_conj(in2) out = fftconvolve(in1, in2, mode) result_type = cupy.result_type(in1, in2) if result_type.kind in 'ui': out = out.round() out = out.astype(result_type, copy=False) if not convolution and inputs_swapped: out = cupy.ascontiguousarray(_st_core._reverse_and_conj(out)) return out
def _correlate(in1, in2, mode='full', method='auto', convolution=False): quick_out = _st_core._check_conv_inputs(in1, in2, mode, convolution) if quick_out is not None: return quick_out if method not in ('auto', 'direct', 'fft'): raise ValueError("acceptable methods are 'auto', 'direct', or 'fft'") if method == 'auto': method = choose_conv_method(in1, in2, mode=mode) if method == 'direct': return _st_core._direct_correlate(in1, in2, mode, in1.dtype, convolution) # if method == 'fft': raise ValueError('fftconvolve currently not supported')
def _correlate2d(in1, in2, mode, boundary, fillvalue, convolution=False): if not (in1.ndim == in2.ndim == 2): raise ValueError('{} inputs must both be 2-D arrays'.format( 'convolve2d' if convolution else 'correlate2d')) _boundaries = { 'fill': 'constant', 'pad': 'constant', 'wrap': 'wrap', 'circular': 'wrap', 'symm': 'reflect', 'symmetric': 'reflect', } boundary = _boundaries.get(boundary) if boundary is None: raise ValueError('Acceptable boundary flags are "fill" (or "pad"), ' '"circular" (or "wrap"), and ' '"symmetric" (or "symm").') quick_out = _st_core._check_conv_inputs(in1, in2, mode, convolution) if quick_out is not None: return quick_out return _st_core._direct_correlate(in1, in2, mode, in1.dtype, convolution, boundary, fillvalue, not convolution)