def train(self):
        tm = TaskManager()
        begin_time = time.time()
        running_loss = []
        train_iter = 0
        f = open('output', 'w+')
        performance = None
        epochs_since_last_improvement = 0

        ptm = AutoEncoderPreTraining(self.sg_embeddings, self.mwe_f)
        optimizer = torch.optim.SGD(ptm.mwe_f.parameters(),
                                    lr=self.learning_rate,
                                    weight_decay=params['weight_decay'])
        tm.step(task_model=ptm,
                generator=self.generator,
                optimizer=optimizer,
                batch_construction=self.prepare_autoencoder_batch)
        for epoch in range(self.num_epochs):
            running_epoch_loss = []
            if epochs_since_last_improvement < self.params['early_stopping']:
                running_epoch_loss = tm.step(
                    task_model=self.task_model,
                    generator=self.generator,
                    optimizer=self.optimizer,
                    batch_construction=self.minimization_types[
                        self.params['train_objective']])
            else:
                print(
                    f"No improvements for {epochs_since_last_improvement}. Training stopped. Report saved at: {self.params['save_path']}_report"
                )
                if not os.path.exists(f'{self.params["save_path"]}_report'):
                    with open(f'{self.params["save_path"]}_report',
                              'w+') as report_f:
                        # write header
                        report_f.write(
                            f"Performance\tTime\tEval\tSeed\tTime\n")

                with open(f'{self.params["save_path"]}_report',
                          'a+') as report_f:
                    # write lines
                    report_f.write(
                        f"{performance}\t{format_number(time.time() - begin_time)}\t{self.params['evaluation']['evaluation_dev_file']}\t{self.params['random_seed']}\t{datetime.now()}\n"
                    )

                return

            if epochs_since_last_improvement < self.params['early_stopping']:
                # Prepare for evaluation
                # Zero the gradients
                self.optimizer.zero_grad()
                # Freeze the network for evaluation
                for param in self.task_model.mwe_f.parameters():
                    param.requires_grad = False

                self.task_model.eval()

                score, model_number = tm.evaluate(self.params,
                                                  self.task_model.mwe_f,
                                                  self.sg_embeddings,
                                                  self.device)
                print(f'Max was with: {model_number}')
                # Unfreeze the network after evaluation
                for param in self.task_model.mwe_f.parameters():
                    param.requires_grad = True

                # Zero whatever gradients might have been computed
                self.optimizer.zero_grad()

                # Keeping the better model
                if performance is None:
                    performance = score
                    print(f"Save new best: {score}")
                    torch.save(self.task_model.mwe_f.state_dict(),
                               f"{self.save_path}.pt")
                    if self.params[
                            'train_objective'] == 'JointTrainingSkipGramMinimization':
                        self.task_model.embedding_function.to_saved_file(
                            f"{self.save_path}_embeddings.pt")
                elif performance < score:
                    epochs_since_last_improvement = 0
                    performance = score
                    print(f"Save new best: {score}")
                    torch.save(self.task_model.mwe_f.state_dict(),
                               f"{self.save_path}.pt")
                    if self.params[
                            'train_objective'] == 'JointTrainingSkipGramMinimization':
                        self.task_model.embedding_function.to_saved_file(
                            f"{self.save_path}_embeddings.pt")
                    #
                else:
                    epochs_since_last_improvement += 1

                # self.scheduler.step(metrics=-score)
                print(
                    f"{epoch} - {format_number(score)}; {self.optimizer.param_groups[0]['lr']}. Current score {format_number(np.mean(running_epoch_loss))}. Took a total of {format_number(time.time() - begin_time)}s; {self.optimizer.param_groups[0]['lr']}. Model norm: {format_number(np.sum([torch.sum(torch.abs(x)) for x in self.task_model.mwe_f.parameters()]))}. Best is {format_number(performance)}. Epochs {epochs_since_last_improvement}"
                )

                self.task_model.train()

        f.close()