Esempio n. 1
0
    def __init__(self,
                 model,
                 criterion,
                 metrics_name,
                 optimizer,
                 train_loader,
                 logger,
                 log_dir,
                 nb_epochs,
                 save_dir,
                 device="cuda:0",
                 log_step=10,
                 start_epoch=0,
                 enable_tensorboard=True,
                 valid_loader=None,
                 lr_scheduler=None,
                 monitor="min val_loss",
                 early_stop=10,
                 save_epoch_period=1,
                 resume=""):
        self.model = model
        self.criterion = criterion
        self.metrics_name = metrics_name
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        self.len_epoch = len(self.train_loader)
        self.do_validation = (self.valid_loader is not None)
        self.lr_scheduler = lr_scheduler
        self.log_step = log_step
        self.epochs = nb_epochs
        self.start_epoch = start_epoch + 1

        self.logger = logger
        self.device = device
        self.save_period = save_epoch_period

        self.writer = TensorboardWriter(log_dir, self.logger,
                                        enable_tensorboard)
        self.train_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.valid_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.checkpoint_dir = save_dir
        if monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = monitor.split()
            assert self.mnt_mode in ['min', 'max']
            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = early_stop
        if resume != "":
            self._resume_checkpoint(resume_path=resume)
        self.model.to(self.device)
Esempio n. 2
0
    def __init__(self, model, criterion, metric_ftns, config, device, data_loader, evaluation=True):
        """
        Initiates the Base tester.
        :param model:       The model to test.
        :param criterion:   The loss function.
        :param metric_ftns: The metrics on which the model will be evaluated during test time.
        :param config:      Configuration file.
        :param device:      The device to use for the computations.
        :param data_loader: Dataloader for the dataset.
        :param evaluation:  True if the tester is used as evaluator while training, False if used for testing the model.
        """
        self.config = config
        self.logger = config.get_logger('tester', config['tester']['verbosity'])
        self.predictions_file_name = config.get_predictions_file_name()

        self.model = model
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.device = device
        self.data_loader = data_loader
        self.evaluation = evaluation

        # Get testing configurations
        cfg_tester = config['tester']

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_tester['tensorboard'])
Esempio n. 3
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])

        self.model = model.cuda()
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
Esempio n. 4
0
	def __init__(self, model, criterion, metric_ftns, optimizer, config):
		self.config = config
		self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

		self.model = model
		self.criterion = criterion
		self.metric_ftns = metric_ftns
		self.optimizer = optimizer

		cfg_trainer = config['trainer']
		self.epochs = cfg_trainer['epochs']
		self.save_period = cfg_trainer['save_period']
		self.monitor = cfg_trainer.get('monitor', 'off')
		self.save = cfg_trainer['save']

		# configuration to monitor model performance and save best
		if self.monitor == 'off':
			self.mnt_mode = 'off'
			self.mnt_best = 0
		else:
			self.mnt_mode, self.mnt_metric = self.monitor.split()
			assert self.mnt_mode in ['min', 'max']

			self.mnt_best = inf if self.mnt_mode == 'min' else -inf
			self.early_stop = cfg_trainer.get('early_stop', inf)
			if self.early_stop <= 0:
				self.early_stop = inf

		self.start_epoch = 1

		self.checkpoint_dir = config.save_dir
		self.tensorboard = cfg_trainer['tensorboard']

		# setup visualization writer instance                
		self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])

		if config.resume is not None:
			self._resume_checkpoint(config.resume)
    def __init__(self, t_c, m_c):

        if "t_c" not in vars(self):
            self.t_c = edict()

            self.t_c.want_log = t_c.want_log
            self.t_c.use_early_stopping = t_c.use_early_stopping

            self.t_c.img_size = t_c.img_size

            self.t_c.save = t_c.save
            self.t_c.save_every = t_c.save_every
            self.t_c.save_ext = t_c.save_ext

            self.t_c.load = t_c.load
            if self.t_c.load:
                self.t_c.load_config = t_c.load_config
                self.t_c.load_model = t_c.load_model
                self.t_c.load_netD = t_c.load_netD
                self.t_c.load_netG = t_c.load_netG
                self.t_c.load_optimD = t_c.load_optimD
                self.t_c.load_optimG = t_c.load_optimG

            self.t_c.test = t_c.test
            self.t_c.test_every = t_c.test_every
            self.t_c.sample_size = t_c.sample_size

            self.t_c.batch_size = t_c.batch_size
            self.t_c.shuffle = t_c.shuffle
            self.t_c.num_workers = t_c.num_workers
            self.t_c.epochs = t_c.epochs

            self.t_c.summary_dir = t_c.summary_dir
            self.t_c.checkpoint_dir = t_c.checkpoint_dir
            self.t_c.log_dir = t_c.log_dir
            self.t_c.out_dir = t_c.out_dir

            self.t_c.data_roots = t_c.data_roots

        if "m_c" not in vars(self):
            self.m_c = edict(m_c)

        self.summary_writer = TensorboardWriter(self.t_c.summary_dir)

        # self._stop_training = False
        if self.t_c.use_early_stopping:
            self.early_stopper_D = EarlyStopping2(patience=20,
                                                  low_threshold=0.009,
                                                  up_threshold=0.99,
                                                  verbose=False)
            self.early_stopper_G = EarlyStopping2(patience=20,
                                                  low_threshold=0.009,
                                                  up_threshold=0.99,
                                                  verbose=False)

        self.init_model()

        if self.t_c.load:
            self.model.load(path=self.t_c.load_model,
                            load_config=self.t_c.load_config,
                            load_netD=self.t_c.load_netD,
                            load_netG=self.t_c.load_netG,
                            load_optimD=self.t_c.load_optimD,
                            load_optimG=self.t_c.load_optimG)

            print("NEW MODEL LOADED CONFIG")
            pprint(self.model.config)

        self.fixed_noise = self.model.generate_fixed_noise(sample_size=32)
