nli_sentences = list(nli_sentences)
random.shuffle(nli_sentences)

#To determine the PCA matrix, we need some example sentence embeddings.
#Here, we compute the embeddings for 20k random sentences from the AllNLI dataset
pca_train_sentences = nli_sentences[0:20000]
train_embeddings = model.encode(pca_train_sentences, convert_to_numpy=True)

#Compute PCA on the train embeddings matrix
pca = PCA(n_components=new_dimension)
pca.fit(train_embeddings)
pca_comp = np.asarray(pca.components_)

# We add a dense layer to the model, so that it will produce directly embeddings with the new size
dense = models.Dense(in_features=model.get_sentence_embedding_dimension(),
                     out_features=new_dimension,
                     bias=False,
                     activation_function=torch.nn.Identity())
dense.linear.weight = torch.nn.Parameter(torch.tensor(pca_comp))
model.add_module('dense', dense)

# Evaluate the model with the reduce embedding size
logger.info("Model with {} dimensions:".format(new_dimension))
stsb_evaluator(model)

# If you like, you can store the model on disc by uncommenting the following line
#model.save('models/bert-base-nli-stsb-mean-tokens-128dim')

# You can then load the adapted model that produces 128 dimensional embeddings like this:
#model = SentenceTransformer('models/bert-base-nli-stsb-mean-tokens-128dim')
logging.info("Teacher Performance:")
dev_evaluator_sts(teacher_model)

# Student model has fewer dimensions. Compute PCA for the teacher to reduce the dimensions
if student_model.get_sentence_embedding_dimension() < teacher_model.get_sentence_embedding_dimension():
    logging.info("Student model has fewer dimensions than the teacher. Compute PCA for down projection")
    pca_sentences = train_sentences_nli[0:20000] + train_sentences_wikipedia[0:20000]
    pca_embeddings = teacher_model.encode(pca_sentences, convert_to_numpy=True)
    pca = PCA(n_components=student_model.get_sentence_embedding_dimension())
    pca.fit(pca_embeddings)

    #Add Dense layer to teacher that projects the embeddings down to the student embedding size
    dense = models.Dense(in_features=teacher_model.get_sentence_embedding_dimension(), out_features=student_model.get_sentence_embedding_dimension(), bias=False, activation_function=torch.nn.Identity())
    dense.linear.weight = torch.nn.Parameter(torch.tensor(pca.components_))
    teacher_model.add_module('dense', dense)

    logging.info("Teacher Performance with {} dimensions:".format(teacher_model.get_sentence_embedding_dimension()))
    dev_evaluator_sts(teacher_model)



# We train the student_model such that it creates sentence embeddings similar to the embeddings from the teacher_model
# For this, we need a large set of sentences. These sentences are embedded using the teacher model,
# and the student tries to mimic these embeddings. It is the same approach as used in: https://arxiv.org/abs/2004.09813
train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=False)
train_data.add_dataset([[sent] for sent in train_sentences_nli], max_sentence_length=256)
train_data.add_dataset([[sent] for sent in train_sentences_wikipedia], max_sentence_length=256)

train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)