def __init__(self, optimize=True): # must be before Module.init since the field is used in __getattr__ Module.__init__(self) self._set_optimized(optimize) self._parameters = OrderedParameterDict(self) self._buffers = OrderedBufferDict(self) self._modules = OrderedModuleDict(self)
def __getattr__(self, attr): if self._has_method(attr): if attr in self.__class__._original_methods: original_method = self.__class__._original_methods[attr] script_method = self._get_method(attr) return functools.wraps(original_method)(script_method) else: return self._get_method(attr) return Module.__getattr__(self, attr)
def set_params_with_array( module: Module, x: np.ndarray, property_dict: Dict[str, TorchAttr] ) -> Module: r"""Set module parameters with values from numpy array. Args: module: Module with parameters to be set x: Numpy array with parameter values property_dict: Dictionary of parameter names and torch attributes as returned by module_to_array. Returns: Module: module with parameters updated in-place. Example: >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) >>> parameter_array, property_dict, bounds_out = module_to_array(mll) >>> parameter_array += 0.1 # perturb parameters (for example only) >>> mll = set_params_with_array(mll, parameter_array, property_dict) """ param_dict = OrderedDict(module.named_parameters()) start_idx = 0 for p_name, attrs in property_dict.items(): # Construct the new tensor if len(attrs.shape) == 0: # deal with scalar tensors end_idx = start_idx + 1 new_data = torch.tensor( x[start_idx], dtype=attrs.dtype, device=attrs.device ) else: end_idx = start_idx + np.prod(attrs.shape) new_data = torch.tensor( x[start_idx:end_idx], dtype=attrs.dtype, device=attrs.device ).view(*attrs.shape) start_idx = end_idx # Update corresponding parameter in-place. Disable autograd to update. param_dict[p_name].requires_grad_(False) param_dict[p_name].copy_(new_data) param_dict[p_name].requires_grad_(True) return module
def count_params(model: nn.Module) -> int: """ Count the number of parameters in a model. """ assert isinstance(model, nn.Module) return sum((parameter.nelement() for parameter in model.parameters()))
def init_weight(m: nn.Module): for name, param in m.named_parameters(): if 'bias' in name: continue nn.init.kaiming_normal_(param.data)
def _update_target_func(_target_func: nn.Module, _func: nn.Module): if _target_func is not None: assert _func is not None _target_func.load_state_dict(_func.state_dict())
def val_sanity_fit(model: nn.Module, val_loader, criterion, device, num_batches: int = None, log_interval: int = 100): """ Performs Sanity fit over valid loader. Use this to dummy check your val_step function. It does not calculate metrics, timing, or does checkpointing. It iterates over both train_loader and val_loader for given batches. Note: - It does not to loss.backward(). Args: model : A PyTorch Detr Model. val_loader : Validation loader. criterion : Loss function to be optimized. device : "cuda" or "cpu" num_batches : (optional) Integer To limit sanity fit over certain batches. Useful is data is too big even for sanity check. log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch. """ model = model.to(device) criterion = criterion.to(device) train_sanity_start = time.time() model.eval() last_idx = len(val_loader) - 1 criterion.eval() cnt = 0 for batch_idx, (inputs, targets) in enumerate(val_loader): last_batch = batch_idx == last_idx images = list(image.to(device) for image in inputs) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(images) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) cnt += 1 if last_batch or batch_idx % log_interval == 0: print( f"Train sanity check passed for batch till {batch_idx} batches" ) if num_batches is not None: if cnt >= num_batches: print(f"Done till {num_batches} train batches") print("All specified batches done") train_sanity_end = time.time() print( f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}" ) return True train_sanity_end = time.time() print("All specified batches done") print( f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}" ) return True
def train_inner(train_data: List[Tuple[List[int], int]], valid_data: List[Tuple[List[int], int]], model: Module, num_classes: int, epochs: int, evaluation_period: int, only_epoch_eval: bool, model_log_directory: str, learning_rate: float, batch_size: int, disable_scheduler: bool = False, scheduler_patience: int = 10, scheduler_factor: float = 0.1, gpu_device: Optional[torch.device] = None, clip_threshold: Optional[float] = None, max_doc_len: Optional[int] = None, word_dropout: float = 0, patience: int = 30, resume_training: bool = False, disable_tqdm: bool = False, tqdm_update_period: int = 1) -> None: # create signal handlers in case script receives termination signals # adapted from: https://stackoverflow.com/a/31709094 for specific_signal in [ signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT ]: signal.signal( specific_signal, partial(signal_handler, os.path.join(model_log_directory, "exit_code"))) # initialize general local variables updates_per_epoch = ceil(len(train_data) / batch_size) patience_reached = False # load model checkpoint if training is being resumed if resume_training and len( glob(os.path.join(model_log_directory, "*last*.pt"))) > 0: model_checkpoint = torch.load(glob( os.path.join(model_log_directory, "*last*.pt"))[0], map_location=torch.device("cpu")) model.load_state_dict( model_checkpoint["model_state_dict"]) # type: ignore if (model_checkpoint["update"] + # type: ignore 1) == updates_per_epoch: # type: ignore current_epoch: int = model_checkpoint["epoch"] + 1 # type: ignore current_update: int = 0 else: current_epoch: int = model_checkpoint["epoch"] # type: ignore current_update: int = model_checkpoint["update"] + 1 # type: ignore best_valid_loss: float = model_checkpoint[ # type: ignore "best_valid_loss"] # type: ignore best_valid_loss_index: int = model_checkpoint[ # type: ignore "best_valid_loss_index"] # type: ignore best_valid_acc: float = model_checkpoint[ # type: ignore "best_valid_acc"] # type: ignore # check for edge-case failures if current_epoch >= epochs: # log information at the end of training LOGGER.info("%s training epoch(s) previously completed, exiting" % epochs) # save exit-code and final processes save_exit_code(os.path.join(model_log_directory, "exit_code"), FINISHED_EPOCHS) return None elif best_valid_loss_index >= patience: LOGGER.info("Patience threshold previously reached, exiting") # save exit-code and final processes save_exit_code(os.path.join(model_log_directory, "exit_code"), PATIENCE_REACHED) return None else: resume_training = False current_epoch = 0 current_update = 0 best_valid_loss_index = 0 best_valid_loss = float("inf") best_valid_acc = float("-inf") # send model to correct device if gpu_device is not None: LOGGER.info("Transferring model to GPU device: %s" % gpu_device) model.to(gpu_device) # instantiate Adam optimizer LOGGER.info("Initializing Adam optimizer with LR: %s" % learning_rate) optimizer = Adam(model.parameters(), lr=learning_rate) # load optimizer state dictionary if resume_training: optimizer.load_state_dict( model_checkpoint["optimizer_state_dict"]) # type: ignore # instantiate negative log-likelihood loss which is summed over batch LOGGER.info("Using NLLLoss with sum reduction") loss_function = NLLLoss(weight=None, reduction="sum") # enable gradient clipping in-place if provided if clip_threshold is not None and clip_threshold > 0: LOGGER.info("Enabling gradient clipping with threshold: %s" % clip_threshold) enable_gradient_clipping(model, clip_threshold) # initialize learning rate scheduler if relevant if not disable_scheduler: LOGGER.info(("Initializing learning rate scheduler with " "factor=%s and patience=%s") % (scheduler_factor, scheduler_patience)) scheduler: Optional[ReduceLROnPlateau] scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience, verbose=True) if resume_training: scheduler.load_state_dict( model_checkpoint["scheduler_state_dict"]) # type: ignore else: scheduler = None # initialize tensorboard writer if provided LOGGER.info("Initializing tensorboard writer in directory: %s" % os.path.join(model_log_directory, "events")) writer = SummaryWriter(os.path.join(model_log_directory, "events")) # set numpy and torch RNG back to previous states before training if resume_training: if current_update == 0: np.random.set_state( model_checkpoint["numpy_last_random_state"]) # type: ignore else: np.random.set_state( model_checkpoint["numpy_epoch_random_state"]) # type: ignore torch.random.set_rng_state( model_checkpoint["torch_last_random_state"]) # type: ignore # loop over epochs for epoch in range(current_epoch, epochs): # set model on train mode and enable autograd model.train() torch.autograd.set_grad_enabled(True) # initialize loop variables if resume_training and epoch == current_epoch and current_update != 0: train_loss: Union[float, torch.Tensor] = model_checkpoint[ # type: ignore "train_loss"] # type: ignore samples_seen: int = model_checkpoint[ # type: ignore "samples_seen"] # type: ignore else: train_loss = 0. samples_seen = 0 # cache numpy random state for model checkpoint numpy_epoch_random_state = np.random.get_state() # main training loop LOGGER.info("Training SoPa++ model") with tqdm(shuffled_chunked_sorted(train_data, batch_size), position=0, mininterval=0.05, disable=disable_tqdm, unit="batch", desc="Training [Epoch %s/%s]" % (epoch + 1, epochs)) as train_tqdm_batches: # loop over train batches for update, batch in enumerate(train_tqdm_batches): # return to previous update and random state, if relevant if (resume_training and epoch == current_epoch and current_update != 0): if update < current_update: continue elif update == current_update: np.random.set_state(model_checkpoint[ # type: ignore "numpy_last_random_state"]) # type: ignore # create batch object and parse out gold labels batch, gold = Batch( [x[0] for x in batch], model.embeddings, # type: ignore to_cuda(gpu_device), word_dropout, max_doc_len), [x[1] for x in batch] # find aggregate loss across samples in batch train_batch_loss = train_batch(model, batch, num_classes, gold, optimizer, loss_function, gpu_device) # add batch loss to train_loss train_loss += train_batch_loss # type: ignore # increment samples seen samples_seen += batch.size() # update tqdm progress bar if (update + 1) % tqdm_update_period == 0 or ( update + 1) == len(train_tqdm_batches): train_tqdm_batches.set_postfix( batch_loss=train_batch_loss.item() / batch.size()) # start evaluation routine if (not only_epoch_eval and (update + 1) % evaluation_period == 0) or (update + 1) == len(train_tqdm_batches): # update tqdm batches counter train_tqdm_batches.update() # set valid loss to zero update_number = (epoch * updates_per_epoch) + (update + 1) valid_loss: Union[float, torch.Tensor] = 0. # set model on eval mode and disable autograd model.eval() torch.autograd.set_grad_enabled(False) # compute mean train loss over updates and accuracy # NOTE: mean_train_loss contains stochastic noise LOGGER.info("Evaluating SoPa++ on training set") train_loss = cast(torch.Tensor, train_loss) mean_train_loss = train_loss.item() / samples_seen train_acc = evaluate_metric(model, train_data, batch_size, gpu_device, accuracy_score, max_doc_len) # add training loss data writer.add_scalar("loss/train_loss", mean_train_loss, update_number) writer.add_scalar("accuracy/train_accuracy", train_acc, update_number) # add named parameter data for name, param in model.named_parameters(): writer.add_scalar("parameter_mean/" + name, param.detach().mean(), update_number) writer.add_scalar("parameter_std/" + name, param.detach().std(), update_number) if param.grad is not None: writer.add_scalar("gradient_mean/" + name, param.grad.detach().mean(), update_number) writer.add_scalar("gradient_std/" + name, param.grad.detach().std(), update_number) # loop over static valid set LOGGER.info("Evaluating SoPa++ on validation set") with tqdm(chunked_sorted(valid_data, batch_size), position=0, mininterval=0.05, disable=disable_tqdm, unit="batch", desc="Validating [Epoch %s/%s] [Batch %s/%s]" % (epoch + 1, epochs, update + 1, updates_per_epoch)) as valid_tqdm_batches: for valid_update, batch in enumerate( valid_tqdm_batches): # create batch object and parse out gold labels batch, gold = Batch( [x[0] for x in batch], model.embeddings, # type: ignore to_cuda(gpu_device), 0., max_doc_len), [x[1] for x in batch] # find aggregate loss across valid samples in batch valid_batch_loss = compute_loss( model, batch, num_classes, gold, loss_function, gpu_device) # add batch loss to valid_loss valid_loss += valid_batch_loss # type: ignore if (valid_update + 1) % tqdm_update_period == 0 or ( valid_update + 1) == len(valid_tqdm_batches): valid_tqdm_batches.set_postfix( batch_loss=valid_batch_loss.item() / batch.size()) # compute mean valid loss and accuracy valid_loss = cast(torch.Tensor, valid_loss) mean_valid_loss = valid_loss.item() / len(valid_data) valid_acc = evaluate_metric(model, valid_data, batch_size, gpu_device, accuracy_score, max_doc_len) # set model on train mode and enable autograd model.train() torch.autograd.set_grad_enabled(True) # add valid loss data to tensorboard writer.add_scalar("loss/valid_loss", mean_valid_loss, update_number) writer.add_scalar("accuracy/valid_accuracy", valid_acc, update_number) # log out report of current evaluation state LOGGER.info("Epoch: {}/{}, Batch: {}/{}".format( epoch + 1, epochs, (update + 1), updates_per_epoch)) LOGGER.info("Mean training loss: {:.3f}, " "Training accuracy: {:.3f}%".format( mean_train_loss, train_acc * 100)) LOGGER.info("Mean validation loss: {:.3f}, " "Validation accuracy: {:.3f}%".format( mean_valid_loss, valid_acc * 100)) # apply learning rate scheduler after evaluation if scheduler is not None: scheduler.step(valid_loss) # check for loss improvement and save model if necessary # optionally increment patience counter or stop training # NOTE: loss values are summed over all data (not mean) if valid_loss.item() < best_valid_loss: # log information and update records LOGGER.info("New best validation loss") if valid_acc > best_valid_acc: best_valid_acc = valid_acc LOGGER.info("New best validation accuracy") # update patience related diagnostics best_valid_loss = valid_loss.item() best_valid_loss_index = 0 LOGGER.info("Patience counter: %s/%s" % (best_valid_loss_index, patience)) # find previous best checkpoint(s) legacy_checkpoints = glob( os.path.join(model_log_directory, "*_best_*.pt")) # save new best checkpoint model_save_file = os.path.join( model_log_directory, "spp_checkpoint_best_{}_{}.pt".format( epoch, (update + 1))) LOGGER.info("Saving best checkpoint: %s" % model_save_file) save_checkpoint(epoch, update, samples_seen, model, optimizer, scheduler, numpy_epoch_random_state, train_loss.item(), best_valid_loss, best_valid_loss_index, best_valid_acc, model_save_file) # delete previous best checkpoint(s) for legacy_checkpoint in legacy_checkpoints: os.remove(legacy_checkpoint) else: # update patience related diagnostics best_valid_loss_index += 1 LOGGER.info("Patience counter: %s/%s" % (best_valid_loss_index, patience)) # create hook to exit training if patience reached if best_valid_loss_index == patience: patience_reached = True # find previous last checkpoint(s) legacy_checkpoints = glob( os.path.join(model_log_directory, "*_last_*.pt")) # save latest checkpoint model_save_file = os.path.join( model_log_directory, "spp_checkpoint_last_{}_{}.pt".format( epoch, (update + 1))) LOGGER.info("Saving last checkpoint: %s" % model_save_file) save_checkpoint(epoch, update, samples_seen, model, optimizer, scheduler, numpy_epoch_random_state, train_loss.item(), best_valid_loss, best_valid_loss_index, best_valid_acc, model_save_file) # delete previous last checkpoint(s) for legacy_checkpoint in legacy_checkpoints: os.remove(legacy_checkpoint) # hook to stop training in case patience was reached # if it was reached strictly before last epoch and update if patience_reached: if not (epoch == max(range(epochs)) and (update + 1) == len(train_tqdm_batches)): LOGGER.info("Patience threshold reached, " "stopping training") # save exit-code and final processes save_exit_code( os.path.join(model_log_directory, "exit_code"), PATIENCE_REACHED) return None # log information at the end of training LOGGER.info("%s training epoch(s) completed, stopping training" % epochs) # save exit-code and final processes save_exit_code(os.path.join(model_log_directory, "exit_code"), FINISHED_EPOCHS)
def train_on_loader(model: nn.Module, train_gen: DataLoader, val_gen: Optional[DataLoader], loss_fn: Any, optimizer: Optimizer, n_epochs: int, batch_first: bool = False, device: Optional[torch.device] = torch.device('cpu'), callbacks: Optional[List[Callback]] = None, before_step=None, verbosity: int = 2) -> ModelHistory: """Trains a model using data from a DataLoader. # Arguments model: The PyTorch model. train_gen: A DataLoader containing the training data. val_gen: A DataLoader containing the validation data. loss_fn: The loss function from which gradients are computed. Its expected signature is `loss_fn(model_output, y_true)`. optimizer: The optimizer used in the backpropagation step. n_epochs: How many passes should be performed over the train_gen. batch_first: For sequential data, if True data is expected to have the layout `[seq_len, batch_size, *]`, otherwise `[batch_size, seq_len, *]`. device: callbacks: List of utility callbacks to help training the model. verbosity: 0: silent, 1:show epoch progress bar, 2: show batch progress bar. # Return A ModelHistory object representing the model training history. """ callbacks_container = CallbacksContainer(callbacks or []) batch_index = 0 if batch_first else 1 model_history = ModelHistory(model) epoch_iterator = range(1, n_epochs + 1) if verbosity == 1: epoch_iterator = tqdm.tqdm(epoch_iterator, desc='Epoch') elif verbosity == 2: callbacks_container.append(ProgressBar(len(train_gen), n_epochs)) for epoch in epoch_iterator: model.train() callbacks_container.on_epoch_begin(epoch, model_history) epoch_loss = 0 seen_samples = 0 training_metrics = defaultdict(int) for batch_id, batch_data in enumerate(train_gen): callbacks_container.on_batch_begin(batch_id, model_history) # even if batch_data = [x, y], batch_features = [x] and batch_y = [y] batch_features: list = batch_data[:-1] batch_labels = batch_data[-1] batch_features = [ _move_to_device(ft, device) for ft in batch_features ] batch_labels = batch_labels.to(device) optimizer.zero_grad() output = model(*batch_features) loss = loss_fn(output, batch_labels) loss.backward() if before_step: before_step(model, loss, optimizer) optimizer.step() # All feature matrices should have the same amount of sample entries, # hence we can take any of them to figure out the batch size n_samples = batch_features[0].size(batch_index) seen_samples += n_samples epoch_loss += loss.item() # Accumulating metrics and losses for the current epoch batch_metrics = model.metric(output, batch_labels) for m_name, m_value in batch_metrics.items(): training_metrics[m_name] += m_value training_metrics['loss'] = epoch_loss / (batch_id + 1) # Normalizing metrics up to the current batch to display in the progress bar model_history.append_batch_data( _normalize_metrics(training_metrics, seen_samples)) callbacks_container.on_batch_end(batch_id, model_history) model_history.append_trn_logs( _normalize_metrics(training_metrics, seen_samples)) if val_gen: val_logs = evaluate_on_loader(model, val_gen, loss_fn, batch_first, device, verbosity=0) # Adding the val_ prefix and storing metrics over the entire validation data val_logs = { 'val_' + m_name: m_value for m_name, m_value in val_logs.items() } model_history.append_dev_logs(val_logs) callbacks_container.on_epoch_end(epoch, model_history) if model_history.should_stop_training(): break model_history.close(n_epochs) callbacks_container.on_train_end() return model_history
def val_step(model: nn.Module, val_loader, criterion, device, num_batches: int = None, log_interval: int = 100): """ Performs one step of validation. Calculates loss, forward pass and returns metrics. Args: model : PyTorch Detr Model. val_loader : Validation loader. criterion : Detr Loss function to be optimized. device : "cuda" or "cpu" num_batches : (optional) Integer To limit validation to certain number of batches. log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch. """ model = model.to(device) start_val_step = time.time() last_idx = len(val_loader) - 1 batch_time_m = utils.AverageMeter() cnt = 0 model.eval() criterion.eval() batch_start = time.time() metrics = OrderedDict() total_loss = utils.AverageMeter() bbox_loss = utils.AverageMeter() giou_loss = utils.AverageMeter() labels_loss = utils.AverageMeter() with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(val_loader): last_batch = batch_idx == last_idx images = list(image.to(device) for image in inputs) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(images) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) cnt += 1 total_loss.update(loss.item()) bbox_loss.update(loss_dict["loss_bbox"].item()) giou_loss.update(loss_dict["loss_giou"].item()) labels_loss.update(loss_dict["loss_ce"].item()) batch_time_m.update(time.time() - batch_start) batch_start = time.time() if last_batch or batch_idx % log_interval == 0: # If we reach the log intervel print( "Batch Validation Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) " .format(batch_time=batch_time_m, )) if num_batches is not None: if cnt >= num_batches: end_val_step = time.time() metrics["total_loss"] = total_loss.avg metrics["bbox_loss"] = bbox_loss.avg metrics["giou_loss"] = giou_loss.avg metrics["labels_loss"] = labels_loss.avg print(f"Done till {num_batches} Validation batches") print( f"Time taken for validation step = {end_val_step - start_val_step} sec" ) return metrics end_val_step = time.time() metrics["total_loss"] = total_loss.avg metrics["bbox_loss"] = bbox_loss.avg metrics["giou_loss"] = giou_loss.avg metrics["labels_loss"] = labels_loss.avg print( f"Time taken for validation step = {end_val_step - start_val_step} sec" ) return metrics
def train(model: nn.Module, data: Union[MoleculeDataset, List[MoleculeDataset]], loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: Namespace, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None, chunk_names: bool = False, val_smiles: List[str] = None, test_smiles: List[str] = None) -> int: """ Trains a model for an epoch. :param model: Model. :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe). :param loss_func: Loss function. :param optimizer: An Optimizer. :param scheduler: A learning rate scheduler. :param args: Arguments. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for printing intermediate results. :param writer: A tensorboardX SummaryWriter. :param chunk_names: Whether to train on the data in chunks. In this case, data must be a list of paths to the data chunks. :param val_smiles: Validation smiles strings without targets. :param test_smiles: Test smiles strings without targets, used for adversarial setting. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() if args.dataset_type == 'bert_pretraining': features_loss = nn.MSELoss() if chunk_names: for path, memo_path in tqdm(data, total=len(data)): featurization.SMILES_TO_FEATURES = dict() if os.path.isfile(memo_path): found_memo = True with open(memo_path, 'rb') as f: featurization.SMILES_TO_FEATURES = pickle.load(f) else: found_memo = False with open(path, 'rb') as f: chunk = pickle.load(f) if args.moe: for source in chunk: source.shuffle() else: chunk.shuffle() n_iter = train(model=model, data=chunk, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, logger=logger, writer=writer, chunk_names=False, val_smiles=val_smiles, test_smiles=test_smiles) if not found_memo: with open(memo_path, 'wb') as f: pickle.dump(featurization.SMILES_TO_GRAPH, f, protocol=pickle.HIGHEST_PROTOCOL) return n_iter if not args.moe: data.shuffle() loss_sum, iter_count = 0, 0 if args.adversarial: if args.moe: train_smiles = [] for d in data: train_smiles += d.smiles() else: train_smiles = data.smiles() train_val_smiles = train_smiles + val_smiles d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 if args.moe: test_smiles = list(test_smiles) random.shuffle(test_smiles) train_smiles = [] for d in data: d.shuffle() train_smiles.append(d.smiles()) num_iters = min(len(test_smiles), min([len(d) for d in data])) elif args.maml: num_iters = args.maml_batches_per_epoch * args.maml_batch_size model.zero_grad() maml_sum_loss = 0 else: num_iters = len(data) if args.last_batch else len( data) // args.batch_size * args.batch_size if args.parallel_featurization: batch_queue = Queue(args.batch_queue_max_size) exit_queue = Queue(1) batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, args.batch_size, exit_queue, args.last_batch)) batch_process.start() currently_loaded_batches = [] iter_size = 1 if args.maml else args.batch_size for i in trange(0, num_iters, iter_size): if args.moe: if not args.batch_domain_encs: model.compute_domain_encs( train_smiles) # want to recompute every batch mol_batch = [ MoleculeDataset(d[i:i + args.batch_size]) for d in data ] train_batch, train_targets = [], [] for b in mol_batch: tb, tt = b.smiles(), b.targets() train_batch.append(tb) train_targets.append(tt) test_batch = test_smiles[i:i + args.batch_size] loss = model.compute_loss(train_batch, train_targets, test_batch) model.zero_grad() loss_sum += loss.item() iter_count += len(mol_batch) elif args.maml: task_train_data, task_test_data, task_idx = data.sample_maml_task( args) mol_batch = task_test_data smiles_batch, features_batch, target_batch = task_train_data.smiles( ), task_train_data.features(), task_train_data.targets(task_idx) # no mask since we only picked data points that have the desired target targets = torch.Tensor(target_batch).unsqueeze(1) if next(model.parameters()).is_cuda: targets = targets.cuda() preds = model(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) grad = torch.autograd.grad( loss, [p for p in model.parameters() if p.requires_grad]) theta = [ p for p in model.named_parameters() if p[1].requires_grad ] # comes in same order as grad theta_prime = { p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta) } for name, nongrad_param in [ p for p in model.named_parameters() if not p[1].requires_grad ]: theta_prime[name] = nongrad_param + torch.zeros( nongrad_param.size()).to(nongrad_param) else: # Prepare batch if args.parallel_featurization: if len(currently_loaded_batches) == 0: currently_loaded_batches = batch_queue.get() mol_batch, featurized_mol_batch = currently_loaded_batches.pop( ) else: if not args.last_batch and i + args.batch_size > len(data): break mol_batch = MoleculeDataset(data[i:i + args.batch_size]) smiles_batch, features_batch, target_batch = mol_batch.smiles( ), mol_batch.features(), mol_batch.targets() if args.dataset_type == 'bert_pretraining': batch = mol2graph(smiles_batch, args) mask = mol_batch.mask() batch.bert_mask(mask) mask = 1 - torch.FloatTensor(mask) # num_atoms features_targets = torch.FloatTensor( target_batch['features'] ) if target_batch[ 'features'] is not None else None # num_molecules x features_size targets = torch.FloatTensor(target_batch['vocab']) # num_atoms if args.bert_vocab_func == 'feature_vector': mask = mask.reshape(-1, 1) else: targets = targets.long() else: batch = smiles_batch mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) if next(model.parameters()).is_cuda: mask, targets = mask.cuda(), targets.cuda() if args.dataset_type == 'bert_pretraining' and features_targets is not None: features_targets = features_targets.cuda() if args.class_balance: class_weights = [] for task_num in range(data.num_tasks()): class_weights.append( args.class_weights[task_num][targets[:, task_num].long()]) class_weights = torch.stack( class_weights).t() # num_molecules x num_tasks else: class_weights = torch.ones(targets.shape) if args.cuda: class_weights = class_weights.cuda() # Run model model.zero_grad() if args.parallel_featurization: previous_graph_input_mode = model.encoder.graph_input model.encoder.graph_input = True # force model to accept already processed input preds = model(featurized_mol_batch, features_batch) model.encoder.graph_input = previous_graph_input_mode else: preds = model(batch, features_batch) if args.dataset_type == 'regression_with_binning': preds = preds.view(targets.size(0), targets.size(1), -1) targets = targets.long() loss = 0 for task in range(targets.size(1)): loss += loss_func( preds[:, task, :], targets[:, task] ) * class_weights[:, task] * mask[:, task] # for some reason cross entropy doesn't support multi target loss = loss.sum() / mask.sum() else: if args.dataset_type == 'unsupervised': targets = targets.long().reshape(-1) if args.dataset_type == 'bert_pretraining': features_preds, preds = preds['features'], preds['vocab'] if args.dataset_type == 'kernel': preds = preds.view(int(preds.size(0) / 2), 2, preds.size(1)) preds = model.kernel_output_layer(preds) loss = loss_func(preds, targets) * class_weights * mask if args.predict_features_and_task: loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \ / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1)) else: loss = loss.sum() / mask.sum() if args.dataset_type == 'bert_pretraining' and features_targets is not None: loss += features_loss(features_preds, features_targets) loss_sum += loss.item() iter_count += len(mol_batch) if args.maml: model_prime = build_model(args=args, params=theta_prime) smiles_batch, features_batch, target_batch = task_test_data.smiles( ), task_test_data.features(), [ t[task_idx] for t in task_test_data.targets() ] # no mask since we only picked data points that have the desired target targets = torch.Tensor([[t] for t in target_batch]) if next(model_prime.parameters()).is_cuda: targets = targets.cuda() model_prime.zero_grad() preds = model_prime(smiles_batch, features_batch) loss = loss_func(preds, targets) loss = loss.sum() / len(smiles_batch) loss_sum += loss.item() iter_count += len( smiles_batch ) # TODO check that this makes sense, but it's just for display maml_sum_loss += loss if i % args.maml_batch_size == args.maml_batch_size - 1: maml_sum_loss.backward() optimizer.step() model.zero_grad() maml_sum_loss = 0 else: loss.backward() if args.max_grad_norm is not None: clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() if args.adjust_weight_decay: current_pnorm = compute_pnorm(model) if current_pnorm < args.pnorm_target: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i]['weight_decay'] = max( 0, optimizer.param_groups[i]['weight_decay'] - args.adjust_weight_decay_step) else: for i in range(len(optimizer.param_groups)): optimizer.param_groups[i][ 'weight_decay'] += args.adjust_weight_decay_step if isinstance(scheduler, NoamLR): scheduler.step() if args.adversarial: for _ in range(args.gan_d_per_g): train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) d_loss, gp_norm = model.train_D(train_val_smiles_batch, test_smiles_batch) train_val_smiles_batch = random.sample(train_val_smiles, args.batch_size) test_smiles_batch = random.sample(test_smiles, args.batch_size) g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch) # we probably only care about the g_loss honestly d_loss_sum += d_loss * args.batch_size gp_norm_sum += gp_norm * args.batch_size g_loss_sum += g_loss * args.batch_size n_iter += len(mol_batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count if args.adversarial: d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0 loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if args.adversarial: debug( f'D Loss = {d_loss_avg:.4e}, G Loss = {g_loss_avg:.4e}, GP Norm = {gp_norm_avg:.4}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) if args.parallel_featurization: exit_queue.put( 0) # dummy var to get the subprocess to know that we're done batch_process.join() return n_iter
def infer_model_device(model: nn.Module): """ infers model device as the device where the majority of parameters and buffers are stored """ device_stats = Counter( tensor.device for tensor in chain(model.parameters(), model.buffers()) if torch.is_tensor(tensor)) return max(device_stats, key=device_stats.get)
def train(args, model: nn.Module, criterion, *, params, train_loader, valid_loader, init_optimizer, use_cuda, n_epochs=None, patience=4, max_lr_changes=2) -> bool: lr = args.lr n_epochs = n_epochs or args.n_epochs params = list(params) optimizer = init_optimizer(args.optimizer, params, lr) run_root = Path(args.run_root) model_path = run_root / 'model-1.pt' best_model_path = run_root / 'best-model.pt' best_valid_loss = 0.0 if model_path.exists(): state = load_model(model, model_path) epoch = state['epoch'] step = state['step'] else: epoch = 1 step = 0 best_valid_loss = float('inf') lr_changes = 0 save = lambda ep, save_name: torch.save({ 'model': model.state_dict(), 'epoch': ep, 'step': step, 'best_valid_loss': best_valid_loss }, str(run_root / save_name)) report_each = 10 log = run_root.joinpath('train.log').open('at', encoding='utf8') valid_losses = [] lr_reset_epoch = epoch for epoch in range(epoch, n_epochs + 1): model.train() tq = tqdm.tqdm(total=(args.epoch_size or len(train_loader) * args.batch_size)) tq.set_description(f'Epoch {epoch}, lr {lr}') losses = [] tl = train_loader if args.epoch_size: tl = islice(tl, args.epoch_size // args.batch_size) try: mean_loss = 0 for i, (inputs, targets) in enumerate(tl): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() outputs = model(inputs) loss = _reduce_loss(criterion(outputs, targets)) batch_size = inputs.size(0) (batch_size * loss).backward() if (i + 1) % args.step == 0: optimizer.step() optimizer.zero_grad() step += 1 tq.update(batch_size) losses.append(loss.item()) mean_loss = np.mean(losses[-report_each:]) tq.set_postfix(loss=f'{mean_loss:.3f}') # if i and i % report_each == 0: # write_event(log, step, loss=mean_loss) tq.close() save(epoch + 1, f'model-{epoch}.pt') valid_metrics = validation(model, criterion, valid_loader, use_cuda, args.model) write_event(log, step, epoch, **valid_metrics) valid_loss = valid_metrics['valid_loss'] valid_losses.append(valid_loss) if valid_loss < best_valid_loss: best_valid_loss = valid_loss save(epoch + 1, 'best-model.pt') elif (patience and epoch - lr_reset_epoch > patience and min(valid_losses[-patience:]) > best_valid_loss): # "patience" epochs without improvement lr_changes +=1 if lr_changes > max_lr_changes: break lr /= 5 print(f'lr updated to {lr}') lr_reset_epoch = epoch optimizer = init_optimizer(args.optimizer, params, lr) except KeyboardInterrupt: tq.close() print('Ctrl+C, saving snapshot') save(epoch, 'model-interrupted.pt') print('done.') return False return True
def annotate_video(movie_file_path: str, dataset_path: str, output_path: str, model: nn.Module, device, max_frame: int = 100000, tracker_max_age: int = 10, plotter: utils.plotter_utils.VisdomPlotter = None, name: str = '', compute_track_mean: bool = False): filename = os.path.join(dataset_path, 'bbx.txt') print('Getting annotations from {}'.format(filename)) bbx_list = utils.read_file_to_list(filename) if bbx_list: bounding_boxes_list = bbx_list else: bounding_boxes_list = get_bounding_boxes(movie_file_path, max_frame=max_frame, tracker_max_age=tracker_max_age) print('Extracting ROI of the video.') cropped_image_list = get_cropped_images(movie_file_path, bounding_boxes_list, max_frame=max_frame) track_dict = get_track_dict(bounding_boxes_list) frame_dict = get_frame_dict(bounding_boxes_list) bbx_dict = get_bbx_dict(bounding_boxes_list) # Data transform data_transform = transforms.Compose([ transforms.ToTensor() ]) dataset = NumpyDataset(cropped_image_list, transform=data_transform) dataloader = torch.utils.data.DataLoader(dataset, num_workers=2, batch_size=100) print('Extracting features.') model = model.to(device) features = ml_utils.extract_features(dataloader, model, device) cluster_techniques_list = ['kmeans', 'spectral', 'hac'] tsne_features, tsne_chosen_samples = projection_utils.tsne_projection(features) pca_features, pca_chosen_samples = projection_utils.pca_projection(features) # Frame level clustering print('Performing frame level clustering.') for cluster_method in cluster_techniques_list: cluster_name = '{}_frame_level_{}'.format(name, cluster_method) predictions, data_dict = clustering.cluster_techniques(features, cluster_method, max_clusters=10) write_video(movie_file_path, output_path, predictions, frame_dict, name=cluster_name, max_frame=max_frame) plotter.scatter_plot(cluster_name + '_tsne', tsne_features, predictions[tsne_chosen_samples]) plotter.scatter_plot(cluster_name + '_pca', pca_features, predictions[pca_chosen_samples]) # Add ground truth if it exist gt_file_path = os.path.join(dataset_path, 'bbx_gt.txt') if os.path.isfile(gt_file_path): print('Creating ground truth video and plots.') bbx_to_gt_list = utils.read_file_to_list(gt_file_path) bbx_to_gt_dict = utils.list_to_dict(bbx_to_gt_list) groundtruth = [] gt_to_idx_dict = {} bbx_count = 0 for bbx in bounding_boxes_list: bbx_idx = bbx[2] gt = bbx_to_gt_dict[bbx_idx] if gt not in gt_to_idx_dict.keys(): gt_to_idx_dict[gt] = bbx_count bbx_count += 1 label = gt_to_idx_dict[gt] groundtruth.append(label) groundtruth = np.array(groundtruth) gt_name = '{}_gt'.format(name) write_video(movie_file_path, output_path, groundtruth, frame_dict, name=gt_name, max_frame=max_frame) plotter.scatter_plot(gt_name + '_tsne', tsne_features, groundtruth[tsne_chosen_samples]) plotter.scatter_plot(gt_name + '_pca', pca_features, groundtruth[pca_chosen_samples]) # Track level clustering if compute_track_mean: print('Performing track level clustering.') mean_features = [] track_to_idx_dict = {} for idx, track_idx in enumerate(track_dict.keys()): feature_track = features[track_dict[track_idx]] mean_features.append(np.mean(feature_track, axis=0)) track_to_idx_dict[track_idx] = idx mean_features = np.asarray(mean_features) for cluster_method in cluster_techniques_list: cluster_name = '{}_track_level_{}'.format(name, cluster_method) mean_predictions, data_dict = clustering.cluster_techniques(mean_features, cluster_method, max_clusters=10) predictions = [] for bbx_idx in bbx_dict.keys(): track_idx = track_to_idx_dict[bbx_dict[bbx_idx][0]] predictions.append(mean_predictions[track_idx]) predictions = np.array(predictions) write_video(movie_file_path, output_path, predictions, frame_dict, name=cluster_name, max_frame=max_frame) plotter.scatter_plot(cluster_name + '_tsne', tsne_features, predictions[tsne_chosen_samples]) plotter.scatter_plot(cluster_name + '_pca', pca_features, predictions[pca_chosen_samples])
def efficientnet_init_weights(model: nn.Module, init_fn=None): init_fn = init_fn or _init_weight_goog for n, m in model.named_modules(): init_fn(m, n)
def count_parameters(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad)
def _layer_flops(layer: nn.Module, layer_args: List[Any], y: Any) -> int: """ Computes the number of FLOPs required for a single layer. For common layers, such as Conv1d, the flop compute is implemented in this centralized place. For other layers, if it defines a method to compute flops with the signature below, we will use it to compute flops. Class MyModule(nn.Module): def flops(self, x): ... """ x = layer_args[0] # get layer type: typestr = layer.__repr__() layer_type = typestr[: typestr.find("(")].strip() batchsize_per_replica = get_batchsize_per_replica(x) flops = None # 1D convolution: if layer_type in ["Conv1d"]: # x shape is N x C x W out_w = int( (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] + 1 ) flops = ( batchsize_per_replica * layer.in_channels * layer.out_channels * layer.kernel_size[0] * out_w / layer.groups ) # 2D convolution: elif layer_type in ["Conv2d"]: out_h = int( (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] + 1 ) out_w = int( (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1] + 1 ) flops = ( batchsize_per_replica * layer.in_channels * layer.out_channels * layer.kernel_size[0] * layer.kernel_size[1] * out_h * out_w / layer.groups ) # learned group convolution: elif layer_type in ["LearnedGroupConv"]: conv = layer.conv out_h = int( (x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / conv.stride[0] + 1 ) out_w = int( (x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / conv.stride[1] + 1 ) count1 = _layer_flops(layer.relu, x) + _layer_flops(layer.norm, x) count2 = ( batchsize_per_replica * conv.in_channels * conv.out_channels * conv.kernel_size[0] * conv.kernel_size[1] * out_h * out_w / layer.condense_factor ) flops = count1 + count2 # non-linearities: elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]: flops = x.numel() # 2D pooling layers: elif layer_type in ["AvgPool2d", "MaxPool2d"]: in_h = x.size()[2] in_w = x.size()[3] if isinstance(layer.kernel_size, int): layer.kernel_size = (layer.kernel_size, layer.kernel_size) kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] out_h = 1 + int( (in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride ) out_w = 1 + int( (in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride ) flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops # adaptive avg pool2d # This is approximate and works only for downsampling without padding # based on aten/src/ATen/native/AdaptiveAveragePooling.cpp elif layer_type in ["AdaptiveAvgPool2d"]: in_h = x.size()[2] in_w = x.size()[3] if isinstance(layer.output_size, int): out_h, out_w = layer.output_size, layer.output_size elif len(layer.output_size) == 1: out_h, out_w = layer.output_size[0], layer.output_size[0] else: out_h, out_w = layer.output_size if out_h > in_h or out_w > in_w: raise ClassyProfilerNotImplementedError(layer) batchsize_per_replica = x.size()[0] num_channels = x.size()[1] kh = in_h - out_h + 1 kw = in_w - out_w + 1 kernel_ops = kh * kw flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops # linear layer: elif layer_type in ["Linear"]: weight_ops = layer.weight.numel() bias_ops = layer.bias.numel() if layer.bias is not None else 0 flops = ((x.numel() / x.size(-1)) if x.ndim > 2 else x.size(0)) * ( weight_ops + bias_ops ) # batch normalization / layer normalization: elif layer_type in [ "BatchNorm1d", "BatchNorm2d", "BatchNorm3d", "SyncBatchNorm", "LayerNorm", ]: flops = 2 * x.numel() # 3D convolution elif layer_type in ["Conv3d"]: out_t = int( (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) // layer.stride[0] + 1 ) out_h = int( (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) // layer.stride[1] + 1 ) out_w = int( (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2]) // layer.stride[2] + 1 ) flops = ( batchsize_per_replica * layer.in_channels * layer.out_channels * layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2] * out_t * out_h * out_w / layer.groups ) # 3D pooling layers elif layer_type in ["AvgPool3d", "MaxPool3d"]: in_t = x.size()[2] in_h = x.size()[3] in_w = x.size()[4] if isinstance(layer.kernel_size, int): layer.kernel_size = ( layer.kernel_size, layer.kernel_size, layer.kernel_size, ) if isinstance(layer.padding, int): layer.padding = (layer.padding, layer.padding, layer.padding) if isinstance(layer.stride, int): layer.stride = (layer.stride, layer.stride, layer.stride) kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2] out_t = 1 + int( (in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] ) out_h = 1 + int( (in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1] ) out_w = 1 + int( (in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2] ) flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops # adaptive avg pool3d # This is approximate and works only for downsampling without padding # based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp elif layer_type in ["AdaptiveAvgPool3d"]: in_t = x.size()[2] in_h = x.size()[3] in_w = x.size()[4] out_t = layer.output_size[0] out_h = layer.output_size[1] out_w = layer.output_size[2] if out_t > in_t or out_h > in_h or out_w > in_w: raise ClassyProfilerNotImplementedError(layer) batchsize_per_replica = x.size()[0] num_channels = x.size()[1] kt = in_t - out_t + 1 kh = in_h - out_h + 1 kw = in_w - out_w + 1 kernel_ops = kt * kh * kw flops = ( batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops ) elif layer_type in ["Dropout", "Identity"]: flops = 0 elif hasattr(layer, "flops"): # If the module already defines a method to compute flops with the signature # below, we use it to compute flops # # Class MyModule(nn.Module): # def flops(self, x): # ... # or # # Class MyModule(nn.Module): # def flops(self, x1, x2): # ... flops = layer.flops(*layer_args) if flops is None: raise ClassyProfilerNotImplementedError(layer) message = [ f"module type: {typestr}", f"input size: {get_shape(x)}", f"output size: {get_shape(y)}", f"params(M): {count_params(layer) / 1e6}", f"flops(M): {int(flops) / 1e6}", ] logging.debug("\t".join(message)) return int(flops)
def test(self, net: nn.Module, clean_data: CSVDataset, triggered_data: CSVDataset, clean_test_triggered_labels_data: CSVDataset, progress_bar_disable: bool = False, torch_dataloader_kwargs: dict = None) -> dict: """ Test the trained network :param net: the trained module to run the test data through :param clean_data: the clean Dataset :param triggered_data: the triggered Dataset, if None, not computed :param clean_test_triggered_labels_data: triggered part of the training dataset but with correct labels; see DataManger.load_data for more information. :param progress_bar_disable: if True, disables the progress bar :param torch_dataloader_kwargs: any keyword arguments to pass directly to PyTorch's DataLoader :return: a dictionary of the statistics on the clean and triggered data (if applicable) """ test_data_statistics = {} net.eval() pin_memory = False if self.device.type != 'cpu': pin_memory = True # drop_last=True is from: https://stackoverflow.com/questions/56576716 data_loader_kwargs_in = dict(batch_size=1, pin_memory=pin_memory, drop_last=True, shuffle=False) if torch_dataloader_kwargs: data_loader_kwargs_in.update(torch_dataloader_kwargs) logger.info('DataLoader[Test] kwargs=' + str(torch_dataloader_kwargs)) data_loader = DataLoader(clean_data, **data_loader_kwargs_in) # Test the classification accuracy on clean data only, for all labels. test_acc, test_n_total, test_n_correct, _ = _eval_acc( data_loader, net, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs) test_data_statistics['clean_accuracy'] = test_acc test_data_statistics['clean_n_total'] = test_n_total logger.info("Accuracy on clean test data: %0.02f" % (test_data_statistics['clean_accuracy'], )) if triggered_data is not None: # Test the classification accuracy on triggered data only, for all labels. # we set batch_size=1 b/c data_loader = DataLoader(triggered_data, batch_size=1, pin_memory=pin_memory) test_acc, test_n_total, test_n_correct, _ = _eval_acc( data_loader, net, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs) test_data_statistics['triggered_accuracy'] = test_acc test_data_statistics['triggered_n_total'] = test_n_total logger.info("Accuracy on triggered test data: %0.02f for n=%s" % (test_data_statistics['triggered_accuracy'], str(test_n_total))) if clean_test_triggered_labels_data is not None: # Test the classification accuracy on clean data for labels which have corresponding triggered examples. # For example, if an MNIST dataset was created with triggered examples only for labels 4 and 5, # then this dataset is the subset of data with labels 4 and 5 that don't have the triggers. data_loader = DataLoader(clean_test_triggered_labels_data, batch_size=1, pin_memory=pin_memory) test_acc, test_n_total, test_n_correct, _ = _eval_acc( data_loader, net, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs) test_data_statistics[ 'clean_test_triggered_label_accuracy'] = test_acc test_data_statistics[ 'clean_test_triggered_label_n_total'] = test_n_total logger.info( "Accuracy on clean-data-triggered-labels: %0.02f for n=%s" % (test_data_statistics['clean_test_triggered_label_accuracy'], str(test_n_total))) return test_data_statistics
def summary(model: nn.Module, input_data: INPUT_DATA_TYPE = None, *args: Any, batch_dim: Optional[int] = 0, branching: bool = True, col_names: Optional[Sequence[str]] = None, col_width: int = 25, depth: int = 3, device: Optional[torch.device] = None, dtypes: Optional[List[torch.dtype]] = None, verbose: int = 1, print_step: bool = True, print_func=print, **kwargs: Any) -> ModelStatistics: """ Summarize the given PyTorch model. Summarized information includes: 1) Layer names, 2) input/output shapes, 3) kernel shape, 4) # of parameters, 5) # of operations (Mult-Adds) Args: model (nn.Module): PyTorch model to summarize input_data (Sequence of Sizes or Tensors): Example input tensor of the model (dtypes inferred from model input). - OR - Shape of input data as a List/Tuple/torch.Size (dtypes must match model input, default is FloatTensors). Should NOT include batch size in the tuple. - OR - If input_data is not provided, no forward pass through the network is performed, and the provided model information is limited to layer names. batch_dim (int): Batch_dimension of input data. Default: 0 If batch_dim is None, the input data is assumed to contain the batch dimension. WARNING: in a future version of torch-summary, the default will change to None. branching (bool): Whether to use the branching layout for the printed output. Default: True col_names (Sequence[str]): Specify which columns to show in the output. Currently supported: ("input_size", "output_size", "num_params", "kernel_size", "mult_adds") If input_data is not provided, only "num_params" is used. Default: ("output_size", "num_params") col_width (int): Width of each column. Default: 25 depth (int): Number of nested layers to traverse (e.g. Sequentials). Default: 3 device (torch.Device): Uses this torch device for model and input_data. If not specified, uses result of torch.cuda.is_available(). Default: None dtypes (List[torch.dtype]): For multiple inputs, specify the size of both inputs, and also specify the types of each parameter here. Default: None verbose (int): 0 (quiet): No output 1 (default): Print model summary 2 (verbose): Show weight and bias layers in full detail Default: 1 *args, **kwargs: Other arguments used in `model.forward` function. Return: ModelStatistics object See torchsummary/model_statistics.py for more information. """ if col_names is None: col_names = ( "num_params", ) if input_data is None else DEFAULT_COLUMN_NAMES validate_user_params(input_data, col_names, verbose) input_size = [] # type: CORRECTED_INPUT_SIZE_TYPE summary_list = [] # type: List[LayerInfo] hooks = None if input_data is None else [ ] # type: Optional[List[RemovableHandle]] idx = {} # type: Dict[int, int] apply_hooks(model, model, batch_dim, depth, summary_list, idx, hooks) if input_data is not None: if device is None: device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") x, input_size = process_input_data(input_data, batch_dim, device, dtypes) args, kwargs = set_device(args, device), set_device(kwargs, device) try: with torch.no_grad(): _ = model.to(device)(*x, *args, **kwargs) # type: ignore except Exception: executed_layers = [ layer for layer in summary_list if layer.executed ] print_func( "Failed to run torchsummary, executed layers up to: {}".format( executed_layers)) raise finally: if hooks is not None: for hook in hooks: hook.remove() formatting = FormattingOptions(branching, depth, verbose, col_names, col_width) formatting.set_layer_name_width(summary_list) results = ModelStatistics(summary_list, input_size, formatting, print_step) if verbose > Verbosity.QUIET.value: print_func(results) return results
def evaluate_classifier(model: nn.Module, test_dl: DataLoader, loss_func: Callable, classes: List[int] = [0, 1]) -> None: "evaluate a pytorch graph model for classification" y_pred = [] y_true = [] y_prob = [] prob_arr = [] test_loss = 0 for bg, labels in test_dl: model.eval() bg.set_e_initializer(dgl.init.zero_initializer) bg.set_n_initializer(dgl.init.zero_initializer) logit = model(bg) probs = torch.softmax(logit, 1).detach().numpy() prob_arr.append(probs) predictions = np.argmax(probs, 1) y_pred += list(predictions) y_true += list(labels) y_prob += list(probs[:, 1]) loss = loss_func(logit, labels) test_loss += loss.detach().item() print('test_loss: ', test_loss / len(test_dl)) print('accuracy: ', accuracy_score(y_true, y_pred)) print('classification report: \n', classification_report(y_true, y_pred)) if len(classes) == 2: print('roc-auc: ', roc_auc_score(y_true, y_prob)) print('bootstrapped roc-auc: ', bs_roc_auc_score(y_true, y_prob)) else: y_test = label_binarize(y_true, classes=classes) n_classes = y_test.shape[1] prob_arr = np.concatenate([x for x in prob_arr], axis=0) # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() roc_auc = dict() bs_roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], _ = roc_curve(y_test[:, i], prob_arr[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) bs_roc_auc[i] = bs_roc_auc_score(y_test[:, i], prob_arr[:, i]) fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), prob_arr.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) bs_roc_auc['micro'] = bs_roc_auc_score(y_test.ravel(), prob_arr.ravel()) print("micro auc score and score for each class: ") for key in roc_auc: print(key, ' : ', roc_auc[key]) print("bootstrapped micro auc score and score for each class: ") for key in bs_roc_auc: print(key, ' : ', bs_roc_auc[key])
def free(self, module: nn.Module): for p in module.parameters(): p.requires_grad = True
def _build_opt(self, model: nn.Module) -> optim.Optimizer: if isinstance(self.opt, str): self.opt = self._interp_opt( self.opt) # Backwards compatability with pre-v0.3.1 saves return self.opt(model.parameters(), **self.opt_args)
def frozen(self, module: nn.Module): for p in module.parameters(): p.requires_grad = False
def train_step( model: nn.Module, train_loader, criterion, device: str, optimizer, scheduler=None, num_batches: int = None, log_interval: int = 100, scaler=None, ): """ Performs one step of training. Calculates loss, forward pass, computes gradient and returns metrics. Args: model : PyTorch Detr Model. train_loader : Train loader. device : "cuda" or "cpu" criterion : Detr Loss function to be optimized. optimizer : Torch optimizer to train. scheduler : Learning rate scheduler. num_batches : (optional) Integer To limit training to certain number of batches. log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch. scaler: (optional) Pass torch.cuda.amp.GradScaler() for fp16 precision Training. """ model = model.to(device) criterion = criterion.to(device) start_train_step = time.time() model.train() last_idx = len(train_loader) - 1 batch_time_m = utils.AverageMeter() criterion.train() cnt = 0 batch_start = time.time() metrics = OrderedDict() total_loss = utils.AverageMeter() bbox_loss = utils.AverageMeter() giou_loss = utils.AverageMeter() labels_loss = utils.AverageMeter() for batch_idx, (inputs, targets) in enumerate(train_loader): last_batch = batch_idx == last_idx images = list(image.to(device) for image in inputs) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] optimizer.zero_grad() if scaler is not None: with amp.autocast(): outputs = model(images) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) scaler.scale(loss).backward() # Step using scaler.step() scaler.step(optimizer) # Update for next iteration scaler.update() else: outputs = model(images) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) loss.backward() optimizer.step() if scheduler is not None: scheduler.step() cnt += 1 total_loss.update(loss.item()) bbox_loss.update(loss_dict["loss_bbox"].item()) giou_loss.update(loss_dict["loss_giou"].item()) labels_loss.update(loss_dict["loss_ce"].item()) batch_time_m.update(time.time() - batch_start) batch_start = time.time() if last_batch or batch_idx % log_interval == 0: # If we reach the log intervel print( "Batch Train Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) " .format(batch_time=batch_time_m, )) if num_batches is not None: if cnt >= num_batches: end_train_step = time.time() metrics["total_loss"] = total_loss.avg metrics["bbox_loss"] = bbox_loss.avg metrics["giou_loss"] = giou_loss.avg metrics["labels_loss"] = labels_loss.avg print(f"Done till {num_batches} train batches") print( f"Time taken for Training step = {end_train_step - start_train_step} sec" ) return metrics end_train_step = time.time() metrics["total_loss"] = total_loss.avg metrics["bbox_loss"] = bbox_loss.avg metrics["giou_loss"] = giou_loss.avg metrics["labels_loss"] = labels_loss.avg print( f"Time taken for Training step = {end_train_step - start_train_step} sec" ) return metrics
def initialize_model(model: nn.Module, cfg: dict, padding_idx: int) -> None: """ This initializes a model based on the provided config. All initializer configuration is part of the `model` section of the configuration file. For an example, see e.g. `https://github.com/joeynmt/joeynmt/ blob/master/configs/iwslt_envi_xnmt.yaml#L47` The main initializer is set using the `initializer` key. Possible values are `xavier`, `uniform`, `normal` or `zeros`. (`xavier` is the default). When an initializer is set to `uniform`, then `init_weight` sets the range for the values (-init_weight, init_weight). When an initializer is set to `normal`, then `init_weight` sets the standard deviation for the weights (with mean 0). The word embedding initializer is set using `embed_initializer` and takes the same values. The default is `normal` with `embed_init_weight = 0.01`. Biases are initialized separately using `bias_initializer`. The default is `zeros`, but you can use the same initializers as the main initializer. Set `init_rnn_orthogonal` to True if you want RNN orthogonal initialization (for recurrent matrices). Default is False. `lstm_forget_gate` controls how the LSTM forget gate is initialized. Default is `1`. :param model: model to initialize :param cfg: the model configuration :param padding_idx: """ # defaults: xavier, embeddings: normal 0.01, biases: zeros, no orthogonal gain = float(cfg.get("init_gain", 1.0)) # for xavier init = cfg.get("initializer", "xavier") init_weight = float(cfg.get("init_weight", 0.01)) embed_init = cfg.get("embed_initializer", "normal") embed_init_weight = float(cfg.get("embed_init_weight", 0.01)) embed_gain = float(cfg.get("embed_init_gain", 1.0)) # for xavier bias_init = cfg.get("bias_initializer", "zeros") bias_init_weight = float(cfg.get("bias_init_weight", 0.01)) init_fn_ = _parse_init(init, init_weight, gain) embed_init_fn_ = _parse_init(embed_init, embed_init_weight, embed_gain) bias_init_fn_ = _parse_init(bias_init, bias_init_weight, gain) with torch.no_grad(): for name, p in model.named_parameters(): if "embed" in name: embed_init_fn_(p) # zero out paddings; assumes all fields have same pad p.data[padding_idx].zero_() elif "bias" in name: if bias_init != "xavier": bias_init_fn_(p) elif len(p.size()) > 1: # RNNs combine multiple matrices is one, which messes up # xavier initialization if init == "xavier" and "rnn" in name: if "encoders" in name: rnn = next(iter(model.encoders.values())).rnn elif "encoder" in name: # matches "encoders" too... rnn = model.encoder.rnn elif "decoder" in name: rnn = model.decoder.rnn else: rnn = None if rnn is not None: n = 4 if isinstance(rnn, nn.LSTM) else 3 else: n = 1 # when would this come up? xavier_uniform_n_(p.data, gain=gain, n=n) else: init_fn_(p) orthogonal = cfg.get("init_rnn_orthogonal", False) lstm_forget_gate = cfg.get("lstm_forget_gate", 1.) # encoder rnn orthogonal initialization & LSTM forget gate if hasattr(model, "encoders"): encoders = list(model.encoders.values()) else: encoders = [model.encoder] for encoder in encoders: if hasattr(encoder, "rnn"): if orthogonal: orthogonal_rnn_init_(encoder.rnn) if isinstance(encoder.rnn, nn.LSTM): lstm_forget_gate_init_(encoder.rnn, lstm_forget_gate) # decoder rnn orthogonal initialization & LSTM forget gate if hasattr(model.decoder, "rnn"): if orthogonal: orthogonal_rnn_init_(model.decoder.rnn) if isinstance(model.decoder.rnn, nn.LSTM): lstm_forget_gate_init_(model.decoder.rnn, lstm_forget_gate)
def save_checkpoint(model: nn.Module, buffer: SampleBuffer, save_dir: Path, tag): checkpoint_dict = {"model_state": model.state_dict(), "sample_buffer": buffer} torch.save(checkpoint_dict, save_dir / tag)
def module_to_array( module: Module, bounds: Optional[ParameterBounds] = None, exclude: Optional[Set[str]] = None, ) -> Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]]: r"""Extract named parameters from a module into a numpy array. Only extracts parameters with requires_grad, since it is meant for optimizing. Args: module: A module with parameters. May specify parameter constraints in a `named_parameters_and_constraints` method. bounds: A ParameterBounds dictionary mapping parameter names to tuples of lower and upper bounds. Bounds specified here take precedence over bounds on the same parameters specified in the constraints registered with the module. exclude: A list of parameter names that are to be excluded from extraction. Returns: 3-element tuple containing - The parameter values as a numpy array. - An ordered dictionary with the name and tensor attributes of each parameter. - A `2 x n_params` numpy array with lower and upper bounds if at least one constraint is finite, and None otherwise. Example: >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) >>> parameter_array, property_dict, bounds_out = module_to_array(mll) """ x: List[np.ndarray] = [] lower: List[np.ndarray] = [] upper: List[np.ndarray] = [] property_dict = OrderedDict() exclude = set() if exclude is None else exclude # get bounds specified in model (if any) bounds_: ParameterBounds = {} if hasattr(module, "named_parameters_and_constraints"): for param_name, _, constraint in module.named_parameters_and_constraints(): if constraint is not None and not constraint.enforced: bounds_[param_name] = constraint.lower_bound, constraint.upper_bound # update with user-supplied bounds (overwrites if already exists) if bounds is not None: bounds_.update(bounds) for p_name, t in module.named_parameters(): if p_name not in exclude and t.requires_grad: property_dict[p_name] = TorchAttr( shape=t.shape, dtype=t.dtype, device=t.device ) x.append(t.detach().view(-1).cpu().double().clone().numpy()) # construct bounds if bounds_: l_, u_ = bounds_.get(p_name, (-inf, inf)) if torch.is_tensor(l_): l_ = l_.cpu().detach() if torch.is_tensor(u_): u_ = u_.cpu().detach() # check for Nones here b/c it may be passed in manually in bounds lower.append(np.full(t.nelement(), l_ if l_ is not None else -inf)) upper.append(np.full(t.nelement(), u_ if u_ is not None else inf)) x_out = np.concatenate(x) bounds_out = None if bounds_: if not all(np.isinf(b).all() for lu in (lower, upper) for b in lu): bounds_out = np.stack([np.concatenate(lower), np.concatenate(upper)]) return x_out, property_dict, bounds_out
def _training_step(self, model: nn.Module, inputs: Dict[str, torch.Tensor]) -> float: loss = model.train_step(**inputs) return loss
def __getattr__(self, attr): if self._has_method(attr): return self._get_method(attr) return Module.__getattr__(self, attr)
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None: for source_param, target_param in zip(source.parameters(), target.parameters()): target_param.data.copy_( target_param.data * (1.0 - tau) + source_param.data * tau )
def initialize_model(model: nn.Module, cfg: dict, txt_padding_idx: int) -> None: """ This initializes a model based on the provided config. All initializer configuration is part of the `model` section of the configuration file. For an example, see e.g. `https://github.com/joeynmt/joeynmt/ blob/master/configs/iwslt_envi_xnmt.yaml#L47` The main initializer is set using the `initializer` key. Possible values are `xavier`, `uniform`, `normal` or `zeros`. (`xavier` is the default). When an initializer is set to `uniform`, then `init_weight` sets the range for the values (-init_weight, init_weight). When an initializer is set to `normal`, then `init_weight` sets the standard deviation for the weights (with mean 0). The word embedding initializer is set using `embed_initializer` and takes the same values. The default is `normal` with `embed_init_weight = 0.01`. Biases are initialized separately using `bias_initializer`. The default is `zeros`, but you can use the same initializers as the main initializer. Set `init_rnn_orthogonal` to True if you want RNN orthogonal initialization (for recurrent matrices). Default is False. `lstm_forget_gate` controls how the LSTM forget gate is initialized. Default is `1`. :param model: model to initialize :param cfg: the model configuration :param txt_padding_idx: index of spoken language text padding token """ # defaults: xavier, embeddings: normal 0.01, biases: zeros, no orthogonal gain = float(cfg.get("init_gain", 1.0)) # for xavier init = cfg.get("initializer", "xavier") init_weight = float(cfg.get("init_weight", 0.01)) embed_init = cfg.get("embed_initializer", "normal") embed_init_weight = float(cfg.get("embed_init_weight", 0.01)) embed_gain = float(cfg.get("embed_init_gain", 1.0)) # for xavier bias_init = cfg.get("bias_initializer", "zeros") bias_init_weight = float(cfg.get("bias_init_weight", 0.01)) # pylint: disable=unnecessary-lambda, no-else-return def _parse_init(s, scale, _gain): scale = float(scale) assert scale > 0.0, "incorrect init_weight" if s.lower() == "xavier": return lambda p: nn.init.xavier_uniform_(p, gain=_gain) elif s.lower() == "uniform": return lambda p: nn.init.uniform_(p, a=-scale, b=scale) elif s.lower() == "normal": return lambda p: nn.init.normal_(p, mean=0.0, std=scale) elif s.lower() == "zeros": return lambda p: nn.init.zeros_(p) else: raise ValueError("unknown initializer") init_fn_ = _parse_init(init, init_weight, gain) embed_init_fn_ = _parse_init(embed_init, embed_init_weight, embed_gain) bias_init_fn_ = _parse_init(bias_init, bias_init_weight, gain) with torch.no_grad(): for name, p in model.named_parameters(): if "txt_embed" in name: if "lut" in name: embed_init_fn_(p) elif "bias" in name: bias_init_fn_(p) elif len(p.size()) > 1: # RNNs combine multiple matrices is one, which messes up # xavier initialization if init == "xavier" and "rnn" in name: n = 1 if "encoder" in name: n = 4 if isinstance(model.encoder.rnn, nn.LSTM) else 3 elif "decoder" in name: n = 4 if isinstance(model.decoder.rnn, nn.LSTM) else 3 xavier_uniform_n_(p.data, gain=gain, n=n) else: init_fn_(p) # zero out paddings if model.txt_embed is not None: model.txt_embed.lut.weight.data[txt_padding_idx].zero_() orthogonal = cfg.get("init_rnn_orthogonal", False) lstm_forget_gate = cfg.get("lstm_forget_gate", 1.0) # encoder rnn orthogonal initialization & LSTM forget gate if hasattr(model.encoder, "rnn"): if orthogonal: orthogonal_rnn_init_(model.encoder.rnn) if isinstance(model.encoder.rnn, nn.LSTM): lstm_forget_gate_init_(model.encoder.rnn, lstm_forget_gate) # decoder rnn orthogonal initialization & LSTM forget gate if hasattr(model.decoder, "rnn"): if orthogonal: orthogonal_rnn_init_(model.decoder.rnn) if isinstance(model.decoder.rnn, nn.LSTM): lstm_forget_gate_init_(model.decoder.rnn, lstm_forget_gate)
def __dir__(self): return sorted(Module.__dir__(self) + self._method_names())
def adjust_weight_to_zero(model: nn.Module, thresh): """If the value < 1e-6, it's set to 0""" for name, param in model.named_parameters(): if 'weight' in name: mask_value(param, thresh)
def train_epoch(self, model: nn.Module, train_loader: DataLoader, val_clean_loader: DataLoader, val_triggered_loader: DataLoader, epoch_num: int, progress_bar_disable: bool = False, use_amp: bool = False): """ Runs one epoch of training on the specified model :param model: the model to train for one epoch :param train_loader: a DataLoader object pointing to the training dataset :param val_clean_loader: a DataLoader object pointing to the validation dataset that is clean :param val_triggered_loader: a DataLoader object pointing to the validation dataset that is triggered :param epoch_num: the epoch number that is being trained :param progress_bar_disable: if True, disables the progress bar :return: a list of statistics for batches where statistics were computed """ pid = os.getpid() train_dataset_len = len(train_loader.dataset) loop = tqdm(train_loader, disable=progress_bar_disable) scaler = None if use_amp: scaler = torch.cuda.amp.GradScaler() train_n_correct, train_n_total = None, None sum_batchmean_train_loss = 0 running_train_acc = 0 num_batches = len(train_loader) model.train() for batch_idx, (x, y_truth) in enumerate(loop): x = x.to(self.device) y_truth = y_truth.to(self.device) # if use_amp: # x = x.half() # y_truth = y_truth.half() # put network into training mode & zero out previous gradient computations self.optimizer.zero_grad() # get predictions based on input & weights learned so far if use_amp: with torch.cuda.amp.autocast(): y_hat = model(x) # compute metrics batch_train_loss = self._eval_loss_function(y_hat, y_truth) else: y_hat = model(x) # compute metrics batch_train_loss = self._eval_loss_function(y_hat, y_truth) sum_batchmean_train_loss += batch_train_loss.item() running_train_acc, train_n_total, train_n_correct = _running_eval_acc( y_hat, y_truth, n_total=train_n_total, n_correct=train_n_correct, soft_to_hard_fn=self.soft_to_hard_fn, soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs) if np.isnan(sum_batchmean_train_loss) or np.isnan( running_train_acc): _save_nandata(x, y_hat, y_truth, batch_train_loss, sum_batchmean_train_loss, running_train_acc, train_n_total, train_n_correct, model) # compute gradient if use_amp: # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose for corresponding forward ops. scaler.scale(batch_train_loss).backward() else: batch_train_loss.backward() # perform gradient clipping if configured if self.optimizer_cfg.training_cfg.clip_grad: if use_amp: # Unscales the gradients of optimizer's assigned params in-place scaler.unscale_(self.optimizer) if self.optimizer_cfg.training_cfg.clip_type == 'norm': # clip_grad_norm_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_norm_( model.parameters(), self.optimizer_cfg.training_cfg.clip_val, **self.optimizer_cfg.training_cfg.clip_kwargs) elif self.optimizer_cfg.training_cfg.clip_type == 'val': # clip_grad_val_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_value_( model.parameters(), self.optimizer_cfg.training_cfg.clip_val) else: msg = "Unknown clipping type for gradient clipping!" logger.error(msg) raise ValueError(msg) if use_amp: # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(self.optimizer) # Updates the scale for next iteration. scaler.update() else: self.optimizer.step() loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs)) loop.set_postfix(avg_train_loss=batch_train_loss.item()) # report batch statistics to tensorflow if self.tb_writer: try: batch_num = int(epoch_num * num_batches + batch_idx) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss', batch_train_loss.item(), global_step=batch_num) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc', running_train_acc, global_step=batch_num) except: # TODO: catch specific expcetions pass if batch_idx % self.num_batches_per_logmsg == 0: logger.info( '{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}' .format(pid, epoch_num, batch_idx * len(x), train_dataset_len, 100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc)) train_stats = EpochTrainStatistics( running_train_acc, sum_batchmean_train_loss / float(num_batches)) # if we have validation data, we compute on the validation dataset num_val_batches_clean = len(val_clean_loader) if num_val_batches_clean > 0: logger.info('Running Validation on Clean Data') running_val_clean_acc, _, _, val_clean_loss = \ _eval_acc(val_clean_loader, model, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function) else: logger.info("No dataset computed for validation on clean dataset!") running_val_clean_acc = None val_clean_loss = None num_val_batches_triggered = len(val_triggered_loader) if num_val_batches_triggered > 0: logger.info('Running Validation on Triggered Data') running_val_triggered_acc, _, _, val_triggered_loss = \ _eval_acc(val_triggered_loader, model, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function) else: logger.info( "No dataset computed for validation on triggered dataset!") running_val_triggered_acc = None val_triggered_loss = None validation_stats = EpochValidationStatistics( running_val_clean_acc, val_clean_loss, running_val_triggered_acc, val_triggered_loss) if num_val_batches_clean > 0: logger.info( '{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}' .format(pid, epoch_num, val_clean_loss, running_val_clean_acc)) if num_val_batches_triggered > 0: logger.info( '{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}' .format(pid, epoch_num, val_triggered_loss, running_val_triggered_acc)) if self.tb_writer: try: batch_num = int((epoch_num + 1) * num_batches) if num_val_batches_clean > 0: self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val-loss', val_clean_loss, global_step=batch_num) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val_acc', running_val_clean_acc, global_step=batch_num) if num_val_batches_triggered > 0: self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val-loss', val_triggered_loss, global_step=batch_num) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val_acc', running_val_triggered_acc, global_step=batch_num) except: pass # update the lr-scheduler if necessary if self.lr_scheduler is not None: if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None: self.lr_scheduler.step() elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower( ) == 'val_acc': val_acc = validation_stats.get_val_acc() if val_acc is not None: self.lr_scheduler.step(val_acc) else: msg = "val_clean_acc not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower( ) == 'val_loss': val_loss = validation_stats.get_val_loss() if val_loss is not None: self.lr_scheduler.step(val_loss) else: msg = "val_clean_loss not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) else: msg = "Unknown mode for calling lr_scheduler!" logger.error(msg) raise ValueError(msg) return train_stats, validation_stats
def tie_encoder_to_decoder_recursively( decoder_pointer: nn.Module, encoder_pointer: nn.Module, module_name: str, uninitialized_encoder_weights: List[str], depth=0, ): assert isinstance(decoder_pointer, nn.Module) and isinstance( encoder_pointer, nn.Module ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" if hasattr(decoder_pointer, "weight"): assert hasattr(encoder_pointer, "weight") encoder_pointer.weight = decoder_pointer.weight if hasattr(decoder_pointer, "bias"): assert hasattr(encoder_pointer, "bias") encoder_pointer.bias = decoder_pointer.bias return encoder_modules = encoder_pointer._modules decoder_modules = decoder_pointer._modules # print("Encoder modules", " ".join([n for n in encoder_modules.keys()])) # print("Decoder modules", " ".join([n for n in decoder_modules.keys()])) if len(decoder_modules) > 0: # print("len(decoder_modules)", len(decoder_modules)) assert ( len(encoder_modules) > 0 ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" all_encoder_weights = set([ module_name + "/" + sub_name for sub_name in encoder_modules.keys() ]) encoder_layer_pos = 0 for name, module in decoder_modules.items(): if name.isdigit(): # print("name is digit", name) encoder_name = str(int(name) + encoder_layer_pos) decoder_name = name # print("encoder_name", encoder_name) # print("decoder_name", decoder_name) if not isinstance( decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( encoder_modules) != len(decoder_modules): # this can happen if the name corresponds to the position in a list module list of layers # in this case the decoder has added a cross-attention that the encoder does not have # thus skip this step and substract one layer pos from encoder encoder_layer_pos -= 1 continue elif name not in encoder_modules: continue elif depth > 500: raise ValueError( "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a " "circular dependency between two or more `nn.Modules` of your model. " ) else: # print("else") decoder_name = encoder_name = name # print("decoder_name = encoder_name = name") # print("decoder_name:", decoder_name) tie_encoder_to_decoder_recursively( decoder_modules[decoder_name], encoder_modules[encoder_name], module_name + "/" + name, uninitialized_encoder_weights, depth=depth + 1, ) all_encoder_weights.remove(module_name + "/" + encoder_name) uninitialized_encoder_weights += list(all_encoder_weights)