def train(self): tm = TaskManager() begin_time = time.time() running_loss = [] train_iter = 0 f = open('output', 'w+') performance = None epochs_since_last_improvement = 0 for epoch in range(self.num_epochs): running_epoch_loss = [] if epochs_since_last_improvement < self.params['early_stopping']: running_epoch_loss=tm.step(task_model=self.task_model, generator=self.generator, optimizer=self.optimizer, batch_construction=self.minimization_types[self.params['train_objective']]) else: # Stop, because no improvement for more than threshold print( f"No improvements for {epochs_since_last_improvement}. Training stopped. Report saved at: {self.params['save_path']}_report") if not os.path.exists(f'{self.params["save_path"]}_report'): with open(f'{self.params["save_path"]}_report', 'w+') as report_f: # write header report_f.write( f"Performance\tTime\tEval\tSeed\tTime\n") with open(f'{self.params["save_path"]}_report', 'a+') as report_f: # write lines report_f.write( f"{performance}\t{format_number(time.time() - begin_time)}\t{self.params['evaluation']['evaluation_dev_file']}\t{self.params['random_seed']}\t{datetime.now()}\n") return torch.save(self.task_model.mwe_f.state_dict(), f"{self.save_path}_{epoch}.pt") # Evaluate. Separate branch for clarity if epochs_since_last_improvement < self.params['early_stopping']: # Prepare for evaluation # Zero the gradients self.optimizer.zero_grad() # Freeze the network for evaluation for param in self.task_model.mwe_f.parameters(): param.requires_grad = False self.task_model.eval() if 'heldout_data' in self.params: score = -tm.evaluateOnHeldoutDataset(params=self.params, task_model=self.task_model, generator=self.dev_generator, batch_construction=self.minimization_types[self.params['train_objective']]) else: score, model_number = tm.evaluateOnTratz(self.params, self.task_model.mwe_f, self.sg_embeddings, self.embedding_device, self.device) print(f'Max was with: {model_number}') # Unfreeze the network after evaluation for param in self.task_model.mwe_f.parameters(): param.requires_grad = True # Zero whatever gradients might have been computed self.optimizer.zero_grad() # Keeping the better model if performance is None: performance = score print(f"Save new best: {score}") torch.save(self.task_model.mwe_f.state_dict(), f"{self.save_path}.pt") if self.params['train_objective'] == 'JointTrainingSkipGramMinimization': self.task_model.embedding_function.to_saved_file(f"{self.save_path}_embeddings.pt") elif performance < score: epochs_since_last_improvement = 0 performance = score print(f"Save new best: {score}") torch.save(self.task_model.mwe_f.state_dict(), f"{self.save_path}.pt") if self.params['train_objective'] == 'JointTrainingSkipGramMinimization': self.task_model.embedding_function.to_saved_file(f"{self.save_path}_embeddings.pt") # else: epochs_since_last_improvement += 1 # self.scheduler.step(metrics=-score) print( f"{epoch} - {format_number(score)}; {self.optimizer.param_groups[0]['lr']}. Current score {format_number(np.mean(running_epoch_loss))}. Took a total of {format_number(time.time() - begin_time)}s; {self.optimizer.param_groups[0]['lr']}. Model norm: {format_number(np.sum([torch.sum(torch.abs(x)) for x in self.task_model.mwe_f.parameters()]))}. Best is {format_number(performance)}. Epochs {epochs_since_last_improvement}") self.task_model.train() f.close()