コード例 #1
0
 def run(self, train_loader, validate_loader, epochs):
     for epoch in range(epochs):
         self.model.train()
         self.train(train_loader, epoch, History())
         self.model.eval()
         self.evaluate(validate_loader, epoch, History())
         self.save_checkpoint(epoch)
コード例 #2
0
    def __init__(self,
                 model,
                 optimizer=None,
                 lr_scheduler=None,
                 logger=None,
                 saver=None,
                 default_handlers=False):
        self.model = model

        # can leave empty
        self.optimizer = optimizer

        # optional
        self.lr_scheduler = lr_scheduler
        self.saver = saver
        self.logger = logger

        # internal
        self.history = History()
        self.state = AttrDict(epoch=0)
        self._events = defaultdict(list)
        self._hist_dict = defaultdict(list)
        self._log = logging.getLogger('onegan.OneEstimator')

        if default_handlers:
            self.add_default_event_handlers()
        self._log.info(f'OneEstimator is initialized')
コード例 #3
0
    def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None):
        self.model = model
        self.optimizer = optimizer
        self.saver = saver
        self.logger = logger
        self.lr_scheduler = lr_scheduler
        self.name = name

        self.history = History()
        self.state = {}
        self._log = logging.getLogger(f'OneGAN.{name}')
        self._log.info(f'OneEstimator<{name}> is initialized')
コード例 #4
0
    def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None):
        self.models = model
        self.schedulers = lr_scheduler
        self.model_g, self.model_d = model if len(model) == 2 else (None, None)
        self.optim_g, self.optim_d = optimizer if optimizer else (None, None)
        self.saver = saver
        self.logger = logger
        self.sched_g, self.sched_d = lr_scheduler if len(lr_scheduler) == 2 else (None, None)
        self.name = name

        self.history = History()
        self.history_val = History()
        self.state = {}
        self._log = logging.getLogger(f'OneGAN.{name}')
        self._log.info(f'OneGANEstimator<{name}> is initialized')
