def jsd(p: tf.Tensor, q: tf.Tensor, base=np.e): """ Implementation of pairwise Jensen-Shannon Divergence based on https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence This returns NaNs for all zero probs (unlabeled). """ import scipy.stats p, q = p.numpy(), q.numpy() # normalize p, q to probabilities p, q = p / p.sum(axis=-1, keepdims=True), q / q.sum(axis=-1, keepdims=True) p, q = p.transpose(), q.transpose() m = 1.0 / 2 * (p + q) jsd = (scipy.stats.entropy(p, m, base=base) / 2.0 + scipy.stats.entropy(q, m, base=base) / 2.0) jsd = np.clip(jsd, 0.0, 1.0).transpose() return jsd
def get_output(self, train=False): data = self.get_input(train) mask = self.get_input_mask(train) if mask is None: mask = T.sum(T.ones_like(data), axis=-1) mask = mask.dimshuffle(0, 1, "x") masked_data = T.switch(T.eq(mask, 0), -np.inf, data) result = masked_data[ T.arange(masked_data.shape[0]).dimshuffle(0, "x", "x"), T.sort(T.argsort(masked_data, axis=1)[:, -self.pooling_size:, :], axis=1), T.arange(masked_data.shape[2]).dimshuffle("x", "x", 0)] return result