def global_transform(X_batch, pad_data, encoder_, n_atoms, nonneg): X_batch = X_batch.swapaxes(0, 1).swapaxes(1, 2) h, w, n = X_batch.shape # Fourier transform X_new = _cdl.patches_to_vectors(X_batch, pad_data=pad_data) # Encode X_new = encoder_(X_new) X_new = _cdl.real2_to_complex(X_new) X_new = X_new.reshape((-1, X_new.shape[1] * n_atoms), order='F') X_new = _cdl.complex_to_real2(X_new) X_new = _cdl.vectors_to_patches(X_new, w, pad_data=pad_data, real=True) X_new = X_new.reshape((X_new.shape[0], X_new.shape[1], n_atoms, n), order='F') X_new = X_new.swapaxes(3, 0).swapaxes(3, 2).swapaxes(3, 1) if nonneg: X_new = np.maximum(X_new, 0.0) return X_new
def set_codebook(self, D): '''Clobber the existing codebook and encoder with a new one.''' self.components_ = D extra_args = {} if self.penalty == 'l1_space': extra_args['height'] = D.shape[1] extra_args['width'] = D.shape[2] extra_args['pad_data'] = self.pad_data extra_args['nonneg'] = self.nonneg D = D.swapaxes(0, 1).swapaxes(1, 2) D = _cdl.patches_to_vectors(D, pad_data=self.pad_data) D = _cdl.columns_to_diags(D) self.fft_components_ = D encoder, D, diagnostics = _cdl.learn_dictionary([], self.n_atoms, reg=self.penalty, alpha=self.alpha, max_steps=0, verbose=False, D=D, **extra_args) self.encoder_ = encoder return self
def global_transform(X_batch, pad_data, encoder_, n_atoms, nonneg): X_batch = X_batch.swapaxes(0, 1).swapaxes(1, 2) h, w, n = X_batch.shape # Fourier transform X_new = _cdl.patches_to_vectors(X_batch, pad_data=pad_data) # Encode X_new = encoder_(X_new) X_new = _cdl.real2_to_complex(X_new) X_new = X_new.reshape( (-1, X_new.shape[1] * n_atoms), order='F') X_new = _cdl.complex_to_real2(X_new) X_new = _cdl.vectors_to_patches(X_new, w, pad_data=pad_data, real=True) X_new = X_new.reshape( (X_new.shape[0], X_new.shape[1], n_atoms, n), order='F') X_new = X_new.swapaxes(3, 0).swapaxes(3, 2).swapaxes(3, 1) if nonneg: X_new = np.maximum(X_new, 0.0) return X_new
def set_codebook(self, D): '''Clobber the existing codebook and encoder with a new one.''' self.components_ = D extra_args = {} if self.penalty == 'l1_space': extra_args['height'] = D.shape[1] extra_args['width'] = D.shape[2] extra_args['pad_data'] = self.pad_data extra_args['nonneg'] = self.nonneg D = D.swapaxes(0, 1).swapaxes(1, 2) D = _cdl.patches_to_vectors(D, pad_data=self.pad_data) D = _cdl.columns_to_diags(D) self.fft_components_ = D encoder, D, diagnostics = _cdl.learn_dictionary( [], self.n_atoms, reg = self.penalty, alpha = self.alpha, max_steps = 0, verbose = False, D = D, **extra_args) self.encoder_ = encoder return self
def data_generator(self, X_full): '''Make a CDL data generator from an input array Arguments --------- X_full -- (ndarray) n-by-h-by-w data array Returns ------- batch_gen -- (generator) batches of CDL-transformed input data the generator will loop infinitely ''' # 1. initialize the RNG if type(self.random_state) is int: random.seed(self.random_state) elif self.random_state is not None: random.setstate(self.random_state) n = X_full.shape[0] indices = range(n) while True: if self.shuffle: random.shuffle(indices) for i in range(0, n, self.chunk_size): if i + self.chunk_size > n: break X = X_full[i:i+self.chunk_size] # Swap the axes around X = X.swapaxes(0, 1).swapaxes(1, 2) # X is now h-*-w-by-n yield _cdl.patches_to_vectors(X, pad_data=self.pad_data)
def data_generator(self, X_full): '''Make a CDL data generator from an input array Arguments --------- X_full -- (ndarray) n-by-h-by-w data array Returns ------- batch_gen -- (generator) batches of CDL-transformed input data the generator will loop infinitely ''' # 1. initialize the RNG if type(self.random_state) is int: random.seed(self.random_state) elif self.random_state is not None: random.setstate(self.random_state) n = X_full.shape[0] indices = range(n) while True: if self.shuffle: random.shuffle(indices) for i in range(0, n, self.chunk_size): if i + self.chunk_size > n: break X = X_full[i:i + self.chunk_size] # Swap the axes around X = X.swapaxes(0, 1).swapaxes(1, 2) # X is now h-*-w-by-n yield _cdl.patches_to_vectors(X, pad_data=self.pad_data)