コード例 #1
0
    def train(self):
        tm = TaskManager()
        begin_time = time.time()
        running_loss = []
        train_iter = 0
        f = open('output', 'w+')
        performance = None
        epochs_since_last_improvement = 0

        for epoch in range(self.num_epochs):
            running_epoch_loss = []
            if epochs_since_last_improvement < self.params['early_stopping']:
                running_epoch_loss=tm.step(task_model=self.task_model, generator=self.generator, optimizer=self.optimizer, batch_construction=self.minimization_types[self.params['train_objective']])
            else: # Stop, because no improvement for more than threshold
                print(
                    f"No improvements for {epochs_since_last_improvement}. Training stopped. Report saved at: {self.params['save_path']}_report")
                if not os.path.exists(f'{self.params["save_path"]}_report'):
                    with open(f'{self.params["save_path"]}_report', 'w+') as report_f:
                        # write header
                        report_f.write(
                            f"Performance\tTime\tEval\tSeed\tTime\n")

                with open(f'{self.params["save_path"]}_report', 'a+') as report_f:
                    # write lines
                    report_f.write(
                        f"{performance}\t{format_number(time.time() - begin_time)}\t{self.params['evaluation']['evaluation_dev_file']}\t{self.params['random_seed']}\t{datetime.now()}\n")

                return
            torch.save(self.task_model.mwe_f.state_dict(), f"{self.save_path}_{epoch}.pt")
            # Evaluate. Separate branch for clarity
            if epochs_since_last_improvement < self.params['early_stopping']:
                # Prepare for evaluation
                
                # Zero the gradients
                self.optimizer.zero_grad()
                # Freeze the network for evaluation
                for param in self.task_model.mwe_f.parameters():
                    param.requires_grad = False

                self.task_model.eval()
                if 'heldout_data' in self.params:
                    score = -tm.evaluateOnHeldoutDataset(params=self.params, task_model=self.task_model, generator=self.dev_generator,
                                                        batch_construction=self.minimization_types[self.params['train_objective']])
                else:
                    score, model_number = tm.evaluateOnTratz(self.params, self.task_model.mwe_f, self.sg_embeddings, self.embedding_device, self.device)
                    print(f'Max was with: {model_number}')


                # Unfreeze the network after evaluation
                for param in self.task_model.mwe_f.parameters():
                    param.requires_grad = True
                
                # Zero whatever gradients might have been computed
                self.optimizer.zero_grad()

                # Keeping the better model
                if performance is None:
                    performance = score
                    print(f"Save new best: {score}")
                    torch.save(self.task_model.mwe_f.state_dict(), f"{self.save_path}.pt")
                    if self.params['train_objective'] == 'JointTrainingSkipGramMinimization':
                        self.task_model.embedding_function.to_saved_file(f"{self.save_path}_embeddings.pt")
                elif performance < score:
                    epochs_since_last_improvement = 0
                    performance = score
                    print(f"Save new best: {score}")
                    torch.save(self.task_model.mwe_f.state_dict(), f"{self.save_path}.pt")
                    if self.params['train_objective'] == 'JointTrainingSkipGramMinimization':
                        self.task_model.embedding_function.to_saved_file(f"{self.save_path}_embeddings.pt")
                    # 
                else:
                    epochs_since_last_improvement += 1

                # self.scheduler.step(metrics=-score)
                print(
                    f"{epoch} - {format_number(score)}; {self.optimizer.param_groups[0]['lr']}. Current score {format_number(np.mean(running_epoch_loss))}. Took a total of {format_number(time.time() - begin_time)}s; {self.optimizer.param_groups[0]['lr']}. Model norm: {format_number(np.sum([torch.sum(torch.abs(x)) for x in self.task_model.mwe_f.parameters()]))}. Best is {format_number(performance)}. Epochs {epochs_since_last_improvement}")

                self.task_model.train()

        f.close()