コード例 #5
0
class OneEstimator:

    def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None):
        self.model = model
        self.optimizer = optimizer
        self.saver = saver
        self.logger = logger
        self.lr_scheduler = lr_scheduler
        self.name = name

        self.history = History()
        self.state = {}
        self._log = logging.getLogger(f'OneGAN.{name}')
        self._log.info(f'OneEstimator<{name}> is initialized')

    def run(self, train_loader, validate_loader, update_fn, inference_fn, epochs):
        for epoch in range(epochs):
            self.state['epoch'] = epoch

            self.train(train_loader, update_fn)
            self.logger.scalar(self.history.metric, epoch) # dsc

            self.evaluate(validate_loader, inference_fn)
            self.logger.scalar(self.history.metric, epoch) # dsc

            self.save_checkpoint()
            self.adjust_learning_rate(self.history.metric['loss/loss']) # dsc
            self._log.info(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def load_checkpoint(self, weight_path, resume=False):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.load(self, weight_path, resume)

    def save_checkpoint(self):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.save(self, epoch=self.state['epoch'])

    def adjust_learning_rate(self, monitor_val):
        if not hasattr(self, 'lr_scheduler') or self.lr_scheduler is None:
            return
        self.lr_scheduler.step(monitor_val)

    def train(self, data_loader, update_fn):
        self.model.train()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            loss, accuracy = update_fn(self.model, data)
            progress.set_postfix(self.history.add({**loss, **accuracy}))
            self.optimizer.zero_grad()
            loss['loss/loss'].backward()
            self.optimizer.step()
        return self.history.metric # @property is not callable, updated by dsc

    def evaluate(self, data_loader, inference_fn):
        self.model.eval()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description('Evaluate')

        for data in progress:
            log_values = inference_fn(self.model, data)
            loss, accuracy = log_values if isinstance(log_values, tuple) else (log_values, {})
            progress.set_postfix(self.history.add({**loss, **accuracy}, log_suffix='_val'))
        return self.history.metric # @property is not callable, updated by dsc

    def dummy_run(self, train_loader, validate_loader, update_fn, inference_fn, epoch_fn, epochs):
        for epoch in range(epochs):
            self.history.clear()
            self.state['epoch'] = epoch
            self.dummy_train(train_loader, update_fn)
            self.dummy_evaluate(validate_loader, inference_fn)

            if isinstance(epoch_fn, list):
                for fn in epoch_fn:
                    fn(epoch, self.history)
            elif callable(epoch_fn):
                epoch_fn(epoch, self.history)

            self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def dummy_train(self, data_loader, update_fn):
        self.model.train()
        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            _stat = update_fn(self.model, data)
            loss, stat = _stat if len(_stat) == 2 else (_stat, {})
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            progress.set_postfix(self.history.add(stat))

    def dummy_evaluate(self, data_loader, inference_fn):
        self.model.eval()
        progress = tqdm.tqdm(data_loader)
        for data in progress:
            _stat = inference_fn(self.model, data)
            if len(_stat) == 2:
                _, stat = _stat
            elif isinstance(_stat, dict):
                stat = _stat
            else:
                stat = {}
            progress.set_postfix(self.history.add(stat, log_suffix='_val'))
コード例 #6
0
class OneGANEstimator:

    def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None):
        self.models = model
        self.schedulers = lr_scheduler
        self.model_g, self.model_d = model if len(model) == 2 else (None, None)
        self.optim_g, self.optim_d = optimizer if optimizer else (None, None)
        self.saver = saver
        self.logger = logger
        self.sched_g, self.sched_d = lr_scheduler if len(lr_scheduler) == 2 else (None, None)
        self.name = name

        self.history = History()
        self.history_val = History()
        self.state = {}
        self._log = logging.getLogger(f'OneGAN.{name}')
        self._log.info(f'OneGANEstimator<{name}> is initialized')

    def run(self, train_loader, validate_loader, update_fn, inference_fn, epochs):
        for epoch in range(epochs):
            self.state['epoch'] = epoch

            self.train(train_loader, update_fn)
            self.logger.scalar(self.history.metric(), epoch)

            self.evaluate(validate_loader, inference_fn)
            self.logger.scalar(self.history.metric(), epoch)

            self.save_checkpoint()
            self.adjust_learning_rate(('loss/loss_g_val', 'loss/loss_d_val'))
            self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def load_checkpoint(self, weight_path, resume=False):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.load(self, weight_path, resume)

    def save_checkpoint(self):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.save(self, self.state['epoch'])

    def adjust_learning_rate(self, monitor_vals):
        if not hasattr(self, 'lr_scheduler') or self.lr_scheduler is None:
            return
        try:
            for sched, monitor_val in zip(self.schedulers, monitor_vals):
                sched.step(self.history[monitor_val])
        except Exception:
            for sched in self.schedulers:
                sched.step()

    def train(self, data_loader, update_fn):
        self.model_g.train()
        self.model_d.train()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            staged_closure = update_fn(self.model_g, self.model_d, data)

            self.optim_d.zero_grad()
            loss_d = next(staged_closure)
            loss_d['loss/loss_d'].backward()
            self.optim_d.step()

            self.optim_g.zero_grad()
            loss_g = next(staged_closure)
            loss_g['loss/loss_g'].backward()
            self.optim_g.step()

            accuracy = next(staged_closure)
            progress.set_postfix(self.history.add({**loss_d, **loss_g, **accuracy}))
            next(staged_closure)
        return self.history.metric()

    def evaluate(self, data_loader, inference_fn):
        self.model_g.eval()
        self.model_d.eval()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description('Evaluate')

        for data in progress:
            staged_closure = inference_fn(self.model_g, self.model_d, data)
            loss_d, loss_g, accuracy, _ = [r for r in staged_closure]

            progress.set_postfix(self.history.add({**loss_d, **loss_g, **accuracy}, log_suffix='_val'))
        return self.history.metric()

    def dummy_run(self, train_loader, validate_loader, update_fn, inference_fn, epoch_fn, epochs):
        for epoch in range(epochs):
            self.state['epoch'] = epoch
            self.dummy_train(train_loader, update_fn)
            self.dummy_evaluate(validate_loader, inference_fn)
            epoch_fn(epoch)
            self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def dummy_train(self, data_loader, update_fn):
        [m.train() for m in self.models]
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            stat = {}
            for staged_closure in update_fn(self.models, data):
                if isinstance(staged_closure, tuple):
                    loss, (optim, key_loss) = staged_closure
                    optim.zero_grad()
                    loss[key_loss].backward()
                    optim.step()
                    stat.update(loss)
                elif isinstance(staged_closure, dict):
                    accuracy = staged_closure
                    stat.update(accuracy)
            progress.set_postfix(self.history.add(stat))

    def dummy_evaluate(self, data_loader, update_fn):
        [m.eval() for m in self.models]
        self.history_val.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            stat = {}
            for staged_closure in update_fn(self.models, data):
                if isinstance(staged_closure, tuple):
                    loss, _ = staged_closure
                    stat.update(loss)
                elif isinstance(staged_closure, dict):
                    accuracy = staged_closure
                    stat.update(accuracy)
            progress.set_postfix(self.history_val.add(stat, log_suffix='_val'))
