示例#1
0
 def __init__(self, model, metrics, optimizer, config, train_dataset, valid_datasets, lr_scheduler=None):
     super().__init__(model, metrics, optimizer, config, train_dataset)
     self.config = config
     self.config['data_loaders']['valid']['args']['batch_size'] = self.data_loader.batch_size
     self.valid_data_loaders = {}
     for corpus, valid_dataset in valid_datasets.items():
         self.valid_data_loaders[corpus] = config.init_obj('data_loaders.valid', module_loader, valid_dataset)
     self.lr_scheduler = lr_scheduler
     self.log_step = math.ceil(len(self.data_loader.dataset) / np.sqrt(self.data_loader.batch_size) / 200)
     self.writer = TensorboardWriter(config.log_dir)
示例#2
0
class BaseTrainer:
    """
    base class for all trainers
    """

    def __init__(self, config, data_loader, losses, transformation_module, registration_module, metrics):
        self.config = config
        self.logger = config.logger

        self.device = 'cuda:0'

        self.data_loader = data_loader
        self.structures_dict = self.config.structures_dict
        self.save_dirs = self.data_loader.save_dirs

        # losses
        self.losses = dict()

        self.losses['data'] = {k: loss.to(self.device) for k, loss in losses['data'].items()}
        self.losses['reg'] = {k: loss.to(self.device) for k, loss in losses['reg'].items()}
        self.losses['entropy'] = losses['entropy'].to(self.device)

        # transformation and registration modules
        self.transformation_module = transformation_module.to(self.device)
        self.registration_module = registration_module.to(self.device)
        
        # differential operator for use with the transformation Jacobian
        self.diff_op = self.losses['reg']['loss'].diff_op

        # model logic
        cfg_trainer = config['trainer']

        self.VI = cfg_trainer['VI']
        self.start_iter_VI, self.no_iters_VI = 1, int(cfg_trainer['no_iters_VI'])
        self.no_samples_VI_test = int(cfg_trainer['no_samples_VI_test'])
        self.log_period_VI = cfg_trainer['log_period_VI']

        self.MCMC = cfg_trainer['MCMC']
        self.MCMC_init = cfg_trainer['MCMC_init']  # NOTE (DG): one of ['VI', 'identity', 'noise']
        self.no_chains = int(cfg_trainer['no_chains'])
        self.no_samples_MCMC = int(cfg_trainer['no_samples_MCMC'])
        self.no_iters_burn_in = int(cfg_trainer['no_iters_burn_in'])
        self.log_period_MCMC = cfg_trainer['log_period_MCMC']

        # metrics and prints
        self.writer = TensorboardWriter(config.log_dir, cfg_trainer['tensorboard'])
        self.writer.write_hparams(config.config_str)
        self.metrics = MetricTracker(*[m for m in metrics], writer=self.writer)

    @abstractmethod
    def _run_model(self):
        raise NotImplementedError

    def run(self):
        self._run_model()
