Пример #1
0
class BaseTrainer:
    def __init__(self,
                 model,
                 loss,
                 resume,
                 config,
                 train_loader,
                 val_loader=None,
                 train_logger=None):
        self.model = model
        self.loss = loss
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_logger = train_logger
        self.logger = logging.getLogger(self.__class__.__name__)
        self.do_validation = self.config['val']['is_val']
        self.start_epoch = 1
        self.improved = False
        self.MEAN = config['dataset']['mean']
        self.STD = config['dataset']['std']

        # SETTING THE DEVICE
        self.device, available_gpus = self._get_available_devices(
            config['train']['n_gpu'])
        if len(available_gpus) > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=available_gpus)
            self.loss = torch.nn.DataParallel(self.loss,
                                              device_ids=available_gpus)
        self.model.to(self.device)
        self.loss.to(self.device)

        # CONFIG
        self.epochs = self.config['train']['epochs']
        self.save_period = self.config['train']['save_period']

        writer_base_dir = config['train']['writer_dir']
        check_dir(writer_base_dir)
        writer_dir = os.path.join(
            writer_base_dir, config['name'] + "_" + config['dataset']['name'])
        check_dir(writer_dir)
        self.writer = tensorboard.SummaryWriter(str(writer_dir))

        # OPTIMIZER
        if self.config['optimizer']['differential_lr']:
            if isinstance(self.model, torch.nn.DataParallel):
                trainable_params = [{
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.module.get_decoder_params())
                }, {
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.module.get_backbone_params()),
                    'lr':
                    config['optimizer']['args']['lr'] / 10
                }]
            else:
                trainable_params = [{
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.get_decoder_params())
                }, {
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.get_backbone_params()),
                    'lr':
                    config['optimizer']['args']['lr'] / 10
                }]
        else:
            trainable_params = filter(lambda p: p.requires_grad,
                                      self.model.parameters())

        self.optimizer = get_instance(torch.optim, 'optimizer', config,
                                      trainable_params)
        self.lr_scheduler = getattr(
            lr_scheduler, config['lr_scheduler']['name'])(self.optimizer,
                                                          self.epochs,
                                                          len(train_loader))

        # monitor
        self.monitor = self.config['train']['monitor']['is_monitor']
        if not self.monitor:
            self.mnt_mode = None
            self.mnt_best = 0
        else:
            self.mnt_mode = self.config['train']['monitor']['type']
            self.mnt_metric = self.config['train']['monitor']['metric']
            self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf
            self.early_stoping = self.config['train']['monitor']['early_stop']
        self.not_improved_count = 0

        # CHECKPOINTS
        start_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
        checkpoint_base_dir = config['train']['save_dir']
        check_dir(checkpoint_base_dir)
        self.checkpoint_dir = os.path.join(
            checkpoint_base_dir, config['name'] + "_" +
            config['dataset']['name'] + "_" + start_time)
        check_dir(self.checkpoint_dir)

    def _get_available_devices(self, n_gpu):
        sys_gpu = torch.cuda.device_count()
        if sys_gpu == 0:
            self.logger.warning('No GPUs detected, using the CPU')
            n_gpu = 0
        elif n_gpu > sys_gpu:
            self.logger.warning(
                'Nbr of GPU requested is {} but only {} are available'.format(
                    n_gpu, sys_gpu))
            n_gpu = sys_gpu

        device = torch.device('cuda:0' if n_gpu > 0 else 'cpu')
        self.logger.info('Detected GPUs: {} Requested: {}'.format(
            sys_gpu, n_gpu))
        available_gpus = list(range(n_gpu))
        return device, available_gpus

    def _save_checkpoint(self, epoch, save_best=False):
        state = {
            'arch': type(self.model).__name__,
            'epoch': epoch,
            'logger': self.train_logger,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            'config': self.config
        }
        filename = str(
            self.checkpoint_dir) + "/checkpoint-epoch" + str(epoch) + ".pth"
        self.logger.info('\nSaving a checkpoint: {} ***'.format(filename))
        torch.save(state, filename)

        if save_best:
            filename = os.path.join(self.checkpoint_dir, "best_epoch.pth")
            self.logger.info("Saveing a best epoch")
            torch.save(state, str(filename))

    def _resume_checkpoint(self, resume_path):
        self.logger.info('Loading checkpoint : {}'.format(resume_path))
        checkpoint = torch.load(resume_path)

        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']
        self.not_improved_count = 0

        if checkpoint['config']['model'] != self.config['model']:
            self.logger.warning({
                'Warning! Current model is not the same as the one in the checkpoint'
            })
        self.model.load_state_dict(checkpoint['state_dict'])

        if checkpoint['config']['optimizer']['name'] != self.config[
                'optimizer']['name']:
            self.logger.warning({
                'Warning! Current model is not the same as the one in the checkpoint'
            })
        self.model.load_state_dict(checkpoint['optimizer'])

        self.train_logger = checkpoint['logger']
        self.logger.info('Checkpoint <{}> (epoch {}) was loaded'.format(
            resume_path, self.start_epoch))

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            results = self._train_epoch(epoch)

            if self.do_validation and epoch % self.config['val'][
                    'val_per_epochs'] == 0:
                results = self._valid_epoch(epoch)
                self.logger.info('\n*** val {} epoch ***'.format(epoch))

                for k, v in results.items():
                    self.logger.info(' {}: {}'.format(k, v))

            if self.train_logger is not None:
                log = {'epoch': epoch, **results}
                self.train_logger.add_entry(log)
            else:
                self.train_logger = Logger()
                log = {'epoch': epoch, **results}
                self.train_logger.add_entry(log)

            if self.mnt_mode is not None and epoch % self.config['val'][
                    'val_per_epochs'] == 0:
                try:
                    if self.mnt_mode == 'min':
                        self.improved = (log[self.mnt_metric] < self.mnt_best)
                    else:
                        self.improved = (log[self.mnt_metric] > self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        'The metrics being tracked ({}) has not been calculated. Training stops.'
                        .format(self.mnt_metric))
                    break

                if self.improved:
                    self.mnt_best = log[self.mnt_metric]
                    self.not_improved_count = 0
                else:
                    self.not_improved_count += 1

                if self.not_improved_count > self.early_stoping:
                    self.logger.info(
                        '\nPerformance didn\'t improve for {} epochs'.format(
                            self.early_stoping))
                    self.logger.warning('Training Stop')
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=self.improved)

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _valid_epoch(self, epoch):
        raise NotImplementedError

    def _eval_metrics(self, output, target):
        raise NotImplementedError
