def __init__( self, model: torch.nn.Module, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, epochs: int, device: torch.device, train_loader: DataLoader, val_loader: Optional[DataLoader] = None, scheduler: Optional = None, # Type: torch.optim.lr_scheduler._LRScheduler writer: Optional[SummaryWriter] = None, save_path: Optional[str] = None, checkpoint_path: Optional[str] = None, show_pbar: bool = True, ) -> None: self.writer = writer # Saving self.save_path = save_path # Device self.device = device # Data self.train_loader = train_loader self.val_loader = val_loader # Model self.model = model self.loss_fn = loss_fn self.optimizer = optimizer self.scheduler = scheduler self.epochs = epochs self.start_epoch = 0 if checkpoint_path: self._load_from_checkpoint(checkpoint_path) # Metrics self.train_loss_metric = LossMetric() self.val_loss_metric = LossMetric() self.train_acc_metric = BinaryAccuracyMetric(threshold=0.5) self.val_acc_metric = BinaryAccuracyMetric(threshold=0.5) # Progress bar self.show_pbar = show_pbar
def __init__( self, model: torch.nn.Module, device: torch.device, loader: DataLoader, checkpoint_path: Optional[str] = None, ) -> None: # Device self.device = device # Data self.loader = loader # Model self.model = model if checkpoint_path: self._load_from_checkpoint(checkpoint_path) # Metrics self.acc_metric = BinaryAccuracyMetric(threshold=0.5)