def test_with_cache(self): cache = numpy.zeros(4, dtype=numpy.int32) w1, w2, dist = self._get_w1_w2_dist() r = libwmdrelax.emd_relaxed(w1, w2, dist, cache) self.assertAlmostEqual(r, 0.3945127) r = libwmdrelax.emd_relaxed(w1, w2, dist, cache=cache) self.assertAlmostEqual(r, 0.3945127)
def test_with_cache(self): cache = libwmdrelax.emd_relaxed_cache_init(4) w1, w2, dist = self._get_w1_w2_dist() r = libwmdrelax.emd_relaxed(w1, w2, dist, cache) self.assertAlmostEqual(r, 0.6125112) r = libwmdrelax.emd_relaxed(w1, w2, dist, cache=cache) self.assertAlmostEqual(r, 0.6125112) libwmdrelax.emd_relaxed_cache_fini(cache)
def _estimate_WMD_relaxation_batch(self, words1, weights1, i2): joint, w1, w2 = self._common_vocabulary_batch(words1, weights1, i2) w1 /= w1.sum() w2 /= w2.sum() evec = self.embeddings[joint] evec_sqr = (evec * evec).sum(axis=1) dists = evec_sqr - 2 * evec.dot(evec.T) + evec_sqr[:, numpy.newaxis] dists[dists < 0] = 0 dists = numpy.sqrt(dists) return libwmdrelax.emd_relaxed(w1, w2, dists, self._relax_cache), \ w1, w2, dists
def test_no_cache(self): w1, w2, dist = self._get_w1_w2_dist() r = libwmdrelax.emd_relaxed(w1, w2, dist) self.assertAlmostEqual(r, 0.6125112)