def average_cosine(complex_sentence, simple_sentence): complex_embeddings = to_embeddings(complex_sentence) simple_embeddings = to_embeddings(simple_sentence) return float( F.cosine_similarity(complex_embeddings.mean(dim=0), simple_embeddings.mean(dim=0), dim=0))
def hungarian_dot(complex_sentence, simple_sentence): complex_embeddings = to_embeddings(complex_sentence) simple_embeddings = to_embeddings(simple_sentence) similarity_matrix = torch.mm(complex_embeddings, simple_embeddings.t()) row_indexes, col_indexes = linear_sum_assignment(-similarity_matrix) # TODO: Penalize less deletion of unimportant words return float(similarity_matrix[row_indexes, col_indexes].sum() / max(len(complex_sentence), len(simple_sentence)))
def hungarian_cosine(complex_sentence, simple_sentence): complex_embeddings = to_embeddings(complex_sentence) simple_embeddings = to_embeddings(simple_sentence) similarity_matrix = torch.zeros(len(complex_embeddings), len(simple_embeddings)) for (i, complex_embedding), (j, simple_embedding) in itertools.product(enumerate(complex_embeddings), enumerate(simple_embeddings)): similarity_matrix[i, j] = F.cosine_similarity(complex_embedding, simple_embedding, dim=0) row_indexes, col_indexes = linear_sum_assignment(-similarity_matrix) # TODO: Penalize less deletion of unimportant words return float(similarity_matrix[row_indexes, col_indexes].sum() / max(len(complex_sentence), len(simple_sentence)))
def average_dot(complex_sentence, simple_sentence): complex_embeddings = to_embeddings(complex_sentence) simple_embeddings = to_embeddings(simple_sentence) return float( torch.dot(complex_embeddings.mean(dim=0), simple_embeddings.mean(dim=0)))