Пример #2
0
class BaseTrainer(object):
	def __init__(self, model, losses, metrics, optimizer, resume, config, data_loader,
				valid_data_loader=None, lr_scheduler=None, grad_clip=None):

		self.config = config
		self.data_loader = data_loader
		self.valid_data_loader = valid_data_loader
		self.do_validation = self.valid_data_loader is not None
		self.lr_scheduler = lr_scheduler
		self.lr_schedule_by_epoch = config['lr_scheduler']['by_epoch']
		self.grad_clip = grad_clip

		# Setup directory for checkpoint saving
		start_time = datetime.datetime.now().strftime('%m%d_%H%M%S')
		self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], config['name'], start_time)
		os.makedirs(self.checkpoint_dir, exist_ok=True)

		# Build logger
		logname = self.config['name']
		logfile = os.path.join(self.checkpoint_dir, "logging.log")
		self.logger = Logger(logname, logfile)

		# Setup GPU device if available, move model into configured device
		self.device, device_ids = self._prepare_device(config['n_gpu'])
		self.model = model.to(self.device)
		if len(device_ids) > 1:
			self.model = torch.nn.DataParallel(model, device_ids=device_ids)

		self.losses = losses
		self.metrics = metrics
		self.optimizer = optimizer

		self.epochs = config['trainer']['epochs']
		self.save_freq = config['trainer']['save_freq']
		self.verbosity = config['trainer']['verbosity']
		self.logger.info("Total epochs: {}".format(self.epochs))

		# configuration to monitor model performance and save best
		self.monitor = config['trainer']['monitor']
		self.monitor_mode = config['trainer']['monitor_mode']
		assert self.monitor_mode in ['min', 'max', 'off']
		self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf
		self.start_epoch = 1

		# setup visualization writer instance
		writer_train_dir = os.path.join(config['visualization']['log_dir'], config['name'], start_time, "train")
		writer_valid_dir = os.path.join(config['visualization']['log_dir'], config['name'], start_time, "valid")
		self.writer_train = WriterTensorboardX(writer_train_dir, self.logger, config['visualization']['tensorboardX'])
		self.writer_valid = WriterTensorboardX(writer_valid_dir, self.logger, config['visualization']['tensorboardX'])

		# Save configuration file into checkpoint directory
		config_save_path = os.path.join(self.checkpoint_dir, 'config.json')
		with open(config_save_path, 'w') as handle:
			json.dump(config, handle, indent=4, sort_keys=False)

		# Resume
		if resume:
			self._resume_checkpoint(resume)

	def train(self):
		for epoch in range(self.start_epoch, self.epochs + 1):
			print("----------------------------------------------------------------")
			self.logger.info("[EPOCH %d/%d]" % (epoch, self.epochs))
			start_time = time()
			result = self._train_epoch(epoch)
			finish_time = time()
			self.logger.info("Finish at {}, Runtime: {:.3f} [s]".format(datetime.datetime.now(), finish_time-start_time))

			# save logged informations into log dict
			log = {}
			for key, value in result.items():
				if key == 'train_metrics':
					log.update({'train_' + mtr.func.__name__ : value[i] for i, mtr in enumerate(self.metrics)})
				elif key == 'valid_metrics':
					log.update({'valid_' + mtr.func.__name__ : value[i] for i, mtr in enumerate(self.metrics)})
				else:
					log[key] = value

			# print logged informations to the screen
			if self.logger is not None:
				self.logger.add_entry(log)
				if self.verbosity >= 1:
					for key, value in sorted(list(log.items())):
						self.logger.info('{:15s}: {}'.format(str(key), value))

			# evaluate model performance according to configured metric, save best checkpoint as model_best
			best = False
			if self.monitor_mode != 'off':
				try:
					if  (self.monitor_mode == 'min' and log[self.monitor] < self.monitor_best) or\
						(self.monitor_mode == 'max' and log[self.monitor] > self.monitor_best):
						self.logger.info("Monitor improved from %f to %f" % (self.monitor_best, log[self.monitor]))
						self.monitor_best = log[self.monitor]
						best = True
				except KeyError:
					if epoch == 1:
						msg = "Warning: Can\'t recognize metric named '{}' ".format(self.monitor)\
							+ "for performance monitoring. model_best checkpoint won\'t be updated."
						self.logger.warning(msg)

			# Save checkpoint
			self._save_checkpoint(epoch, save_best=best)

	def _train_epoch(self, epoch):
		self.logger.info("Train on epoch...")
		self.model.train()
		self.writer_train.set_step(epoch)

		# Perform training
		total_loss = 0
		total_metrics = np.zeros(len(self.metrics))
		n_iter = len(self.data_loader)
		train_pbar = tqdm(enumerate(self.data_loader), total=n_iter)
		for batch_idx, data in train_pbar:
			# Send data to device
			for key, value in data.items():
				data[key] = value.to(self.device)

			# Forward and Backward
			output, losses, loss = self._forward(data)
			self._backward(loss)

			# Learning rate scheduler by iteration
			if self.lr_scheduler is not None and not self.lr_schedule_by_epoch:
				self.lr_scheduler.step()

			# Accumulate loss and metrics
			loss_iter = loss.item()
			total_loss += loss_iter
			metrics_iter = self._eval_metrics(output, data)
			total_metrics += metrics_iter

			# Visualize results
			if (batch_idx==n_iter-2) and (self.verbosity>=2):
				self._visualize_results(output, data)

			# tqdm progress bar
			if self.verbosity>=1:
				pbar_dict = self._get_progress_bar_dict(losses, loss_iter, metrics_iter)
				train_pbar.set_postfix(**pbar_dict)

		# Learning rate scheduler by epoch
		if self.lr_scheduler is not None and self.lr_schedule_by_epoch:
			self.lr_scheduler.step()

		# Record log
		total_loss /= len(self.data_loader)
		total_metrics /= len(self.data_loader)
		log = {'train_loss': total_loss, 'train_metrics': total_metrics.tolist()}

		# Write training result to TensorboardX
		self.writer_train.add_scalar('loss', total_loss)
		for i, metric in enumerate(self.metrics):
			self.writer_train.add_scalar('metrics/%s'%(metric.func.__name__), total_metrics[i])

		if self.verbosity>=2:
			for i in range(len(self.optimizer.param_groups)):
				self.writer_train.add_scalar('lr/group%d'%(i), self.optimizer.param_groups[i]['lr'])

		# Perform validating
		if self.do_validation:
			self.logger.info("Validate on epoch...")
			val_log = self._valid_epoch(epoch)
			log = {**log, **val_log}
		return log

	def _valid_epoch(self, epoch):
		self.model.eval()
		total_val_loss = 0
		total_val_metrics = np.zeros(len(self.metrics))
		n_iter = len(self.valid_data_loader)
		self.writer_valid.set_step(epoch)
		with torch.no_grad():
			# Validate
			for batch_idx, data in tqdm(enumerate(self.valid_data_loader), total=n_iter):
				# Send data to device
				for key, value in data.items():
					data[key] = value.to(self.device)

				# Forward
				output, _, loss = self._forward(data)

				# Accumulate loss and metrics
				total_val_loss += loss.item()
				total_val_metrics += self._eval_metrics(output, data)

				# Visualize results
				if (batch_idx==n_iter-2) and (self.verbosity>=2):
					self._visualize_results(output, data)

			# Record log
			total_val_loss /= len(self.valid_data_loader)
			total_val_metrics /= len(self.valid_data_loader)
			val_log = {
				'valid_loss': total_val_loss,
				'valid_metrics': total_val_metrics.tolist(),
			}

			# Write validating result to TensorboardX
			self.writer_valid.add_scalar('loss', total_val_loss)
			for i, metric in enumerate(self.metrics):
				self.writer_valid.add_scalar('metrics/%s'%(metric.func.__name__), total_val_metrics[i])

		return val_log

	def _forward(self, data):
		self.optimizer.zero_grad()
		output = self.model(**data)
		losses = self.losses(**output, **data)
		loss = self._sum_losses(losses)
		return output, losses, loss

	def _backward(self, loss):
		loss.backward()
		if self.grad_clip is not None:
			clip_grad.clip_grad_norm_(filter(lambda p: p.requires_grad, self.model.parameters()), **self.grad_clip)
		self.optimizer.step()

	def _eval_metrics(self, output, data):
		acc_metrics = np.zeros(len(self.metrics))
		for i, metric in enumerate(self.metrics):
			acc_metrics[i] += metric(**output, **data)
		return acc_metrics

	def _sum_losses(self, losses):
		loss = sum(loss_val for loss_val in list(losses.values()))
		return loss

	def _visualize_results(self, output, data):
		pass

	def _get_progress_bar_dict(self, losses, loss, metrics):
		pbar_dict = dict()
		pbar_dict['lr'] = self.optimizer.param_groups[0]['lr']
		for key, val in losses.items():
			pbar_dict['loss_%s'%(key)] = val.item()
		pbar_dict['loss'] = loss
		if self.verbosity>=3:
			for i, metric in enumerate(self.metrics):
				pbar_dict['%s'%(metric.func.__name__)] = metrics[i]
		return pbar_dict

	def _prepare_device(self, n_gpu_use):
		""" 
		setup GPU device if available, move model into configured device
		""" 
		n_gpu = torch.cuda.device_count()
		if n_gpu_use > 0 and n_gpu == 0:
			self.logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
			n_gpu_use = 0
		if n_gpu_use > n_gpu:
			msg = "Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu)
			self.logger.warning(msg)
			n_gpu_use = n_gpu
		device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
		list_ids = list(range(n_gpu_use))
		return device, list_ids

	def _save_checkpoint(self, epoch, save_best=False):
		"""
		Saving checkpoints

		:param epoch: current epoch number
		:param log: logging information of the epoch
		:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
		"""
		# Construct savedict
		arch = type(self.model).__name__
		state = {
			'arch': arch,
			'epoch': epoch,
			'state_dict': self.model.state_dict(),
			'optimizer': self.optimizer.state_dict(),
			'monitor_best': self.monitor_best,
			'config': self.config
		}

		# Save checkpoint for each epoch
		if self.save_freq is not None:	# Use None mode to avoid over disk space with large models
			if epoch % self.save_freq == 0:
				filename = os.path.join(self.checkpoint_dir, 'epoch{}.pth'.format(epoch))
				torch.save(state, filename)
				self.logger.info("Saving checkpoint at {}".format(filename))

		# Save the best checkpoint
		if save_best:
			best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
			torch.save(state, best_path)
			self.logger.info("Saving current best at {}".format(best_path))
		else:
			self.logger.info("Monitor is not improved from %f" % (self.monitor_best))

	def _resume_checkpoint(self, resume_path):
		"""
		Resume from saved checkpoints

		:param resume_path: Checkpoint path to be resumed
		"""
		self.logger.info("Loading checkpoint: {}".format(resume_path))
		checkpoint = torch.load(resume_path)
		self.start_epoch = checkpoint['epoch'] + 1
		self.monitor_best = checkpoint['monitor_best']

		# load architecture params from checkpoint.
		if checkpoint['config']['arch'] != self.config['arch']:
			self.logger.warning('Warning: Architecture configuration given in config file is different from that of checkpoint. ' + \
								'This may yield an exception while state_dict is being loaded.')
		self.model.load_state_dict(checkpoint['state_dict'], strict=True)

		# load optimizer state from checkpoint only when optimizer type is not changed. 
		if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
			self.logger.warning('Warning: Optimizer type given in config file is different from that of checkpoint. ' + \
								'Optimizer parameters not being resumed.')
		else:
			self.optimizer.load_state_dict(checkpoint['optimizer'])
	
		self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch-1))