def train_fn( model, loader, device, loss_fn, optimizer, scheduler=None, accumulation_steps=1, verbose=True, tensorboard_logger=None, logdir=None, ): """Train step. Args: model (nn.Module): model to train loader (DataLoader): loader with data device (str or torch.device): device to use for placing batches loss_fn (nn.Module): loss function, should be callable optimizer (torch.optim.Optimizer): model parameters optimizer scheduler ([type], optional): batch scheduler to use. Default is `None`. accumulation_steps (int, optional): number of steps to accumulate gradients. Default is `1`. verbose (bool, optional): verbosity mode. Default is True. Returns: dict with metics computed during the training on loader """ model.train() metrics = {"loss": 0.0} n_batches = len(loader) indices_to_save = [ int(n_batches * pcnt) for pcnt in np.arange(0.1, 1, 0.1) ] with tqdm(total=len(loader), desc="train", disable=not verbose) as progress: for idx, batch in enumerate(loader): images, targets, target_availabilities = t2d( ( batch["image"], batch["target_positions"], batch["target_availabilities"], ), device, ) zero_grad(optimizer) predictions, confidences = model(images) loss = loss_fn(targets, predictions, confidences, target_availabilities) _loss = loss.detach().item() metrics["loss"] += _loss if tensorboard_logger is not None: tensorboard_logger.metric("loss", _loss, idx) loss.backward() progress.set_postfix_str(f"loss - {_loss:.5f}") progress.update(1) if (idx + 1) in indices_to_save and logdir is not None: checkpoint = make_checkpoint("train", idx + 1, model) save_checkpoint(checkpoint, logdir, f"train_{idx}.pth") if (idx + 1) % accumulation_steps == 0: optimizer.step() if scheduler is not None: scheduler.step() if idx == DEBUG: break metrics["loss"] /= idx + 1 return metrics
def train_fn( model, loader, device, loss_fn, optimizer, scheduler=None, accumulation_steps=1, verbose=True, tensorboard_logger=None, logdir=None, validation_fn=None, ): """Train step. Args: model (nn.Module): model to train loader (DataLoader): loader with data device (str or torch.device): device to use for placing batches loss_fn (nn.Module): loss function, should be callable optimizer (torch.optim.Optimizer): model parameters optimizer scheduler ([type], optional): batch scheduler to use. Default is `None`. accumulation_steps (int, optional): number of steps to accumulate gradients. Default is `1`. verbose (bool, optional): verbosity mode. Default is True. Returns: dict with metics computed during the training on loader """ model.train() metrics = {"regression_loss": 0.0, "mask_loss": 0.0, "loss": 0.0} n_batches = len(loader) indices_to_save = [ int(n_batches * pcnt) for pcnt in np.arange(0.1, 1, 0.1) ] last_score = 0.0 with tqdm(total=len(loader), desc="train", disable=not verbose) as progress: for idx, batch in enumerate(loader): (images, targets, target_availabilities, masks) = t2d( ( batch["image"], batch["target_positions"], batch["target_availabilities"], batch["mask"], ), device, ) zero_grad(optimizer) predictions, confidences, masks_logits = model(images) rloss = loss_fn(targets, predictions, confidences, target_availabilities) mloss = 1e4 * F.binary_cross_entropy_with_logits( masks_logits, masks) loss = rloss + mloss _rloss = rloss.detach().item() _mloss = mloss.detach().item() _loss = loss.detach().item() metrics["regression_loss"] += _rloss metrics["mask_loss"] += _mloss metrics["mask_loss"] += _loss if (idx + 1) % 30_000 == 0 and validation_fn is not None: score = validation_fn(model=model, device=device) model.train() last_score = score if logdir is not None: checkpoint = make_checkpoint("train", idx + 1, model) save_checkpoint(checkpoint, logdir, f"train_{idx}.pth") else: score = None if tensorboard_logger is not None: tensorboard_logger.metric("regression_loss", _rloss, idx) tensorboard_logger.metric("mask_loss", _mloss, idx) tensorboard_logger.metric("loss", _loss, idx) if score is not None: tensorboard_logger.metric("score", score, idx) if (idx + 1) % 1_000 == 0: # masks_gt - (bs)x(1)x(h)x(w) # masks - (bs)x(1)x(h)x(w) tensorboard_logger.writer.add_images( "gt_vs_mask", torch.cat([masks, torch.sigmoid(masks_logits)], dim=-1), idx, ) loss.backward() progress.set_postfix_str(f"rloss - {_rloss:.5f}, " f"mloss - {_mloss:.5f}, " f"loss - {_loss:.5f}, " f"score - {last_score:.5f}") progress.update(1) if (idx + 1) % accumulation_steps == 0: optimizer.step() if scheduler is not None: scheduler.step() if idx == DEBUG: break