class OneGANEstimator: def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None): self.models = model self.schedulers = lr_scheduler self.model_g, self.model_d = model if len(model) == 2 else (None, None) self.optim_g, self.optim_d = optimizer if optimizer else (None, None) self.saver = saver self.logger = logger self.sched_g, self.sched_d = lr_scheduler if len(lr_scheduler) == 2 else (None, None) self.name = name self.history = History() self.history_val = History() self.state = {} self._log = logging.getLogger(f'OneGAN.{name}') self._log.info(f'OneGANEstimator<{name}> is initialized') def run(self, train_loader, validate_loader, update_fn, inference_fn, epochs): for epoch in range(epochs): self.state['epoch'] = epoch self.train(train_loader, update_fn) self.logger.scalar(self.history.metric(), epoch) self.evaluate(validate_loader, inference_fn) self.logger.scalar(self.history.metric(), epoch) self.save_checkpoint() self.adjust_learning_rate(('loss/loss_g_val', 'loss/loss_d_val')) self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end') def load_checkpoint(self, weight_path, resume=False): if not hasattr(self, 'saver') or self.saver is None: return self.saver.load(self, weight_path, resume) def save_checkpoint(self): if not hasattr(self, 'saver') or self.saver is None: return self.saver.save(self, self.state['epoch']) def adjust_learning_rate(self, monitor_vals): if not hasattr(self, 'lr_scheduler') or self.lr_scheduler is None: return try: for sched, monitor_val in zip(self.schedulers, monitor_vals): sched.step(self.history[monitor_val]) except Exception: for sched in self.schedulers: sched.step() def train(self, data_loader, update_fn): self.model_g.train() self.model_d.train() self.history.clear() progress = tqdm.tqdm(data_loader) progress.set_description(f'Epoch#{self.state["epoch"] + 1}') for data in progress: staged_closure = update_fn(self.model_g, self.model_d, data) self.optim_d.zero_grad() loss_d = next(staged_closure) loss_d['loss/loss_d'].backward() self.optim_d.step() self.optim_g.zero_grad() loss_g = next(staged_closure) loss_g['loss/loss_g'].backward() self.optim_g.step() accuracy = next(staged_closure) progress.set_postfix(self.history.add({**loss_d, **loss_g, **accuracy})) next(staged_closure) return self.history.metric() def evaluate(self, data_loader, inference_fn): self.model_g.eval() self.model_d.eval() self.history.clear() progress = tqdm.tqdm(data_loader) progress.set_description('Evaluate') for data in progress: staged_closure = inference_fn(self.model_g, self.model_d, data) loss_d, loss_g, accuracy, _ = [r for r in staged_closure] progress.set_postfix(self.history.add({**loss_d, **loss_g, **accuracy}, log_suffix='_val')) return self.history.metric() def dummy_run(self, train_loader, validate_loader, update_fn, inference_fn, epoch_fn, epochs): for epoch in range(epochs): self.state['epoch'] = epoch self.dummy_train(train_loader, update_fn) self.dummy_evaluate(validate_loader, inference_fn) epoch_fn(epoch) self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end') def dummy_train(self, data_loader, update_fn): [m.train() for m in self.models] self.history.clear() progress = tqdm.tqdm(data_loader) progress.set_description(f'Epoch#{self.state["epoch"] + 1}') for data in progress: stat = {} for staged_closure in update_fn(self.models, data): if isinstance(staged_closure, tuple): loss, (optim, key_loss) = staged_closure optim.zero_grad() loss[key_loss].backward() optim.step() stat.update(loss) elif isinstance(staged_closure, dict): accuracy = staged_closure stat.update(accuracy) progress.set_postfix(self.history.add(stat)) def dummy_evaluate(self, data_loader, update_fn): [m.eval() for m in self.models] self.history_val.clear() progress = tqdm.tqdm(data_loader) progress.set_description(f'Epoch#{self.state["epoch"] + 1}') for data in progress: stat = {} for staged_closure in update_fn(self.models, data): if isinstance(staged_closure, tuple): loss, _ = staged_closure stat.update(loss) elif isinstance(staged_closure, dict): accuracy = staged_closure stat.update(accuracy) progress.set_postfix(self.history_val.add(stat, log_suffix='_val'))
class OneEstimator: def __init__(self, model, optimizer=None, lr_scheduler=None, logger=None, saver=None, name=None): self.model = model self.optimizer = optimizer self.saver = saver self.logger = logger self.lr_scheduler = lr_scheduler self.name = name self.history = History() self.state = {} self._log = logging.getLogger(f'OneGAN.{name}') self._log.info(f'OneEstimator<{name}> is initialized') def run(self, train_loader, validate_loader, update_fn, inference_fn, epochs): for epoch in range(epochs): self.state['epoch'] = epoch self.train(train_loader, update_fn) self.logger.scalar(self.history.metric(), epoch) self.evaluate(validate_loader, inference_fn) self.logger.scalar(self.history.metric(), epoch) self.save_checkpoint() self.adjust_learning_rate(self.history.metric()['loss/loss_val']) self._log.info(f'OneEstimator<{self.name}> epoch#{epoch} end') def load_checkpoint(self, weight_path, resume=False): if not hasattr(self, 'saver') or self.saver is None: return self.saver.load(self, weight_path, resume) def save_checkpoint(self): if not hasattr(self, 'saver') or self.saver is None: return self.saver.save(self, self.state['epoch']) def adjust_learning_rate(self, monitor_val): if not hasattr(self, 'lr_scheduler') or self.lr_scheduler is None: return self.lr_scheduler.step(monitor_val) def train(self, data_loader, update_fn): self.model.train() self.history.clear() progress = tqdm.tqdm(data_loader) progress.set_description(f'Epoch#{self.state["epoch"] + 1}') for data in progress: loss, accuracy = update_fn(self.model, data) progress.set_postfix(self.history.add({**loss, **accuracy})) self.optimizer.zero_grad() loss['loss/loss'].backward() self.optimizer.step() return self.history.metric() def evaluate(self, data_loader, inference_fn): self.model.eval() self.history.clear() progress = tqdm.tqdm(data_loader) progress.set_description('Evaluate') for data in progress: log_values = inference_fn(self.model, data) loss, accuracy = log_values if isinstance(log_values, tuple) else (log_values, {}) progress.set_postfix(self.history.add({**loss, **accuracy}, log_suffix='_val')) return self.history.metric() def dummy_run(self, train_loader, validate_loader, update_fn, inference_fn, epoch_fn, epochs): for epoch in range(epochs): self.history.clear() self.state['epoch'] = epoch self.dummy_train(train_loader, update_fn) self.dummy_evaluate(validate_loader, inference_fn) if isinstance(epoch_fn, list): for fn in epoch_fn: fn(epoch, self.history) elif callable(epoch_fn): epoch_fn(epoch, self.history) self._log.debug(f'OneEstimator<{self.name}> epoch#{epoch} end') def dummy_train(self, data_loader, update_fn): self.model.train() progress = tqdm.tqdm(data_loader) progress.set_description(f'Epoch#{self.state["epoch"] + 1}') for data in progress: _stat = update_fn(self.model, data) loss, stat = _stat if len(_stat) == 2 else (_stat, {}) self.optimizer.zero_grad() loss.backward() self.optimizer.step() progress.set_postfix(self.history.add(stat)) def dummy_evaluate(self, data_loader, inference_fn): self.model.eval() progress = tqdm.tqdm(data_loader) for data in progress: _stat = inference_fn(self.model, data) if len(_stat) == 2: _, stat = _stat elif isinstance(_stat, dict): stat = _stat else: stat = {} progress.set_postfix(self.history.add(stat, log_suffix='_val'))