class BaseGANTrainer():
    """
    .
    """
    def __init__(self, t_c, m_c):

        if "t_c" not in vars(self):
            self.t_c = edict()

            self.t_c.want_log = t_c.want_log
            self.t_c.use_early_stopping = t_c.use_early_stopping

            self.t_c.img_size = t_c.img_size

            self.t_c.save = t_c.save
            self.t_c.save_every = t_c.save_every
            self.t_c.save_ext = t_c.save_ext

            self.t_c.load = t_c.load
            if self.t_c.load:
                self.t_c.load_config = t_c.load_config
                self.t_c.load_model = t_c.load_model
                self.t_c.load_netD = t_c.load_netD
                self.t_c.load_netG = t_c.load_netG
                self.t_c.load_optimD = t_c.load_optimD
                self.t_c.load_optimG = t_c.load_optimG

            self.t_c.test = t_c.test
            self.t_c.test_every = t_c.test_every
            self.t_c.sample_size = t_c.sample_size

            self.t_c.batch_size = t_c.batch_size
            self.t_c.shuffle = t_c.shuffle
            self.t_c.num_workers = t_c.num_workers
            self.t_c.epochs = t_c.epochs

            self.t_c.summary_dir = t_c.summary_dir
            self.t_c.checkpoint_dir = t_c.checkpoint_dir
            self.t_c.log_dir = t_c.log_dir
            self.t_c.out_dir = t_c.out_dir

            self.t_c.data_roots = t_c.data_roots

        if "m_c" not in vars(self):
            self.m_c = edict(m_c)

        self.summary_writer = TensorboardWriter(self.t_c.summary_dir)

        # self._stop_training = False
        if self.t_c.use_early_stopping:
            self.early_stopper_D = EarlyStopping2(patience=20,
                                                  low_threshold=0.009,
                                                  up_threshold=0.99,
                                                  verbose=False)
            self.early_stopper_G = EarlyStopping2(patience=20,
                                                  low_threshold=0.009,
                                                  up_threshold=0.99,
                                                  verbose=False)

        self.init_model()

        if self.t_c.load:
            self.model.load(path=self.t_c.load_model,
                            load_config=self.t_c.load_config,
                            load_netD=self.t_c.load_netD,
                            load_netG=self.t_c.load_netG,
                            load_optimD=self.t_c.load_optimD,
                            load_optimG=self.t_c.load_optimG)

            print("NEW MODEL LOADED CONFIG")
            pprint(self.model.config)

        self.fixed_noise = self.model.generate_fixed_noise(sample_size=32)

    def init_model(self):
        raise NotImplementedError

    def run(self):
        """
        The main operator
        """
        try:
            self.train()
        except KeyboardInterrupt:
            print("")
            print(70 * "-")
            print("You have entered CTRL+C... Wait to finalize")
            # Prompt user if he wants to save the model params
            answer = yes_or_no("What to save model parameters before quiting?")
            if answer:
                self.model.save(self.t_c.checkpoint_dir, self.t_c.save_ext)

            exit(-1)

    def train(self):
        self.start = self.model.epochs_trained
        self.end = self.t_c.epochs + self.start

        for epoch in range(self.start, self.end):
            dataloader = self._get_dataloader()
            self._train_one_epoch(dataloader, epoch)

    def _train_one_epoch(self, dataloader, epoch):

        all_batches = len(dataloader)

        self.model.reset_meters()

        # For each batch in the dataloader
        for batch_num, batch_data in enumerate(dataloader):

            errD, errG, D_x, D_G_z1, D_G_z2 = self.model._train_step(
                batch_data)
            # Log batch stats into terminal
            self._log_train_step_stats(epoch, self.end, batch_num, all_batches,
                                       errD, errG, D_x, D_G_z1, D_G_z2)

            if self.t_c.use_early_stopping:
                self._stop_training = self.early_stopper_D.feed(
                    D_x) or self.early_stopper_G.feed(D_G_z2)
                if self._stop_training:
                    exit(-1)

        # Save model
        if self.t_c.save and epoch % self.t_c.save_every == 0:
            self.model.save(self.t_c.checkpoint_dir, self.t_c.save_ext)

        # Test model
        if self.t_c.test and epoch % self.t_c.test_every == 0:
            fake_samples = self.model.generate_images(
                sample_size=self.t_c.sample_size)
            fixed_samples = self.model.generater_fixed_images(self.fixed_noise)
            self.summary_writer.image_summary(f"Fake", fake_samples, epoch)
            self.summary_writer.image_summary(f"FixedNoise", fixed_samples,
                                              epoch)

        if self.t_c.want_log:
            d_mean, d_std = self.model.meterD.value()
            g_mean, g_std = self.model.meterG.value()
            self.summary_writer.plot_losses("LossesMeans", "D", "G", d_mean,
                                            g_mean, epoch)
            self.summary_writer.plot_losses("LossesStds", "D", "G", d_std,
                                            g_std, epoch)

        self.model.epochs_trained += 1

    def _get_dataloader(self):

        dataset = self._get_dataset()
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.t_c.batch_size,
            shuffle=self.t_c.shuffle,
            num_workers=self.t_c.num_workers)

        return dataloader

    def _get_dataset(self):
        dataset = ArtDataset(self.t_c.data_roots,
                             transforms_=[
                                 transforms.Resize(self.t_c.img_size),
                                 transforms.CenterCrop(self.t_c.img_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5))
                             ])

        return dataset

    def _log_train_step_stats(self, epoch_num, all_epochs, batch_num,
                              all_batches, errD, errG, D_x, D_G_z1, D_G_z2):

        # Output training stats
        print(
            '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
            % (epoch_num, all_epochs, batch_num, all_batches, errD, errG, D_x,
               D_G_z1, D_G_z2))
