Exemplo n.º 1
0
class OneGANEstimator:

    def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None):
        self.models = model
        self.schedulers = lr_scheduler
        self.model_g, self.model_d = model if len(model) == 2 else (None, None)
        self.optim_g, self.optim_d = optimizer if optimizer else (None, None)
        self.saver = saver
        self.logger = logger
        self.sched_g, self.sched_d = lr_scheduler if len(lr_scheduler) == 2 else (None, None)
        self.name = name

        self.history = History()
        self.history_val = History()
        self.state = {}
        self._log = logging.getLogger(f'OneGAN.{name}')
        self._log.info(f'OneGANEstimator<{name}> is initialized')

    def run(self, train_loader, validate_loader, update_fn, inference_fn, epochs):
        for epoch in range(epochs):
            self.state['epoch'] = epoch

            self.train(train_loader, update_fn)
            self.logger.scalar(self.history.metric(), epoch)

            self.evaluate(validate_loader, inference_fn)
            self.logger.scalar(self.history.metric(), epoch)

            self.save_checkpoint()
            self.adjust_learning_rate(('loss/loss_g_val', 'loss/loss_d_val'))
            self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def load_checkpoint(self, weight_path, resume=False):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.load(self, weight_path, resume)

    def save_checkpoint(self):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.save(self, self.state['epoch'])

    def adjust_learning_rate(self, monitor_vals):
        if not hasattr(self, 'lr_scheduler') or self.lr_scheduler is None:
            return
        try:
            for sched, monitor_val in zip(self.schedulers, monitor_vals):
                sched.step(self.history[monitor_val])
        except Exception:
            for sched in self.schedulers:
                sched.step()

    def train(self, data_loader, update_fn):
        self.model_g.train()
        self.model_d.train()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            staged_closure = update_fn(self.model_g, self.model_d, data)

            self.optim_d.zero_grad()
            loss_d = next(staged_closure)
            loss_d['loss/loss_d'].backward()
            self.optim_d.step()

            self.optim_g.zero_grad()
            loss_g = next(staged_closure)
            loss_g['loss/loss_g'].backward()
            self.optim_g.step()

            accuracy = next(staged_closure)
            progress.set_postfix(self.history.add({**loss_d, **loss_g, **accuracy}))
            next(staged_closure)
        return self.history.metric()

    def evaluate(self, data_loader, inference_fn):
        self.model_g.eval()
        self.model_d.eval()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description('Evaluate')

        for data in progress:
            staged_closure = inference_fn(self.model_g, self.model_d, data)
            loss_d, loss_g, accuracy, _ = [r for r in staged_closure]

            progress.set_postfix(self.history.add({**loss_d, **loss_g, **accuracy}, log_suffix='_val'))
        return self.history.metric()

    def dummy_run(self, train_loader, validate_loader, update_fn, inference_fn, epoch_fn, epochs):
        for epoch in range(epochs):
            self.state['epoch'] = epoch
            self.dummy_train(train_loader, update_fn)
            self.dummy_evaluate(validate_loader, inference_fn)
            epoch_fn(epoch)
            self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def dummy_train(self, data_loader, update_fn):
        [m.train() for m in self.models]
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            stat = {}
            for staged_closure in update_fn(self.models, data):
                if isinstance(staged_closure, tuple):
                    loss, (optim, key_loss) = staged_closure
                    optim.zero_grad()
                    loss[key_loss].backward()
                    optim.step()
                    stat.update(loss)
                elif isinstance(staged_closure, dict):
                    accuracy = staged_closure
                    stat.update(accuracy)
            progress.set_postfix(self.history.add(stat))

    def dummy_evaluate(self, data_loader, update_fn):
        [m.eval() for m in self.models]
        self.history_val.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            stat = {}
            for staged_closure in update_fn(self.models, data):
                if isinstance(staged_closure, tuple):
                    loss, _ = staged_closure
                    stat.update(loss)
                elif isinstance(staged_closure, dict):
                    accuracy = staged_closure
                    stat.update(accuracy)
            progress.set_postfix(self.history_val.add(stat, log_suffix='_val'))
Exemplo n.º 2
0
class OneEstimator:

    def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None):
        self.model = model
        self.optimizer = optimizer
        self.saver = saver
        self.logger = logger
        self.lr_scheduler = lr_scheduler
        self.name = name

        self.history = History()
        self.state = {}
        self._log = logging.getLogger(f'OneGAN.{name}')
        self._log.info(f'OneEstimator<{name}> is initialized')

    def run(self, train_loader, validate_loader, update_fn, inference_fn, epochs):
        for epoch in range(epochs):
            self.state['epoch'] = epoch

            self.train(train_loader, update_fn)
            self.logger.scalar(self.history.metric(), epoch)

            self.evaluate(validate_loader, inference_fn)
            self.logger.scalar(self.history.metric(), epoch)

            self.save_checkpoint()
            self.adjust_learning_rate(self.history.metric()['loss/loss_val'])
            self._log.info(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def load_checkpoint(self, weight_path, resume=False):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.load(self, weight_path, resume)

    def save_checkpoint(self):
        if not hasattr(self, 'saver') or self.saver is None:
            return
        self.saver.save(self, self.state['epoch'])

    def adjust_learning_rate(self, monitor_val):
        if not hasattr(self, 'lr_scheduler') or self.lr_scheduler is None:
            return
        self.lr_scheduler.step(monitor_val)

    def train(self, data_loader, update_fn):
        self.model.train()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            loss, accuracy = update_fn(self.model, data)
            progress.set_postfix(self.history.add({**loss, **accuracy}))
            self.optimizer.zero_grad()
            loss['loss/loss'].backward()
            self.optimizer.step()
        return self.history.metric()

    def evaluate(self, data_loader, inference_fn):
        self.model.eval()
        self.history.clear()

        progress = tqdm.tqdm(data_loader)
        progress.set_description('Evaluate')

        for data in progress:
            log_values = inference_fn(self.model, data)
            loss, accuracy = log_values if isinstance(log_values, tuple) else (log_values, {})
            progress.set_postfix(self.history.add({**loss, **accuracy}, log_suffix='_val'))
        return self.history.metric()

    def dummy_run(self, train_loader, validate_loader, update_fn, inference_fn, epoch_fn, epochs):
        for epoch in range(epochs):
            self.history.clear()
            self.state['epoch'] = epoch
            self.dummy_train(train_loader, update_fn)
            self.dummy_evaluate(validate_loader, inference_fn)

            if isinstance(epoch_fn, list):
                for fn in epoch_fn:
                    fn(epoch, self.history)
            elif callable(epoch_fn):
                epoch_fn(epoch, self.history)

            self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end')

    def dummy_train(self, data_loader, update_fn):
        self.model.train()
        progress = tqdm.tqdm(data_loader)
        progress.set_description(f'Epoch#{self.state["epoch"] + 1}')

        for data in progress:
            _stat = update_fn(self.model, data)
            loss, stat = _stat if len(_stat) == 2 else (_stat, {})
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            progress.set_postfix(self.history.add(stat))

    def dummy_evaluate(self, data_loader, inference_fn):
        self.model.eval()
        progress = tqdm.tqdm(data_loader)
        for data in progress:
            _stat = inference_fn(self.model, data)
            if len(_stat) == 2:
                _, stat = _stat
            elif isinstance(_stat, dict):
                stat = _stat
            else:
                stat = {}
            progress.set_postfix(self.history.add(stat, log_suffix='_val'))