def train(dataloader: torch.utils.data.DataLoader, model: torch.nn.Module, loss_func: Any, optimizer: torch.optim.Optimizer, epoch: int, args: argparse.Namespace): """ Train the given model for a single epoch using the given dataloader. Args: dataloader: The dataloader containing the training data. model: Instance of the model that is being trained. loss_func: A loss function to compute the error between the actual and the desired output of the model. optimizer: An instance of an optimizer that is used to compute and perform the updates to the weights of the network. epoch: The current training epoch. args: Namespace object containing some global variable (e.g., command line arguments, such as the batch size) """ # ------------------------------------------------------------------------- # Preliminaries # ------------------------------------------------------------------------- # Activate training mode model.train() # Keep track the time to process a batch, as well as the batch losses batch_times = AverageMeter() batch_losses = AverageMeter() # ------------------------------------------------------------------------- # Process the training dataset in mini-batches # ------------------------------------------------------------------------- for batch_idx, (data, target) in enumerate(dataloader): # Initialize start time of the batch batch_start = time.time() # Fetch data and move to device data, target = data.to(args.device), target.to(args.device) target = target.squeeze() # Clear gradients optimizer.zero_grad() # Compute forward pass through model output = model.forward(data).squeeze() # Calculate the loss for the batch loss = loss_func(output, target) # Back-propagate the loss and update the weights loss.backward() optimizer.step(closure=None) # --------------------------------------------------------------------- # Log information about current batch to TensorBoard # --------------------------------------------------------------------- if args.tensorboard: # Compute how many examples we have processed already and log the # loss value for the current batch global_step = ((epoch - 1) * args.n_train_batches + batch_idx) * \ args.batch_size args.logger.add_scalar(tag='loss/train', scalar_value=loss.item(), global_step=global_step) # --------------------------------------------------------------------- # Additional logging to console # --------------------------------------------------------------------- # Store the loss and processing time for the current batch batch_losses.update(loss.item()) batch_times.update(time.time() - batch_start) # Print information to console, if applicable if batch_idx % args.log_interval == 0: # Which fraction of batches have we already processed this epoch? percent = 100. * batch_idx / args.n_train_batches # Print some information about how the training is going print(f'Epoch: {epoch:>3}/{args.epochs}', end=' | ', flush=True) print(f'Batch: {batch_idx:>3}/{args.n_train_batches}', flush=True, end=' ') print(f'({percent:>4.1f}%)', end=' | ', flush=True) print(f'Loss: {loss.item():.6f}', end=' | ', flush=True) print(f'Time: {batch_times.value:>6.3f}s', flush=True)
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0, neptune=None): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(samples) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if neptune: neptune.log_metric('train/loss', loss_value) if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs)
def train_epoch_kontschieder(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' if log is not None and epoch==1: log.create_log(log_loss, 'epoch', 'batch', 'loss', 'batch_train_acc') # Reset the gradients optimizer.zero_grad() if disable_derivative_free_leaf_optim: print("WARNING: kontschieder arguments will be ignored when training leaves with gradient descent") else: if tree._kontschieder_normalization: # Iterate over the dataset multiple times to learn leaves following Kontschieder's approach for _ in range(10): # Train leaves with derivative-free algorithm using normalization factor train_leaves_epoch(tree, train_loader, epoch, device) else: # Train leaves with Kontschieder's derivative-free algorithm, but using softmax train_leaves_epoch(tree, train_loader, epoch, device) # Train prototypes and network. # If disable_derivative_free_leaf_optim, leafs are optimized with gradient descent as well. # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Make sure the model is in train mode tree.train() for i, (xs, ys) in train_iter: xs, ys = xs.to(device), ys.to(device) # Reset the gradients optimizer.zero_grad() # Perform a forward pass through the network ys_pred, _ = tree.forward(xs) # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() # Count the number of correct classifications ys_pred = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info
def train(model: torch.nn.Module, train_dls: List[DataLoader], optimizer: torch.optim.Optimizer, scheduler: LambdaLR, validation_evaluator: MultiDatasetClassificationEvaluator, n_epochs: int, device: AnyStr, log_interval: int = 1, patience: int = 10, model_dir: str = "wandb_local", gradient_accumulation: int = 1, domain_name: str = ''): #best_loss = float('inf') best_f1 = 0.0 patience_counter = 0 epoch_counter = 0 total = sum(len(dl) for dl in train_dls) # Main loop while epoch_counter < n_epochs: dl_iters = [iter(dl) for dl in train_dls] dl_idx = list(range(len(dl_iters))) finished = [0] * len(dl_iters) i = 0 with tqdm(total=total, desc="Training") as pbar: while sum(finished) < len(dl_iters): random.shuffle(dl_idx) for d in dl_idx: domain_dl = dl_iters[d] batches = [] try: for j in range(gradient_accumulation): batches.append(next(domain_dl)) except StopIteration: finished[d] = 1 if len(batches) == 0: continue optimizer.zero_grad() for batch in batches: model.train() batch = tuple(t.to(device) for t in batch) input_ids = batch[0] masks = batch[1] labels = batch[2] # Null the labels if its the test data if d == len(train_dls) - 1: labels = None # Testing with random domains to see if any effect #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) domains = batch[3] loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha=True) loss = loss.mean() / gradient_accumulation if i % log_interval == 0: # wandb.log({ # "Loss": loss.item(), # "alpha0": alpha[:,0].cpu(), # "alpha1": alpha[:, 1].cpu(), # "alpha2": alpha[:, 2].cpu(), # "alpha_shared": alpha[:, 3].cpu() # }) wandb.log({"Loss": loss.item()}) loss.backward() i += 1 pbar.update(1) optimizer.step() if scheduler is not None: scheduler.step() gc.collect() # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) print(f"Validation F1: {F1}") #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') # Saving the best model and early stopping #if val_loss < best_loss: if F1 > best_f1: best_model = model.state_dict() #best_loss = val_loss best_f1 = F1 #wandb.run.summary['best_validation_loss'] = best_loss torch.save( model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth' ) patience_counter = 0 # Log to wandb wandb.log({ 'Validation accuracy': acc, 'Validation Precision': P, 'Validation Recall': R, 'Validation F1': F1, 'Validation loss': val_loss }) else: patience_counter += 1 # Stop training once we have lost patience if patience_counter == patience: break gc.collect() epoch_counter += 1
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0, writer=None, args=None): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 losses_items = [] FNs, FPs, TPs, AVGs, TAR = [], [], [], [], [] # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): for i, (samples, targets, info) in enumerate(data_loader): samples = samples.to(device) targets = [t.to(device) for t in targets] outputs = model(samples) # import numpy as np # couples = [] # for x in np.arange(1/10, 1, 1/5): # for y in np.arange(1/12, 1, 1/6): # couples.append(torch.tensor([x, y])) # outputs['pred_boxes'][0] = torch.cat(couples).view(-1, 2) loss_dict, indices = criterion(outputs, targets) if epoch % 50 == 0 or epoch == (args.epochs - 1): step = (epoch * len(data_loader) + i) * args.batch_size plot_images(writer, step, samples, outputs, targets, indices, epoch, i, tag='train', folder=args.comment) for d in range(len(samples)): FN, FP, TP, in_dist = spine_evaluation(outputs['pred_boxes'][d], outputs['pred_logits'][d], targets[d][:, 1:3], info[d], args) FNs.append(FN) FPs.append(FP) TPs.append(TP) TAR.append(len(targets[d])) AVGs.append(in_dist) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict).float() not_used_keys = [ k for k in loss_dict.keys() if k not in weight_dict.keys() ] if len(not_used_keys) > 0 and i == 0: print( f'[WARNING] these keys are not used to calculate the loss: {not_used_keys}' ) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() losses_items.append(loss_value) print( f"{epoch:03d}_{i:03d} loss_value: {loss_value:.04f} mean {mean(losses_items):.04f} loss_centers {loss_dict['loss_centers'].item():.04f} loss_bce {loss_dict['loss_bce'].item():.04f} loss_spine_l1 {loss_dict['loss_spine_l1'].item():.04f} id: {info[0]['patient_id']}" ) if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) writer.add_scalar('train_metric/FN', sum(FNs) / sum(TAR), global_step=epoch) writer.add_scalar('train_metric/FP', sum(FPs) / sum(TAR), global_step=epoch) writer.add_scalar('train_metric/TP', sum(TPs) / sum(TAR), global_step=epoch) if len(torch.cat(AVGs)) > 0: writer.add_scalar('train_metric/avg_dist', torch.cat(AVGs).mean(), global_step=epoch) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train_epoch(model: nn.Module, optimizer: torch.optim.Optimizer, loss_func: nn.Module, loader: DataLoader, cfg: Dict, epoch: int, use_mse: bool): print('train_epoch') """Train model for a single epoch. Parameters ---------- model : nn.Module The PyTorch model to train optimizer : torch.optim.Optimizer Optimizer used for weight updating loss_func : nn.Module The loss function, implemented as a PyTorch Module loader : DataLoader PyTorch DataLoader containing the training data in batches. cfg : Dict Dictionary containing the run config epoch : int Current Number of epoch use_mse : bool If True, loss_func is nn.MSELoss(), else NSELoss() which expects addtional std of discharge vector """ model.train() # process bar handle pbar = tqdm(loader, file=sys.stdout) pbar.set_description(f'# Epoch {epoch}') # Iterate in batches over training set for data in pbar: # print('\n') # delete old gradients optimizer.zero_grad() # forward pass through LSTM if len(data) == 3: x, y, q_stds = data x, y, q_stds = x.to(DEVICE), y.to(DEVICE), q_stds.to(DEVICE) predictions = model(x)[0] # forward pass through EALSTM elif len(data) == 4: x_d, x_s, y, q_stds = data x_d, x_s, y = x_d.to(DEVICE), x_s.to(DEVICE), y.to(DEVICE) predictions = model(x_d, x_s)[0] # MSELoss mask = ~torch.isnan(predictions) if use_mse: loss = loss_func(predictions[mask], y[mask]) # NSELoss needs std of each basin for each sample else: q_stds = q_stds.to(DEVICE) loss = loss_func(predictions[mask], y[mask], q_stds) # calculate gradients loss.backward() if cfg["clip_norm"]: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["clip_value"]) # perform parameter update optimizer.step() pbar.set_postfix_str(f"Loss: {loss.item():5f}")
def attach(self, optimizer: torch.optim.Optimizer): r""" Attaches the privacy engine to the optimizer. Attaches to the ``PrivacyEngine`` an optimizer object,and injects itself into the optimizer's step. To do that it, 1. Validates that the model does not have unsupported layers. 2. Adds a pointer to this object (the ``PrivacyEngine``) inside the optimizer. 3. Moves optimizer's original ``step()`` function to ``original_step()``. 4. Monkeypatches the optimizer's ``step()`` function to call ``step()`` on the query engine automatically whenever it would call ``step()`` for itself. Args: optimizer: The optimizer to which the privacy engine will attach """ if hasattr(optimizer, "privacy_engine"): if optimizer.privacy_engine != self: raise ValueError( f"Trying to attach to optimizer: {optimizer}, but that optimizer is " f"already attached to a different Privacy Engine: {optimizer.privacy_engine}." ) else: warnings.warn( "Trying to attach twice to the same optimizer. Nothing to do." ) return self.validator.validate(self.module) norm_clipper = (clipping.ConstantFlatClipper(self.max_grad_norm) if not isinstance(self.max_grad_norm, list) else clipping.ConstantPerLayerClipper(self.max_grad_norm)) if self.misc_settings.get("experimental", False): norm_clipper = clipping._Dynamic_Clipper_( [self.max_grad_norm], self.misc_settings.get("clip_per_layer", False), self.misc_settings.get("clipping_method", clipping.ClippingMethod.STATIC), self.misc_settings.get("clipping_ratio", 0.0), self.misc_settings.get("clipping_momentum", 0.0), ) self.clipper = PerSampleGradientClipper( self.module, norm_clipper, self.batch_first, self.loss_reduction, ) def dp_zero_grad(self): self.privacy_engine.zero_grad() self.original_zero_grad() def dp_step(self, closure=None, is_empty=False): self.privacy_engine.step(is_empty) if isinstance(self.privacy_engine.module, DifferentiallyPrivateDistributedDataParallel): average_gradients(self.privacy_engine.module) self.original_step(closure) def poisson_dp_step(self, closure=None): # Perform one step as usual self.dp_step(closure) # Taking empty steps to simulate empty batches num_empty_batches = self.privacy_engine._sample_poisson_empty_batches( ) for _ in range(num_empty_batches): self.zero_grad() self.dp_step(closure, is_empty=True) optimizer.privacy_engine = self optimizer.dp_step = types.MethodType(dp_step, optimizer) optimizer.original_step = optimizer.step optimizer.step = types.MethodType( poisson_dp_step if self.poisson else dp_step, optimizer) optimizer.original_zero_grad = optimizer.zero_grad optimizer.zero_grad = types.MethodType(dp_zero_grad, optimizer) def virtual_step(self): self.privacy_engine.virtual_step() optimizer.virtual_step = types.MethodType(virtual_step, optimizer) # create a cross reference for detaching self.optimizer = optimizer if self.poisson: # Optional initial step on empty batch num_empty_batches = self._sample_poisson_empty_batches() for _ in range(num_empty_batches): self.optimizer.zero_grad() for p in self.module.parameters(): if p.requires_grad: p.grad = torch.zeros_like(p) self.optimizer.dp_step(closure=None, is_empty=True)
def train( dataset: torch.utils.data.Dataset, model: torch.nn.Module, epochs: int, batch_size: int, optimizer: torch.optim.Optimizer, stopping_delta: Optional[float] = None, collate_fn=default_collate, cuda: bool = True, sampler: Optional[torch.utils.data.sampler.Sampler] = None, silent: bool = False, update_freq: int = 10, evaluate_batch_size: int = 1024, update_callback: Optional[Callable[[float, float], None]] = None, epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None, ) -> None: """ Train the DEC model given a dataset, a model instance and various configuration parameters. :param dataset: instance of Dataset to use for training :param model: instance of DEC model to train :param epochs: number of training epochs :param batch_size: size of the batch to train with :param optimizer: instance of optimizer to use :param stopping_delta: label delta as a proportion to use for stopping, None to disable, default None :param collate_fn: function to merge a list of samples into mini-batch :param cuda: whether to use CUDA, defaults to True :param sampler: optional sampler to use in the DataLoader, defaults to None :param silent: set to True to prevent printing out summary statistics, defaults to False :param update_freq: frequency of batches with which to update counter, None disables, default 10 :param evaluate_batch_size: batch size for evaluation stage, default 1024 :param update_callback: optional function of accuracy and loss to update, default None :param epoch_callback: optional function of epoch and model, default None :return: None """ static_dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, pin_memory=False, sampler=sampler, shuffle=False, ) train_dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, sampler=sampler, shuffle=True, ) data_iterator = tqdm( static_dataloader, leave=True, unit="batch", postfix={ "epo": -1, "acc": "%.4f" % 0.0, "lss": "%.8f" % 0.0, "dlb": "%.4f" % -1, }, disable=silent, ) kmeans = KMeans(n_clusters=model.cluster_number, n_init=20) model.train() features = [] actual = [] # form initial cluster centres for index, batch in enumerate(data_iterator): if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2: batch, value = batch # if we have a prediction label, separate it to actual actual.append(value) if cuda: batch = batch.cuda(non_blocking=True) features.append(model.encoder(batch).detach().cpu()) actual = torch.cat(actual).long() predicted = kmeans.fit_predict(torch.cat(features).numpy()) predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long) _, accuracy = cluster_accuracy(predicted, actual.cpu().numpy()) cluster_centers = torch.tensor( kmeans.cluster_centers_, dtype=torch.float, requires_grad=True ) if cuda: cluster_centers = cluster_centers.cuda(non_blocking=True) with torch.no_grad(): # initialise the cluster centers model.state_dict()["assignment.cluster_centers"].copy_(cluster_centers) loss_function = nn.KLDivLoss(size_average=False) delta_label = None for epoch in range(epochs): features = [] data_iterator = tqdm( train_dataloader, leave=True, unit="batch", postfix={ "epo": epoch, "acc": "%.4f" % (accuracy or 0.0), "lss": "%.8f" % 0.0, "dlb": "%.4f" % (delta_label or 0.0), }, disable=silent, ) model.train() for index, batch in enumerate(data_iterator): if (isinstance(batch, tuple) or isinstance(batch, list)) and len( batch ) == 2: batch, _ = batch # if we have a prediction label, strip it away if cuda: batch = batch.cuda(non_blocking=True) output = model(batch) target = target_distribution(output).detach() loss = loss_function(output.log(), target) / output.shape[0] data_iterator.set_postfix( epo=epoch, acc="%.4f" % (accuracy or 0.0), lss="%.8f" % float(loss.item()), dlb="%.4f" % (delta_label or 0.0), ) optimizer.zero_grad() loss.backward() optimizer.step(closure=None) features.append(model.encoder(batch).detach().cpu()) if update_freq is not None and index % update_freq == 0: loss_value = float(loss.item()) data_iterator.set_postfix( epo=epoch, acc="%.4f" % (accuracy or 0.0), lss="%.8f" % loss_value, dlb="%.4f" % (delta_label or 0.0), ) if update_callback is not None: update_callback(accuracy, loss_value, delta_label) predicted, actual = predict( dataset, model, batch_size=evaluate_batch_size, collate_fn=collate_fn, silent=True, return_actual=True, cuda=cuda, ) delta_label = ( float((predicted != predicted_previous).float().sum().item()) / predicted_previous.shape[0] ) if stopping_delta is not None and delta_label < stopping_delta: print( 'Early stopping as label delta "%1.5f" less than "%1.5f".' % (delta_label, stopping_delta) ) break predicted_previous = predicted _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy()) data_iterator.set_postfix( epo=epoch, acc="%.4f" % (accuracy or 0.0), lss="%.8f" % 0.0, dlb="%.4f" % (delta_label or 0.0), ) if epoch_callback is not None: epoch_callback(epoch, model)
def fine_tune_train_and_eval( input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_masks: torch.Tensor, start_positions: torch.Tensor, end_positions: torch.Tensor, batch_size: Tuple[int, int], model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_ratio: float = 0.9, training_epochs: int = 3, lr_scheduler_warmup_steps: int = 0, save_model_path: Optional[str] = None, save_stats_dict_path: Optional[str] = None, device_: Optional[ str] = None # if None, it automatically detects if a GPU is available, if not uses a CPU ) -> Tuple[torch.nn.Module, Dict[str, Dict[str, Union[float, str]]]]: """ Performs the fine tuning of the model and returns the trained model as well as a dictionary with evaluation statistics at each epochs which can be used to check overfitting and training time. :param input_ids: torch.tensor of shape (N, max_len) representing the ids of each token of the N encoded sequence pairs, with padding at the end up to max_len. If decoded, the input_ids will consist of a "[CLS]" token, followed by the question's tokens, followed by a "[SEP]" token, followed by the context's tokens, followed by a "[SEP]" token, followed by "[PAD]" tokens, if relevant, up to max_len. :param token_type_ids: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for token positions in the context text, 0 elsewhere (i.e. in question and padding) :param attention_masks: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for non-"[PAD]" tokens, 0 for "[PAD]" tokens. :param start_positions: torch.tensor of shape (N) containing the index of the first answer token for each answer :param end_positions: torch.tensor of shape (N) containing the index of the last answer token for each answer :param batch_size: a tuple of 2 integers, representing the batch size of the train and validation dataloaders respectively. :param model: the model to use (must be instance of torch.nn.Module). For question answering, transformers.BertForQuestionAnswering is recommended. :param optimizer: the optimizer to use for the model (must be instance of torch.optim.Optimizer). :param train_ratio: the train / (train + validation) split ratio. Default: 0.9 (i.e. 90% of the input data will go to the train dataloader and 10% to the validation dataloader). The split is random. :param training_epochs: the number of training epochs. Default: 3. :param lr_scheduler_warmup_steps: the number of warmup steps of the learning rate scheduler. Default: 0. Note: the purpose of this scheduler is to update the learning rate over the course of the training. It is preferable for the learning rate to gradually get smaller and smaller so that training makes gradually finer adjustments to the weights as the loss gets smaller. :param save_model_path: if specified, the path where to save the model (should have '.pt' extension). The model will be save at every epoch with the epoch suffix, for easy comparison. Default: None. :param save_stats_dict_path: if specified, the path where to save the dictionary of statistics (should have '.json' extension). Default: None. :param device_: if specified, the device used for the computations. Can be one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu. If set to None, it will default to GPU (cuda) if one is available, else it will use a CPU. Default: None :return: model: the fine tuned model. training_stats: a dictionary with a number of statistics. For each epoch, the training loss, validation loss, validation accuracy, training time and validation time are included. """ assert all( [ isinstance(i, torch.Tensor) for i in [ input_ids, token_type_ids, attention_masks, start_positions, end_positions ] ] ), "Some inputs are not tensors. When training, start_positions and end_positions must be tensors, not lists." assert input_ids.shape == token_type_ids.shape == attention_masks.shape, "Some input shapes are incompatible." assert input_ids.shape[0] == len(start_positions) == len( end_positions), "Some input shapes are incompatible" train_dataloader, valid_dataloader = _build_dataloaders( input_ids, token_type_ids, attention_masks, start_positions, end_positions, batch_size, train_ratio) training_steps = training_epochs * len( train_dataloader) # epochs * number of batches lr_scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=lr_scheduler_warmup_steps, num_training_steps=training_steps) device = set_hardware_acceleration(default=device_) model = model.to(device) training_stats = {} for epoch in (range(training_epochs)): logger.info( f"Training epoch {epoch + 1} of {training_epochs}. Running training." ) t_i = time() model.train() cumulative_train_loss_per_epoch = 0. for batch_num, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)): logger.debug( f"Running training batch {batch_num + 1} of {len(train_dataloader)}." ) batch_input_ids, batch_token_type_ids, batch_attention_masks, batch_start_positions, batch_end_positions = \ batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device), batch[4].to(device) model.zero_grad() # model.zero_grad() and optimizer.zero_grad() are the same IF all model parameters are in that optimizer. # It could be safer to call model.zero_grad() if you have two or more optimizers for one model. loss, start_logits, end_logits = model( input_ids=batch_input_ids, attention_mask=batch_attention_masks, token_type_ids=batch_token_type_ids, start_positions=batch_start_positions, end_positions=batch_end_positions ) # BertForQuestionAnswering uses CrossEntropyLoss by default, no need to calculate explicitly cumulative_train_loss_per_epoch += loss.item() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # clipping the norm of the gradients to 1.0 to help prevent the "exploding gradients" issues. optimizer.step() # update model parameters lr_scheduler.step() # update the learning rate average_training_loss_per_batch = cumulative_train_loss_per_epoch / len( train_dataloader) training_time = format_time(time() - t_i) logger.info(f"Epoch {epoch + 1} took {training_time} to train.") logger.info( f"Average training loss: {average_training_loss_per_batch}. \n Running validation." ) if torch.cuda.is_available(): logger.info(f"GPU memory usage: \n{gpu_memory_usage()}") t_i = time() model.eval() pred_start = torch.tensor( [], dtype=torch.long, device=device) # initialising tensors for storing results pred_end = torch.tensor([], dtype=torch.long, device=device) true_start = torch.tensor([], dtype=torch.long, device=device) true_end = torch.tensor([], dtype=torch.long, device=device) cumulative_eval_loss_per_epoch = 0 cumulative_eval_accuracy_per_epoch = 0 # WE DO THIS DIFFERENTLY. SHALL WE REMOVE THIS? for batch_num, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)): logger.info( f"Running validation batch {batch_num + 1} of {len(valid_dataloader)}." ) batch_input_ids, batch_token_type_ids, batch_attention_masks, batch_start_positions, batch_end_positions = \ batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device), batch[4].to(device) with torch.no_grad(): loss, start_logits, end_logits = model( input_ids=batch_input_ids, attention_mask=batch_attention_masks, token_type_ids=batch_token_type_ids, start_positions=batch_start_positions, end_positions=batch_end_positions ) # if we pass it the true labels, i.e. start_positions and end_positions it will also return the loss cumulative_eval_loss_per_epoch += loss.item() # SHALL WE MOVE THE BELOW TO CPU AND NUMPY OR KEEP GPU AND PYTORCH? pred_start_positions = torch.argmax(start_logits, dim=1) pred_end_positions = torch.argmax(end_logits, dim=1) pred_start = torch.cat((pred_start, pred_start_positions)) pred_end = torch.cat((pred_end, pred_end_positions)) true_start = torch.cat((true_start, batch_start_positions)) true_end = torch.cat((true_end, batch_end_positions)) if torch.cuda.is_available(): logger.debug(f"GPU memory usage: \n{gpu_memory_usage()}") total_correct_start = int(sum(pred_start == true_start)) total_correct_end = int(sum(pred_end == true_end)) total_correct = total_correct_start + total_correct_end total_indices = len(true_start) + len(true_end) average_validation_accuracy_per_epoch = total_correct / total_indices average_validation_loss_per_batch = cumulative_eval_loss_per_epoch / len( valid_dataloader) valid_time = format_time(time() - t_i) logger.info(f"Epoch {epoch + 1} took {valid_time} to validate.") logger.info( f"Average validation loss: {average_validation_loss_per_batch}.") logger.info( f"Average validation accuracy (out of 1): {average_validation_accuracy_per_epoch}." ) if torch.cuda.is_available(): logger.info(f"GPU memory usage: \n{gpu_memory_usage()}") training_stats[f"epoch_{epoch + 1}"] = { "training_loss": average_training_loss_per_batch, "valid_loss": average_validation_loss_per_batch, "valid_accuracy": average_validation_accuracy_per_epoch, "training_time": training_time, "valid_time": valid_time } if save_model_path is not None: save_model_path = save_model_path.split(".")[ 0] # removing extension if present torch.save(model.state_dict(), f"{save_model_path}_epoch_{epoch + 1}.pt" ) # readd .pt extension if save_stats_dict_path is not None: with open(save_stats_dict_path, "w") as file: json.dump(training_stats, file) return model, training_stats
def train_model(epoch: int, opt: argparse.Namespace, conf: Dict, model: BiaffineParser, optimizer: torch.optim.Optimizer, train_batch: BatcherBase, valid_batch: Batcher, test_batch: Batcher, ix2label: Dict, best_valid: float, test_result: float): model.reset_timer() model.train() cnt = 0 start_time = time.time() witnessed_improved_valid_result = False total_loss, total_arc_loss, total_tag_loss, total_n_tags = 0., 0., 0., 0. for inputs, head_indices, head_tags, _ in train_batch.get(): cnt += 1 model.zero_grad() forward_output_dict = model.forward(inputs, head_tags, head_indices) n_tags = inputs['length'].sum().item() loss = forward_output_dict['loss'] total_loss += loss.item() * n_tags total_arc_loss += forward_output_dict['arc_loss'].item() * n_tags total_tag_loss += forward_output_dict['tag_loss'].item() * n_tags total_n_tags += n_tags loss.backward() if 'clip_grad' in conf['optimizer']: torch.nn.utils.clip_grad_norm_(model.parameters(), conf['optimizer']['clip_grad']) optimizer.step() if cnt % opt.report_steps == 0: logger.info( "| epoch {:3d} | step {:>6d} | lr {:.3g} | ms/batch {:5.2f} | loss {:.4f} " "(arc {:.4f} rel {:.4f}) |".format( epoch, cnt, optimizer.param_groups[0]['lr'], 1000 * (time.time() - start_time) / opt.report_steps, total_loss / total_n_tags, total_arc_loss / total_n_tags, total_tag_loss / total_n_tags)) start_time = time.time() if cnt % opt.eval_steps == 0: eval_time = time.time() valid_result = eval_model(model, valid_batch, ix2label, opt, opt.gold_valid_path) logging_str = "| epoch {:3d} | step {:>6d} | lr {:.3g} | loss {:.4f} (arc {:.4f} rel {:.4f}) " \ "| dev {:.4f} |".format(epoch, cnt, optimizer.param_groups[0]['lr'], total_loss / total_n_tags, total_arc_loss / total_n_tags, total_tag_loss / total_n_tags, valid_result) if valid_result > best_valid: logging_str = logging_str + ' NEW |' logger.info(logging_str) if valid_result > best_valid: witnessed_improved_valid_result = True torch.save(model.state_dict(), os.path.join(opt.model, 'model.pkl')) best_valid = valid_result if test_batch is not None: test_result = eval_model(model, test_batch, ix2label, opt, opt.gold_test_path) logging_str = "| epoch {:3d} | step {:>6d} | lr {:.3g} | " \ "| test {:.4f} |".format(epoch, cnt, optimizer.param_groups[0]['lr'], test_result) logger.info(logging_str) eval_time = time.time() - eval_time start_time += eval_time logging_str = "| epoch {:3d} | step {:>6d} | lr {:.3g} | loss {:.4f} " \ "(arc {:.4f} rel {:.4f}) |".format(epoch, cnt, optimizer.param_groups[0]['lr'], total_loss / total_n_tags, total_arc_loss / total_n_tags, total_tag_loss / total_n_tags) logger.info(logging_str) logger.info( "| time tracking | input {:.2f}s | context {:.2f}s | classification {:.2f}s" .format(model.input_encoding_timer.total_eclipsed_time(), model.context_encoding_timer.total_eclipsed_time(), model.classification_timer.total_eclipsed_time())) return best_valid, test_result, witnessed_improved_valid_result
def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, device: torch.device, epoch: int, summary: TensorboardSummary, max_norm: float = 0, amp: object = None): """ train model for 1 epoch """ model.train() criterion.train() # initialize stats train_stats = { 'l1': 0.0, 'occ_be': 0.0, 'l1_raw': 0.0, 'iou': 0.0, 'rr': 0.0, 'epe': 0.0, 'error_px': 0.0, 'total_px': 0.0 } tbar = tqdm(data_loader) for idx, data in enumerate(tbar): # forward pass aa = data["disp"] _, losses, sampled_disp = forward_pass(model, data, device, criterion, train_stats) # terminate training if exploded if not math.isfinite(losses['aggregated'].item()): print("Loss is {}, stopping training".format( losses['aggregated'].item())) sys.exit(1) # backprop optimizer.zero_grad() if amp is not None: with amp.scale_loss(losses['aggregated'], optimizer) as scaled_loss: scaled_loss.backward() else: losses['aggregated'].backward() # clip norm if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # step optimizer optimizer.step() print('pixel_error', losses['error_px'] / losses['total_px']) # clear cache torch.cuda.empty_cache() # compute avg train_stats[ 'px_error_rate'] = train_stats['error_px'] / train_stats['total_px'] # log to tensorboard write_summary(train_stats, summary, epoch, 'train') print('Training loss', train_stats['l1'], 'pixel error rate', train_stats['px_error_rate']) print('RR loss', train_stats['rr']) return
def train_one_epoch(model: torch.nn.Module, d_vae: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, log_writer=None, lr_scheduler=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 for step, (batch, _) in enumerate( metric_logger.log_every(data_loader, print_freq, header)): # assign learning rate & weight decay for each step it = start_steps + step # global training iteration if lr_schedule_values is not None or wd_schedule_values is not None: for i, param_group in enumerate(optimizer.param_groups): if lr_schedule_values is not None: param_group["lr"] = lr_schedule_values[it] * param_group[ "lr_scale"] if wd_schedule_values is not None and param_group[ "weight_decay"] > 0: param_group["weight_decay"] = wd_schedule_values[it] samples, images, bool_masked_pos = batch images = images.to(device, non_blocking=True) samples = samples.to(device, non_blocking=True) bool_masked_pos = bool_masked_pos.to(device, non_blocking=True) with torch.no_grad(): input_ids = d_vae.get_codebook_indices(images).flatten(1) bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool) labels = input_ids[bool_masked_pos] with torch.cuda.amp.autocast(): outputs = model(samples, bool_masked_pos=bool_masked_pos, return_all_tokens=False) loss = nn.CrossEntropyLoss()(input=outputs, target=labels) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) optimizer.zero_grad() # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr( optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order) loss_scale_value = loss_scaler.state_dict()["scale"] torch.cuda.synchronize() mlm_acc = (outputs.max(-1)[1] == labels).float().mean().item() metric_logger.update(mlm_acc=mlm_acc) if log_writer is not None: log_writer.update(mlm_acc=mlm_acc, head="loss") metric_logger.update(loss=loss_value) metric_logger.update(loss_scale=loss_scale_value) min_lr = 10. max_lr = 0. for group in optimizer.param_groups: min_lr = min(min_lr, group["lr"]) max_lr = max(max_lr, group["lr"]) metric_logger.update(lr=max_lr) metric_logger.update(min_lr=min_lr) weight_decay_value = None for group in optimizer.param_groups: if group["weight_decay"] > 0: weight_decay_value = group["weight_decay"] metric_logger.update(weight_decay=weight_decay_value) metric_logger.update(grad_norm=grad_norm) if log_writer is not None: log_writer.update(loss=loss_value, head="loss") log_writer.update(loss_scale=loss_scale_value, head="opt") log_writer.update(lr=max_lr, head="opt") log_writer.update(min_lr=min_lr, head="opt") log_writer.update(weight_decay=weight_decay_value, head="opt") log_writer.update(grad_norm=grad_norm, head="opt") log_writer.set_step() if lr_scheduler is not None: lr_scheduler.step_update(start_steps + step) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def run_epoch(experiment: comet_ml.Experiment, network: ProbabilisticExternalAgentCurvePredictor, optimizer: torch.optim.Optimizer, dataloader: data_utils.DataLoader, debug: bool = False, use_tqdm=False): num_samples = 0.0 if use_tqdm: t = tqdm(enumerate(dataloader), total=len(dataloader)) else: t = enumerate(dataloader) network.train() # This is important to call before training! # we are only doing single-device training for now, so this works fine. dev = next(network.parameters()).device dtype = next(network.parameters()).dtype lossf = 0 d = network.output_dim for (i, datadict) in t: valid_mask = datadict["valid_mask"] past_positions = datadict["past_positions"] past_velocities = datadict["past_velocities"] past_quaternions = datadict["past_quaternions"] future_positions = datadict["future_positions"] tfuture = datadict["tfuture"] valid_past_positions = ( past_positions[valid_mask].type(dtype).to(dev))[:, :, [0, 2]] valid_past_velocities = ( past_velocities[valid_mask].type(dtype).to(dev))[:, :, [0, 2]] valid_past_quaternions = past_quaternions[valid_mask].type(dtype).to( dev) valid_future_positions = ( future_positions[valid_mask].type(dtype).to(dev))[:, :, [0, 2]] valid_tfuture = tfuture[valid_mask].type(dtype).to(dev) if network.input_dim == 4: networkinput = torch.cat( [valid_past_positions, valid_past_velocities], dim=2) elif network.input_dim == 8: networkinput = torch.cat([ valid_past_positions, valid_past_velocities, valid_past_quaternions ], dim=2) else: raise ValueError( "Currently, only input dimensions of 4 and 8 are supported") batch_size = networkinput.shape[0] means, varfactors, covarfactors = network(networkinput) meancurves = torch.cat( [valid_future_positions[:, 0].unsqueeze(1), means], dim=1) if debug: pass dt = valid_tfuture[:, -1] - valid_tfuture[:, 0] s_torch_cur = (valid_tfuture - valid_tfuture[:, 0, None]) / dt[:, None] Mpos = mu.bezierM(s_torch_cur, network.bezier_order) Msquare = torch.square(Mpos) pred_points = torch.matmul(Mpos, meancurves) deltas = pred_points - valid_future_positions squared_norms = torch.sum(torch.square(deltas), dim=2) point_estimate_loss = torch.mean(squared_norms) scale_trils = torch.diag_embed(varfactors) + torch.diag_embed( covarfactors, offset=-1) # print(scale_trils[0]) covars = torch.matmul(scale_trils, scale_trils.transpose(2, 3)) covars_expand = covars.unsqueeze(1).expand(batch_size, Msquare.shape[1], Msquare.shape[2], d, d) poscovar = torch.sum(Msquare[:, :, :, None, None] * covars_expand, dim=2) # print(pred_points.shape) # print(poscovar.shape) distpos = torch.distributions.MultivariateNormal( pred_points, covariance_matrix=poscovar, validate_args=False) log_probs = distpos.log_prob(valid_future_positions) NLL = 0.0001 * torch.mean(-log_probs) loss = point_estimate_loss + NLL if not (loss == loss): continue optimizer.zero_grad() loss.backward() # Weight and bias updates. optimizer.step() # logging information if ((i % 15) == 0): experiment.log_metric("point_estimate_loss", point_estimate_loss.item()) experiment.log_metric("NLL", NLL.item()) if use_tqdm: curr_loss = loss.item() lossf += curr_loss t.set_postfix({ "point_estimate_loss": point_estimate_loss.item(), "NLL": NLL.item() })
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, args, postprocessors=None): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=' ') metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) if args.stage != 3: metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 max_norm = args.clip_max_norm for vid_name_list, locations, samples, targets, num_frames, base, s_e_scores \ in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device) s_e_scores = s_e_scores.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(locations, samples, s_e_scores) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print('Loss is {}, stopping training'.format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) if args.stage != 3: metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]['lr']) metric_logger.synchronize_between_processes() return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, loss_dict
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, amp: bool = True, teacher_model: torch.nn.Module = None, teach_loss: torch.nn.Module = None, distill_token: bool=False, choices=None, mode='super', retrain_config=None): model.train() criterion.train() # set random seed random.seed(epoch) metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 if mode == 'retrain': config = retrain_config model_module = unwrap_model(model) print(config) model_module.set_sample_config(config=config) print(model_module.get_sampled_params_numel(config)) for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) # sample random config if mode == 'super': config = sample_configs(choices=choices) model_module = unwrap_model(model) model_module.set_sample_config(config=config) elif mode == 'retrain': config = retrain_config model_module = unwrap_model(model) model_module.set_sample_config(config=config) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) if amp: with torch.cuda.amp.autocast(): if teacher_model: with torch.no_grad(): teach_output = teacher_model(samples) _, teacher_label = teach_output.topk(1, 1, True, True) if distill_token: output_cls, output_dis = model(samples) loss = 1/2 * criterion(output_cls, targets) + 1/2 * teach_loss(output_dis, teacher_label.squeeze()) else: outputs = model(samples) loss = 1/2 * criterion(outputs, targets) + 1/2 * teach_loss(outputs, teacher_label.squeeze()) else: outputs = model(samples) loss = criterion(outputs, targets) else: outputs = model(samples) if teacher_model: with torch.no_grad(): teach_output = teacher_model(samples) _, teacher_label = teach_output.topk(1, 1, True, True) loss = 1 / 2 * criterion(outputs, targets) + 1 / 2 * teach_loss(outputs, teacher_label.squeeze()) else: loss = criterion(outputs, targets) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) optimizer.zero_grad() # this attribute is added by timm on one optimizer (adahessian) if amp: is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order) else: loss.backward() optimizer.step() torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train(encoder: EncoderBILSTM, decoder: DecoderLSTM, epoch_count: int, train_loader: DataLoader, criterion, optimizer_enc: torch.optim.Optimizer, optimizer_dec: torch.optim.Optimizer, is_cuda: bool, teacher_forcing: bool = False, debug: bool = False, lr_schedule=False, start_epoch_at: int = 0): losses = [] best_loss = 1000000 for epoch in range(start_epoch_at, start_epoch_at + epoch_count): total_batch_loss = 0 for ind, batch in enumerate(train_loader): loss = 0 questions, questions_org_len, answers, answers_org_len, pID = batch if questions.shape[1] > 1000: break if is_cuda: questions = questions.cuda() answers = answers.cuda() encoder_input, encoder_len = answers, answers_org_len decoder_input, decoder_len = questions, questions_org_len if is_cuda: encoder_out, encoder_hidden = encoder( encoder_input, torch.LongTensor(encoder_len).cuda(), False) encoder_len = torch.FloatTensor(encoder_len) if not teacher_forcing: decoder_inp = torch.ones((len(questions), 1), dtype=torch.long).cuda() else: encoder_out, encoder_hidden = encoder( encoder_input, torch.LongTensor(encoder_len), False) encoder_len = torch.FloatTensor(encoder_len) if not teacher_forcing: decoder_inp = torch.ones((len(questions), 1), dtype=torch.long) if teacher_forcing: decoder_out, decoder_hidden, attn_scores = decoder( decoder_input[:, :-1], encoder_hidden, encoder_out, encoder_len, False) decoder_out = decoder_out.transpose(0, 1).contiguous() decoder_out = decoder_out.transpose(1, 2).contiguous() loss = criterion(decoder_out, questions[:, :-1]) else: decoder_hidden = (encoder_hidden[0].clone(), encoder_hidden[1].clone()) eval_mode = False for j in range(questions.shape[1]): decoder_out, decoder_hidden, attn_scores = decoder( decoder_inp, decoder_hidden, encoder_out, encoder_len, eval_mode=eval_mode) # obtaining log_softmax scores we need to minimize log softmax over a span. decoder_out = decoder_out.squeeze(0) prediction = torch.argmax(decoder_out, 1).unsqueeze(1) loss_val = criterion(decoder_out, questions[:, j]) loss += loss_val / questions.shape[1] decoder_inp = prediction.clone().detach() eval_mode = True optimizer_enc.zero_grad() optimizer_dec.zero_grad() loss.backward() clip_grad_norm_(encoder.parameters(), 5) clip_grad_norm_(decoder.parameters(), 5) optimizer_enc.step() optimizer_dec.step() if lr_schedule: optimizer_enc = exp_lr_scheduler(optimizer_enc, epoch, lr_decay_epoch=8) optimizer_dec = exp_lr_scheduler(optimizer_dec, epoch, lr_decay_epoch=8) total_batch_loss += loss.item() if debug: print("Batch Loss: %f" % loss.item()) if ind % 1000 == 0: print("Batch %d Loss: %f" % (ind, loss.item())) losses.append(total_batch_loss) print("Epoch[%d] Loss: %f" % (epoch, total_batch_loss)) if total_batch_loss < best_loss: torch.save(encoder.state_dict(), "model_weights/%d-encoder-SGD-small.pth" % epoch) torch.save(decoder.state_dict(), "model_weights/%d-decoder-SGD-small.pth" % epoch) best_loss = total_batch_loss torch.save(encoder.state_dict(), "model_weights/final-encoder-SGD-small.pth") torch.save(decoder.state_dict(), "model_weights/final-decoder-SGD-small.pth") return losses
def _torch_step(optimizer: torch.optim.Optimizer, scaler: Optional[torch.cuda.amp.GradScaler] = None) -> None: if scaler is None: optimizer.step() else: scaler.step(optimizer) optimizer.zero_grad()
def train_person_segmentor( model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, valid_loader: torch.utils.data.DataLoader, criterion: callable, optimiser: torch.optim.Optimizer, *, save_model_path: Path, learning_rate: Number = 6e-2, scheduler: torch.optim.lr_scheduler = None, n_epochs: int = 100, writer: ImageWriterMixin = MockWriter(), ): """ :param model: :type model: :param train_loader: :type train_loader: :param valid_loader: :type valid_loader: :param criterion: :type criterion: :param optimiser: :type optimiser: :param scheduler: :type scheduler: :param save_model_path: :type save_model_path: :param n_epochs: :type n_epochs: :return: :rtype:""" valid_loss_min = numpy.Inf # track change in validation loss assert n_epochs > 0, n_epochs E = tqdm(range(1, n_epochs + 1)) for epoch_i in E: train_loss = 0.0 valid_loss = 0.0 with TorchTrainSession(model): for data, target in tqdm(train_loader): output, *_ = model(data.to(global_torch_device())) loss = criterion(output, target.to(global_torch_device()).float()) optimiser.zero_grad() loss.backward() optimiser.step() train_loss += loss.cpu().item() * data.size(0) with TorchEvalSession(model): with torch.no_grad(): for data, target in tqdm(valid_loader): ( output, *_, ) = model( # forward pass: compute predicted outputs by passing inputs to the model data.to(global_torch_device())) validation_loss = criterion( # calculate the batch loss output, target.to(global_torch_device()).float()) writer.scalar( "dice_validation", dice_loss(output, target.to(global_torch_device()).float()), ) valid_loss += validation_loss.detach().cpu().item( ) * data.size(0) # update average validation loss writer.image("prediction", torch.sigmoid(output), epoch_i) # write the last batch # calculate average losses train_loss = train_loss / len(train_loader.dataset) valid_loss = valid_loss / len(valid_loader.dataset) # save model if validation loss has decreased if valid_loss <= valid_loss_min: print( f"Validation loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f}). Saving model ..." ) torch.save(model.state_dict(), save_model_path) valid_loss_min = valid_loss if scheduler: scheduler.step() optimiser, scheduler = reschedule_learning_rate( model, optimiser, epoch_i, scheduler, starting_learning_rate=learning_rate, ) # print training/validation statistics current_lr = next(iter(optimiser.param_groups))["lr"] E.set_description(f"Epoch: {epoch_i} " f"Training Loss: {train_loss:.6f} " f"Validation Loss: {valid_loss:.6f} " f"Learning rate: {current_lr:.6f}") writer.scalar("training_loss", train_loss) writer.scalar("validation_loss", valid_loss) writer.scalar("learning_rate", current_lr) return model
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, teacher=None, set_training_mode=True): # TODO fix this for finetuning # model.train(set_training_mode) model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 100 for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) samples, targets, mix_rate, aux_targets = two_mix( samples, targets, num_patch=samples.shape[-1] // 16) with torch.cuda.amp.autocast(): # outputs, r_loss = model(samples) outputs, r_loss, s_loss, proj = model(samples, aux_targets) loss = torch.sum(-targets * (1e-8 + outputs.softmax(dim=-1)).log(), dim=-1).mean() loss_value = loss.item() loss += 1. * (r_loss + 1. * s_loss) if not math.isfinite(loss.item()): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) optimizer.zero_grad() # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr( optimizer, 'is_second_order') and optimizer.is_second_order loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order) torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.meters['r'].update(r_loss.item(), n=targets.shape[0]) # metric_logger.meters['p'].update(proj.item(), n=targets.shape[0]) metric_logger.meters['s'].update(s_loss.item(), n=targets.shape[0]) # metric_logger.meters['cos'].update(cos.item(), n=targets.shape[0]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train_domain_classifier( model: torch.nn.Module, train_dl: DataLoader, optimizer: torch.optim.Optimizer, scheduler: LambdaLR, validation_evaluator: MultiDatasetClassificationEvaluator, n_epochs: int, device: AnyStr, log_interval: int = 1, patience: int = 10, model_dir: str = "wandb_local", gradient_accumulation: int = 1, domain_name: str = ''): #best_loss = float('inf') best_acc = 0.0 patience_counter = 0 epoch_counter = 0 total = sum(len(dl) for dl in train_dls) # Main loop while epoch_counter < n_epochs: for i, batch in enumerate(tqdm(train_dl)): model.train() batch = tuple(t.to(device) for t in batch) input_ids = batch[0] masks = batch[1] labels = batch[2] # Testing with random domains to see if any effect #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) domains = batch[3] loss, logits = model(input_ids, attention_mask=masks, labels=domains) loss = loss / gradient_accumulation if i % gradient_accumulation == 0: loss.backward() optimizer.step() optimizer.zero_grad() if scheduler is not None: scheduler.step() gc.collect() # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) print(f"Validation acc: {acc}") # Saving the best model and early stopping #if val_loss < best_loss: if acc > best_acc: best_model = model.state_dict() best_acc = acc torch.save( model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_domainclassifier_{domain_name}.pth' ) patience_counter = 0 else: patience_counter += 1 # Stop training once we have lost patience if patience_counter == patience: break gc.collect() epoch_counter += 1
def train_one_epoch(args, model: torch.nn.Module, criterion: torch.nn.Module, dataloader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 50 for samples, targets, support_images, support_class_ids, support_targets in metric_logger.log_every(dataloader, print_freq, header): # * Sample Support Categories; # * Filters Targets (only keep GTs within support categories); # * Samples Support Images and Targets targets, support_images, support_class_ids, support_targets = \ sample_support_categories(args, targets, support_images, support_class_ids, support_targets) samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] support_images = support_images.to(device) support_class_ids = support_class_ids.to(device) support_targets = [{k: v.to(device) for k, v in t.items()} for t in support_targets] outputs = model(samples, targets=targets, supp_samples=support_images, supp_class_ids=support_class_ids, supp_targets=support_targets) loss_dict = criterion(outputs) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is NaN - {}. \nTraining terminated unexpectedly.\n".format(loss_value)) print("loss dict:") print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() if max_norm > 0: grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) else: grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) optimizer.step() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(grad_norm=grad_total_norm) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) del support_images del support_class_ids del support_targets del samples del targets del outputs del weight_dict del grad_total_norm del loss_value del losses del loss_dict del loss_dict_reduced del loss_dict_reduced_scaled del loss_dict_reduced_unscaled return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train_epoch(tree: ProtoTree, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, disable_derivative_free_leaf_optim: bool, device, log: Log = None, log_prefix: str = 'log_train_epochs', progress_prefix: str = 'Train Epoch' ) -> dict: tree = tree.to(device) # Make sure the model is in eval mode tree.eval() # Store info about the procedure train_info = dict() total_loss = 0. total_acc = 0. # Create a log if required log_loss = f'{log_prefix}_losses' nr_batches = float(len(train_loader)) with torch.no_grad(): _old_dist_params = dict() for leaf in tree.leaves: _old_dist_params[leaf] = leaf._dist_params.detach().clone() # Optimize class distributions in leafs eye = torch.eye(tree._num_classes).to(device) # Show progress on progress bar train_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc=progress_prefix+' %s'%epoch, ncols=0) # Iterate through the data set to update leaves, prototypes and network for i, (xs, ys) in train_iter: # Make sure the model is in train mode tree.train() # Reset the gradients optimizer.zero_grad() xs, ys = xs.to(device), ys.to(device) # Perform a forward pass through the network ys_pred, info = tree.forward(xs) # Learn prototypes and network with gradient descent. # If disable_derivative_free_leaf_optim, leaves are optimized with gradient descent as well. # Compute the loss if tree._log_probabilities: loss = F.nll_loss(ys_pred, ys) else: loss = F.nll_loss(torch.log(ys_pred), ys) # Compute the gradient loss.backward() # Update model parameters optimizer.step() if not disable_derivative_free_leaf_optim: #Update leaves with derivate-free algorithm #Make sure the tree is in eval mode tree.eval() with torch.no_grad(): target = eye[ys] #shape (batchsize, num_classes) for leaf in tree.leaves: if tree._log_probabilities: # log version update = torch.exp(torch.logsumexp(info['pa_tensor'][leaf.index] + leaf.distribution() + torch.log(target) - ys_pred, dim=0)) else: update = torch.sum((info['pa_tensor'][leaf.index] * leaf.distribution() * target)/ys_pred, dim=0) leaf._dist_params -= (_old_dist_params[leaf]/nr_batches) F.relu_(leaf._dist_params) #dist_params values can get slightly negative because of floating point issues. therefore, set to zero. leaf._dist_params += update # Count the number of correct classifications ys_pred_max = torch.argmax(ys_pred, dim=1) correct = torch.sum(torch.eq(ys_pred_max, ys)) acc = correct.item() / float(len(xs)) train_iter.set_postfix_str( f'Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}, Acc: {acc:.3f}' ) # Compute metrics over this batch total_loss+=loss.item() total_acc+=acc if log is not None: log.log_values(log_loss, epoch, i + 1, loss.item(), acc) train_info['loss'] = total_loss/float(i+1) train_info['train_accuracy'] = total_acc/float(i+1) return train_info
def resume_training(args: argparse.Namespace, hp: HParams, tier: int, model: Tier, optimizer: torch.optim.Optimizer, logger: logging.Logger) \ -> Tuple[Tier, torch.optim.Optimizer]: """ Loads the model specified in args.checkpoint_path to resume training from that point. Args: args (argparse.Namespace): parameters to set up the training. At least, args must contain: args = {"path_config": ..., "tier": ..., "checkpoint_path": ...} hp (HParams): hyperparameters for the model and other parameters (training, dataset, ...) tier (int): number of the tier to load. model (Tier): model where the weights will be loaded. optimizer (torch.optim.Optimizer): optimizer where the information will be loaded. logger (logging.Logger): to log general information about resuming the training. Returns: model (Tier) and optimizer (torch.optim.Optimizer) """ if not Path(args.checkpoint_path).exists(): logger.error( f"Path for resuming training {args.checkpoint_path} does not exist." ) raise Exception( f"Path for resuming training {args.checkpoint_path} does not exist." ) logger.info(f"Resuming training with weights from: {args.checkpoint_path}") checkpoint = torch.load(args.checkpoint_path) hp_chkpt = checkpoint["hp"] # Check if current hyperparameters and the ones from saved model are the same if hp_chkpt.audio != hp.audio: logger.warning("New params for audio are different from checkpoint. " "It will use new params.") if hp_chkpt.network != hp.network: logger.error( "New params for network structure are different from checkpoint.") raise Exception( "New params for network structure are different from checkpoint.") if checkpoint["tier_idx"] != tier: logger.error( f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})." ) raise Exception( f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})." ) if hp_chkpt.data != hp.data: logger.warning("New params for dataset are different from checkpoint. " "It will use new params.") if hp_chkpt.training != hp.training: logger.warning( "New params for training are different from checkpoint. " "It will use new params.") # epoch_chkpt = checkpoint["epoch"] # iterations_chkpt = checkpoint["iterations"] # total_iterations_chkpt = checkpoint["total_iterations"] model.load_state_dict(checkpoint["tier"]) optimizer.load_state_dict(checkpoint["optimizer"]) return model, optimizer
def run_epoch( data_iterator: DataLoader, model: nn.Module, optimizer: torch.optim.Optimizer = None, is_test: bool = False, is_metadata: bool = False ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: """ Runs an epoch of training (or testing) Parameters ---------- data_iterator: DataLoader object model: Pytorch model optimizer: Pytorch optimizer is_test: Set true if the function is used in testing the model with test set, so it returns confusion matrix too. is_metadata: Is there additional metadata besides images and labels (for model using metadata) Returns ------- Mean loss and accuracy of the epoch, and optionally confusion matrix """ loss = [] acc = [] confusion_m = torch.zeros((3, 3)) with tqdm(total=len(data_iterator)) as t: for idx, data in enumerate(data_iterator): t.update(1) labels = data[1].to(device) image = data[0].to(device) if is_metadata: metadata = data[2].to(device) if not model.training: with torch.no_grad(): model_out = model.forward(image, metadata) else: model_out = model.forward(image, metadata) else: if not model.training: with torch.no_grad(): model_out = model.forward(image) else: model_out = model.forward(image) indiv_loss = nn.functional.cross_entropy( model_out, labels, weight=torch.FloatTensor(weights_train).to(device)) loss.append(indiv_loss.item()) prediction = torch.argmax(model_out, dim=1) asd = np.equal(prediction.cpu().numpy(), labels.cpu().numpy()) accuracy = np.mean(asd) acc.append(accuracy) if is_test is True: for idx2, label in enumerate(labels): confusion_m[label.item(), prediction[idx2].item()] += 1 if model.training is True: optimizer.zero_grad() indiv_loss.backward() optimizer.step() if is_test is True: return np.mean(loss), np.mean(acc), confusion_m.numpy() return np.mean(loss), np.mean(acc), None
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0, accumulate_batches=1): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 100 num_samples = len(data_loader) warmup_scheduler = None if epoch == 0: warmup_factor = 1. / 1000 warmup_iters = min(1000, len(data_loader) - 1) warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) optimizer.zero_grad() for step, (samples, targets) in enumerate(metric_logger.log_every( data_loader, print_freq, header), start=1): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(samples) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) losses.backward() if ((step % accumulate_batches) == 0) or (step == num_samples): if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() optimizer.zero_grad() if warmup_scheduler is not None: warmup_scheduler.step() metric_logger.update( loss=loss_value, **loss_dict_reduced_scaled) #, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) print("Global scale", model.global_scale) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def attach(self, optimizer: torch.optim.Optimizer): r""" Attaches the privacy engine to the optimizer. Attaches to the ``PrivacyEngine`` an optimizer object,and injects itself into the optimizer's step. To do that it, 1. Validates that the model does not have unsupported layers. 2. Adds a pointer to this object (the ``PrivacyEngine``) inside the optimizer. 3. Moves optimizer's original ``step()`` function to ``original_step()``. 4. Monkeypatches the optimizer's ``step()`` function to call ``step()`` on the query engine automatically whenever it would call ``step()`` for itself. Args: optimizer: The optimizer to which the privacy engine will attach """ self.validator.validate(self.module) norm_clipper = ( # pyre-fixme[6]: Expected `float` for 1st param but got # `Union[List[float], float]`. clipping.ConstantFlatClipper(self.max_grad_norm) if not isinstance(self.max_grad_norm, list) # pyre-fixme[6]: Expected `List[float]` for 1st param but got # `Union[List[float], float]`. else clipping.ConstantPerLayerClipper(self.max_grad_norm)) if self.misc_settings.get("experimental", False): norm_clipper = clipping._Dynamic_Clipper_( # pyre-fixme[6]: Expected `List[float]` for 1st param but got # `List[Union[List[float], float]]`. [self.max_grad_norm], self.misc_settings.get("clip_per_layer", False), self.misc_settings.get("clipping_method", clipping.ClippingMethod.STATIC), self.misc_settings.get("clipping_ratio", 0.0), self.misc_settings.get("clipping_momentum", 0.0), ) self.clipper = PerSampleGradientClipper(self.module, norm_clipper, self.batch_first) def dp_step(self, closure=None): self.privacy_engine.step() self.original_step(closure) # Pyre doesn't like monkeypatching. But we'll do it anyway :) optimizer.privacy_engine = self # pyre-ignore optimizer.original_step = optimizer.step # pyre-ignore optimizer.step = types.MethodType(dp_step, optimizer) # pyre-ignore def virtual_step(self): self.privacy_engine.virtual_step() # pyre-ignore optimizer.virtual_step = types.MethodType(virtual_step, optimizer) # create a cross reference for detaching self.optimizer = optimizer # pyre-ignore
def step_optimizer( self, optimizer: torch.optim.Optimizer, clip_grads: Optional[Callable[[Iterator], None]] = None, auto_zero_grads: bool = True, scaler: Optional[Any] = None, # Should be torch.cuda.amp.GradScaler, but: # * other implementations might be possible # * requiring this type forces upgrades to PyTorch 1.6+ ) -> None: """ Perform a single optimization step. This function must be called once for each optimizer. However, the order of different optimizers' steps can be specified by calling this function in different orders. Also, gradient accumulation across iterations is performed by the Determined training loop by setting the experiment configuration field :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>`. Here is a code example: .. code-block:: python def clip_grads(params): torch.nn.utils.clip_grad_norm_(params, 0.0001), self.context.step_optimizer(self.opt1, clip_grads) Arguments: optimizer(``torch.optim.Optimizer``): Which optimizer should be stepped. clip_grads(a function, optional): This function should have one argument for parameters in order to clip the gradients. auto_zero_grads(bool, optional): Automatically zero out gradients automatically after stepping the optimizer. If false, you need to call ``optimizer.zero_grad()`` manually. Note that if :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>` is greater than 1, ``auto_zero_grads`` must be true. scaler(``torch.cuda.amp.GradScaler``, optional): The scaler to use for stepping the optimizer. This should be unset if not using AMP, and is necessary if ``wrap_scaler()`` was called directly. """ check.true( auto_zero_grads or self.hvd_config.aggregation_frequency == 1, "if optimizations.aggregation_frequency is larger than 1, " "you can only set auto_zero_grads to be true. ", ) if not self._should_communicate_and_update(): return # Communication needs to be synchronized so that is completed # before we apply gradient clipping and `step()`. In the case of APEX # this is called in backward() instead, so that it's inside the context # manager and before unscaling. if self.hvd_config.use and not self._use_apex: optimizer.synchronize() # type: ignore parameters = ([ p for group in optimizer.param_groups for p in group.get("params", []) ] if not self._use_apex else apex.amp.master_params(optimizer)) if self.hvd_config.average_aggregated_gradients: self._average_gradients( parameters=parameters, divisor=self.hvd_config.aggregation_frequency) if clip_grads is not None: if self._scaler and self.experimental._auto_amp: self._scaler.unscale_(optimizer) clip_grads(parameters) # For stepping the optimizer we will operate on the scaler passed # in, or fall back to the wrapped scaler (if any). if scaler is None and self.experimental._auto_amp: scaler = self._scaler if scaler: def step_fn() -> None: scaler.step(optimizer) # type: ignore else: step_fn = optimizer.step # type: ignore if self.hvd_config.use: with optimizer.skip_synchronize(): # type: ignore step_fn() else: step_fn() if auto_zero_grads: optimizer.zero_grad()
def step_optimizer( self, optimizer: torch.optim.Optimizer, # type: ignore clip_grads: Optional[Callable[[Iterator], None]] = None, auto_zero_grads: bool = True, ) -> None: """ Perform a single optimization step. This function must be called once for each optimizer. However, the order of different optimizers' steps can be specified by calling this function in different orders. Also, gradient accumulation across iterations is performed by the Determined training loop by setting the experiment configuration field :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>`. Here is a code example: .. code-block:: python def clip_grads(params): torch.nn.utils.clip_grad_norm_(params, 0.0001), self.context.step_optimizer(self.opt1, clip_grads) Arguments: optimizer(``torch.optim.Optimizer``): Which optimizer should be stepped. clip_grads(a function, optional): This function should have one argument for parameters in order to clip the gradients. auto_zero_grads(bool, optional): Automatically zero out gradients automatically after stepping the optimizer. If false, you need to call ``optimizer.zero_grad()`` manually. Note that if :ref:`optimizations.aggregation_frequency <config-aggregation-frequency>` is greater than 1, ``auto_zero_grads`` must be true. """ check.true( auto_zero_grads or self.hvd_config.aggregation_frequency > 1, "if optimizations.aggregation_frequency is larger than 1, " "you can only set auto_zero_grads to be true. ", ) if self._should_communicate_and_update(): # Communication needs to be synchronized so that is completed # before we apply gradient clipping and `step()`. if self.hvd_config.use and not self._use_amp: optimizer.synchronize() parameters = ( [p for group in optimizer.param_groups for p in group.get("params", [])] if not self._use_amp else apex.amp.master_params(optimizer) ) if self.hvd_config.average_aggregated_gradients: self._average_gradients( parameters=parameters, divisor=self.hvd_config.aggregation_frequency ) if clip_grads is not None: clip_grads(parameters) if self.hvd_config.use: with optimizer.skip_synchronize(): optimizer.step() else: optimizer.step() if auto_zero_grads: optimizer.zero_grad()
def _train_epoch( train_device: torch.device, model: torch.jit.ScriptModule, ddpmodel: ModelWrapperForDDP, model_path: Path, optim: torch.optim.Optimizer, assembler: tube.ChannelAssembler, stat: utils.MultiCounter, epoch: int, optim_params: OptimParams, sync_period: int, ) -> None: global _train_epoch_waiting_time #global _perfect_player pre_num_add = assembler.buffer_num_add() pre_num_sample = assembler.buffer_num_sample() sync_s = 0. num_sync = 0 t = time.time() time.sleep(_train_epoch_waiting_time) lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model for eid in range(optim_params.epoch_len): batch = assembler.sample(optim_params.batchsize) batch = utils.to_device(batch, train_device) loss, pred_pi, pred_v = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"], batch["pi_mask"], stat) # _perfect_player.loss(batch['m_h'], pred_pi, pred_v, batch["pi"], batch["v"], stat) loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), optim_params.grad_clip) optim.step() optim.zero_grad() if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0: sync_t0 = time.time() assembler.update_model(model.state_dict()) sync_s += time.time() - sync_t0 num_sync += 1 stat["loss"].feed(loss.detach().item()) stat["grad_norm"].feed(grad_norm) post_num_add = assembler.buffer_num_add() post_num_sample = assembler.buffer_num_sample() time_elapsed = time.time() - t delta_add = post_num_add - pre_num_add print("buffer add rate: %.2f / s" % (delta_add / time_elapsed)) delta_sample = post_num_sample - pre_num_sample if delta_sample > 8 * delta_add: # If the sample rate is not at least 8x the add rate, everything is fine. _train_epoch_waiting_time += time_elapsed else: _train_epoch_waiting_time = 0 print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed)) print( f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)" ) stat.summary(epoch) wandb.log({ "epoch": epoch, "loss": stat["loss"].mean(), "grad_norm": stat["grad_norm"].mean() }) stat.reset()