def normalize_vecs(vecs: np.array, normalize_mode: str): """ Normalize embeddings by their norms / recenter them. """ for t in normalize_mode.split(','): if t == '': continue if t == 'center': mean = vecs.mean(0, keepdims=True) vecs -= mean elif t == 'renorm': vecs /= vecs.norm(2, 1, keepdims=True) else: raise Exception('Unknown normalization type: "%s"' % t) return vecs