def test_roberta_wkpooling(self):
        word_embedding_model = models.Transformer(
            'roberta-base', model_args={'output_hidden_states': True})
        pooling_model = models.WKPooling(
            word_embedding_model.get_word_embedding_dimension())
        model = SentenceTransformer(
            modules=[word_embedding_model, pooling_model])
        scores = [
            0.9594874382019043, 0.9928674697875977, 0.9241214990615845,
            0.9309519529342651, 0.9506515264511108
        ]

        for sentences, score in zip(WKPoolingTest.sentence_pairs, scores):
            embedding = model.encode(sentences, convert_to_numpy=True)

            similarity = 1 - scipy.spatial.distance.cosine(
                embedding[0], embedding[1])
            assert abs(similarity - score) < 0.01
    def test_bert_wkpooling(self):
        word_embedding_model = models.BERT(
            'bert-base-uncased', model_args={'output_hidden_states': True})
        pooling_model = models.WKPooling(
            word_embedding_model.get_word_embedding_dimension())
        model = SentenceTransformer(
            modules=[word_embedding_model, pooling_model])
        scores = [
            0.6906377742193329, 0.9910573945907297, 0.8395676755959804,
            0.7569234597143, 0.8324509121875274
        ]

        for sentences, score in zip(WKPoolingTest.sentence_pairs, scores):
            embedding = model.encode(sentences, convert_to_numpy=True)

            similarity = 1 - scipy.spatial.distance.cosine(
                embedding[0], embedding[1])
            assert abs(similarity - score) < 0.01
layer_index = [int(i) for i in args.layer_index.split(',')]
if args.whitening:
    args.sts_corpus += "white/"
    target_eval_files = [f + "-white" for f in target_eval_files]

word_embedding_model = models.Transformer(args.encoder_name,
                                          model_args={
                                              'output_hidden_states': True,
                                              'batch_size': args.batch_size
                                          })

if args.last2avg:
    layer_index = [0, -1]
# pooling_model = LayerNPooling(args.pooling, word_embedding_model.get_word_embedding_dimension(), layers=layer_index)
if args.wk:
    pooling_model = models.WKPooling(
        word_embedding_model.get_word_embedding_dimension())
    logger.info("wkpooling")
else:
    pooling_model = LayerNPooling(
        args.pooling,
        word_embedding_model.get_word_embedding_dimension(),
        layers=layer_index)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

logger.info("Pool:{}, Encoder:{}, Whitening:{}".format(args.pooling,
                                                       args.encoder_name,
                                                       args.whitening))

evaluators = {
    task: []