示例#3
0
    def __init__(self, config, logger, generator, discriminator,
                 gif_generator):
        self.config = config
        self.logger = logger
        self.gif_generator = gif_generator

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(config['n_gpu'])

        self.generator = self.initialize_training(generator)
        self.discriminator = self.initialize_training(discriminator)

        self.generator['config'] = config['generator']
        self.discriminator['config'] = config['discriminator']

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.early_stop = cfg_trainer.get('early_stop', inf)
        self.start_epoch = 1
        self.n_critic = cfg_trainer["n_critic"]

        self.num_samples = cfg_trainer.get('num_samples', 0)
        self.z_size = config['generator']['arch']['args']['input_size']
        if cfg_trainer.get('fixed_image', False):
            self.default_z = Variable(torch.randn(self.num_samples,
                                                  z_size)).to(self.device)
        else:
            self.default_z = None

        self.generator['monitor'] = self.initialize_monitor(
            cfg_trainer['monitor'].get('generator', 'off'))
        self.discriminator['monitor'] = self.initialize_monitor(
            cfg_trainer['monitor'].get('discriminator', 'off'))

        self.generator['checkpoint_dir'] = config.save_dir / "generator"
        self.generator['checkpoint_dir'].mkdir(parents=True, exist_ok=True)
        self.discriminator[
            'checkpoint_dir'] = config.save_dir / "discriminator"
        self.discriminator['checkpoint_dir'].mkdir(parents=True, exist_ok=True)

        self.generator['writer'] = TensorboardWriter(
            config.log_dir / 'generator', self.logger,
            cfg_trainer['tensorboard'])
        self.discriminator['writer'] = TensorboardWriter(
            config.log_dir / 'discriminator', self.logger,
            cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint('generator', config.resume['generator'])
            self._resume_checkpoint('discriminator',
                                    config.resume['discriminator'])
示例#4
0
    def __init__(self, model, criterion: TrackingLoss, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            if self.early_stop <= 0:
                self.early_stop = inf

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance                
        self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#5
0
    def __init__(self, config, disc_model, disc_loss, disc_optimizer,
                 gen_model, gen_loss, gen_optimizer):

        self.config = config
        self.logger = config.get_logger('trainer', config['trainer']['level'])
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.epochs = config['trainer']['epochs']
        self.disc_gen_ratio = config['trainer']['disc_gen_ratio']

        self.save_period = config['trainer']['save_period']
        self.checkpoint_dir = config.save_dir
        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        config['trainer']['tensorboard'])

        BaseDiscriminator.__init__(self, disc_model, disc_loss, disc_optimizer)
        BaseGenerator.__init__(self, gen_model, gen_loss, gen_optimizer)

        self.disc_model = self.disc_model.to(self.device)
        self.gen_model = self.gen_model.to(self.device)
        if len(device_ids) > 1:
            self.disc_model = torch.nn.DataParallel(self.disc_model,
                                                    device_ids=device_ids)
            self.disc_model = torch.nn.DataParallel(self.disc_model,
                                                    device_ids=device_ids)
    def __init__(self, models, criterion, metric_ftns, optimizers, config):
        self.config = config
        self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

        self.n_ensembles = len(models)

        log_dict = {}
        self.keys = []
        self.keys = [*['loss_' + str(i) for i in range(self.n_ensembles)], *[m.__name__ + '_' + str(i) for m in metric_ftns for i in range(self.n_ensembles)]]
        for key in self.keys:
            log_dict[key] = []
        self.log = log_dict

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])

        self.models = []
        for model in models:
            model = model.to(self.device)
            if len(device_ids) > 1:
                model = torch.nn.DataParallel(model, device_ids=device_ids)
            self.models.append(model)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizers = optimizers

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split() # mnt_mode min \ mnt_metric: val_loss
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = np.empty(self.n_ensembles)
            if self.mnt_mode == 'min':
                self.mnt_best.fill(np.inf)
            else:
                self.mnt_best.fill(-np.inf)
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1
        self.start_epochs = np.ones(self.n_ensembles)

        self.checkpoint_dir = config.save_dir
        self.log_dir = config.log_dir

        # setup visualization writer instance                
        self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
    def __init__(self, model, loss, metrics, optimizer, config):
        """
        params:
            model: neural network model
            loss: function to optimize
            metrics: list of metrics to validate during training
            optimizer:
            config: ConfigParser object
        """
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        # Loss to optimize
        self.loss = loss
        # Metric to validate
        self.metrics = metrics
        # Optimizer to use, e.g Adam
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        # Number of epochs to train
        self.epochs = cfg_trainer['epochs']
        # Save model every # epochs
        self.save_period = cfg_trainer['save_period']
        # Whether to monitor best model and apply early stopping
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            # Num epochs to track
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1
        # Last model in the model directory
        self.keep_last = config['trainer']['keep_last']
        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#8
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config,train_sampler=None):
        self.config = config
        if dist.is_initialized():
            logger_name="{}{}".format(__name__,dist.get_rank())
        else:
            logger_name=__name__
        self.logger = get_logger(name=logger_name, log_dir=config.log_dir, verbosity=config['trainer']['verbosity'])

        self.model = model
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')
        self.train_sampler=train_sampler

        # configuration to monitor model performance and save best
        if not is_master() or self.monitor == 'off':
            self.mnt_mode = 'off'
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            if self.early_stop <= 0:
                self.early_stop = inf

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance                
        if is_master():
            self.writer = TensorboardWriter(config, cfg_trainer['tensorboard'])
        else:
            self.writer = TensorboardWriter(config, False)

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.device_ids = device_ids
        self.model = model
        self.model = self.model.to(self.device)

        self.real_model = self.model
        if len(self.device_ids) > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=device_ids)

        self.criterion = criterion.to(self.device)
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.load_crt is not None:
            print("Loading from cRT pretrain: {}".format(config.load_crt))
            self._load_crt(config.load_crt)

        if config.resume is not None:
            state_dict_only = config._config.get("resume_state_dict_only",
                                                 False)
            self._resume_checkpoint(config.resume,
                                    state_dict_only=state_dict_only)
