def normal_training(config): device = torch.device(config['device']) print('Using device', device) exp, model, train_dataloader, eval_dataloader, loss_func = setup_training( config) exp.set_name(config['experiment_name']) model.train() model = model.to(device) optimizers = get_optimizers(model, config) evaluator = Evaluation(eval_dataloader, config) num_examples = 0 for epoch in range(config['training']['training_epochs']): for idx, batch in enumerate(train_dataloader): batch = (batch[0].to(device), batch[1].to(device)) num_examples += len(batch[0]) loss, train_accuracy = training_step(batch, model, optimizers, loss_func) if idx % config['training']['log_every_n_batches'] == 0: print(epoch, num_examples, loss.detach().cpu().numpy()) exp.log_metric('train_loss', loss.detach().cpu().numpy(), step=num_examples, epoch=epoch) if idx % config['training']['eval_every_n_batches'] == 0: results = evaluator.eval_model(model, loss_func) for metric in results: print(metric, results[metric]) exp.log_metric(metric, results[metric], step=num_examples, epoch=epoch)
print('Model loaded.') n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) log_dict = helpers.flatten_dict(config) log_dict.update({'trainable_params': n_params}) exp.log_parameters(log_dict) test_dataset = data.CSVDatasetsMerger(helpers.get_datasets_paths(config, 'test')) test_dataloader = DataLoader(test_dataset, batch_size=config['evaluation']['eval_batch_size'], shuffle=False, drop_last=False, num_workers=config['evaluation']['n_eval_workers'], collate_fn=text_proc) evaluator = Evaluation(test_dataloader, config) print('Testing ...') results, assets, image_fns = evaluator.eval_model(model, finished_training=True) print('Finished testing. Uploading ...') exp.log_metrics(results, step=0, epoch=0) [exp.log_asset_data(asset, step=0) for asset in assets] [exp.log_image(fn, step=0) for fn in image_fns] print('Finished uploading.')
class TuneTrainable(Trainable): def _setup(self, config): inject_tuned_hyperparameters(config, config) os.chdir(os.path.dirname(os.path.realpath(__file__))) print('Trainable got the following config after injection', config) self.config = config self.device = self.config['device'] self.exp, self.model, self.train_dataloader, self.eval_dataloader = setup_training( self.config) self.exp.set_name(config['experiment_name'] + self._experiment_id) self.exp_name = config['experiment_name'] + self._experiment_id self.exp.send_notification(title='Experiment ' + str(self._experiment_id) + ' ended') self.train_data_iter = iter(self.train_dataloader) self.model = self.model.to(self.device) self.model.train() n_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) log_dict = flatten_dict(config) log_dict.update({'trainable_params': n_params}) self.exp.log_parameters(log_dict) self.optimizers = get_optimizers(self.model, self.config) self.evaluator = Evaluation(self.eval_dataloader, self.config) self.num_examples = 0 self.batch_idx = 0 self.epoch = 1 self.ewma = EWMA(beta=0.75) self.last_accu = -1.0 self.max_accu = -1.0 self.back_prop_every_n_batches = config['training'][ 'back_prop_every_n_batches'] self.checkpoint_best = config['training']['checkpoint_best'] def get_batch(self): try: batch = next(self.train_data_iter) return batch except StopIteration: self.train_data_iter = iter(self.train_dataloader) batch = next(self.train_data_iter) self.batch_idx = 0 self.epoch += 1 return batch def _train(self): total_log_step_loss = 0 total_log_step_train_accu = 0 total_log_step_n = 0 [opt.zero_grad() for opt in self.optimizers] while True: batch = self.get_batch() self.batch_idx += 1 self.num_examples += len(batch[0]) batch = (batch[0].to(self.device), batch[1].to(self.device)) loss, train_accu = training_step( batch, self.model, self.optimizers, step=(self.batch_idx % self.back_prop_every_n_batches == 0)) total_log_step_loss += loss.cpu().detach().numpy() total_log_step_train_accu += train_accu total_log_step_n += 1 if self.batch_idx % self.config['training'][ 'log_every_n_batches'] == 0: avg_loss = total_log_step_loss / total_log_step_n avg_accu = total_log_step_train_accu / total_log_step_n total_log_step_n = 0 print(f'{Fore.YELLOW}Total number of seen examples:', self.num_examples, 'Average loss of current log step:', avg_loss, 'Average train accuracy of current log step:', avg_accu, f"{Style.RESET_ALL}") self.exp.log_metric('train_loss', avg_loss, step=self.num_examples, epoch=self.epoch) self.exp.log_metric('train_accuracy', avg_accu, step=self.num_examples, epoch=self.epoch) total_log_step_loss = 0 total_log_step_train_accu = 0 if (self.batch_idx + 1) % self.config['training']['eval_every_n_batches'] == 0: results, assets, image_fns = self.evaluator.eval_model( self.model) print(self.config['tune']['discriminating_metric'], results[self.config['tune']['discriminating_metric']]) self.exp.log_metrics(results, step=self.num_examples, epoch=self.epoch) [ self.exp.log_asset_data(asset, step=self.num_examples) for asset in assets ] [ self.exp.log_image(fn, step=self.num_examples) for fn in image_fns ] accu_diff_avg = abs( results[self.config['tune']['discriminating_metric']] - self.ewma.get()) accu_diff_cons = abs( results[self.config['tune']['discriminating_metric']] - self.last_accu) no_change_in_accu = 1 if accu_diff_avg < 0.0005 and accu_diff_cons < 0.002 and self.num_examples > 70000 else 0 self.ewma.update( results[self.config['tune']['discriminating_metric']]) self.last_accu = results[self.config['tune'] ['discriminating_metric']] if self.max_accu < results[self.config['tune'] ['discriminating_metric']]: self.max_accu = results[self.config['tune'] ['discriminating_metric']] if self.checkpoint_best: self.save_checkpoint('checkpoints', self.exp_name + '.pt') print( f'{Fore.GREEN}New best model saved.{Style.RESET_ALL}' ) self.exp.log_metric('max_accuracy', self.max_accu, step=self.num_examples, epoch=self.epoch) training_results = { self.config['tune']['discriminating_metric']: self.max_accu, 'num_examples': self.num_examples, 'no_change_in_accu': no_change_in_accu } return training_results def _save(self, checkpoint_dir): return self.save_checkpoint(checkpoint_dir, 'checkpoint_file.pt') def save_checkpoint(self, checkpoint_dir, fname='checkpoint_file.pt'): print(f'{Fore.CYAN}Saving model ...{Style.RESET_ALL}') save_dict = {'model_state_dict': self.model.state_dict()} for i, optimizer in enumerate(self.optimizers): save_dict['op_' + str(i) + '_state_dict'] = optimizer.state_dict() torch.save(save_dict, os.path.join(checkpoint_dir, fname)) return os.path.join(checkpoint_dir, fname) def _restore(self, checkpoint_path): checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['model_state_dict']) for i, optimizer in enumerate(self.optimizers): optimizer.load_state_dict(checkpoint['op_' + str(i) + '_state_dict']) def stop(self): results, assets, image_fns = self.evaluator.eval_model( self.model, finished_training=True) self.exp.log_metrics(results, step=self.num_examples, epoch=self.epoch) [ self.exp.log_asset_data(asset, step=self.num_examples) for asset in assets ] [self.exp.log_image(fn, step=self.num_examples) for fn in image_fns] return super().stop()