コード例 #1
0
def loss_batch(
    model: nn.Module,
    xb: Tensor,
    yb: Tensor,
    loss_func: OptionalLossFunction = None,
    opt: OptionalOptimizer = None,
    cb_handler: Optional[CallbackHandler] = None,
) -> Tuple[Union[Tensor, int, float, str]]:
    "Calculate loss and metrics for a batch, call out to callbacks as necessary."
    cb_handler = if_none(cb_handler, CallbackHandler())
    if not is_listy(xb):
        xb = [xb]
    if not is_listy(yb):
        yb = [yb]
    out = model(*xb)
    out = cb_handler.on_loss_begin(out)

    if not loss_func:
        return to_detach(out), yb[0].detach()
    loss = loss_func(out, *yb)

    if opt is not None:
        loss = cb_handler.on_backward_begin(loss)
        loss.backward()
        cb_handler.on_backward_end()
        opt.step()
        cb_handler.on_step_end()
        opt.zero_grad()

    return loss.detach().cpu()
コード例 #2
0
ファイル: hooks.py プロジェクト: klu1211/pytorch-toolbox
 def hook_fn_wrapper(self, module: nn.Module, input: Tensors,
                     output: Tensors):
     "Applies `hook_func` to `module`, `input`, `output`."
     if self.detach:
         input = (o.detach()
                  for o in input) if is_listy(input) else input.detach()
         output = (
             o.detach()
             for o in output) if is_listy(output) else output.detach()
     self.stored = self.hook_fn(module, input, output)
コード例 #3
0
def lr_find(
    learn,
    start_lr: Floats = 1e-7,
    end_lr: Floats = 10,
    num_it: int = 100,
    stop_div: bool = True,
    **kwargs: Any
):
    "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes."
    start_lr = np.array(start_lr) if is_listy(start_lr) else start_lr
    end_lr = np.array(end_lr) if is_listy(end_lr) else end_lr
    cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
    a = int(np.ceil(num_it / len(learn.data.train_dl)))
    learn.fit(a, start_lr, callbacks=[cb], **kwargs)
コード例 #4
0
def validate(
    model: nn.Module,
    dl: DataLoader,
    loss_func: OptionalLossFunction = None,
    cb_handler: Optional[CallbackHandler] = None,
    pbar: Optional[PBar] = None,
    average=True,
    n_batch: Optional[int] = None,
) -> Iterator[Tuple[Union[Tensor, int], ...]]:
    "Calculate loss and metrics for the validation set."
    model.eval()
    with torch.no_grad():
        val_losses, nums = [], []
        for xb, yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)):
            if cb_handler:
                xb, yb = cb_handler.on_batch_begin(xb,
                                                   yb,
                                                   train=False,
                                                   phase=Phase.VAL)
            val_losses.append(
                loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler))
            if not is_listy(yb):
                yb = [yb]
            nums.append(yb[0].shape[0])
            if cb_handler and cb_handler.on_batch_end(val_losses[-1]):
                break
            if n_batch and (len(nums) >= n_batch):
                break
        nums = np.array(nums, dtype=np.float32)
        if average:
            return (to_numpy(torch.stack(val_losses)) *
                    nums).sum() / nums.sum()
        else:
            return val_losses
コード例 #5
0
    def _predict_on_dl(self,
                       dl,
                       phase,
                       callbacks=None,
                       callback_fns=None,
                       metrics=None):
        assert dl is not None, "A dataloader must be provided"
        assert (
            phase is not None
        ), "A phase must be provided: {Phase.TRAIN, Phase.VAL, Phase.TEST}"
        callbacks_fns = [cb(self) for cb in if_none(callback_fns, [])]
        callbacks = self.callbacks + if_none(callbacks, []) + if_none(
            callbacks_fns, [])

        if phase is not Phase.TEST:
            metrics = if_none(metrics, self.metrics)
        else:
            metrics = []
        cb_handler = CallbackHandler(callbacks=callbacks, metrics=metrics)
        with torch.no_grad():
            self.model.eval()
            cb_handler.on_epoch_begin()
            for xb, yb in progbar(dl):
                xb, yb = cb_handler.on_batch_begin(xb,
                                                   yb,
                                                   train=False,
                                                   phase=phase)
                if not is_listy(xb):
                    xb = [xb]
                out = self.model(*xb)
                out = cb_handler.on_loss_begin(out)
                out = cb_handler.on_batch_end(out)
コード例 #6
0
 def freeze_layer_groups(self, layer_group_idxs):
     if not is_listy(layer_group_idxs):
         layer_group_idxs = [layer_group_idxs]
     super().unfreeze()
     for i in layer_group_idxs:
         for l in self.layer_groups[i]:
             if not self.train_bn or not isinstance(l, bn_types):
                 requires_grad(l, False)
コード例 #7
0
ファイル: core.py プロジェクト: klu1211/pytorch-toolbox
 def on_batch_end(self, last_output, last_target, train, **kwargs):
     if not is_listy(last_target):
         last_target = [last_target]
     self.count += last_target[0].size(0)
     try:
         self.val += (last_target[0].size(0) *
                      self.func(last_output, *last_target).detach().cpu())
     except TypeError:  # catch for multiple arguments
         self.val += (last_target[0].size(0) *
                      self.func(last_output, last_target[0]).detach().cpu())
コード例 #8
0
ファイル: core.py プロジェクト: klu1211/pytorch-toolbox
 def __init__(self,
              vals: StartOptEnd,
              n_iter: int,
              func: Optional[AnnealFunc] = None):
     if is_tuple(vals):
         self.start, self.end = (vals[0], vals[1])
     elif is_listy(vals):
         self.start, self.end = vals, listify(0, vals)
     else:
         self.start, self.end = vals, 0
     self.n_iter = max(1, n_iter)
     if func is None:
         self.func = annealing_linear if is_tuple(vals) else annealing_no
     else:
         self.func = func
     self.n = 0
コード例 #9
0
 def __init__(
     self,
     learn,
     lr_max: float,
     moms: Floats = (0.95, 0.85),
     div_factor: float = 25.0,
     pct_start: float = 0.3,
 ):
     super().__init__()
     self.learn = learn
     self.lr_max = lr_max
     self.moms = moms
     self.div_factor = div_factor
     self.pct_start = pct_start
     self.moms = tuple(listify(self.moms, 2))
     if is_listy(self.lr_max):
         self.lr_max = np.array(self.lr_max)
コード例 #10
0
 def unfreeze_layer_groups(self, layer_group_idxs):
     if not is_listy(layer_group_idxs):
         layer_group_idxs = [layer_group_idxs]
     layer_group_idxs_to_freeze = list(
         set(list(range(len(self.layer_groups)))) - set(layer_group_idxs))
     self.freeze_layer_groups(layer_group_idxs_to_freeze)
コード例 #11
0
ファイル: hooks.py プロジェクト: klu1211/pytorch-toolbox
def _hook_inner(m, i, o):
    return o if isinstance(o, Tensor) else o if is_listy(o) else list(o)
コード例 #12
0
def to_detach(b: Tensors):
    "Recursively detach lists of tensors in `b `"
    if is_listy(b):
        return [to_detach(o) for o in b]
    return b.detach() if isinstance(b, Tensor) else b