def loss_batch( model: nn.Module, xb: Tensor, yb: Tensor, loss_func: OptionalLossFunction = None, opt: OptionalOptimizer = None, cb_handler: Optional[CallbackHandler] = None, ) -> Tuple[Union[Tensor, int, float, str]]: "Calculate loss and metrics for a batch, call out to callbacks as necessary." cb_handler = if_none(cb_handler, CallbackHandler()) if not is_listy(xb): xb = [xb] if not is_listy(yb): yb = [yb] out = model(*xb) out = cb_handler.on_loss_begin(out) if not loss_func: return to_detach(out), yb[0].detach() loss = loss_func(out, *yb) if opt is not None: loss = cb_handler.on_backward_begin(loss) loss.backward() cb_handler.on_backward_end() opt.step() cb_handler.on_step_end() opt.zero_grad() return loss.detach().cpu()
def hook_fn_wrapper(self, module: nn.Module, input: Tensors, output: Tensors): "Applies `hook_func` to `module`, `input`, `output`." if self.detach: input = (o.detach() for o in input) if is_listy(input) else input.detach() output = ( o.detach() for o in output) if is_listy(output) else output.detach() self.stored = self.hook_fn(module, input, output)
def lr_find( learn, start_lr: Floats = 1e-7, end_lr: Floats = 10, num_it: int = 100, stop_div: bool = True, **kwargs: Any ): "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes." start_lr = np.array(start_lr) if is_listy(start_lr) else start_lr end_lr = np.array(end_lr) if is_listy(end_lr) else end_lr cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div) a = int(np.ceil(num_it / len(learn.data.train_dl))) learn.fit(a, start_lr, callbacks=[cb], **kwargs)
def validate( model: nn.Module, dl: DataLoader, loss_func: OptionalLossFunction = None, cb_handler: Optional[CallbackHandler] = None, pbar: Optional[PBar] = None, average=True, n_batch: Optional[int] = None, ) -> Iterator[Tuple[Union[Tensor, int], ...]]: "Calculate loss and metrics for the validation set." model.eval() with torch.no_grad(): val_losses, nums = [], [] for xb, yb in progress_bar(dl, parent=pbar, leave=(pbar is not None)): if cb_handler: xb, yb = cb_handler.on_batch_begin(xb, yb, train=False, phase=Phase.VAL) val_losses.append( loss_batch(model, xb, yb, loss_func, cb_handler=cb_handler)) if not is_listy(yb): yb = [yb] nums.append(yb[0].shape[0]) if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break if n_batch and (len(nums) >= n_batch): break nums = np.array(nums, dtype=np.float32) if average: return (to_numpy(torch.stack(val_losses)) * nums).sum() / nums.sum() else: return val_losses
def _predict_on_dl(self, dl, phase, callbacks=None, callback_fns=None, metrics=None): assert dl is not None, "A dataloader must be provided" assert ( phase is not None ), "A phase must be provided: {Phase.TRAIN, Phase.VAL, Phase.TEST}" callbacks_fns = [cb(self) for cb in if_none(callback_fns, [])] callbacks = self.callbacks + if_none(callbacks, []) + if_none( callbacks_fns, []) if phase is not Phase.TEST: metrics = if_none(metrics, self.metrics) else: metrics = [] cb_handler = CallbackHandler(callbacks=callbacks, metrics=metrics) with torch.no_grad(): self.model.eval() cb_handler.on_epoch_begin() for xb, yb in progbar(dl): xb, yb = cb_handler.on_batch_begin(xb, yb, train=False, phase=phase) if not is_listy(xb): xb = [xb] out = self.model(*xb) out = cb_handler.on_loss_begin(out) out = cb_handler.on_batch_end(out)
def freeze_layer_groups(self, layer_group_idxs): if not is_listy(layer_group_idxs): layer_group_idxs = [layer_group_idxs] super().unfreeze() for i in layer_group_idxs: for l in self.layer_groups[i]: if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False)
def on_batch_end(self, last_output, last_target, train, **kwargs): if not is_listy(last_target): last_target = [last_target] self.count += last_target[0].size(0) try: self.val += (last_target[0].size(0) * self.func(last_output, *last_target).detach().cpu()) except TypeError: # catch for multiple arguments self.val += (last_target[0].size(0) * self.func(last_output, last_target[0]).detach().cpu())
def __init__(self, vals: StartOptEnd, n_iter: int, func: Optional[AnnealFunc] = None): if is_tuple(vals): self.start, self.end = (vals[0], vals[1]) elif is_listy(vals): self.start, self.end = vals, listify(0, vals) else: self.start, self.end = vals, 0 self.n_iter = max(1, n_iter) if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no else: self.func = func self.n = 0
def __init__( self, learn, lr_max: float, moms: Floats = (0.95, 0.85), div_factor: float = 25.0, pct_start: float = 0.3, ): super().__init__() self.learn = learn self.lr_max = lr_max self.moms = moms self.div_factor = div_factor self.pct_start = pct_start self.moms = tuple(listify(self.moms, 2)) if is_listy(self.lr_max): self.lr_max = np.array(self.lr_max)
def unfreeze_layer_groups(self, layer_group_idxs): if not is_listy(layer_group_idxs): layer_group_idxs = [layer_group_idxs] layer_group_idxs_to_freeze = list( set(list(range(len(self.layer_groups)))) - set(layer_group_idxs)) self.freeze_layer_groups(layer_group_idxs_to_freeze)
def _hook_inner(m, i, o): return o if isinstance(o, Tensor) else o if is_listy(o) else list(o)
def to_detach(b: Tensors): "Recursively detach lists of tensors in `b `" if is_listy(b): return [to_detach(o) for o in b] return b.detach() if isinstance(b, Tensor) else b