Esempio n. 7
0
class Trainer():
    def __init__(self,
                 model,
                 criterion,
                 metrics_name,
                 optimizer,
                 train_loader,
                 logger,
                 log_dir,
                 nb_epochs,
                 save_dir,
                 device="cuda:0",
                 log_step=10,
                 start_epoch=0,
                 enable_tensorboard=True,
                 valid_loader=None,
                 lr_scheduler=None,
                 monitor="min val_loss",
                 early_stop=10,
                 save_epoch_period=1,
                 resume=""):
        self.model = model
        self.criterion = criterion
        self.metrics_name = metrics_name
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        self.len_epoch = len(self.train_loader)
        self.do_validation = (self.valid_loader is not None)
        self.lr_scheduler = lr_scheduler
        self.log_step = log_step
        self.epochs = nb_epochs
        self.start_epoch = start_epoch + 1

        self.logger = logger
        self.device = device
        self.save_period = save_epoch_period

        self.writer = TensorboardWriter(log_dir, self.logger,
                                        enable_tensorboard)
        self.train_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.valid_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.checkpoint_dir = save_dir
        if monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = monitor.split()
            assert self.mnt_mode in ['min', 'max']
            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = early_stop
        if resume != "":
            self._resume_checkpoint(resume_path=resume)
        self.model.to(self.device)

    def train(self):
        not_improved_count = 0

        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)
            log = {'epoch': epoch}
            log.update(result)
            self.logger.info('    {:15s}: {}'.format(str("mnt best"),
                                                     self.mnt_best))
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False
                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1
                if (not_improved_count > self.early_stop) and (self.early_stop
                                                               > 0):
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, best)

    def _train_epoch(self, epoch):
        self.model.train()
        self.train_metrics.reset()
        start_time = time.time()

        for batch_idx, sample in enumerate(self.train_loader):
            data = sample['image']
            target = sample['mask']
            data, target = data.to(self.device), target.to(self.device)
            current_lr = self.lr_scheduler(self.optimizer, batch_idx, epoch)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met_name in self.metrics_name:
                self.train_metrics.update(
                    met_name,
                    getattr(metrics, met_name)(output, target))
            if batch_idx % self.log_step == 0:
                time_to_run = time.time() - start_time
                start_time = time.time()
                speed = self.log_step / time_to_run
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f} LR: {:.6f}  Speed: {:.4f}iters/s' \
                                  .format(epoch, self._progress(batch_idx), loss.item(), current_lr, speed))
                for met_name in self.metrics_name:
                    self.writer.add_scalar(met_name,
                                           self.train_metrics.avg(met_name))
                self.writer.add_scalar('loss', self.train_metrics.avg('loss'))
                self.writer.add_scalar("lr", current_lr)
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
            assert batch_idx <= self.len_epoch
        log = self.train_metrics.result()
        if self.do_validation:
            print("Start validation")
            val_log, iou_classes = self._valid_epoch(epoch)

            log.update(**{'val_' + k: v for k, v in val_log.items()})
            for key, value in iou_classes.items():
                log.update({key: value})
        return log

    def _valid_epoch(self, epoch):
        self.model.eval()
        self.valid_metrics.reset()
        iou_tracker = metrics.IoU(2)
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.valid_loader):
                data = sample['image']
                target = sample['mask']
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                target = target.cpu().numpy()
                output = output[:, 0]
                output = output.data.cpu().numpy()
                pred = np.zeros_like(output)
                pred[output > 0.5] = 1
                pred = pred.astype(np.int64)
                for i in range(len(target)):
                    iou_tracker.add_batch(target[i], pred[i])
        iou_classes = iou_tracker.get_iou()
        for key, value in iou_classes.items():
            self.writer.add_scalar(key, value)
        self.writer.add_scalar('val_loss', self.valid_metrics.avg('loss'))

        for met_name in self.metrics_name:
            self.writer.add_scalar(met_name, self.valid_metrics.avg(met_name))

        # for name, p in self.model.named_parameters():
        #     print(name, p)
        #     self.writer.add_histogram(name, p.cpu().data.numpy(), bins='auto')
        #
        return self.valid_metrics.result(), iou_classes

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        current = batch_idx
        total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            # 'config': self.config
        }
        filename = str(self.checkpoint_dir /
                       'checkpoint-epoch{:06d}.pth'.format(epoch))
        torch.save(state, filename)
        self.delete_checkpoint()
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def delete_checkpoint(self):
        checkpoints_file = list(
            self.checkpoint_dir.glob("checkpoint-epoch*.pth"))
        checkpoints_file.sort()
        for checkpoint_file in checkpoints_file[:-5]:
            os.remove(str(checkpoint_file.absolute()))

    def _resume_checkpoint(self, resume_path):
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
Esempio n. 8
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        """

        :param model: 模型
        :param criterion: 损失标准
        :param metric_ftns: 度量工具函数(评价指标)
        :param optimizer: 优化器
        :param config: 配置
        """

        # 配置
        self.config = config
        # logger
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # 准备计算代理,返回self.device和gpu列表
        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        # 模型丢进计算代理
        self.model = model.to(self.device)

        # # 1、单卡,把下面注释掉;2、多卡并行,不注释下面,
        # # 如果gpu数大于1
        # if len(device_ids) > 1:
        #     # 实现了并行计算
        #     # DataParallel(),Implements data parallelism at the module level.
        #     self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        # 损失标准
        self.criterion = criterion
        # 度量工具函数
        """
        这个
        "metrics": [
            "binary_accuracy",
            "binary_f1",
            "binary_auc",
        ],
        """
        self.metric_ftns = metric_ftns
        # 优化器
        self.optimizer = optimizer

        # 训练器配置
        cfg_trainer = config['trainer']
        # 当前轮开始的epoch
        self.start_epoch = 1
        # 这轮一共要训练的epoch
        self.epochs = cfg_trainer['epochs']
        # 保存周期
        self.save_period = cfg_trainer['save_period']

        # 监视曲线
        """
        "monitor": "min val_loss"
        json.get('monitor', 'off'),如果没'monitor'键,会返回'off'        
        """
        self.monitor = cfg_trainer.get('monitor', 'off')

        # 配置监视曲线,来保存最好模型
        # configuration to monitor model performance and save best

        # 若监视曲线,是关闭的
        if self.monitor == 'off':
            # 倾向模式,关闭
            self.mnt_mode = 'off'
            # 当前最好倾向,0
            self.mnt_best = 0

        # 若监视曲线,是开启的
        else:
            # 倾向模式"min",
            # 度量曲线"val_loss",
            # "monitor": "min val_loss"
            self.mnt_mode, self.mnt_metric = self.monitor.split()

            # 确保倾向模式在['min', 'max']
            assert self.mnt_mode in ['min', 'max']

            # 当前最好倾向,初始化
            """
            倾向min,赋inf无穷大,
            反之则反

            站的越高,看的越远
            反之则反:站的越低,看的越近(句子对反)
            反之亦然:看的越远,高的越高(前后句反)            
            """
            self.mnt_best = inf if self.mnt_mode == 'min' else -inf

            # 提前停止
            """
            是几个模型没有提高之后的停止计数值
            "early_stop": 10,
            如果'early_stop'没有值,那么inf
            """
            self.early_stop = cfg_trainer.get('early_stop', inf)

        # 可视化Writer实例
        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        # 检查点模型保存路径
        self.checkpoint_dir = config.save_dir
        # print(self.checkpoint_dir)

        # 如果设置了重启路径
        if config.resume is not None:
            # 起用重启
            self._resume_checkpoint(config.resume)
    def __init__(self, model, criterion, train_metric_ftns, eval_metric_ftns,
                 optimizer, config, device, data_loader, valid_data_loader,
                 lr_scheduler):
        """
        Initiates the Base trainer.
        :param model:               The model to train.
        :param criterion:           The loss function.
        :param train_metric_ftns:   The metrics on which the model will be evaluated during evaluation or train time.
        :param eval_metric_ftns:    The metrics on which the model will be evaluated during evaluation or train time.
        :param optimizer:           The optimizer to use for optimizing the parameters of the model.
        :param config:              Configuration file.
        :param device:              The device to use for computations.
        :param data_loader:         Dataloader for the train dataset.
        :param valid_data_loader:   Dataloader for the validation dataset.
        :param lr_scheduler:        Scheduler for the learning rate.
        """
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        self.model = model
        self.criterion = criterion
        self.train_metric_ftns = train_metric_ftns
        self.eval_metric_ftns = eval_metric_ftns
        self.optimizer = optimizer
        self.device = device
        self.data_loader = data_loader
        self.valid_data_loader = valid_data_loader
        self.lr_scheduler = lr_scheduler

        # Get training configurations
        cfg_trainer = config['trainer']

        # Metrics to display for the best model.
        self.best_model_metrics_log = cfg_trainer[
            'best_model_metrics_log'].split()
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer[
            'save_period']  # Once in how many epochs to save the models parameters.

        # Metric which will be used to choose the best model
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best model.
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in [
                'min', 'max'
            ], "Invalid monitor mode, should be min or max"

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            if self.early_stop <= 0:
                self.early_stop = inf

        # Dictionary to keep the metrics results of the best model.
        self.model_best_metrics = {}

        # The epoch to start working from.
        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir
        print(self.checkpoint_dir)

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        # If resume path is given, resume training from checkpoint.
        if config.resume is not None:
            self._resume_checkpoint(config.resume)