示例#10
0
    def __init__(
        self,
        model: torch.nn.Module,
        criterion: torch.nn.modules.loss._Loss,
        metric_ftns: List[Callable[..., float]],
        optimizer: torch.optim.Optimizer,
        config: ConfigParser,
        lr_scheduler: Union[torch.optim.lr_scheduler._LRScheduler,
                            torch.optim.lr_scheduler.ReduceLROnPlateau,
                            None, ] = None,
    ):
        self.config = config
        self.logger = config.get_logger("trainer",
                                        config["trainer"]["verbosity"])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config["n_gpu"])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        cfg_trainer = config["trainer"]
        self.epochs = cfg_trainer["epochs"]
        self.save_period = cfg_trainer["save_period"]
        self.monitor = cfg_trainer.get("monitor", "off")
        self.save_last = cfg_trainer.get("save_last", False)

        # configuration to monitor model performance and save best
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop = cfg_trainer.get("early_stop", inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.model_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer["tensorboard"])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
    def __init__(self, model, loss, metrics, optimizer, config, mini_train,
                 num_keep_ckpts, skip_tboard):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.loss = loss
        self.metrics = metrics
        self.optimizer = optimizer
        self.num_keep_ckpts = num_keep_ckpts
        self.skip_tboard = skip_tboard or mini_train

        # This property can be overriden in the subclass
        self.skip_first_n_saves = 0

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')
        self.save_only_best = cfg_trainer.get("save_only_best", True)

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        if not self.skip_tboard:
            summary_dir = config.log_dir / f"seed-{config['seed']}"
            self.writer = TensorboardWriter(summary_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        self.include_optim_in_ckpts = config["trainer"].get(
            "include_optim_in_ckpts", 1)
        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#12
0
    def __init__(self, generator, discriminator, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(config['n_gpu'])

        self.generator = self.initialize_training(generator)
        self.discriminator = self.initialize_training(discriminator)

        self.generator['config'] = config['generator']
        self.discriminator['config'] = config['discriminator']

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.early_stop = cfg_trainer.get('early_stop', inf)
        self.start_epoch = 1

        self.generator['monitor'] = self.initialize_monitor(
            cfg_trainer['monitor'].get('generator', 'off'))
        self.discriminator['monitor'] = self.initialize_monitor(
            cfg_trainer['monitor'].get('discriminator', 'off'))

        self.generator['checkpoint_dir'] = os.path.join(
            config.save_dir, "generator")
        self.discriminator['checkpoint_dir'] = os.path.join(
            config.save_dir, "discriminator")

        self.generator['writer'] = TensorboardWriter(
            config.log_dir / 'generator', self.generator['logger'],
            cfg_trainer['tensorboard'])
        self.discriminator['writer'] = TensorboardWriter(
            config.log_dir / 'discriminator', self.discriminator['logger'],
            cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#13
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        #
        #        # If criterion is a stateful object (for now considered to me an nn.Module) and not
        #        #  a functional then move it to the computing device
        #        if isinstance(self.criterion, torch.nn.Module):
        #            self.criterion = self.criterion.to(self.device)

        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
        elif config.load is not None:
            self._load_checkpoint(config.load)
示例#14
0
    def __init__(self, model, criterion, metrics, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move models into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        # For training
        self.criterion = criterion.to(self.device)
        self.metrics = metrics
        self.optimizer = optimizer
        self.metric_names = [x.__name__ for x in self.metrics]

        # Hyper parameter for trainer
        cfg_trainer = config['trainer']
        self.accumulation_steps = cfg_trainer['accumulation_steps']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')
        self.do_validation = cfg_trainer['do_validation']
        self.do_validation_interval = cfg_trainer['do_validation_interval']
        self.log_step = cfg_trainer['log_step']
        if cfg_trainer['save_for_track']:
            self.save_for_track = cfg_trainer['save_for_track']
        else:
            self.save_for_track = None

        self.start_epoch = 1
        self.checkpoint_dir = config.save_dir

        # configuration to monitor models performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])
示例#15
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        log_dict = {}
        self.keys = ['loss', *[m.__name__ for m in metric_ftns]]
        for key in self.keys:
            log_dict[key] = []
        self.log = log_dict

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir
        self.log_dir = config.log_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#16
0
    def __init__(self, model, criterion, optimizer, metric_fcns, config):
        self.config = config
        self.logger = config.get_logger(
            name="trainer", verbosity=config["trainer"]["verbosity"])
        # setup GPU device if available
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        # move model to configured device
        self.model = model.to(self.device)
        # if multiple devices, parallelize module
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(module=model,
                                               device_ids=device_ids)
        self.criterion = criterion
        self.optimizer = optimizer
        self.metric_fcns = metric_fcns

        # trainer configuration
        cfg_trainer = config["trainer"]
        self.epochs = cfg_trainer["epochs"]  # set number of epochs
        # number of epochs to save  best model
        self.save_period = cfg_trainer["save_period"]
        self.monitor = cfg_trainer.get("monitor", "off")

        # configuration to monitor model performance and save best model
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop = cfg_trainer.get("early_stop", inf)

        self.start_epoch = 1  # start epoch from 1

        self.checkpoint_dir = config.save_dir  # TODO: comment

        # setup visualization writer instance
        self.writer = TensorboardWriter(
            log_dir=config.log_dir,
            logger=self.logger,
            enabled=cfg_trainer["tensorboard"],
        )

        # TODO: comment
        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#17
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(
            config['n_gpu'])  # device_ids是用在DataParallel的
        self.model = model.to(self.device)  # 模型to device
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns  # 评价方法
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']  # 每多少epoch存一次
        self.monitor = cfg_trainer.get(
            'monitor', 'off'
        )  # 这是dict的方法,返回cfg_trainer中的monitor,如果不存在,返回默认值'off',可能跟后面的配置有关

        # configuration to monitor model performance and save best
        # monitor用来记录各种指标变化情况
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in [
                'min', 'max'
            ]  # 断言是:应该...,断言失败抛出assertionerror,断言成功继续运行

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)  # 早停

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#18
0
    def __init__(self, model, criterion, metric_ftns, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        model.initialize(config, self.device)

        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        trainable_params = filter(lambda p: p.requires_grad,
                                  model.parameters())
        self.optimizer = config.init_obj('optimizer', torch.optim,
                                         trainable_params)

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
    def __init__(self, model, loss, metrics, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        # TODO: Check if this helps
        torch.backends.cudnn.benchmark = True
        self.loss = loss
        self.metrics = metrics
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        self.writer = None
        if cfg_trainer[
                'tensorboard']:  # TODO should just move tensorboard writer directly in here
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#20
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        self.logger.info("Model params: {}".format(self.model))
        self.logger.info("Current Opitmizer state: {}".format(self.optimizer))
        self.logger.info("Criterion: {}".format(self.criterion.__name__))
        for metric in self.metric_ftns:
            self.logger.info("Tracking metric {}".format(metric.__name__))
示例#21
0
    def __init__(self, torch_objs: dict, save_dir, **kwargs):
        # data_loaders
        self.train_data_loaders = torch_objs["data_loaders"]["train"]
        self.valid_data_loaders = torch_objs["data_loaders"]["valid"]
        # models
        self.models = torch_objs["models"]
        # losses
        self.losses = torch_objs["losses"]
        # metrics
        self.metrics_iter = torch_objs["metrics"]["iter"]
        self.metrics_epoch = torch_objs["metrics"]["epoch"]
        self.metrics_threshold = torch_objs["metrics"]["threshold"]
        # optimizers
        self.optimizers = torch_objs["optimizers"]
        # lr_schedulers
        self.lr_schedulers = torch_objs["lr_schedulers"]
        # amp
        self.amp = torch_objs["amp"]

        self.model_dir = save_dir["model"]
        # set json kwargs to self.{kwargs}
        for key, value in kwargs.items():
            setattr(self, key, value)

        self.logger = get_logger("trainer", verbosity=self.verbosity)
        if self.early_stop <= 0 or self.early_stop is None:
            self.early_stop = np.inf
        self.start_epoch = 1

        # configuration to monitor model performance and save best
        self.num_best = 0
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]
            self.mnt_best = np.inf if self.mnt_mode == "min" else -np.inf

        # setup visualization writer instance
        self.writer = TensorboardWriter(save_dir["log"], self.logger,
                                        self.tensorboard)

        if self.metrics_threshold is not None:
            self.threshold = 0.5
示例#22
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        # 创建一个日志
        self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

        # 获取到GPU或者CPU
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)
        # 设置优化器,损失函数,评估函数
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        # epoch
        self.epochs = cfg_trainer['epochs']
        # 多少次保存模型一次
        self.save_period = cfg_trainer['save_period']
        # 拿不到monitor,就为off
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            # mnt_best为最好的损失
            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance                
        self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#23
0
    def __init__(self, model, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device = self._prepare_device(config['n_gpu'])
        self.n_gpu_use = len(self.device)
        if self.device[0] != -1:
            self.model = model.to(self.device[0])
        else:
            self.model = model
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')
        self.clip = cfg_trainer.get('clip')

        # configuration to monitor model performance and save best
        self.best_log = None
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        self.best_state = None
示例#24
0
    def __init__(self, model, loss, metrics, optimizer, config, use_apex=True):
        self.config = config

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)
        self.loss = loss
        self.metrics = metrics
        self.optimizer = optimizer
        if use_apex:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O1")

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#25
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger("trainer",
                                        config["trainer"]["verbosity"])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config["n_gpu"])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config["trainer"]
        self.epochs = cfg_trainer["epochs"]
        self.save_period = cfg_trainer["save_period"]
        self.monitor = cfg_trainer.get("monitor", "off")

        # configuration to monitor model performance and save best
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop = cfg_trainer.get("early_stop", inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer["tensorboard"])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#26
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config  # ConfigParser类
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])  #

        self.model = model
        self.criterion = criterion  # loss函数
        self.metric_ftns = metric_ftns  # 评价标准
        self.optimizer = optimizer  # 优化函数

        cfg_trainer = config['trainer']  # 训练参数获取
        self.epochs = cfg_trainer['epochs']  # 获取训练总轮数
        self.save_period = cfg_trainer['save_period']  # 获取保存频率
        self.monitor = cfg_trainer.get('monitor', 'off')  # 监视器

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split(
            )  # 监视模型是否精进的指标
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)  # 早停法
            if self.early_stop <= 0:
                self.early_stop = inf

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(
            config.log_dir, self.logger,
            cfg_trainer['tensorboard'])  # 使用tensorboard记录结果

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#27
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger("trainer",
                                        config["trainer"]["verbosity"])

        self.model = model
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config["trainer"]
        self.epochs = cfg_trainer["epochs"]
        self.save_period = cfg_trainer["save_period"]
        self.monitor = cfg_trainer.get("monitor", "off")

        # configuration to monitor model performance and save best
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop = cfg_trainer.get("early_stop", inf)
            if self.early_stop <= 0:
                self.early_stop = inf

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer["tensorboard"])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
示例#28
0
    def __init__(self,
                 model,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 max_len_step=None):
        '''

        :param model:
        :param optimizer:
        :param config:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param max_len_step:  controls number of batches(steps) in each epoch.
        '''
        self.config = config
        self.distributed = config['distributed']
        if self.distributed:
            self.local_master = (config['local_rank'] == 0)
            self.global_master = (dist.get_rank() == 0)
        else:
            self.local_master = True
            self.global_master = True
        self.logger = config.get_logger(
            'trainer',
            config['trainer']['log_verbosity']) if self.local_master else None

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['local_rank'], config['local_world_size'])
        self.model = model.to(self.device)

        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        monitor_open = cfg_trainer['monitor_open']
        if monitor_open:
            self.monitor = cfg_trainer.get('monitor', 'off')
        else:
            self.monitor = 'off'

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.monitor_mode = 'off'
            self.monitor_best = 0
        else:
            self.monitor_mode, self.monitor_metric = self.monitor.split()
            assert self.monitor_mode in ['min', 'max']

            self.monitor_best = inf if self.monitor_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            self.early_stop = inf if self.early_stop == -1 else self.early_stop

        self.start_epoch = 1

        if self.local_master:
            self.checkpoint_dir = config.save_dir
            # setup visualization writer instance
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        # load checkpoint for resume training
        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        # load checkpoint following load to multi-gpu, avoid 'module.' prefix
        if self.config['trainer']['sync_batch_norm'] and self.distributed:
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)

        if self.distributed:
            self.model = DDP(self.model,
                             device_ids=self.device_ids,
                             output_device=self.device_ids[0],
                             find_unused_parameters=True)

        self.data_loader = data_loader
        if max_len_step is None:  # max length of iteration step of every epoch
            # epoch-based training
            self.len_step = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_step = max_len_step
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        log_step = self.config['trainer']['log_step_interval']
        self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int(
            np.sqrt(data_loader.batch_size))

        val_step_interval = self.config['trainer']['val_step_interval']
        # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\
        #                                             else int(np.sqrt(data_loader.batch_size))
        self.val_step_interval = val_step_interval

        self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda']

        self.train_loss_metrics = MetricTracker(
            'loss',
            'gl_loss',
            'crf_loss',
            writer=self.writer if self.local_master else None)
        self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls)
