def ifft(self, fu, u): """Inverse Fourier transforms in y and z""" Uc_mpi = self.work_arrays[((self.num_processes, self.Np[0], self.Np[1], self.Nf), self.complex, 0)] Uc_hatT = self.work_arrays[(self.complex_shape_T(), self.complex, 0)] self.comm.Alltoall([fu, self.mpitype], [Uc_mpi, self.mpitype]) Uc_hatT[:] = rollaxis(Uc_mpi, 1).reshape(self.complex_shape_T()) u = irfft2(Uc_hatT, u, axes=(1, 2), threads=self.threads, planner_effort=self.planner_effort['irfft2']) return u
def backward(self, fu, u, fun, dealias=None): Uc_hat = self.work_arrays[(self.complex_shape(), self.complex, 0, False)] Uc_mpi = Uc_hat.reshape( (self.num_processes, self.Np[0], self.Np[1], self.Nf)) fun = fun.backward if dealias == '2/3-rule' and self.dealias.shape == (0, ): self.dealias = self.get_dealias_filter() if self.num_processes == 1: if not dealias == '3/2-rule': fup = fu if dealias == '2/3-rule': fup = self.work_arrays[(fu, 1, False)] fup[:] = fu fup *= self.dealias Uc_hat = fun(fup, Uc_hat) u = irfft2(Uc_hat, u, axes=(1, 2), overwrite_input=True, threads=self.threads, planner_effort=self.planner_effort['irfft2']) else: if not self.dealias_cheb: Upad_hat = self.work_arrays[(self.complex_shape_padded(), self.complex, 0)] Upad_hat_z = self.work_arrays[((self.Np[0], int(self.padsize * self.N[1]), self.Nf), self.complex, 0)] Uc_hat = fun(fu * self.padsize**2, Uc_hat) Upad_hat_z = SlabShen_R2C.copy_to_padded( Uc_hat, Upad_hat_z, self.N, 1) Upad_hat_z[:] = ifft( Upad_hat_z, axis=1, threads=self.threads, planner_effort=self.planner_effort['ifft']) Upad_hat = SlabShen_R2C.copy_to_padded( Upad_hat_z, Upad_hat, self.N, 2) u = irfft(Upad_hat, u, axis=2, overwrite_input=True, threads=self.threads, planner_effort=self.planner_effort['irfft']) else: # Intermediate work arrays required for transform Upad_hat = self.work_arrays[(self.complex_shape_padded_0(), self.complex, 0, False)] Upad_hat0 = self.work_arrays[( self.complex_shape_padded_0(), self.complex, 1)] Upad_hat1 = self.work_arrays[( self.complex_shape_padded_1(), self.complex, 0, False)] Upad_hat2 = self.work_arrays[( self.complex_shape_padded_2(), self.complex, 0)] Upad_hat3 = self.work_arrays[( self.complex_shape_padded_3(), self.complex, 0)] # Expand in x-direction and perform ifst Upad_hat0 = SlabShen_R2C.copy_to_padded( fu * self.padsize**2, Upad_hat0, self.N, 0) Upad_hat = fun(Upad_hat0, Upad_hat) Upad_hat2 = SlabShen_R2C.copy_to_padded( Upad_hat, Upad_hat2, self.N, 1) Upad_hat2[:] = ifft( Upad_hat2, axis=1, threads=self.threads, planner_effort=self.planner_effort['ifft']) # pad in z-direction and perform final irfft Upad_hat3 = SlabShen_R2C.copy_to_padded( Upad_hat2, Upad_hat3, self.N, 2) u = irfft(Upad_hat3, u, axis=2, overwrite_input=True, threads=self.threads, planner_effort=self.planner_effort['irfft']) return u if not dealias == '3/2-rule': Uc_hatT = self.work_arrays[(self.complex_shape_T(), self.complex, 0, False)] if dealias == '2/3-rule': fu *= self.dealias Uc_hat = fun(fu, Uc_hat) if self.communication == 'Alltoall': self.comm.Alltoall(MPI.IN_PLACE, [Uc_hat, self.mpitype]) Uc_hatT[:] = rollaxis(Uc_mpi, 1).reshape(self.complex_shape_T()) #Uc_mpi = self.work_arrays[((self.num_processes, self.Np[0], self.Np[1], self.Nf), self.complex, 0, False)] #self.comm.Alltoall([Uc_hat, self.mpitype], [Uc_mpi, self.mpitype]) #Uc_hatT = rollaxis(Uc_mpi, 1).reshape(self.complex_shape_T()) elif self.communication == 'Alltoallw': if not self._subarraysA: self._subarraysA, self._subarraysB, self._counts_displs = self.get_subarrays( ) Uc_hatT = self.work_arrays[(self.complex_shape_T(), self.complex, 0, False)] self.comm.Alltoallw( [Uc_hat, self._counts_displs, self._subarraysA], [Uc_hatT, self._counts_displs, self._subarraysB]) u = irfft2(Uc_hatT, u, axes=(1, 2), overwrite_input=True, threads=self.threads, planner_effort=self.planner_effort['irfft2']) else: Uc_hatT = self.work_arrays[(self.complex_shape_T(), self.complex, 0, False)] if not self.dealias_cheb: Upad_hatT = self.work_arrays[(self.complex_shape_padded_T(), self.complex, 0)] Upad_hat_z = self.work_arrays[((self.Np[0], int(self.padsize * self.N[1]), self.Nf), self.complex, 0)] Uc_hat = fun(fu * self.padsize**2, Uc_hat) if self.communication == 'Alltoall': # In-place #self.comm.Alltoall(MPI.IN_PLACE, [Uc_hat, self.mpitype]) # Not in-place Uc_mpi = self.work_arrays[((self.num_processes, self.Np[0], self.Np[1], self.Nf), self.complex, 0, False)] self.comm.Alltoall([Uc_hat, self.mpitype], [Uc_mpi, self.mpitype]) Uc_hatT[:] = rollaxis(Uc_mpi, 1).reshape(self.complex_shape_T()) elif self.communication == 'Alltoallw': if not self._subarraysA: self._subarraysA, self._subarraysB, self._counts_displs = self.get_subarrays( ) self.comm.Alltoallw( [Uc_hat, self._counts_displs, self._subarraysA], [Uc_hatT, self._counts_displs, self._subarraysB]) Upad_hat_z = SlabShen_R2C.copy_to_padded( Uc_hatT, Upad_hat_z, self.N, 1) Upad_hat_z[:] = ifft( Upad_hat_z, axis=1, threads=self.threads, planner_effort=self.planner_effort['ifft']) Upad_hatT = SlabShen_R2C.copy_to_padded( Upad_hat_z, Upad_hatT, self.N, 2) u = irfft(Upad_hatT, u, axis=2, overwrite_input=True, threads=self.threads, planner_effort=self.planner_effort['irfft']) else: assert self.num_processes <= self.N[ 0] / 2, "Number of processors cannot be larger than N[0]/2 for 3/2-rule" # Intermediate work arrays required for transform Upad_hat = self.work_arrays[(self.complex_shape_padded_0(), self.complex, 0, False)] Upad_hat0 = self.work_arrays[(self.complex_shape_padded_0(), self.complex, 1)] Upad_hat1 = self.work_arrays[(self.complex_shape_padded_1(), self.complex, 0, False)] Upad_hat2 = self.work_arrays[(self.complex_shape_padded_2(), self.complex, 0)] Upad_hat3 = self.work_arrays[(self.complex_shape_padded_3(), self.complex, 0)] # Expand in x-direction and perform ifst Upad_hat0 = SlabShen_R2C.copy_to_padded( fu * self.padsize**2, Upad_hat0, self.N, 0) Upad_hat = fun(Upad_hat0, Upad_hat) if self.communication == 'Alltoall': # Communicate to distribute first dimension (like Fig. 2b but padded in x-dir) self.comm.Alltoall(MPI.IN_PLACE, [Upad_hat, self.mpitype]) # Transpose data and pad in y-direction before doing ifft. Now data is padded in x and y U_mpi = Upad_hat.reshape(self.complex_shape_padded_0_I()) Upad_hat1[:] = rollaxis(U_mpi, 1).reshape(Upad_hat1.shape) elif self.communication == 'Alltoallw': if not self._subarraysA_pad: self._subarraysA_pad, self._subarraysB_pad, self._counts_displs = self.get_subarrays( padsize=self.padsize) self.comm.Alltoallw( [Upad_hat, self._counts_displs, self._subarraysA_pad], [Upad_hat1, self._counts_displs, self._subarraysB_pad]) Upad_hat2 = SlabShen_R2C.copy_to_padded( Upad_hat1, Upad_hat2, self.N, 1) Upad_hat2[:] = ifft(Upad_hat2, axis=1, threads=self.threads, planner_effort=self.planner_effort['ifft']) # pad in z-direction and perform final irfft Upad_hat3 = SlabShen_R2C.copy_to_padded( Upad_hat2, Upad_hat3, self.N, 2) u = irfft(Upad_hat3, u, axis=2, overwrite_input=True, threads=self.threads, planner_effort=self.planner_effort['irfft']) return u