예제 #1
0
def main(args):
    # Read the dataset
    df = pd.read_csv(args.file)
    embedder = BertWrapper(args.model_path, max_seq_length=256)
    pooler = PoolingLayer(embedder.get_word_embedding_dimension(),
                          pooling_mode_mean_tokens=True,
                          pooling_mode_cls_token=False,
                          pooling_mode_max_tokens=False,
                          layer_to_use=args.layer)
    model = SentenceEncoder(modules=[embedder, pooler])
    model.eval()

    evaluator = EmbeddingSimilarityEvaluator(
        main_similarity=SimilarityFunction.COSINE)

    if args.t2s:
        df["text_1"] = df["text_1"].apply(convert_t2s)
        df["text_2"] = df["text_2"].apply(convert_t2s)

    tmp = model.encode(df["text_1"].tolist() + df["text_2"].tolist(),
                       batch_size=16,
                       show_progress_bar=True)
    embeddings1, embeddings2 = tmp[:df.shape[0]], tmp[df.shape[0]:]

    spearman_score = evaluator(embeddings1,
                               embeddings2,
                               labels=df["similarity"].values)
    print(spearman_score)

    preds = 1 - paired_cosine_distances(embeddings1, embeddings2)
    df["pred"] = preds
    df.to_csv("cache/annotated_zero_shot_pred.csv", index=False)
    print(f"Pred {pd.Series(preds).describe()}")
    return preds, df["similarity"].values
예제 #2
0
def main(args):
    encoder = SentenceEncoder(model_path=args.model_path)
    encoder.eval()
    if APEX and args.amp and (not args.torchscript):
        encoder = amp.initialize(encoder, opt_level=args.amp)
    if args.torchscript:
        if args.amp:
            encoder[0].bert = encoder[0].bert.half()
        traced_model = torch.jit.trace(
            encoder[0].bert,
            (torch.zeros(8, 256).long().cuda(),
             torch.zeros(8, 256).long().cuda(),
             torch.ones(8, 256).long().cuda())
        )
        encoder[0].bert = traced_model
        assert isinstance(encoder[0].bert, torch.jit.TopLevelTracedModule)
    encoder.max_seq_length = 256
    print(encoder[1].get_config_dict())
    encoder[1].pooling_mode_cls_token = False
    encoder[1].pooling_mode_mean_tokens = True
    print(encoder[1].get_config_dict())

    preds, _ = raw(args, encoder)

    print(f"Pred {pd.Series(preds).describe()}")