def train( self, problem: Problem, startEpoch: int, nEpochs: int, batchSize: int, scheduler: _LRScheduler, ) -> Iterator[FractionalPerformanceSummary]: self.cur_epoch = startEpoch d_types = [d.data_type for d in problem.datasets] assert Split.TRAIN in d_types, "training dataset should be included" ################### Start Training ##################################### logger.info("Model layers") logger.info(str(self.model.modules)) loaders = self._get_loaders(problem, batchSize=batchSize) # start epochs while self.cur_epoch < nEpochs: self.cur_epoch += 1 logger.info("Starting epoch %d" % self.cur_epoch) epoch_stats = self._pass_one_epoch(problem, loaders, Mode.TRAIN) logger.info("Finished epoch %d" % self.cur_epoch) scheduler.step() yield self._get_worker_performance_summary(epoch_stats)
def fit_one( self, model, tasks: List[Task], dataloader: DataLoader, optimizer: Optimizer, scheduler: _LRScheduler, training_logger: ResultLogger, train_model=True, ) -> float: model.train(train_model) for task in tasks: task.train() task_names = [t.name for t in tasks] losses = 0.0 total = 0.0 for i, (x, y) in enumerate(dataloader): outputs = self._step(model, tasks, x, y, optimizer) training_logger.log(outputs, task_names, i + 1, len(dataloader)) for o in outputs: losses += o.loss.item() total += 1.0 scheduler.step() return losses / total
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model': checkpoint = torch.load(path_to_checkpoint) self.load_state_dict(checkpoint['state_dict']) step = checkpoint['step'] if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if scheduler is not None: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) return step
def project(self, latents: Latents, images: torch.Tensor, optimizer: Optimizer, num_steps: int, loss_function: Callable, lr_scheduler: _LRScheduler = None) -> Tuple[LatentPaths, Latents]: pbar = tqdm(range(num_steps), leave=False) latent_path = [] noise_path = [] best_latent = best_noise = best_psnr = None for i in pbar: img_gen, _ = self.generate(latents) batch, channel, height, width = img_gen.shape if height > 256: factor = height // 256 img_gen = img_gen.reshape( batch, channel, height // factor, factor, width // factor, factor ) img_gen = img_gen.mean([3, 5]) # # n_loss = noise_regularize(noises) loss, loss_dict = loss_function(img_gen, images) optimizer.zero_grad() loss.backward() optimizer.step() loss_dict['psnr'] = self.psnr(img_gen, images).item() loss_dict['lr'] = optimizer.param_groups[0]["lr"] if lr_scheduler is not None: lr_scheduler.step() self.log.append(loss_dict) if best_psnr is None or best_psnr < loss_dict['psnr']: best_psnr = loss_dict['psnr'] best_latent = latents.latent.detach().clone().cpu() best_noise = [noise.detach().clone().cpu() for noise in latents.noise] if i % self.debug_step == 0: latent_path.append(latents.latent.detach().clone().cpu()) noise_path.append([noise.detach().clone().cpu() for noise in latents.noise]) loss_description = "; ".join(f"{key}: {value:.6f}" for key, value in loss_dict.items()) pbar.set_description(loss_description) loss_dict['iteration'] = i if self.abort_condition is not None and self.abort_condition(loss_dict): break latent_path.append(latents.latent.detach().clone().cpu()) noise_path.append([noise.detach().clone().cpu() for noise in latents.noise]) return LatentPaths(latent_path, noise_path), Latents(best_latent, best_noise)
def enumerate_scheduler(scheduler: _LRScheduler, steps: int) -> List[float]: """ Reads the current learning rate via get_last_lr, run 1 scheduler step, and repeat. Returns the LR values. """ lrs = [] for _ in range(steps): lr = scheduler.get_last_lr() # type: ignore assert isinstance(lr, list) assert len(lr) == 1 lrs.append(lr[0]) scheduler.step() return lrs
def train_step(net: nn.Module, crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool, inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]: outputs = net(inputs) loss = crit(outputs, targets) optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(net.parameters(), grad_clip) optim.step() if sched and not sched_on_epoch: sched.step() return outputs, loss.item()
def simulate_values( # type: ignore[override] cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: Any) -> List[List[int]]: """Method to simulate scheduled values during num_events events. Args: num_events (int): number of events during the simulation. lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap. Returns: list of pairs: [event_index, value] """ if not isinstance(lr_scheduler, _LRScheduler): raise TypeError( "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, " f"but given {type(lr_scheduler)}") # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which # should be replicated in order to simulate LR values and # not perturb original scheduler. with tempfile.TemporaryDirectory() as tmpdirname: cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt" obj = { "lr_scheduler": lr_scheduler.state_dict(), "optimizer": lr_scheduler.optimizer.state_dict( ), # type: ignore[attr-defined] } torch.save(obj, cache_filepath.as_posix()) values = [] scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs) # type: ignore[call-arg] for i in range(num_events): params = [ p[scheduler.param_name] for p in scheduler.optimizer_param_groups ] values.append([i] + params) scheduler(engine=None) obj = torch.load(cache_filepath.as_posix()) lr_scheduler.load_state_dict(obj["lr_scheduler"]) lr_scheduler.optimizer.load_state_dict( obj["optimizer"]) # type: ignore[attr-defined] return values
def train(model: nn.Module, loader: DataLoader, class_loss: nn.Module, optimizer: Optimizer, scheduler: _LRScheduler, epoch: int, callback: VisdomLogger, freq: int, ex: Experiment = None) -> None: model.train() device = next(model.parameters()).device to_device = lambda x: x.to(device, non_blocking=True) loader_length = len(loader) train_losses = AverageMeter(device=device, length=loader_length) train_accs = AverageMeter(device=device, length=loader_length) pbar = tqdm(loader, ncols=80, desc='Training [{:03d}]'.format(epoch)) for i, (batch, labels, indices) in enumerate(pbar): batch, labels, indices = map(to_device, (batch, labels, indices)) logits, features = model(batch) loss = class_loss(logits, labels).mean() acc = (logits.detach().argmax(1) == labels).float().mean() optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() train_losses.append(loss) train_accs.append(acc) if callback is not None and not (i + 1) % freq: step = epoch + i / loader_length callback.scalar('xent', step, train_losses.last_avg, title='Train Losses') callback.scalar('train_acc', step, train_accs.last_avg, title='Train Acc') if ex is not None: for i, (loss, acc) in enumerate( zip(train_losses.values_list, train_accs.values_list)): step = epoch + i / loader_length ex.log_scalar('train.loss', loss, step=step) ex.log_scalar('train.acc', acc, step=step)
def save_checkpoint( model: nn.Module = None, optimizer: optim.Optimizer = None, scheduler: sche._LRScheduler = None, amp=None, exp_name: str = "", current_epoch: int = 1, full_net_path: str = "", state_net_path: str = "", ): """ 保存完整参数模型(大)和状态参数模型(小) Args: model (nn.Module): model object optimizer (optim.Optimizer): optimizer object scheduler (sche._LRScheduler): scheduler object amp (): apex.amp exp_name (str): exp_name current_epoch (int): in the epoch, model **will** be trained full_net_path (str): the path for saving the full model parameters state_net_path (str): the path for saving the state dict. """ state_dict = { "arch": exp_name, "epoch": current_epoch, "net_state": model.state_dict(), "opti_state": optimizer.state_dict(), "sche_state": scheduler.state_dict(), "amp_state": amp.state_dict() if amp else None, } torch.save(state_dict, full_net_path) torch.save(model.state_dict(), state_net_path)
def snapshot(self, net: torch.nn.Module, opt: Optimizer, sched: _LRScheduler = None, epoch: int = None, subdir='.'): """ Writes a snapshot of the training, i.e. network weights, optimizer state and scheduler state to a file in the log directory. :param net: the neural network :param opt: the optimizer used :param sched: the learning rate scheduler used :param epoch: the current epoch :param subdir: if given, creates a subdirectory in the log directory. The data is written to a file in this subdirectory instead. :return: """ outfile = pt.join(self.dir, subdir, 'snapshot.pt') if not pt.exists(os.path.dirname(outfile)): os.makedirs(os.path.dirname(outfile)) torch.save( { 'net': net.state_dict(), 'opt': opt.state_dict(), 'sched': sched.state_dict(), 'epoch': epoch }, outfile) return outfile
def resume_checkpoint( model: nn.Module = None, optimizer: optim.Optimizer = None, scheduler: sche._LRScheduler = None, exp_name: str = "", load_path: str = "", mode: str = "all", ): """ 从保存节点恢复模型 Args: model (nn.Module): model object optimizer (optim.Optimizer): optimizer object scheduler (sche._LRScheduler): scheduler object exp_name (str): exp_name load_path (str): 模型存放路径 mode (str): 选择哪种模型恢复模式: - 'all': 回复完整模型,包括训练中的的参数; - 'onlynet': 仅恢复模型权重参数 Returns mode: 'all' start_epoch; 'onlynet' None """ if os.path.exists(load_path) and os.path.isfile(load_path): construct_print(f"Loading checkpoint '{load_path}'") checkpoint = torch.load(load_path) if mode == "all": if exp_name == checkpoint["arch"]: start_epoch = checkpoint["epoch"] model.load_state_dict(checkpoint["net_state"]) optimizer.load_state_dict(checkpoint["opti_state"]) scheduler.load_state_dict(checkpoint["sche_state"]) construct_print(f"Loaded '{load_path}' " f"(will train at epoch" f" {checkpoint['epoch']})") return start_epoch else: raise Exception(f"{load_path} does not match.") elif mode == "onlynet": model.load_state_dict(checkpoint) construct_print(f"Loaded checkpoint '{load_path}' " f"(only has the model's weight params)") else: raise NotImplementedError else: raise Exception(f"{load_path}路径不正常,请检查")
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str: path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth') checkpoint = { 'state_dict': self.state_dict(), 'step': step, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict() } torch.save(checkpoint, path_to_checkpoint) return path_to_checkpoint
def _better_lr_sched_repr(lr_sched: _LRScheduler) -> str: return ( lr_sched.__class__.__name__ + "(\n " + "\n ".join( f"{k}: {v}" for k, v in lr_sched.state_dict().items() if not k.startswith("_") ) + "\n)" )
def fit_support( self, model, tasks: List[Task], dataloader: DataLoader, optimizer: Optimizer, scheduler: _LRScheduler, training_logger: ResultLogger, ): support_loss = 1.0 support_epoch = 0 # Don't change default optimizer and scheduler states optimizer_state_dict = deepcopy(optimizer.state_dict()) scheduler_state_dict = deepcopy(scheduler.state_dict()) # Reset tasks states for task in tasks: task.reset() model.freeze_weights() while (support_loss > self.support_min_loss and support_epoch < self.support_max_epochs): support_epoch += 1 support_loss = self.fit_one( model, tasks, dataloader, optimizer, scheduler, training_logger.epoch(support_epoch, self.support_max_epochs), train_model=False, ) optimizer.load_state_dict(optimizer_state_dict) scheduler.load_state_dict(scheduler_state_dict) model.defreeze_weights()
def collect_state_dict( self, iteration: Union[float, int], model: EmmentalModel, optimizer: Optimizer, lr_scheduler: _LRScheduler, metric_dict: Dict[str, float], ) -> Dict[str, Any]: r"""Collect the state dict of the model. Args: iteration(float or int): The current iteration. model(EmmentalModel): The model to checkpoint. optimizer(Optimizer): The optimizer used during training process. lr_scheduler(_LRScheduler): Learning rate scheduler. metric_dict(dict): the metric dict. Returns: dict: The state dict. """ model_params = { "name": model.name, "module_pool": model.collect_state_dict(), # "task_names": model.task_names, # "task_flows": model.task_flows, # "loss_funcs": model.loss_funcs, # "output_funcs": model.output_funcs, # "scorers": model.scorers, } state_dict = { "iteration": iteration, "model": model_params, "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None, "metric_dict": metric_dict, } return state_dict
def train(model: MoleculeModel, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: TrainArgs, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`. :param loss_func: Loss function. :param optimizer: An optimizer. :param scheduler: A learning rate scheduler. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for recording output. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() loss_sum, iter_count = 0, 0 for batch in tqdm(data_loader, total=len(data_loader)): # Prepare batch batch: MoleculeDataset mol_batch, features_batch, target_batch = batch.batch_graph( ), batch.features(), batch.targets() mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # Run model model.zero_grad() preds = model(mol_batch, features_batch) # Move tensors to correct device mask = mask.to(preds.device) targets = targets.to(preds.device) class_weights = torch.ones(targets.shape, device=preds.device) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(batch) loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def resume_checkpoint( model: nn.Module = None, optimizer: optim.Optimizer = None, scheduler: sche._LRScheduler = None, amp=None, exp_name: str = "", load_path: str = "", mode: str = "all", ): """ 从保存节点恢复模型 Args: model (nn.Module): model object optimizer (optim.Optimizer): optimizer object scheduler (sche._LRScheduler): scheduler object amp (): apex.amp exp_name (str): exp_name load_path (str): 模型存放路径 mode (str): 选择哪种模型恢复模式: - 'all': 回复完整模型,包括训练中的的参数; - 'onlynet': 仅恢复模型权重参数 Returns mode: 'all' start_epoch; 'onlynet' None """ if os.path.exists(load_path) and os.path.isfile(load_path): construct_print(f"Loading checkpoint '{load_path}'") checkpoint = torch.load(load_path) if mode == "all": if exp_name and exp_name != checkpoint["arch"]: # 如果给定了exp_name,那么就必须匹配对应的checkpoint["arch"],否则不作要求 raise Exception( f"We can not match {exp_name} with {load_path}.") start_epoch = checkpoint["epoch"] if hasattr(model, "module"): model.module.load_state_dict(checkpoint["net_state"]) else: model.load_state_dict(checkpoint["net_state"]) optimizer.load_state_dict(checkpoint["opti_state"]) scheduler.load_state_dict(checkpoint["sche_state"]) if checkpoint.get("amp_state", None): if amp: amp.load_state_dict(checkpoint["amp_state"]) else: construct_print("You are not using amp.") else: construct_print("The state_dict of amp is None.") construct_print(f"Loaded '{load_path}' " f"(will train at epoch" f" {checkpoint['epoch']})") return start_epoch elif mode == "onlynet": if hasattr(model, "module"): model.module.load_state_dict(checkpoint) else: model.load_state_dict(checkpoint) construct_print(f"Loaded checkpoint '{load_path}' " f"(only has the model's weight params)") else: raise NotImplementedError else: raise Exception(f"{load_path}路径不正常,请检查")
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, metric_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, metric_sum, iter_count = [0]*(len(args.atom_targets) + len(args.bond_targets)), \ [0]*(len(args.atom_targets) + len(args.bond_targets)), 0 loss_weights = args.loss_weights num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch #mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) # FIXME assign 0 to None in target # targets = [[0 if x is None else x for x in tb] for tb in target_batch] targets = [torch.Tensor(np.concatenate(x)) for x in zip(*target_batch)] if next(model.parameters()).is_cuda: # mask, targets = mask.cuda(), targets.cuda() targets = [x.cuda() for x in targets] # FIXME #class_weights = torch.ones(targets.shape) #if args.cuda: # class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) targets = [x.reshape([-1, 1]) for x in targets] #FIXME mutlticlass ''' if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask ''' loss_multi_task = [] metric_multi_task = [] for target, pred, lw in zip(targets, preds, loss_weights): loss = loss_func(pred, target) loss = loss.sum() / target.shape[0] loss_multi_task.append(loss * lw) if args.cuda: metric = metric_func(pred.data.cpu().numpy(), target.data.cpu().numpy()) else: metric = metric_func(pred.data.numpy(), target.data.numpy()) metric_multi_task.append(metric) loss_sum = [x + y for x, y in zip(loss_sum, loss_multi_task)] iter_count += 1 sum(loss_multi_task).backward() optimizer.step() metric_sum = [x + y for x, y in zip(metric_sum, metric_multi_task)] if isinstance(scheduler, NoamLR) or isinstance(scheduler, SinexpLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = [x / iter_count for x in loss_sum] metric_avg = [x / iter_count for x in metric_sum] loss_sum, iter_count, metric_sum = [0]*(len(args.atom_targets) + len(args.bond_targets)), \ 0, \ [0]*(len(args.atom_targets) + len(args.bond_targets)) loss_str = ', '.join(f'lss_{i} = {lss:.4e}' for i, lss in enumerate(loss_avg)) metric_str = ', '.join(f'mc_{i} = {mc:.4e}' for i, mc in enumerate(metric_avg)) lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'{loss_str}, {metric_str}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: for i, lss in enumerate(loss_avg): writer.add_scalar(f'train_loss_{i}', lss, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data = deepcopy(data) data.shuffle() if args.uncertainty == 'bootstrap': data.sample(int(4 * len(data) / args.ensemble_size)) loss_sum, iter_count = 0, 0 num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) if model.uncertainty: pred_targets = preds[:, [ j for j in range(len(preds[0])) if j % 2 == 0 ]] pred_var = preds[:, [j for j in range(len(preds[0])) if j % 2 == 1]] loss = loss_func(pred_targets, pred_var, targets) # sigma = ((pred_targets - targets) ** 2).detach() # loss = loss_func(pred_targets, targets) * class_weights * mask # loss += nn.MSELoss(reduction='none')(pred_sigma, sigma) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, iter_count = 0, 0 num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) #print('class_weight',class_weights.size(),class_weights) #print('mask',mask.size(),mask) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() ############ add L1 regularization ############ ffn_d0_L1_reg_loss = 0 ffn_d1_L1_reg_loss = 0 ffn_d2_L1_reg_loss = 0 ffn_final_L1_reg_loss = 0 ffn_mol_L1_reg_loss = 0 lamda_ffn_d0 = 0 lamda_ffn_d1 = 0 lamda_ffn_d2 = 0 lamda_ffn_final = 0 lamda_ffn_mol = 0 for param in model.ffn_d0.parameters(): ffn_d0_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_d1.parameters(): ffn_d1_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_d2.parameters(): ffn_d2_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_final.parameters(): ffn_final_L1_reg_loss += torch.sum(torch.abs(param)) for param in model.ffn_mol.parameters(): ffn_mol_L1_reg_loss += torch.sum(torch.abs(param)) loss += lamda_ffn_d0 * ffn_d0_L1_reg_loss + lamda_ffn_d1 * ffn_d1_L1_reg_loss + lamda_ffn_d2 * ffn_d2_L1_reg_loss + lamda_ffn_final * ffn_final_L1_reg_loss + lamda_ffn_mol * ffn_mol_L1_reg_loss ############ add L1 regularization ############ ############ add L2 regularization ############ ''' ffn_d0_L2_reg_loss = 0 ffn_d1_L2_reg_loss = 0 ffn_d2_L2_reg_loss = 0 ffn_final_L2_reg_loss = 0 ffn_mol_L2_reg_loss = 0 lamda_ffn_d0 = 1e-6 lamda_ffn_d1 = 1e-6 lamda_ffn_d2 = 1e-5 lamda_ffn_final = 1e-4 lamda_ffn_mol = 1e-3 for param in model.ffn_d0.parameters(): ffn_d0_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_d1.parameters(): ffn_d1_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_d2.parameters(): ffn_d2_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_final.parameters(): ffn_final_L2_reg_loss += torch.sum(torch.square(param)) for param in model.ffn_mol.parameters(): ffn_mol_L2_reg_loss += torch.sum(torch.square(param)) loss += lamda_ffn_d0 * ffn_d0_L2_reg_loss + lamda_ffn_d1 * ffn_d1_L2_reg_loss + lamda_ffn_d2 * ffn_d2_L2_reg_loss + lamda_ffn_final * ffn_final_L2_reg_loss + lamda_ffn_mol * ffn_mol_L2_reg_loss ''' ############ add L2 regularization ############ loss_sum += loss.item() iter_count += len(mol_batch) #loss.backward(retain_graph=True) # wei, retain_graph=True loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) #print(model) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, n_iter: int = 0, logger: logging.Logger = None) -> int: model.train() data.shuffle() loss_sum, iter_count = 0, 0 num_iters = len( data ) // config.batch_size * config.batch_size # don't use the last batch if it's small, for stability iter_size = config.batch_size if config.verbose: generater = trange else: generater = range for i in generater(0, num_iters, iter_size): # Prepare batch if i + config.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + config.batch_size]) smiles_batch, target_batch = mol_batch.smiles(), mol_batch.targets() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) if config.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch) if config.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) return n_iter
def fit( step :FunctionType, epochs :int, model :ModuleType, optimizer :OptimizerType, scheduler :SchedulerType, data_loader :DataLoader, model_path :str, logger :object, early_stop :bool =False, pgd_kwargs :Optional[dict] =None, verbose :bool =False ) -> Tuple[ModuleType, dict]: """Standard pytorch boilerplate for training a model. Allows for early stopping wrt robust accuracy rather than the usual clean accuracy. """ device = next(model.parameters()).device prev_robust_acc = 0. start_train_time = time.time() logger.info('Epoch \t Seconds \t LR \t \t Train Loss \t Train Acc') for epoch in range(epochs): start_epoch_time = time.time() train_loss = 0 train_acc = 0 train_n = 0 data_generator = enumerate(data_loader) if verbose: data_generator = tqdm(data_generator, total=len(data_loader), desc=f'Epoch {epoch + 1}') for i, (X, y) in data_generator: X, y = X.to(device), y.to(device) if i == 0: first_batch = (X, y) loss, logits = step(X, y, model =model, optimizer=optimizer, scheduler=scheduler) train_loss += loss.item() * y.size(0) train_acc += (logits.argmax(dim=1) == y).sum().item() train_n += y.size(0) if early_stop: assert pgd_kwargs is not None # Check current PGD robustness of model using random minibatch X, y = first_batch pgd_delta = attack_pgd( model=model, X =X, y =y, opt =optimizer, **pgd_kwargs) model.eval() with torch.no_grad(): output = model(clamp(X + pgd_delta[:X.size(0)], pgd_kwargs['lower_limit'], pgd_kwargs['upper_limit'])) robust_acc = (output.softmax(dim=1).argmax(dim=1) == y).sum().item() / y.size(0) if robust_acc - prev_robust_acc < -0.2: logger.info('EARLY STOPPING TRIGGERED') break prev_robust_acc = robust_acc best_state_dict = copy.deepcopy(model.state_dict()) model.train() epoch_time = time.time() lr = scheduler.get_last_lr()[0] logger.info('%d \t %.1f \t \t %.4f \t %.4f \t %.4f', epoch, epoch_time - start_epoch_time, lr, train_loss / train_n, train_acc / train_n) train_time = time.time() if not early_stop: best_state_dict = model.state_dict() torch.save(best_state_dict, model_path) logger.info('Total train time: %.4f minutes', (train_time - start_train_time)/60) return model, best_state_dict
def train_epoch(epoch, epochs_count, model: nn.Module, device: torch.device, optimizer: optim.Optimizer, scaler: Optional[GradScaler], criterion_list: List[nn.Module], train_dataloader: DataLoader, cfg, train_metrics: Dict, progress_bar: tqdm, run_folder, limit_samples, fold_index, folds_count, scheduler: _LRScheduler, schedule_lr_each_step): epoch_metrics = {'loss': []} epoch_metrics.update({metric: [] for metric in cfg['metrics']}) grad_acc_iters = max(cfg['train_params'].get('grad_acc_iters', 1), 1) grad_clipping = cfg['train_params'].get('grad_clipping', 0) pseudo_label = cfg['train_params'].get('pseudo_label', False) use_fp = cfg['train_data_loader'].get('use_fp', False) aux_weights = cfg['train_params'].get('aux_weights', None) batch_size = cfg['train_data_loader']['batch_size'] epoch_samples = len(train_dataloader) * batch_size seen_samples = 0 model.train() # reset CustomMetrics for metric in criterion_list: # type:CustomMetrics if isinstance(metric, CustomMetrics): metric.reset() optimizer.zero_grad() is_optimizer_update_finished = True for iteration_id, data in enumerate(train_dataloader): if limit_samples is not None and seen_samples >= limit_samples: # DEBUG break data_img, data_class, data_record_ids = data batch_size = len(data_img) if progress_bar.n == 0: progress_bar.reset( ) # reset start time to remove time while DataLoader was populating processes if scaler is None: pass # outputs, loss, metrics = forward_pass(data_img, data_class, model, device, criterion_list, pseudo_label, # use_fp, aux_weights=aux_weights) # if grad_acc_iters > 1: # loss = loss / grad_acc_iters # loss.backward() else: with autocast(): outputs, loss, metrics = forward_pass(data_img, data_class, model, device, criterion_list, pseudo_label, use_fp, aux_weights=aux_weights) if grad_acc_iters > 1: loss = loss / grad_acc_iters scaler.scale(loss).backward() is_optimizer_update_finished = False if grad_acc_iters <= 1 or (iteration_id + 1) % grad_acc_iters == 0: if scaler is None: pass # optimizer.step() else: if grad_clipping > 0: # Unscales the gradients of optimizer's assigned params in-place scaler.unscale_(optimizer) # Since the gradients of optimizer's assigned params are unscaled, clips as usual: torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping) scaler.step(optimizer) scaler.update() optimizer.zero_grad() is_optimizer_update_finished = True # Metrics and progress seen_samples += batch_size collect_metrics(cfg, epoch_metrics, loss * grad_acc_iters, metrics, criterion_list, '') lr = [group['lr'] for group in optimizer.param_groups][0] print_info(progress_bar, 'Train', epoch_metrics, loss * grad_acc_iters, lr, '', batch_size, epoch, epochs_count, seen_samples, epoch_samples, fold_index, folds_count) if (iteration_id + 1) % 5 == 0 and iteration_id + 1 < len(train_dataloader): log_dict = dict([(key, value[-1]) for key, value in epoch_metrics.items()]) if schedule_lr_each_step and scheduler is not None: log_dict['lr'] = optimizer.param_groups[0]['lr'] wandb.log(log_dict, (epoch - 1) * epoch_samples + seen_samples, commit=True) # Step lr scheduler if schedule_lr_each_step and scheduler is not None: scheduler.step() # Finish optimizer step after the not completed gradient accumulation batch if not is_optimizer_update_finished: if scaler is None: optimizer.step() else: if grad_clipping > 0: # Unscales the gradients of optimizer's assigned params in-place scaler.unscale_(optimizer) # Since the gradients of optimizer's assigned params are unscaled, clips as usual: torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping) scaler.step(optimizer) scaler.update() for idx, item in enumerate(epoch_metrics.items()): key, value = item if isinstance(criterion_list[idx], CustomMetrics): value = [criterion_list[idx].compute()] train_metrics[key].append(np.mean(value)) progress_bar.write('')
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, validate_interval: int = 10, save: bool = False, amp: bool = False, verbose: bool = True, indent: int = 0, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, get_data_fn: Callable[..., tuple[InputType, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, validate_func: Callable[..., tuple[float, ...]] = None, epoch_func: Callable[[], None] = None, save_fn: Callable = None, file_path: str = None, folder_path: str = None, suffix: str = None, **kwargs): loader_train = loader_train if loader_train is not None else self.dataset.loader['train'] get_data_fn = get_data_fn if get_data_fn is not None else self.get_data loss_fn = loss_fn if loss_fn is not None else self.loss validate_func = validate_func if validate_func is not None else self._validate save_fn = save_fn if save_fn is not None else self.save scaler: torch.cuda.amp.GradScaler = None if amp and env['num_gpus']: scaler = torch.cuda.amp.GradScaler() _, best_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, verbose=verbose, indent=indent, **kwargs) losses = AverageMeter('Loss') top1 = AverageMeter('Acc@1') top5 = AverageMeter('Acc@5') params: list[list[nn.Parameter]] = [param_group['params'] for param_group in optimizer.param_groups] for _epoch in range(epoch): if epoch_func is not None: self.activate_params([]) epoch_func() self.activate_params(params) losses.reset() top1.reset() top5.reset() epoch_start = time.perf_counter() loader = loader_train if verbose and env['tqdm']: loader = tqdm(loader_train) self.train() self.activate_params(params) optimizer.zero_grad() for data in loader: # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') if amp and env['num_gpus']: loss = loss_fn(_input, _label, amp=True) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss = loss_fn(_input, _label) loss.backward() optimizer.step() optimizer.zero_grad() with torch.no_grad(): _output = self.get_logits(_input) acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) losses.update(loss.item(), batch_size) top1.update(acc1, batch_size) top5.update(acc5, batch_size) empty_cache() epoch_time = str(datetime.timedelta(seconds=int( time.perf_counter() - epoch_start))) self.eval() self.activate_params([]) if verbose: pre_str = '{blue_light}Epoch: {0}{reset}'.format( output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35) _str = ' '.join([ f'Loss: {losses.avg:.4f},'.ljust(20), f'Top1 Acc: {top1.avg:.3f}, '.ljust(20), f'Top5 Acc: {top5.avg:.3f},'.ljust(20), f'Time: {epoch_time},'.ljust(20), ]) prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '', indent=indent) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1: _, cur_acc, _ = validate_func(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, verbose=verbose, indent=indent, **kwargs) if cur_acc >= best_acc: prints('best result update!', indent=indent) prints(f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: print('-' * 50) self.zero_grad()
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, grad_clip: float = None, print_prefix: str = 'Epoch', start_epoch: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', verbose: bool = True, indent: int = 0, **kwargs): loader_train = loader_train if loader_train is not None else self.dataset.loader[ 'train'] get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data loss_fn = loss_fn if callable(loss_fn) else self.loss validate_fn = validate_fn if callable(validate_fn) else self._validate save_fn = save_fn if callable(save_fn) else self.save # if not callable(iter_fn) and hasattr(self, 'iter_fn'): # iter_fn = getattr(self, 'iter_fn') if not callable(epoch_fn) and hasattr(self, 'epoch_fn'): epoch_fn = getattr(self, 'epoch_fn') if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'): after_loss_fn = getattr(self, 'after_loss_fn') scaler: torch.cuda.amp.GradScaler = None if not env['num_gpus']: amp = False if amp: scaler = torch.cuda.amp.GradScaler() _, best_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=None, tag=tag, _epoch=start_epoch, verbose=verbose, indent=indent, **kwargs) params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) total_iter = epoch * len(loader_train) for _epoch in range(epoch): _epoch += 1 if callable(epoch_fn): self.activate_params([]) epoch_fn(optimizer=optimizer, lr_scheduler=lr_scheduler, _epoch=_epoch, epoch=epoch, start_epoch=start_epoch) self.activate_params(params) logger = MetricLogger() logger.meters['loss'] = SmoothedValue() logger.meters['top1'] = SmoothedValue() logger.meters['top5'] = SmoothedValue() loader_epoch = loader_train if verbose: header = '{blue_light}{0}: {1}{reset}'.format( print_prefix, output_iter(_epoch, epoch), **ansi) header = header.ljust(30 + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) self.train() self.activate_params(params) optimizer.zero_grad() for i, data in enumerate(loader_epoch): _iter = _epoch * len(loader_train) + i # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') _output = self(_input, amp=amp) loss = loss_fn(_input, _label, _output=_output, amp=amp) if amp: scaler.scale(loss).backward() if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) scaler.step(optimizer) scaler.update() else: loss.backward() if grad_clip is not None: nn.utils.clip_grad_norm_(params) if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch) optimizer.step() optimizer.zero_grad() acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) logger.meters['loss'].update(float(loss), batch_size) logger.meters['top1'].update(acc1, batch_size) logger.meters['top5'].update(acc5, batch_size) empty_cache( ) # TODO: should it be outside of the dataloader loop? self.eval() self.activate_params([]) loss, acc = logger.meters['loss'].global_avg, logger.meters[ 'top1'].global_avg if writer is not None: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch + start_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch + start_epoch) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if _epoch % validate_interval == 0 or _epoch == epoch: _, cur_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=writer, tag=tag, _epoch=_epoch + start_epoch, verbose=verbose, indent=indent, **kwargs) if cur_acc >= best_acc: if verbose: prints('{green}best result update!{reset}'.format( **ansi), indent=indent) prints( f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: prints('-' * 50, indent=indent) self.zero_grad()
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: Number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: Total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, iter_count = 0, 0 # don't use the last batch if it's small, for stability num_iters = len(data) // args.batch_size * args.batch_size iter_size = args.batch_size for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = \ mol_batch.smiles(), mol_batch.features(), mol_batch.targets() mask = torch.Tensor([[not np.isnan(x) for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if np.isnan(x) else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(smiles_batch, features_batch) # todo: change the loss function for property prediction tasks if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join( f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'\nLoss = {loss_avg:.4e}, PNorm = {pnorm:.4f},' f' GNorm = {gnorm:.4f}, {lrs_str}') if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for idx, learn_rate in enumerate(lrs): writer.add_scalar( f'learning_rate_{idx}', learn_rate, n_iter) return n_iter
def train(module: nn.Module, num_classes: int, epochs: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, lr_warmup_epochs: int = 0, model_ema: ExponentialMovingAverage = None, model_ema_steps: int = 32, grad_clip: float = None, pre_conditioner: None | KFAC | EKFAC = None, print_prefix: str = 'Train', start_epoch: int = 0, resume: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, forward_fn: Callable[..., torch.Tensor] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', accuracy_fn: Callable[..., list[float]] = None, verbose: bool = True, output_freq: str = 'iter', indent: int = 0, change_train_eval: bool = True, lr_scheduler_freq: str = 'epoch', backward_and_step: bool = True, **kwargs): r"""Train the model""" if epochs <= 0: return get_data_fn = get_data_fn or (lambda x: x) forward_fn = forward_fn or module.__call__ loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy( _output or forward_fn(_input), _label)) validate_fn = validate_fn or validate accuracy_fn = accuracy_fn or accuracy scaler: torch.cuda.amp.GradScaler = None if not env['num_gpus']: amp = False if amp: scaler = torch.cuda.amp.GradScaler() best_validate_result = (0.0, float('inf')) if validate_interval != 0: best_validate_result = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, forward_fn=forward_fn, loss_fn=loss_fn, writer=None, tag=tag, _epoch=start_epoch, verbose=verbose, indent=indent, **kwargs) best_acc = best_validate_result[0] params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) len_loader_train = len(loader_train) total_iter = (epochs - resume) * len_loader_train logger = MetricLogger() logger.create_meters(loss=None, top1=None, top5=None) if resume and lr_scheduler: for _ in range(resume): lr_scheduler.step() iterator = range(resume, epochs) if verbose and output_freq == 'epoch': header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi) header = header.ljust(max(len(header), 30) + get_ansi_len(header)) iterator = logger.log_every(range(resume, epochs), header=print_prefix, tqdm_header='Epoch', indent=indent) for _epoch in iterator: _epoch += 1 logger.reset() if callable(epoch_fn): activate_params(module, []) epoch_fn(optimizer=optimizer, lr_scheduler=lr_scheduler, _epoch=_epoch, epochs=epochs, start_epoch=start_epoch) loader_epoch = loader_train if verbose and output_freq == 'iter': header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) loader_epoch = logger.log_every(loader_train, header=header, tqdm_header='Batch', indent=indent) if change_train_eval: module.train() activate_params(module, params) for i, data in enumerate(loader_epoch): _iter = _epoch * len_loader_train + i # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') if pre_conditioner is not None and not amp: pre_conditioner.track.enable() _output = forward_fn(_input, amp=amp, parallel=True) loss = loss_fn(_input, _label, _output=_output, amp=amp) if backward_and_step: optimizer.zero_grad() if amp: scaler.scale(loss).backward() if callable(after_loss_fn) or grad_clip is not None: scaler.unscale_(optimizer) if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) if grad_clip is not None: nn.utils.clip_grad_norm_(params, grad_clip) scaler.step(optimizer) scaler.update() else: loss.backward() if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs) if pre_conditioner is not None: pre_conditioner.track.disable() pre_conditioner.step() if grad_clip is not None: nn.utils.clip_grad_norm_(params, grad_clip) optimizer.step() if model_ema and i % model_ema_steps == 0: model_ema.update_parameters(module) if _epoch <= lr_warmup_epochs: # Reset ema buffer to keep copying weights # during warmup period model_ema.n_averaged.fill_(0) if lr_scheduler and lr_scheduler_freq == 'iter': lr_scheduler.step() acc1, acc5 = accuracy_fn(_output, _label, num_classes=num_classes, topk=(1, 5)) batch_size = int(_label.size(0)) logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5) empty_cache() optimizer.zero_grad() if lr_scheduler and lr_scheduler_freq == 'epoch': lr_scheduler.step() if change_train_eval: module.eval() activate_params(module, []) loss, acc = (logger.meters['loss'].global_avg, logger.meters['top1'].global_avg) if writer is not None: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch + start_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch + start_epoch) if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs): validate_result = validate_fn(module=module, num_classes=num_classes, loader=loader_valid, get_data_fn=get_data_fn, forward_fn=forward_fn, loss_fn=loss_fn, writer=writer, tag=tag, _epoch=_epoch + start_epoch, verbose=verbose, indent=indent, **kwargs) cur_acc = validate_result[0] if cur_acc >= best_acc: best_validate_result = validate_result if verbose: prints('{purple}best result update!{reset}'.format(**ansi), indent=indent) prints( f'Current Acc: {cur_acc:.3f} ' f'Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: prints('-' * 50, indent=indent) module.zero_grad() return best_validate_result
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() data.shuffle() loss_sum, iter_count = 0, 0 iter_size = args.batch_size if args.class_balance: # Reconstruct data so that each batch has equal number of positives and negatives # (will leave out a different random sample of negatives each epoch) assert len( data[0].targets) == 1 # only works for single class classification pos = [d for d in data if d.targets[0] == 1] neg = [d for d in data if d.targets[0] == 0] new_data = [] pos_size = iter_size // 2 pos_index = neg_index = 0 while True: new_pos = pos[pos_index:pos_index + pos_size] new_neg = neg[neg_index:neg_index + iter_size - len(new_pos)] if len(new_pos) == 0 or len(new_neg) == 0: break if len(new_pos) + len(new_neg) < iter_size: new_pos = pos[pos_index:pos_index + iter_size - len(new_neg)] new_data += new_pos + new_neg pos_index += len(new_pos) neg_index += len(new_neg) data = new_data num_iters = len( data ) // args.batch_size * args.batch_size # don't use the last batch if it's small, for stability for i in trange(0, num_iters, iter_size): # Prepare batch if i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch, weight_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets(), mol_batch.weights() batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) weights = torch.Tensor([[0 if x is None else x for x in tb] for tb in weight_batch]) # print (weight_batch) # print (weights) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() if args.enable_weight: class_weights = weights else: class_weights = torch.ones(targets.shape) # print(class_weights) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() preds = model(batch, features_batch) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask # print ("loss") # print (loss) # print (class_weights) loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(mol_batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None, chunk_names: bool = False, val_smiles: List[str] = None, test_smiles: List[str] = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :param chunk_names: Whether to train on the data in chunks. In this case, data must be a list of paths to the data chunks. :param val_smiles: Validation smiles strings without targets. :param test_smiles: Test smiles strings without targets, used for adversarial setting. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() if args.dataset_type == 'bert_pretraining': features_loss = nn.MSELoss() if chunk_names: for path, memo_path in tqdm(data, total=len(data)): featurization.SMILES_TO_FEATURES = dict() if os.path.isfile(memo_path): found_memo = True with open(memo_path, 'rb') as f: featurization.SMILES_TO_FEATURES = pickle.load(f) else: found_memo = False with open(path, 'rb') as f: chunk = pickle.load(f) if args.moe: for source in chunk: source.shuffle() else: chunk.shuffle() n_iter = train(model=model, data=chunk, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, logger=logger, writer=writer, chunk_names=False, val_smiles=val_smiles, test_smiles=test_smiles) if not found_memo: with open(memo_path, 'wb') as f: pickle.dump(featurization.SMILES_TO_GRAPH, f, protocol=pickle.HIGHEST_PROTOCOL) return n_iter if not args.moe: data.shuffle() loss_sum, iter_count = 0, 0 if args.adversarial: if args.moe: train_smiles = [] for d in data: train_smiles += d.smiles() else: train_smiles = data.smiles() train_val_smiles = train_smiles + val_smiles d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 if args.moe: test_smiles = list(test_smiles) random.shuffle(test_smiles) train_smiles = [] for d in data: d.shuffle() train_smiles.append(d.smiles()) num_iters = min(len(test_smiles), min([len(d) for d in data])) elif args.maml: num_iters = args.maml_batches_per_epoch * args.maml_batch_size model.zero_grad() maml_sum_loss = 0 else: num_iters = len(data) if args.last_batch else len( data) // args.batch_size * args.batch_size if args.parallel_featurization: batch_queue = Queue(args.batch_queue_max_size) exit_queue = Queue(1) batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, args.batch_size, exit_queue, args.last_batch)) batch_process.start() currently_loaded_batches = [] iter_size = 1 if args.maml else args.batch_size for i in trange(0, num_iters, iter_size): if args.moe: if not args.batch_domain_encs: model.compute_domain_encs( train_smiles) # want to recompute every batch mol_batch = [ MoleculeDataset(d[i:i + args.batch_size]) for d in data ] train_batch, train_targets = [], [] for b in mol_batch: tb, tt = b.smiles(), b.targets() train_batch.append(tb) train_targets.append(tt) test_batch = test_smiles[i:i + args.batch_size] loss = model.compute_loss(train_batch, train_targets, test_batch) model.zero_grad() loss_sum += loss.item() iter_count += len(mol_batch) elif args.maml: task_train_data, task_test_data, task_idx = data.sample_maml_task( args) mol_batch = task_test_data smiles_batch, features_batch, target_batch = task_train_data.smiles( ), task_train_data.features(), task_train_data.targets(task_idx) # no mask since we only picked data points that have the desired target targets = torch.Tensor(target_batch).unsqueeze(1) if next(model.parameters()).is_cuda: targets = targets.cuda() preds = model(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) grad = torch.autograd.grad( loss, [p for p in model.parameters() if p.requires_grad]) theta = [ p for p in model.named_parameters() if p[1].requires_grad ] # comes in same order as grad theta_prime = { p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta) } for name, nongrad_param in [ p for p in model.named_parameters() if not p[1].requires_grad ]: theta_prime[name] = nongrad_param + torch.zeros( nongrad_param.size()).to(nongrad_param) else: # Prepare batch if args.parallel_featurization: if len(currently_loaded_batches) == 0: currently_loaded_batches = batch_queue.get() mol_batch, featurized_mol_batch = currently_loaded_batches.pop( ) else: if not args.last_batch and i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() if args.dataset_type == 'bert_pretraining': batch = mol2graph(smiles_batch, args) mask = mol_batch.mask() batch.bert_mask(mask) mask = 1 - torch.FloatTensor(mask) # num_atoms features_targets = torch.FloatTensor( target_batch['features'] ) if target_batch[ 'features'] is not None else None # num_molecules x features_size targets = torch.FloatTensor(target_batch['vocab']) # num_atoms if args.bert_vocab_func == 'feature_vector': mask = mask.reshape(-1, 1) else: targets = targets.long() else: batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() if args.dataset_type == 'bert_pretraining' and features_targets is not None: features_targets = features_targets.cuda() if args.class_balance: class_weights = [] for task_num in range(data.num_tasks()): class_weights.append( args.class_weights[task_num][targets[:, task_num].long()]) class_weights = torch.stack( class_weights).t() # num_molecules x num_tasks else: class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() if args.parallel_featurization: previous_graph_input_mode = model.encoder.graph_input model.encoder.graph_input = True # force model to accept already processed input preds = model(featurized_mol_batch, features_batch) model.encoder.graph_input = previous_graph_input_mode else: preds = model(batch, features_batch) if args.dataset_type == 'regression_with_binning': preds = preds.view(targets.size(0), targets.size(1), -1) targets = targets.long() loss = 0 for task in range(targets.size(1)): loss += loss_func( preds[:, task, :], targets[:, task] ) * class_weights[:, task] * mask[:, task] # for some reason cross entropy doesn't support multi target loss = loss.sum() / mask.sum() else: if args.dataset_type == 'unsupervised': targets = targets.long().reshape(-1) if args.dataset_type == 'bert_pretraining': features_preds, preds = preds['features'], preds['vocab'] if args.dataset_type == 'kernel': preds = preds.view(int(preds.size(0) / 2), 2, preds.size(1)) preds = model.kernel_output_layer(preds) loss = loss_func(preds, targets) * class_weights * mask if args.predict_features_and_task: loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \ / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1)) else: loss = loss.sum() / mask.sum() if args.dataset_type == 'bert_pretraining' and features_targets is not None: loss += features_loss(features_preds, features_targets) loss_sum += loss.item() iter_count += len(mol_batch) if args.maml: model_prime = build_model(args=args, params=theta_prime) smiles_batch, features_batch, target_batch = task_test_data.smiles( ), task_test_data.features(), [ t[task_idx] for t in task_test_data.targets() ] # no mask since we only picked data points that have the desired target targets = torch.Tensor([[t] for t in target_batch]) if next(model_prime.parameters()).is_cuda: targets = targets.cuda() model_prime.zero_grad() preds = model_prime(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) loss_sum += loss.item() iter_count += len( smiles_batch ) # TODO check that this makes sense, but it's just for display maml_sum_loss += loss if i % args.maml_batch_size == args.maml_batch_size - 1: maml_sum_loss.backward() optimizer.step() model.zero_grad() maml_sum_loss = 0 else: loss.backward() if args.max_grad_norm is not None: clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() if args.adjust_weight_decay: current_pnorm = compute_pnorm(model) if current_pnorm < args.pnorm_target: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i]['weight_decay'] = max( 0, optimizer.param_groups[i]['weight_decay'] - args.adjust_weight_decay_step) else: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i][ 'weight_decay'] += args.adjust_weight_decay_step if isinstance(scheduler, NoamLR): scheduler.step() if args.adversarial: for _ in range(args.gan_d_per_g): train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) d_loss, gp_norm = model.train_D(train_val_smiles_batch, test_smiles_batch) train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch) # we probably only care about the g_loss honestly d_loss_sum += d_loss * args.batch_size gp_norm_sum += gp_norm * args.batch_size g_loss_sum += g_loss * args.batch_size n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count if args.adversarial: d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 loss_sum, iter_count = 0, 0 lrs_str = ', '.join('lr_{} = {:.4e}'.format(i, lr) for i, lr in enumerate(lrs)) debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, {}".format( loss_avg, pnorm, gnorm, lrs_str)) if args.adversarial: debug( "D Loss = {:.4e}, G Loss = {:.4e}, GP Norm = {:.4}".format( d_loss_avg, g_loss_avg, gp_norm_avg)) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar('learning_rate_{}'.format(i), lr, n_iter) if args.parallel_featurization: exit_queue.put( 0) # dummy var to get the subprocess to know that we're done batch_process.join() return n_iter
def train(model: nn.Module, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: Optional[Logger] = None, writer: Optional = None) -> int: """Trains a model for an epoch Parameters ---------- model : nn.Module the model to train data_loader : MoleculeDataLoader an iterable of MoleculeDatasets loss_func : Callable the loss function optimizer : Optimizer the optimizer scheduler : _LRScheduler the learning rate scheduler args : Namespace a Namespace object the containing necessary attributes n_iter : int the current number of training iterations logger : Optional[Logger] (Default = None) a logger for printing intermediate results. If None, print intermediate results to stdout writer : Optional[SummaryWriter] (Default = None) A tensorboardX SummaryWriter Returns ------- n_iter : int The total number of iterations (training examples) trained on so far """ model.train() # loss_sum = 0 # iter_count = 0 for batch in tqdm(data_loader, desc='Step', unit='minibatch', leave=False): # Prepare batch mol_batch = batch.batch_graph() features_batch = batch.features() # Run model model.zero_grad() preds = model(mol_batch, features_batch) targets = batch.targets() # targets might have None's mask = torch.Tensor( [list(map(bool, ys)) for ys in targets]).to(preds.device) targets = torch.Tensor( [[y or 0 for y in ys] for ys in targets]).to(preds.device) class_weights = torch.ones(targets.shape).to(preds.device) # if args.dataset_type == 'multiclass': # targets = targets.long() # loss = (torch.cat([ # loss_func(preds[:, target_index, :], # targets[:, target_index]).unsqueeze(1) # for target_index in range(preds.size(1)) # ], dim=1) * class_weights * mask # ) if model.uncertainty: pred_means = preds[:, 0::2] pred_vars = preds[:, 1::2] loss = loss_func(pred_means, pred_vars, targets) else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() # loss_sum += loss.item() # iter_count += len(batch) loss.backward() optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard # if (n_iter // args.batch_size) % args.log_frequency == 0: # lrs = scheduler.get_lr() # pnorm = compute_pnorm(model) # gnorm = compute_gnorm(model) # loss_avg = loss_sum / iter_count # loss_sum, iter_count = 0, 0 # lrs_str = ', '.join( # f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) # debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, ' # + f'GNorm = {gnorm:.4f}, {lrs_str}') # if writer: # writer.add_scalar('train_loss', loss_avg, n_iter) # writer.add_scalar('param_norm', pnorm, n_iter) # writer.add_scalar('gradient_norm', gnorm, n_iter) # for i, lr in enumerate(lrs): # writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter