def export(model, filename_or_stream='export.pkl', pickle_module=pickle, pickle_protocol=2): from fastai.torch_core import rank_distrib import torch "Export the content of `self` without the items and the optimizer state for inference" if rank_distrib(): return # don't export if child proc model._end_cleanup() old_dbunch = model.dls model.dls = model.dls.new_empty() state = model.opt.state_dict() if model.opt is not None else None model.opt = None target = open( model.path / filename_or_stream, 'wb') if is_pathlike(filename_or_stream) else filename_or_stream with warnings.catch_warnings(): # To avoid the warning that come from PyTorch about model not being checked warnings.simplefilter("ignore") torch.save(model, target, pickle_module=pickle_module, pickle_protocol=pickle_protocol) model.create_opt() if state is not None: model.opt.load_state_dict(state) model.dls = old_dbunch
def before_fit(self): # Make sure this is a training run self.run = not hasattr(self.learn, 'lr_finder') and \ not hasattr(self, "gather_preds") and rank_distrib() == 0 if not self.run: return # Log Hyper Parameters params = self.learn.gather_args() self.logger.log_hyperparams(params)
def before_fit(self): self.run = (rank_distrib() == 0) and not (hasattr( self.learn, 'lr_finder') or hasattr(self, "gather_preds")) if self.run: self.writer = SummaryWriter(log_dir=self.log_dir) self.train_metrics = listify( ifnone(self.train_metrics, self.learn.loss_func)) self.train_metric_names = listify( ifnone(self.train_metric_names, [self.get_name(s) for s in self.train_metrics])) self.train_metric_names = [ 'train_' + name for name in self.train_metric_names ] self.smooth_dict = { name: (0, 0) for name in self.train_metric_names }
def on_batch_end(self, **kwargs) -> None: if dist.is_initialized() and rank_distrib() != 0: return if kwargs['train']: # print('loss', kwargs['last_loss'].item()) self.tb_writer.add_scalar('Train/loss', kwargs['last_loss'].item(), kwargs['iteration'])
def before_fit(self): self.run = not hasattr(self.learn, 'lr_finder') and hasattr(self, "gather_preds") and rank_distrib()==0 if self.run: self._setup()