def test_get_subslice(): a = slice(10, 100) b = slice(20, 30) assert_array_equal(get_sub_slice(a, b), np.arange(30, 40)) assert_array_equal(get_sub_slice(None, b), np.arange(20, 30)) a = np.arange(10, 100) assert_array_equal(get_sub_slice(a, b), np.arange(30, 40))
def partial_fit(self, X, sample_indices=None): """ Update the factorization using rows from X Parameters ---------- X: ndarray, shape (n_samples, n_features) Input data sample_indices: Indices for each row of X. If None, consider that row i index is i (useful when providing the whole data to the function) Returns ------- self """ X = check_array(X, dtype=[np.float32, np.float64], order='C') n_samples, n_features = X.shape batches = gen_batches(n_samples, self.batch_size) for batch in batches: this_X = X[batch] these_sample_indices = get_sub_slice(sample_indices, batch) self._single_batch_fit(this_X, these_sample_indices) return self
def transform(self, X): """ Compute the codes associated to input matrix X, decomposing it onto the dictionary Parameters ---------- X: ndarray, shape = (n_samples, n_features) Returns ------- code: ndarray, shape = (n_samples, n_components) """ check_is_fitted(self, 'components_') dtype = self.components_.dtype X = check_array(X, order='C', dtype=dtype.type) if X.flags['WRITEABLE'] is False: X = X.copy() n_samples, n_features = X.shape if not hasattr(self, 'G_agg') or self.G_agg != 'full': G = self.components_.dot(self.components_.T) else: G = self.G_ Dx = X.dot(self.components_.T) code = np.ones((n_samples, self.n_components), dtype=dtype) sample_indices = np.arange(n_samples) size_job = ceil(n_samples / self.n_threads) batches = list(gen_batches(n_samples, size_job)) par_func = lambda batch: _enet_regression_single_gram( G, Dx[batch], X[batch], code, get_sub_slice( sample_indices, batch), self.code_l1_ratio, self.code_alpha, self.code_pos, self.tol, self.max_iter) if self.n_threads > 1: res = self._pool.map(par_func, batches) _ = list(res) else: _enet_regression_single_gram(G, Dx, X, code, sample_indices, self.code_l1_ratio, self.code_alpha, self.code_pos, self.tol, self.max_iter) return code
def _compute_code(self, X, sample_indices, w_sample, subset): """Update regression statistics if necessary and compute code from X[:, subset]""" batch_size, n_features = X.shape reduction = self.reduction if self.n_threads > 1: size_job = ceil(batch_size / self.n_threads) batches = list(gen_batches(batch_size, size_job)) if self.Dx_agg != 'full' or self.G_agg != 'full': components_subset = self.components_[:, subset] if self.Dx_agg == 'full': Dx = X.dot(self.components_.T) else: X_subset = X[:, subset] Dx = X_subset.dot(components_subset.T) * reduction if self.Dx_agg == 'average': self.Dx_average_[sample_indices] \ *= 1 - w_sample[:, np.newaxis] self.Dx_average_[sample_indices] \ += Dx * w_sample[:, np.newaxis] Dx = self.Dx_average_[sample_indices] if self.G_agg != 'full': G = components_subset.dot(components_subset.T) * reduction if self.G_agg == 'average': G_average = np.array(self.G_average_[sample_indices], copy=True) if self.n_threads > 1: par_func = lambda batch: _update_G_average( G_average[batch], G, w_sample[batch], ) res = self._pool.map(par_func, batches) _ = list(res) else: _update_G_average(G_average, G, w_sample) self.G_average_[sample_indices] = G_average else: G = self.G_ if self.n_threads > 1: if self.G_agg == 'average': par_func = lambda batch: _enet_regression_multi_gram( G_average[batch], Dx[batch], X[batch], self.code_, get_sub_slice(sample_indices, batch), self.code_l1_ratio, self.code_alpha, self.code_pos, self.tol, self.max_iter) else: par_func = lambda batch: _enet_regression_single_gram( G, Dx[batch], X[batch], self.code_, get_sub_slice(sample_indices, batch), self.code_l1_ratio, self.code_alpha, self.code_pos, self.tol, self.max_iter) res = self._pool.map(par_func, batches) _ = list(res) else: if self.G_agg == 'average': _enet_regression_multi_gram(G_average, Dx, X, self.code_, sample_indices, self.code_l1_ratio, self.code_alpha, self.code_pos, self.tol, self.max_iter) else: _enet_regression_single_gram(G, Dx, X, self.code_, sample_indices, self.code_l1_ratio, self.code_alpha, self.code_pos, self.tol, self.max_iter)