class FastWordCentroidRetrieval(BaseEstimator, RetriEvalMixin): """Docstring for FastWordCentrodRetrieval. """ def __init__(self, embedding, analyzer='word', matching=None, name="FWCD", n_jobs=1, use_idf=True): """TODO: to be defined1. """ self.name = name self.matching = Matching(**dict(matching)) if matching else None self.vect = EmbeddedVectorizer(embedding, analyzer=analyzer, norm='l2', use_idf=use_idf) self.nn = NearestNeighbors(n_jobs=n_jobs, metric='cosine', algorithm='brute') def fit(self, X_raw, y=None): cents = self.vect.fit_transform(X_raw) # print("Largest singular value: {:.2f}".format( # np.linalg.norm(cents, ord=2))) # cents = all_but_the_top(cents, 1) # print("Largest singular value: {:.2f}".format( # np.linalg.norm(cents, ord=2))) # print("Renormalizing") # normalize(cents, copy=False) # print("Largest singular value: {:.2f}".format( # np.linalg.norm(cents, ord=2))) self.centroids = cents print(' FIT centroids shape', self.centroids.shape) self._y = y if self.matching: self.matching.fit(X_raw) else: self.nn.fit(cents) def query(self, query, k=None, indices=None): centroids = self.centroids if k is None: k = centroids.shape[0] q_centroid = self.vect.transform([query]) if self.matching: ind = self.matching.predict(query) centroids, labels = centroids[ind], self._y[ind] n_ret = min(k, centroids.shape[0]) if n_ret == 0: return [] self.nn.fit(centroids) elif indices: centroids, labels = centroids[ind], self._y[ind] n_ret = min(k, centroids.shape[0]) if n_ret == 0: return [] self.nn.fit(centroids) else: labels = self._y n_ret = k ind = self.nn.kneighbors(q_centroid, n_neighbors=n_ret, return_distance=False)[0] return labels[ind]
class WordCentroidRetrieval(BaseEstimator, RetriEvalMixin): """ Retrieval Model based on Word Centroid Distance """ def __init__(self, embedding, analyzer, name="WCD", n_jobs=1, normalize=True, verbose=0, oov=None, matching=True, **kwargs): self.name = name self._embedding = embedding self._normalize = normalize self._oov = oov self.verbose = verbose self.n_jobs = n_jobs self._neighbors = NearestNeighbors(**kwargs) self._analyzer = analyzer if matching is True: self._matching = Matching() elif matching is False or matching is None: self._matching = None else: self._matching = Matching(**dict(matching)) def _compute_centroid(self, words): if len(words) == 0: # no words left at all? could also return zeros return self._embedding[self._oov] E = self._embedding embedded_words = np.vstack([E[word] for word in words]) centroid = np.mean(embedded_words, axis=0).reshape(1, -1) return centroid def fit(self, docs, labels): E, analyze = self._embedding, self._analyzer analyzed_docs = (analyze(doc) for doc in docs) # out of vocabulary words do not have to contribute to the centroid filtered_docs = (filter_vocab(E, d, self._oov) for d in analyzed_docs) centroids = np.vstack([self._compute_centroid(doc) for doc in filtered_docs]) # can we generate? if self.verbose > 0: print("Centroids shape:", centroids.shape) if self._normalize: normalize(centroids, norm='l2', copy=False) self._y = np.asarray(labels) if self._matching: self._matching.fit(docs) self._centroids = centroids else: # if we dont do matching, its enough to fit a nearest neighbors on # all centroids before query time self._neighbors.fit(centroids) return self def query(self, query, k=None, return_distance=False): if k is None: k = len(self._centroids) E, analyze, nn = self._embedding, self._analyzer, self._neighbors tokens = analyze(query) words = filter_vocab(E, tokens, self._oov) query_centroid = self._compute_centroid(words) if self._normalize: query_centroid = normalize(query_centroid, norm='l2', copy=False) if self.verbose > 0: print("Analyzed query", words) # print("Centered (normalized) query shape", query_centroid.shape) if self._matching: matched = self._matching.predict(query) centroids, labels = self._centroids[matched], self._y[matched] if len(centroids) == 0: return [] # nothing to fit here nn.fit(centroids) # k `leq` n_matched n_ret = min(k, len(matched)) else: labels = self._y n_ret = k # either fit nn on the fly or precomputed in own fit method dist, ind = nn.kneighbors(query_centroid, n_neighbors=n_ret, return_distance=True) dist, ind = dist[0], ind[0] # we only had one query in the first place if return_distance: return labels[ind], dist else: return labels[ind]