示例#29
0
class Trainer:
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 max_len_step=None):
        '''

        :param model:
        :param optimizer:
        :param config:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param max_len_step:  controls number of batches(steps) in each epoch.
        '''
        self.config = config
        self.distributed = config['distributed']
        if self.distributed:
            self.local_master = (config['local_rank'] == 0)
            self.global_master = (dist.get_rank() == 0)
        else:
            self.local_master = True
            self.global_master = True
        self.logger = config.get_logger(
            'trainer',
            config['trainer']['log_verbosity']) if self.local_master else None

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['local_rank'], config['local_world_size'])
        self.model = model.to(self.device)

        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        monitor_open = cfg_trainer['monitor_open']
        if monitor_open:
            self.monitor = cfg_trainer.get('monitor', 'off')
        else:
            self.monitor = 'off'

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.monitor_mode = 'off'
            self.monitor_best = 0
        else:
            self.monitor_mode, self.monitor_metric = self.monitor.split()
            assert self.monitor_mode in ['min', 'max']

            self.monitor_best = inf if self.monitor_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            self.early_stop = inf if self.early_stop == -1 else self.early_stop

        self.start_epoch = 1

        if self.local_master:
            self.checkpoint_dir = config.save_dir
            # setup visualization writer instance
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        # load checkpoint for resume training
        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        # load checkpoint following load to multi-gpu, avoid 'module.' prefix
        if self.config['trainer']['sync_batch_norm'] and self.distributed:
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)

        if self.distributed:
            self.model = DDP(self.model,
                             device_ids=self.device_ids,
                             output_device=self.device_ids[0],
                             find_unused_parameters=True)

        self.data_loader = data_loader
        if max_len_step is None:  # max length of iteration step of every epoch
            # epoch-based training
            self.len_step = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_step = max_len_step
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        log_step = self.config['trainer']['log_step_interval']
        self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int(
            np.sqrt(data_loader.batch_size))

        val_step_interval = self.config['trainer']['val_step_interval']
        # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\
        #                                             else int(np.sqrt(data_loader.batch_size))
        self.val_step_interval = val_step_interval

        self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda']

        self.train_loss_metrics = MetricTracker(
            'loss',
            'gl_loss',
            'crf_loss',
            writer=self.writer if self.local_master else None)
        self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls)

    def train(self):
        """
        Full training logic, including train and validation.
        """

        if self.distributed:
            dist.barrier()  # Syncing machines before training

        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            # ensure distribute worker sample different data,
            # set different random seed by passing epoch to sampler
            if self.distributed:
                self.data_loader.sampler.set_epoch(epoch)
            result_dict = self._train_epoch(epoch)

            # print logged informations to the screen
            if self.do_validation:
                val_result_dict = result_dict['val_result_dict']
                val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict)
            else:
                val_res = ''
            # every epoch log information
            self.logger_info(
                '[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} '
                'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.format(
                    epoch, self.epochs, result_dict['loss'],
                    result_dict['gl_loss'] * self.gl_loss_lambda,
                    result_dict['crf_loss'], val_res))

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation:
                best, not_improved_count = self._is_best_monitor_metric(
                    best, not_improved_count, val_result_dict)
                if not_improved_count > self.early_stop:
                    self.logger_info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _is_best_monitor_metric(self, best, not_improved_count,
                                val_result_dict):
        '''
        monitor metric
        :param best:
        :param not_improved_count:
        :param val_result_dict:
        :return:
        '''
        entity_name, metric = self.monitor_metric.split('-')
        val_monitor_metric_res = val_result_dict[entity_name][metric]
        try:
            # check whether model performance improved or not, according to specified metric(monitor_metric)
            improved = (self.monitor_mode == 'min' and val_monitor_metric_res <= self.monitor_best) or \
                       (self.monitor_mode == 'max' and val_monitor_metric_res >= self.monitor_best)
        except KeyError:
            self.logger_warning(
                "Warning: Metric '{}' is not found. "
                "Model performance monitoring is disabled.".format(
                    self.monitor_metric))
            self.monitor_mode = 'off'
            improved = False
        if improved:
            self.monitor_best = val_monitor_metric_res
            not_improved_count = 0
            best = True
        else:
            not_improved_count += 1
        return best, not_improved_count

    def _train_epoch(self, epoch):
        '''
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log dict that contains average loss and metric in this epoch.
        '''
        self.model.train()
        self.train_loss_metrics.reset()
        ## step iteration start ##
        for step_idx, input_data_item in enumerate(self.data_loader):
            step_idx += 1
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(
                        input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(self.device,
                                                          non_blocking=True)
            if self.config['trainer']['anomaly_detection']:
                # This mode will increase the runtime and should only be enabled for debugging
                with torch.autograd.detect_anomaly():
                    self.optimizer.zero_grad()
                    # model forward
                    output = self.model(**input_data_item)
                    # calculate loss
                    gl_loss = output['gl_loss']
                    crf_loss = output['crf_loss']
                    total_loss = torch.sum(
                        crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                    # backward
                    total_loss.backward()
                    # self.average_gradients(self.model)
                    self.optimizer.step()
            else:
                self.optimizer.zero_grad()
                # model forward
                output = self.model(**input_data_item)
                # calculate loss
                gl_loss = output['gl_loss']
                crf_loss = output['crf_loss']
                total_loss = torch.sum(
                    crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                # backward
                total_loss.backward()
                # self.average_gradients(self.model)
                self.optimizer.step()

            # Use a barrier() to make sure that all process have finished forward and backward
            if self.distributed:
                dist.barrier()
                #  obtain the sum of all total_loss at all processes
                dist.all_reduce(total_loss, op=dist.reduce_op.SUM)

                size = dist.get_world_size()
            else:
                size = 1
            gl_loss /= size  # averages gl_loss across the whole world
            crf_loss /= size  # averages crf_loss across the whole world

            # calculate average loss across the batch size
            avg_gl_loss = torch.mean(gl_loss)
            avg_crf_loss = torch.mean(crf_loss)
            avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss
            # update metrics
            self.writer.set_step((epoch - 1) * self.len_step + step_idx -
                                 1) if self.local_master else None
            self.train_loss_metrics.update('loss', avg_loss.item())
            self.train_loss_metrics.update(
                'gl_loss',
                avg_gl_loss.item() * self.gl_loss_lambda)
            self.train_loss_metrics.update('crf_loss', avg_crf_loss.item())

            # log messages
            if step_idx % self.log_step == 0:
                self.logger_info(
                    'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}'
                    .format(epoch, self.epochs, step_idx, self.len_step,
                            avg_loss.item(),
                            avg_gl_loss.item() * self.gl_loss_lambda,
                            avg_crf_loss.item()))
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            # do validation after val_step_interval iteration
            if self.do_validation and step_idx % self.val_step_interval == 0:
                val_result_dict = self._valid_epoch(epoch)
                self.logger_info(
                    '[Step Validation] Epoch:[{}/{}] Step:[{}/{}]  \n{}'.
                    format(epoch, self.epochs, step_idx, self.len_step,
                           SpanBasedF1MetricTracker.dict2str(val_result_dict)))

                # check if best metric, if true, then save as model_best checkpoint.
                best, not_improved_count = self._is_best_monitor_metric(
                    False, 0, val_result_dict)
                if best:
                    self._save_checkpoint(epoch, best)

            # decide whether continue iter
            if step_idx == self.len_step + 1:
                break

        ## step iteration end ##

        # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss}
        log = self.train_loss_metrics.result()

        # do validation after training an epoch
        if self.do_validation:
            val_result_dict = self._valid_epoch(epoch)
            log['val_result_dict'] = val_result_dict

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        '''
         Validate after training an epoch or regular step, this is a time-consuming procedure if validation data is big.
        :param epoch: Integer, current training epoch.
        :return: A dict that contains information about validation
        '''

        self.model.eval()
        self.valid_f1_metrics.reset()
        with torch.no_grad():
            for step_idx, input_data_item in enumerate(self.valid_data_loader):
                for key, input_value in input_data_item.items():
                    if input_value is not None and isinstance(
                            input_value, torch.Tensor):
                        input_data_item[key] = input_value.to(
                            self.device, non_blocking=True)

                output = self.model(**input_data_item)
                # print("awesome 307")
                logits = output['logits']
                new_mask = output['new_mask']
                if hasattr(self.model, 'module'):
                    #  List[(List[int], torch.Tensor)] contain the tag indices of the maximum likelihood tag sequence.
                    #  and the score of the viterbi path.
                    best_paths = self.model.module.decoder.crf_layer.viterbi_tags(
                        logits, mask=new_mask, logits_batch_first=True)
                else:
                    best_paths = self.model.decoder.crf_layer.viterbi_tags(
                        logits, mask=new_mask, logits_batch_first=True)
                predicted_tags = []
                for path, score in best_paths:
                    predicted_tags.append(path)

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + step_idx, 'valid') \
                    if self.local_master else None

                # calculate and update f1 metrics
                # (B, N*T, out_dim)
                predicted_tags_hard_prob = logits * 0
                for i, instance_tags in enumerate(predicted_tags):
                    for j, tag_id in enumerate(instance_tags):
                        predicted_tags_hard_prob[i, j, tag_id] = 1

                golden_tags = input_data_item['iob_tags_label']
                mask = input_data_item['mask']
                union_iob_tags = iob_tags_to_union_iob_tags(golden_tags, mask)

                if self.distributed:
                    dist.barrier()  #
                self.valid_f1_metrics.update(predicted_tags_hard_prob.long(),
                                             union_iob_tags, new_mask)

        # add histogram of model parameters to the tensorboard
        # for name, p in self.model.named_parameters():
        #     self.writer.add_histogram(name, p, bins='auto')

        f1_result_dict = self.valid_f1_metrics.result()

        # rollback to train mode
        self.model.train()

        return f1_result_dict

    def average_gradients(self, model):
        '''
        Gradient averaging
        :param model:
        :return:
        '''
        size = float(dist.get_world_size())
        for param in model.parameters():
            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size

    def logger_info(self, msg):
        self.logger.info(msg) if self.local_master else None

    def logger_warning(self, msg):
        self.logger.warning(msg) if self.local_master else None

    def _prepare_device(self, local_rank, local_world_size):
        '''
         setup GPU device if available, move model into configured device
        :param local_rank:
        :param local_world_size:
        :return:
        '''
        if self.distributed:
            ngpu_per_process = torch.cuda.device_count() // local_world_size
            device_ids = list(
                range(local_rank * ngpu_per_process,
                      (local_rank + 1) * ngpu_per_process))

            if torch.cuda.is_available() and local_rank != -1:
                torch.cuda.set_device(
                    device_ids[0]
                )  # device_ids[0] =local_rank if local_world_size = n_gpu per node
                device = 'cuda'
                self.logger_info(
                    f"[Process {os.getpid()}] world_size = {dist.get_world_size()}, "
                    +
                    f"rank = {dist.get_rank()}, n_gpu/process = {ngpu_per_process}, device_ids = {device_ids}"
                )
            else:
                self.logger_warning('Training will be using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, device_ids
        else:
            n_gpu = torch.cuda.device_count()
            print(f"NUMBER GPU {n_gpu}")
            n_gpu_use = local_world_size
            if n_gpu_use > 0 and n_gpu == 0:
                self.logger_warning(
                    "Warning: There\'s no GPU available on this machine,"
                    "training will be performed on CPU.")
                n_gpu_use = 0
            if n_gpu_use > n_gpu:
                self.logger_warning(
                    "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                    "on this machine.".format(n_gpu_use, n_gpu))
                n_gpu_use = n_gpu

            list_ids = list(range(n_gpu_use))
            if n_gpu_use > 0:
                torch.cuda.set_device(
                    list_ids[0])  # only use first available gpu as devices
                self.logger_warning(f'Training is using GPU {list_ids[0]}!')
                device = 'cuda'
            else:
                self.logger_warning('Training is using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, list_ids

    def _save_checkpoint(self, epoch, save_best=False):
        '''
        Saving checkpoints
        :param epoch:  current epoch number
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        :return:
        '''
        # only local master process do save model
        if not self.local_master:
            return

        if hasattr(self.model, 'module'):
            arch = type(self.model.module).__name__
            state_dict = self.model.module.state_dict()
        else:
            arch = type(self.model).__name__
            state_dict = self.model.state_dict()
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': state_dict,
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.monitor_best,
            'config': self.config
        }
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger_info("Saving current best: model_best.pth ...")
        else:
            filename = str(self.checkpoint_dir /
                           'checkpoint-epoch{}.pth'.format(epoch))
            torch.save(state, filename)
            self.logger_info("Saving checkpoint: {} ...".format(filename))

    def _resume_checkpoint(self, resume_path):
        '''
        Resume from saved checkpoints
        :param resume_path: Checkpoint path to be resumed
        :return:
        '''
        resume_path = str(resume_path)
        self.logger_info("Loading checkpoint: {} ...".format(resume_path))
        # map_location = {'cuda:%d' % 0: 'cuda:%d' % self.config['local_rank']}
        checkpoint = torch.load(resume_path, map_location=self.device)
        self.start_epoch = checkpoint['epoch'] + 1
        self.monitor_best = checkpoint['monitor_best']

        # load architecture params from checkpoint.
        if checkpoint['config']['model_arch'] != self.config['model_arch']:
            self.logger_warning(
                "Warning: Architecture configuration given in config file is different from that of "
                "checkpoint. This may yield an exception while state_dict is being loaded."
            )
        self.model.load_state_dict(checkpoint['state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger_warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint. "
                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger_info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
示例#30
0
    def __init__(
        self,
        model: Module,
        loss_fn: Callable,
        loss_args: Dict[str, Any],
        metric_fns: List[Callable],
        metric_args: List[Dict[str, Any]],
        optimizer: Optimizer,
        config: ConfigParser,
    ):

        self.config: ConfigParser = config
        self.logger: Logger = config.get_logger("trainer",
                                                config["trainer"]["verbosity"])

        # Setup GPU device if available.
        self.device: torch.device
        device_ids: List[int]
        self.device, device_ids = self._prepare_device(config["n_gpu"])

        # Move model into configured device(s).
        self.model: Module = model.to(self.device)
        if len(device_ids) > 1:
            self.model = DataParallel(model, device_ids=device_ids)

        # Set loss function and arguments.
        self.loss_fn: Callable = loss_fn
        self.loss_args: Dict[str, Any] = loss_args

        # Set all metric functions and associated arguments.
        self.metric_fns: List[Callable] = metric_fns
        self.metric_args: List[Dict[str, Any]] = metric_args

        # Set optimizer.
        self.optimizer: Optimizer = optimizer

        # Set training configuration.
        cfg_trainer: Dict[str, Any] = config["trainer"]
        self.epochs: int = cfg_trainer["epochs"]
        self.save_period: int = cfg_trainer["save_period"]
        self.monitor: str = cfg_trainer.get("monitor", "off")

        # Configuration to monitor model performance and save best.
        if self.monitor == "off":
            self.mnt_mode: str = "off"
            self.mnt_best: float = 0
        else:
            self.mnt_metric: str
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop: float = cfg_trainer.get("early_stop", inf)

        self.start_epoch: int = 1
        self.checkpoint_dir: Path = config.save_dir

        # Setup visualization writer instance.
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer["tensorboard"])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)