def __init__(self, model, criterion, metrics_name, optimizer, train_loader, logger, log_dir, nb_epochs, save_dir, device="cuda:0", log_step=10, start_epoch=0, enable_tensorboard=True, valid_loader=None, lr_scheduler=None, monitor="min val_loss", early_stop=10, save_epoch_period=1, resume=""): self.model = model self.criterion = criterion self.metrics_name = metrics_name self.optimizer = optimizer self.train_loader = train_loader self.valid_loader = valid_loader self.len_epoch = len(self.train_loader) self.do_validation = (self.valid_loader is not None) self.lr_scheduler = lr_scheduler self.log_step = log_step self.epochs = nb_epochs self.start_epoch = start_epoch + 1 self.logger = logger self.device = device self.save_period = save_epoch_period self.writer = TensorboardWriter(log_dir, self.logger, enable_tensorboard) self.train_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.valid_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.checkpoint_dir = save_dir if monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = early_stop if resume != "": self._resume_checkpoint(resume_path=resume) self.model.to(self.device)
def __init__(self, model, criterion, metric_ftns, config, device, data_loader, evaluation=True): """ Initiates the Base tester. :param model: The model to test. :param criterion: The loss function. :param metric_ftns: The metrics on which the model will be evaluated during test time. :param config: Configuration file. :param device: The device to use for the computations. :param data_loader: Dataloader for the dataset. :param evaluation: True if the tester is used as evaluator while training, False if used for testing the model. """ self.config = config self.logger = config.get_logger('tester', config['tester']['verbosity']) self.predictions_file_name = config.get_predictions_file_name() self.model = model self.criterion = criterion self.metric_ftns = metric_ftns self.device = device self.data_loader = data_loader self.evaluation = evaluation # Get testing configurations cfg_tester = config['tester'] # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_tester['tensorboard'])
def __init__(self, model, criterion, metric_ftns, optimizer, config): self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) # setup GPU device if available, move model into configured device self.device, device_ids = self._prepare_device(config['n_gpu']) self.model = model.cuda() if len(device_ids) > 1: self.model = torch.nn.DataParallel(self.model, device_ids=device_ids) self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) self.start_epoch = 1 self.checkpoint_dir = config.save_dir # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) if config.resume is not None: self._resume_checkpoint(config.resume)
def __init__(self, model, criterion, metric_ftns, optimizer, config): self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) self.model = model self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') self.save = cfg_trainer['save'] # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) if self.early_stop <= 0: self.early_stop = inf self.start_epoch = 1 self.checkpoint_dir = config.save_dir self.tensorboard = cfg_trainer['tensorboard'] # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) if config.resume is not None: self._resume_checkpoint(config.resume)
def __init__(self, t_c, m_c): if "t_c" not in vars(self): self.t_c = edict() self.t_c.want_log = t_c.want_log self.t_c.use_early_stopping = t_c.use_early_stopping self.t_c.img_size = t_c.img_size self.t_c.save = t_c.save self.t_c.save_every = t_c.save_every self.t_c.save_ext = t_c.save_ext self.t_c.load = t_c.load if self.t_c.load: self.t_c.load_config = t_c.load_config self.t_c.load_model = t_c.load_model self.t_c.load_netD = t_c.load_netD self.t_c.load_netG = t_c.load_netG self.t_c.load_optimD = t_c.load_optimD self.t_c.load_optimG = t_c.load_optimG self.t_c.test = t_c.test self.t_c.test_every = t_c.test_every self.t_c.sample_size = t_c.sample_size self.t_c.batch_size = t_c.batch_size self.t_c.shuffle = t_c.shuffle self.t_c.num_workers = t_c.num_workers self.t_c.epochs = t_c.epochs self.t_c.summary_dir = t_c.summary_dir self.t_c.checkpoint_dir = t_c.checkpoint_dir self.t_c.log_dir = t_c.log_dir self.t_c.out_dir = t_c.out_dir self.t_c.data_roots = t_c.data_roots if "m_c" not in vars(self): self.m_c = edict(m_c) self.summary_writer = TensorboardWriter(self.t_c.summary_dir) # self._stop_training = False if self.t_c.use_early_stopping: self.early_stopper_D = EarlyStopping2(patience=20, low_threshold=0.009, up_threshold=0.99, verbose=False) self.early_stopper_G = EarlyStopping2(patience=20, low_threshold=0.009, up_threshold=0.99, verbose=False) self.init_model() if self.t_c.load: self.model.load(path=self.t_c.load_model, load_config=self.t_c.load_config, load_netD=self.t_c.load_netD, load_netG=self.t_c.load_netG, load_optimD=self.t_c.load_optimD, load_optimG=self.t_c.load_optimG) print("NEW MODEL LOADED CONFIG") pprint(self.model.config) self.fixed_noise = self.model.generate_fixed_noise(sample_size=32)
class BaseGANTrainer(): """ . """ def __init__(self, t_c, m_c): if "t_c" not in vars(self): self.t_c = edict() self.t_c.want_log = t_c.want_log self.t_c.use_early_stopping = t_c.use_early_stopping self.t_c.img_size = t_c.img_size self.t_c.save = t_c.save self.t_c.save_every = t_c.save_every self.t_c.save_ext = t_c.save_ext self.t_c.load = t_c.load if self.t_c.load: self.t_c.load_config = t_c.load_config self.t_c.load_model = t_c.load_model self.t_c.load_netD = t_c.load_netD self.t_c.load_netG = t_c.load_netG self.t_c.load_optimD = t_c.load_optimD self.t_c.load_optimG = t_c.load_optimG self.t_c.test = t_c.test self.t_c.test_every = t_c.test_every self.t_c.sample_size = t_c.sample_size self.t_c.batch_size = t_c.batch_size self.t_c.shuffle = t_c.shuffle self.t_c.num_workers = t_c.num_workers self.t_c.epochs = t_c.epochs self.t_c.summary_dir = t_c.summary_dir self.t_c.checkpoint_dir = t_c.checkpoint_dir self.t_c.log_dir = t_c.log_dir self.t_c.out_dir = t_c.out_dir self.t_c.data_roots = t_c.data_roots if "m_c" not in vars(self): self.m_c = edict(m_c) self.summary_writer = TensorboardWriter(self.t_c.summary_dir) # self._stop_training = False if self.t_c.use_early_stopping: self.early_stopper_D = EarlyStopping2(patience=20, low_threshold=0.009, up_threshold=0.99, verbose=False) self.early_stopper_G = EarlyStopping2(patience=20, low_threshold=0.009, up_threshold=0.99, verbose=False) self.init_model() if self.t_c.load: self.model.load(path=self.t_c.load_model, load_config=self.t_c.load_config, load_netD=self.t_c.load_netD, load_netG=self.t_c.load_netG, load_optimD=self.t_c.load_optimD, load_optimG=self.t_c.load_optimG) print("NEW MODEL LOADED CONFIG") pprint(self.model.config) self.fixed_noise = self.model.generate_fixed_noise(sample_size=32) def init_model(self): raise NotImplementedError def run(self): """ The main operator """ try: self.train() except KeyboardInterrupt: print("") print(70 * "-") print("You have entered CTRL+C... Wait to finalize") # Prompt user if he wants to save the model params answer = yes_or_no("What to save model parameters before quiting?") if answer: self.model.save(self.t_c.checkpoint_dir, self.t_c.save_ext) exit(-1) def train(self): self.start = self.model.epochs_trained self.end = self.t_c.epochs + self.start for epoch in range(self.start, self.end): dataloader = self._get_dataloader() self._train_one_epoch(dataloader, epoch) def _train_one_epoch(self, dataloader, epoch): all_batches = len(dataloader) self.model.reset_meters() # For each batch in the dataloader for batch_num, batch_data in enumerate(dataloader): errD, errG, D_x, D_G_z1, D_G_z2 = self.model._train_step( batch_data) # Log batch stats into terminal self._log_train_step_stats(epoch, self.end, batch_num, all_batches, errD, errG, D_x, D_G_z1, D_G_z2) if self.t_c.use_early_stopping: self._stop_training = self.early_stopper_D.feed( D_x) or self.early_stopper_G.feed(D_G_z2) if self._stop_training: exit(-1) # Save model if self.t_c.save and epoch % self.t_c.save_every == 0: self.model.save(self.t_c.checkpoint_dir, self.t_c.save_ext) # Test model if self.t_c.test and epoch % self.t_c.test_every == 0: fake_samples = self.model.generate_images( sample_size=self.t_c.sample_size) fixed_samples = self.model.generater_fixed_images(self.fixed_noise) self.summary_writer.image_summary(f"Fake", fake_samples, epoch) self.summary_writer.image_summary(f"FixedNoise", fixed_samples, epoch) if self.t_c.want_log: d_mean, d_std = self.model.meterD.value() g_mean, g_std = self.model.meterG.value() self.summary_writer.plot_losses("LossesMeans", "D", "G", d_mean, g_mean, epoch) self.summary_writer.plot_losses("LossesStds", "D", "G", d_std, g_std, epoch) self.model.epochs_trained += 1 def _get_dataloader(self): dataset = self._get_dataset() dataloader = torch.utils.data.DataLoader( dataset, batch_size=self.t_c.batch_size, shuffle=self.t_c.shuffle, num_workers=self.t_c.num_workers) return dataloader def _get_dataset(self): dataset = ArtDataset(self.t_c.data_roots, transforms_=[ transforms.Resize(self.t_c.img_size), transforms.CenterCrop(self.t_c.img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) return dataset def _log_train_step_stats(self, epoch_num, all_epochs, batch_num, all_batches, errD, errG, D_x, D_G_z1, D_G_z2): # Output training stats print( '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch_num, all_epochs, batch_num, all_batches, errD, errG, D_x, D_G_z1, D_G_z2))
class Trainer(): def __init__(self, model, criterion, metrics_name, optimizer, train_loader, logger, log_dir, nb_epochs, save_dir, device="cuda:0", log_step=10, start_epoch=0, enable_tensorboard=True, valid_loader=None, lr_scheduler=None, monitor="min val_loss", early_stop=10, save_epoch_period=1, resume=""): self.model = model self.criterion = criterion self.metrics_name = metrics_name self.optimizer = optimizer self.train_loader = train_loader self.valid_loader = valid_loader self.len_epoch = len(self.train_loader) self.do_validation = (self.valid_loader is not None) self.lr_scheduler = lr_scheduler self.log_step = log_step self.epochs = nb_epochs self.start_epoch = start_epoch + 1 self.logger = logger self.device = device self.save_period = save_epoch_period self.writer = TensorboardWriter(log_dir, self.logger, enable_tensorboard) self.train_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.valid_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.checkpoint_dir = save_dir if monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = early_stop if resume != "": self._resume_checkpoint(resume_path=resume) self.model.to(self.device) def train(self): not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) log = {'epoch': epoch} log.update(result) self.logger.info(' {:15s}: {}'.format(str("mnt best"), self.mnt_best)) for key, value in log.items(): self.logger.info(' {:15s}: {}'.format(str(key), value)) best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if (not_improved_count > self.early_stop) and (self.early_stop > 0): self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, best) def _train_epoch(self, epoch): self.model.train() self.train_metrics.reset() start_time = time.time() for batch_idx, sample in enumerate(self.train_loader): data = sample['image'] target = sample['mask'] data, target = data.to(self.device), target.to(self.device) current_lr = self.lr_scheduler(self.optimizer, batch_idx, epoch) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met_name in self.metrics_name: self.train_metrics.update( met_name, getattr(metrics, met_name)(output, target)) if batch_idx % self.log_step == 0: time_to_run = time.time() - start_time start_time = time.time() speed = self.log_step / time_to_run self.logger.debug('Train Epoch: {} {} Loss: {:.6f} LR: {:.6f} Speed: {:.4f}iters/s' \ .format(epoch, self._progress(batch_idx), loss.item(), current_lr, speed)) for met_name in self.metrics_name: self.writer.add_scalar(met_name, self.train_metrics.avg(met_name)) self.writer.add_scalar('loss', self.train_metrics.avg('loss')) self.writer.add_scalar("lr", current_lr) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) assert batch_idx <= self.len_epoch log = self.train_metrics.result() if self.do_validation: print("Start validation") val_log, iou_classes = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) for key, value in iou_classes.items(): log.update({key: value}) return log def _valid_epoch(self, epoch): self.model.eval() self.valid_metrics.reset() iou_tracker = metrics.IoU(2) with torch.no_grad(): for batch_idx, sample in enumerate(self.valid_loader): data = sample['image'] target = sample['mask'] data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step( (epoch - 1) * len(self.valid_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) target = target.cpu().numpy() output = output[:, 0] output = output.data.cpu().numpy() pred = np.zeros_like(output) pred[output > 0.5] = 1 pred = pred.astype(np.int64) for i in range(len(target)): iou_tracker.add_batch(target[i], pred[i]) iou_classes = iou_tracker.get_iou() for key, value in iou_classes.items(): self.writer.add_scalar(key, value) self.writer.add_scalar('val_loss', self.valid_metrics.avg('loss')) for met_name in self.metrics_name: self.writer.add_scalar(met_name, self.valid_metrics.avg(met_name)) # for name, p in self.model.named_parameters(): # print(name, p) # self.writer.add_histogram(name, p.cpu().data.numpy(), bins='auto') # return self.valid_metrics.result(), iou_classes def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.mnt_best, # 'config': self.config } filename = str(self.checkpoint_dir / 'checkpoint-epoch{:06d}.pth'.format(epoch)) torch.save(state, filename) self.delete_checkpoint() self.logger.info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def delete_checkpoint(self): checkpoints_file = list( self.checkpoint_dir.glob("checkpoint-epoch*.pth")) checkpoints_file.sort() for checkpoint_file in checkpoints_file[:-5]: os.remove(str(checkpoint_file.absolute())) def _resume_checkpoint(self, resume_path): self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
def __init__(self, model, criterion, metric_ftns, optimizer, config): """ :param model: 模型 :param criterion: 损失标准 :param metric_ftns: 度量工具函数(评价指标) :param optimizer: 优化器 :param config: 配置 """ # 配置 self.config = config # logger self.logger = config.get_logger('trainer', config['trainer']['verbosity']) # 准备计算代理,返回self.device和gpu列表 # setup GPU device if available, move model into configured device self.device, device_ids = self._prepare_device(config['n_gpu']) # 模型丢进计算代理 self.model = model.to(self.device) # # 1、单卡,把下面注释掉;2、多卡并行,不注释下面, # # 如果gpu数大于1 # if len(device_ids) > 1: # # 实现了并行计算 # # DataParallel(),Implements data parallelism at the module level. # self.model = torch.nn.DataParallel(model, device_ids=device_ids) # 损失标准 self.criterion = criterion # 度量工具函数 """ 这个 "metrics": [ "binary_accuracy", "binary_f1", "binary_auc", ], """ self.metric_ftns = metric_ftns # 优化器 self.optimizer = optimizer # 训练器配置 cfg_trainer = config['trainer'] # 当前轮开始的epoch self.start_epoch = 1 # 这轮一共要训练的epoch self.epochs = cfg_trainer['epochs'] # 保存周期 self.save_period = cfg_trainer['save_period'] # 监视曲线 """ "monitor": "min val_loss" json.get('monitor', 'off'),如果没'monitor'键,会返回'off' """ self.monitor = cfg_trainer.get('monitor', 'off') # 配置监视曲线,来保存最好模型 # configuration to monitor model performance and save best # 若监视曲线,是关闭的 if self.monitor == 'off': # 倾向模式,关闭 self.mnt_mode = 'off' # 当前最好倾向,0 self.mnt_best = 0 # 若监视曲线,是开启的 else: # 倾向模式"min", # 度量曲线"val_loss", # "monitor": "min val_loss" self.mnt_mode, self.mnt_metric = self.monitor.split() # 确保倾向模式在['min', 'max'] assert self.mnt_mode in ['min', 'max'] # 当前最好倾向,初始化 """ 倾向min,赋inf无穷大, 反之则反 站的越高,看的越远 反之则反:站的越低,看的越近(句子对反) 反之亦然:看的越远,高的越高(前后句反) """ self.mnt_best = inf if self.mnt_mode == 'min' else -inf # 提前停止 """ 是几个模型没有提高之后的停止计数值 "early_stop": 10, 如果'early_stop'没有值,那么inf """ self.early_stop = cfg_trainer.get('early_stop', inf) # 可视化Writer实例 # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) # 检查点模型保存路径 self.checkpoint_dir = config.save_dir # print(self.checkpoint_dir) # 如果设置了重启路径 if config.resume is not None: # 起用重启 self._resume_checkpoint(config.resume)
def __init__(self, model, criterion, train_metric_ftns, eval_metric_ftns, optimizer, config, device, data_loader, valid_data_loader, lr_scheduler): """ Initiates the Base trainer. :param model: The model to train. :param criterion: The loss function. :param train_metric_ftns: The metrics on which the model will be evaluated during evaluation or train time. :param eval_metric_ftns: The metrics on which the model will be evaluated during evaluation or train time. :param optimizer: The optimizer to use for optimizing the parameters of the model. :param config: Configuration file. :param device: The device to use for computations. :param data_loader: Dataloader for the train dataset. :param valid_data_loader: Dataloader for the validation dataset. :param lr_scheduler: Scheduler for the learning rate. """ self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) self.model = model self.criterion = criterion self.train_metric_ftns = train_metric_ftns self.eval_metric_ftns = eval_metric_ftns self.optimizer = optimizer self.device = device self.data_loader = data_loader self.valid_data_loader = valid_data_loader self.lr_scheduler = lr_scheduler # Get training configurations cfg_trainer = config['trainer'] # Metrics to display for the best model. self.best_model_metrics_log = cfg_trainer[ 'best_model_metrics_log'].split() self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer[ 'save_period'] # Once in how many epochs to save the models parameters. # Metric which will be used to choose the best model self.monitor = cfg_trainer.get('monitor', 'off') # configuration to monitor model performance and save best model. if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in [ 'min', 'max' ], "Invalid monitor mode, should be min or max" self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) if self.early_stop <= 0: self.early_stop = inf # Dictionary to keep the metrics results of the best model. self.model_best_metrics = {} # The epoch to start working from. self.start_epoch = 1 self.checkpoint_dir = config.save_dir print(self.checkpoint_dir) # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) # If resume path is given, resume training from checkpoint. if config.resume is not None: self._resume_checkpoint(config.resume)