コード例 #7
0
class OneEstimator(EstimatorEventMixin, Estimator):
    r""" Estimator for network training and evaluation.

    Args:
        model (torch.nn.Module): defined model for estimator.
        optimizer (torch.optim, optional): optimizer for model training.
        lr_scheduler (torch.optim.lr_scheduler, optional): learning rate scheduler for
            model training.
        logger (extension.TensorBoardLogger, optional): training state logger (default: None).
        saver (extension.Checkpoint, optional): checkpoint persistence (default: None).
        default_handlers (bool): turn on/off the defalt handlers (default: False).

    Attributes:
        history (extension.History): internal statistics of training state.
    """
    def __init__(self,
                 model,
                 optimizer=None,
                 lr_scheduler=None,
                 logger=None,
                 saver=None,
                 default_handlers=False):
        self.model = model

        # can leave empty
        self.optimizer = optimizer

        # optional
        self.lr_scheduler = lr_scheduler
        self.saver = saver
        self.logger = logger

        # internal
        self.history = History()
        self.state = AttrDict(epoch=0)
        self._events = defaultdict(list)
        self._hist_dict = defaultdict(list)
        self._log = logging.getLogger('onegan.OneEstimator')

        if default_handlers:
            self.add_default_event_handlers()
        self._log.info(f'OneEstimator is initialized')

    def add_default_event_handlers(self):
        self.add_event_handler(Events.ITERATION_END, iteration_end_logging)
        self.add_event_handler(Events.EPOCH_END, epoch_end_logging)
        self.add_event_handler(Events.EPOCH_END, save_checkpoint)
        self.add_event_handler(Events.EPOCH_END, adjust_learning_rate)

    def tensorboard_logging(self, image=None, histogram=None, prefix=None):
        ''' wrapper in estimator for Tensorboard logger.

        Args:
            image: dict() of a list of images
            histogram: dict() of tensors for accumulated histogram
            prefix: prefix string for keyword-image
        '''
        if not hasattr(self, 'logger') or self.logger is None:
            return

        if image and prefix:
            self.logger.image(image, self.state.epoch, prefix)
            self._log.debug('tensorboard_logging logs images')

        if histogram and prefix:
            for tag, tensor in histogram.items():
                self._hist_dict[f'{prefix}{tag}'].append(tensor.clone())
            self._log.debug('tensorboard_logging accumulate histograms')

    def tensorboard_epoch_logging(self, scalar=None):
        ''' wrapper in estimator for Tensorboard logger.

        Args:
            scalar: dict() of a list of scalars
        '''
        if not hasattr(self, 'logger') or self.logger is None:
            return

        self.logger.scalar(scalar, self.state.epoch)
        self._log.debug('tensorboard_epoch_logging logs scalars')

        if self._hist_dict:
            kw_histograms = {
                tag: torch.cat(tensors)
                for tag, tensors in self._hist_dict.items()
            }
            self.logger.histogram(kw_histograms, self.state.epoch)
            self._hist_dict = defaultdict(list)
            self._log.debug('tensorboard_epoch_logging logs histograms')

    def run(self,
            train_loader,
            validate_loader,
            closure_fn,
            epochs,
            longtime_pbar=False):
        epoch_range = tqdm.trange(
            epochs,
            desc='Training Procedure') if longtime_pbar else range(epochs)

        for epoch in epoch_range:
            self.history.clear()
            self.state.epoch = epoch
            self._trigger(Events.EPOCH_START)

            self.train(train_loader, closure_fn, longtime_pbar)
            self.evaluate(validate_loader, closure_fn, longtime_pbar)

            self._trigger(Events.EPOCH_END)
            self._log.debug(f'OneEstimator epoch#{epoch} end')

    def train(self, data_loader, update_fn, longtime_pbar=False):
        self.model.train()
        progress = tqdm.tqdm(data_loader,
                             desc=f'Epoch#{self.state.epoch + 1}',
                             leave=not longtime_pbar)

        for data in progress:
            self._trigger(Events.ITERATION_START)

            result = update_fn(self.model, data)
            # `loss`, `status` should be in result (dict)

            loss = result.pop('loss')
            assert loss, 'Returned result from closure must contain key `loss` to backward()'
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            status = result.pop('status')
            assert status, 'Returned result from closure must contain key `status` for history'
            current_status = self.history.update(status)

            progress.set_postfix(current_status)
            self.state.update(result)

            self._trigger(Events.ITERATION_END)

    def evaluate(self, data_loader, inference_fn, longtime_pbar=False):
        self.model.eval()
        progress = tqdm.tqdm(data_loader,
                             desc='evaluating',
                             leave=not longtime_pbar)

        with torch.no_grad():
            for data in progress:
                self._trigger(Events.ITERATION_START)

                result = inference_fn(self.model, data)
                # `status` should be in result (dict)

                status = result.pop('status')
                assert status, 'Returned result from closure must contain key `status` for history'
                current_status = self.history.update(status, log_suffix='_val')

                progress.set_postfix(current_status)
                self.state.update(result)

                self._trigger(Events.ITERATION_END)