Пример #1
0
def export(model,
           filename_or_stream='export.pkl',
           pickle_module=pickle,
           pickle_protocol=2):
    from fastai.torch_core import rank_distrib
    import torch
    "Export the content of `self` without the items and the optimizer state for inference"
    if rank_distrib(): return  # don't export if child proc
    model._end_cleanup()
    old_dbunch = model.dls
    model.dls = model.dls.new_empty()
    state = model.opt.state_dict() if model.opt is not None else None
    model.opt = None
    target = open(
        model.path / filename_or_stream,
        'wb') if is_pathlike(filename_or_stream) else filename_or_stream
    with warnings.catch_warnings():
        # To avoid the warning that come from PyTorch about model not being checked
        warnings.simplefilter("ignore")
        torch.save(model,
                   target,
                   pickle_module=pickle_module,
                   pickle_protocol=pickle_protocol)
    model.create_opt()
    if state is not None:
        model.opt.load_state_dict(state)
    model.dls = old_dbunch
Пример #2
0
    def before_fit(self):
        # Make sure this is a training run
        self.run = not hasattr(self.learn, 'lr_finder') and \
                   not hasattr(self, "gather_preds") and rank_distrib() == 0
        if not self.run:
            return

        # Log Hyper Parameters
        params = self.learn.gather_args()
        self.logger.log_hyperparams(params)
Пример #3
0
    def before_fit(self):
        self.run = (rank_distrib() == 0) and not (hasattr(
            self.learn, 'lr_finder') or hasattr(self, "gather_preds"))

        if self.run:
            self.writer = SummaryWriter(log_dir=self.log_dir)
            self.train_metrics = listify(
                ifnone(self.train_metrics, self.learn.loss_func))
            self.train_metric_names = listify(
                ifnone(self.train_metric_names,
                       [self.get_name(s) for s in self.train_metrics]))
            self.train_metric_names = [
                'train_' + name for name in self.train_metric_names
            ]

            self.smooth_dict = {
                name: (0, 0)
                for name in self.train_metric_names
            }
Пример #4
0
 def on_batch_end(self, **kwargs) -> None:
     if dist.is_initialized() and rank_distrib() != 0: return
     if kwargs['train']:
         # print('loss', kwargs['last_loss'].item())
         self.tb_writer.add_scalar('Train/loss', kwargs['last_loss'].item(),
                                   kwargs['iteration'])
Пример #5
0
 def before_fit(self):
     self.run = not hasattr(self.learn, 'lr_finder') and hasattr(self, "gather_preds") and rank_distrib()==0
     if self.run:
         self._setup()