def __call__(self, module: torch.nn.Module) -> None: """ Applies an initializer to all parameters in a module that match one of the regexes we were given in this object's constructor. Does nothing to parameters that do not match. Parameters ---------- module : torch.nn.Module, required. The Pytorch module to apply the initializers to. """ logger.info("Initializing parameters") unused_regexes = set([initializer[0] for initializer in self._initializers]) uninitialized_parameters = set() # Store which initialisers were applied to which parameters. for name, parameter in module.named_parameters(): for initializer_regex, initializer in self._initializers: allow = self._prevent_regex is None or not bool(re.search(self._prevent_regex, name)) if allow and re.search(initializer_regex, name): logger.info("Initializing %s using %s intitializer", name, initializer_regex) initializer(parameter) unused_regexes.discard(initializer_regex) break else: # no break uninitialized_parameters.add(name) for regex in unused_regexes: logger.warning("Did not use initialization regex that was passed: %s", regex) logger.info("Done initializing parameters; the following parameters are using their " "default initialization from their code") uninitialized_parameter_list = list(uninitialized_parameters) uninitialized_parameter_list.sort() for name in uninitialized_parameter_list: logger.info(" %s", name)
def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List: frozen_parameter_names = [] tunable_parameter_names = [] for name, parameter in model.named_parameters(): if not parameter.requires_grad: frozen_parameter_names.append(name) else: tunable_parameter_names.append(name) return [frozen_parameter_names, tunable_parameter_names]
def __call__(self, module: torch.nn.Module) -> torch.Tensor: """ Parameters ---------- module : torch.nn.Module, required The module to regularize. """ accumulator = 0.0 # For each parameter find the first matching regex. for name, parameter in module.named_parameters(): for regex, regularizer in self._regularizers: if re.search(regex, name): penalty = regularizer(parameter) accumulator = accumulator + penalty break return accumulator
def train( dataset: torch.utils.data.Dataset, model: torch.nn.Module, epochs: int, batch_size: int, optimizer: torch.optim.Optimizer, stopping_delta: Optional[float] = None, collate_fn=default_collate, cuda: bool = True, sampler: Optional[torch.utils.data.sampler.Sampler] = None, silent: bool = False, update_freq: int = 10, evaluate_batch_size: int = 1024, update_callback: Optional[Callable[[float, float], None]] = None, epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None, ) -> None: """ Train the DEC model given a dataset, a model instance and various configuration parameters. :param dataset: instance of Dataset to use for training :param model: instance of DEC model to train :param epochs: number of training epochs :param batch_size: size of the batch to train with :param optimizer: instance of optimizer to use :param stopping_delta: label delta as a proportion to use for stopping, None to disable, default None :param collate_fn: function to merge a list of samples into mini-batch :param cuda: whether to use CUDA, defaults to True :param sampler: optional sampler to use in the DataLoader, defaults to None :param silent: set to True to prevent printing out summary statistics, defaults to False :param update_freq: frequency of batches with which to update counter, None disables, default 10 :param evaluate_batch_size: batch size for evaluation stage, default 1024 :param update_callback: optional function of accuracy and loss to update, default None :param epoch_callback: optional function of epoch and model, default None :return: None """ static_dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, pin_memory=False, sampler=sampler, shuffle=False, ) train_dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, sampler=sampler, shuffle=True, ) data_iterator = tqdm( static_dataloader, leave=True, unit="batch", postfix={ "epo": -1, "acc": "%.4f" % 0.0, "lss": "%.8f" % 0.0, "dlb": "%.4f" % -1, }, disable=silent, ) kmeans = KMeans(n_clusters=model.cluster_number, n_init=20) model.train() features = [] actual = [] # form initial cluster centres for index, batch in enumerate(data_iterator): if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2: batch, value = batch # if we have a prediction label, separate it to actual actual.append(value) if cuda: batch = batch.cuda(non_blocking=True) features.append(model.encoder(batch).detach().cpu()) actual = torch.cat(actual).long() predicted = kmeans.fit_predict(torch.cat(features).numpy()) predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long) _, accuracy = cluster_accuracy(predicted, actual.cpu().numpy()) cluster_centers = torch.tensor( kmeans.cluster_centers_, dtype=torch.float, requires_grad=True ) if cuda: cluster_centers = cluster_centers.cuda(non_blocking=True) with torch.no_grad(): # initialise the cluster centers model.state_dict()["assignment.cluster_centers"].copy_(cluster_centers) loss_function = nn.KLDivLoss(size_average=False) delta_label = None for epoch in range(epochs): features = [] data_iterator = tqdm( train_dataloader, leave=True, unit="batch", postfix={ "epo": epoch, "acc": "%.4f" % (accuracy or 0.0), "lss": "%.8f" % 0.0, "dlb": "%.4f" % (delta_label or 0.0), }, disable=silent, ) model.train() for index, batch in enumerate(data_iterator): if (isinstance(batch, tuple) or isinstance(batch, list)) and len( batch ) == 2: batch, _ = batch # if we have a prediction label, strip it away if cuda: batch = batch.cuda(non_blocking=True) output = model(batch) target = target_distribution(output).detach() loss = loss_function(output.log(), target) / output.shape[0] data_iterator.set_postfix( epo=epoch, acc="%.4f" % (accuracy or 0.0), lss="%.8f" % float(loss.item()), dlb="%.4f" % (delta_label or 0.0), ) optimizer.zero_grad() loss.backward() optimizer.step(closure=None) features.append(model.encoder(batch).detach().cpu()) if update_freq is not None and index % update_freq == 0: loss_value = float(loss.item()) data_iterator.set_postfix( epo=epoch, acc="%.4f" % (accuracy or 0.0), lss="%.8f" % loss_value, dlb="%.4f" % (delta_label or 0.0), ) if update_callback is not None: update_callback(accuracy, loss_value, delta_label) predicted, actual = predict( dataset, model, batch_size=evaluate_batch_size, collate_fn=collate_fn, silent=True, return_actual=True, cuda=cuda, ) delta_label = ( float((predicted != predicted_previous).float().sum().item()) / predicted_previous.shape[0] ) if stopping_delta is not None and delta_label < stopping_delta: print( 'Early stopping as label delta "%1.5f" less than "%1.5f".' % (delta_label, stopping_delta) ) break predicted_previous = predicted _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy()) data_iterator.set_postfix( epo=epoch, acc="%.4f" % (accuracy or 0.0), lss="%.8f" % 0.0, dlb="%.4f" % (delta_label or 0.0), ) if epoch_callback is not None: epoch_callback(epoch, model)
def predict( input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_masks: torch.Tensor, model: torch.nn.Module, batch_size: int, device_: Optional[str] = None, # if None, it automatically detects if a GPU is available, if not uses a CPU disable_progress_bar: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Given a trained model and unseen data, performs the predictions and returns the results. Unlike in the fine-tuning and training stages, during prediction there's no need to build a dataloader which splits the set into train and validation, and randomly shuffles the training samples. We can just pass the items directly one by one. As we're not training, there are no training epochs either. :param input_ids: torch.tensor of shape (N, max_len) representing the ids of each token of the N encoded sequence pairs, with padding at the end up to max_len. If decoded, the input_ids will consist of a "[CLS]" token, followed by the question's tokens, followed by a "[SEP]" token, followed by the context's tokens, followed by a "[SEP]" token, followed by "[PAD]" tokens, if relevant, up to max_len. :param token_type_ids: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for token positions in the context text, 0 elsewhere (i.e. in question and padding) :param attention_masks: torch.tensor of shape (N, max_len) where each Nth dimension is filled with 1 for non-"[PAD]" tokens, 0 for "[PAD]" tokens. :param model: the model to use (must be instance of torch.nn.Module). As we're performing predictions, this must be a trained model. :param batch_size: the batch size to use for predictions. Batching samples speeds up processing. :param device_: if specified, the device used for the computations. Can be one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu. If set to None, it will default to GPU (cuda) if one is available, else it will use a CPU. Default: None :param disable_progress_bar: bool; whether to disable the tqdm progress bar. When used in production for quickly returning answers to a single or small set of questions, the bar might be distracting. Default: False. :return: pred_start: torch.tensor of shape (N) with the predicted indices of the first answer token for each answer pred_end: torch.tensor of shape (N) with the predicted indices of the last answer token for each answer """ assert input_ids.shape == token_type_ids.shape == attention_masks.shape, "Some input shapes are wrong" device = set_hardware_acceleration(default=device_) model = model.to(device) model.eval() pred_start = torch.tensor([], dtype=torch.long, device=device) # initialising tensors for storing results pred_end = torch.tensor([], dtype=torch.long, device=device) t_i = time() # batch the samples to speed up processing. We do batching manually here to avoid using DataLoader for batch_i in tqdm(range(0, len(input_ids), batch_size), disable=disable_progress_bar): batch_input_ids = ## create the batch manually batch_token_type_ids = ## create the batch manually batch_attention_masks = ## create the batch manually with torch.no_grad(): ## implement the forward pass. Look up the documentation, as this is slightly different compared to the training loop pred_start_positions = torch.argmax(start_logits, dim=1) pred_end_positions = torch.argmax(end_logits, dim=1) pred_start = torch.cat((pred_start, pred_start_positions)) pred_end = torch.cat((pred_end, pred_end_positions)) if torch.cuda.is_available(): logger.debug("GPU memory usage: \n", gpu_memory_usage()) logger.info(f"All predictions calculated in {format_time(time() - t_i)}.") if torch.cuda.is_available(): logger.info("GPU memory usage: \n", gpu_memory_usage()) return pred_start, pred_end
def get_default_optimizer_params( model: torch.nn.Module, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None, ) -> List[Dict[str, Any]]: """ Get default param list for optimizer, with support for a few types of overrides. If no overrides needed, this is equivalent to `model.parameters()`. Args: base_lr: lr for every group by default. Can be omitted to use the one in optimizer. weight_decay: weight decay for every group by default. Can be omitted to use the one in optimizer. weight_decay_norm: override weight decay for params in normalization layers bias_lr_factor: multiplier of lr for bias parameters. weight_decay_bias: override weight decay for bias parameters overrides: if not `None`, provides values for optimizer hyperparameters (LR, weight decay) for module parameters with a given name; e.g. ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and weight decay values for all module parameters named `embedding`. For common detection models, ``weight_decay_norm`` is the only option needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings from Detectron1 that are not found useful. Example: :: torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), lr=0.01, weight_decay=1e-4, momentum=0.9) """ if overrides is None: overrides = {} defaults = {} if base_lr is not None: defaults["lr"] = base_lr if weight_decay is not None: defaults["weight_decay"] = weight_decay bias_overrides = {} if bias_lr_factor is not None and bias_lr_factor != 1.0: # NOTE: unlike Detectron v1, we now by default make bias hyperparameters # exactly the same as regular weights. if base_lr is None: raise ValueError("bias_lr_factor requires base_lr") bias_overrides["lr"] = base_lr * bias_lr_factor if weight_decay_bias is not None: bias_overrides["weight_decay"] = weight_decay_bias if len(bias_overrides): if "bias" in overrides: raise ValueError("Conflicting overrides for 'bias'") overrides["bias"] = bias_overrides norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) params: List[Dict[str, Any]] = [] memo: Set[torch.nn.parameter.Parameter] = set() for module in model.modules(): for module_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) hyperparams = copy.copy(defaults) if isinstance(module, norm_module_types) and weight_decay_norm is not None: hyperparams["weight_decay"] = weight_decay_norm hyperparams.update(overrides.get(module_param_name, {})) params.append({"params": [value], **hyperparams}) return reduce_param_groups(params)
def plot_attention( cls, model: torch.nn.Module, output_dir: Optional[Path], summary_writer: Optional[SummaryWriter], iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], reporter: SubReporter, options: TrainerOptions, ) -> None: assert check_argument_types() import matplotlib ngpu = options.ngpu no_forward_run = options.no_forward_run matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator model.eval() for ids, batch in iterator: assert isinstance(batch, dict), type(batch) assert len(next(iter(batch.values()))) == len(ids), ( len(next(iter(batch.values()))), len(ids), ) batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: continue # 1. Forwarding model and gathering all attentions # calculate_all_attentions() uses single gpu only. att_dict = calculate_all_attentions(model, batch) # 2. Plot attentions: This part is slow due to matplotlib for k, att_list in att_dict.items(): assert len(att_list) == len(ids), (len(att_list), len(ids)) for id_, att_w in zip(ids, att_list): if isinstance(att_w, torch.Tensor): att_w = att_w.detach().cpu().numpy() if att_w.ndim == 2: att_w = att_w[None] elif att_w.ndim > 3 or att_w.ndim == 1: raise RuntimeError( f"Must be 2 or 3 dimension: {att_w.ndim}") w, h = plt.figaspect(1.0 / len(att_w)) fig = plt.Figure(figsize=(w * 1.3, h * 1.3)) axes = fig.subplots(1, len(att_w)) if len(att_w) == 1: axes = [axes] for ax, aw in zip(axes, att_w): ax.imshow(aw.astype(np.float32), aspect="auto") ax.set_title(f"{k}_{id_}") ax.set_xlabel("Input") ax.set_ylabel("Output") ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) if output_dir is not None: p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png" p.parent.mkdir(parents=True, exist_ok=True) fig.savefig(p) if summary_writer is not None: summary_writer.add_figure(f"{k}_{id_}", fig, reporter.get_epoch()) # Dummy register() stimulates to increment the counter reporter.register({})
def my_correct_bias(model: torch.nn.Module, quant_params, num_quant_samples: int, data_loader, num_bias_correct_samples: int, conv_bn_dict=None, perform_only_empirical_bias_corr: bool = True, layers_to_ignore=None, quantizer_modifications=None): if layers_to_ignore is None: layers_to_ignore = [] # Find batch size and shape of input tensor batch_size, input_shape = aimet_torch.utils.get_input_shape_batch_size( data_loader) # Rounding up number of samples to batch size n_batches_bias_correction = int( np.ceil(num_bias_correct_samples / batch_size)) n_batches_quantization = int(np.ceil(num_quant_samples / batch_size)) data_loader_n_samples_bias_corr = aimet_torch.utils.IterFirstX( data_loader, n_batches_bias_correction) data_loader_n_samples_quant = aimet_torch.utils.IterFirstX( data_loader, n_batches_quantization) # TODO: Remove wrapper function # Create a wrapping function for data loader for quantization def pass_data_through_model(model, early_stopping_iterations=None, use_cuda=False): # pylint: disable=unused-argument # forward pass for given number of batches for model for (images_in_one_batch, _) in data_loader_n_samples_quant: aimet_torch.bias_correction.forward_pass(model, images_in_one_batch) ordered_conv_linear_nodes = aimet_torch.utils.get_ordered_lists_of_conv_fc( model, input_shape) if conv_bn_dict is None: conv_bn_dict = aimet_torch.bias_correction.find_all_conv_bn_with_activation( model, input_shape) # Create a copy of the model as reference model model_copy = copy.deepcopy(model) # Add bias for all the layers whose bias is None for name, module in ordered_conv_linear_nodes: if module.bias is None: if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): output_size = module.out_channels elif isinstance(module, torch.nn.Linear): output_size = module.out_features module.bias = torch.nn.Parameter(torch.zeros(output_size)) module.bias.data = module.bias.data.to(device=module.weight.device) # Quantize full model dummy_tensors = aimet_torch.utils.create_rand_tensors_given_shapes( input_shape) dummy_tensors = [ tensor.to(aimet_torch.utils.get_device(model)) for tensor in dummy_tensors ] q = aimet_torch.quantsim.QuantizationSimModel( model=model, quant_scheme=quant_params.quant_scheme, rounding_mode=quant_params.round_mode, default_output_bw=quant_params.act_bw, default_param_bw=quant_params.weight_bw, in_place=True, dummy_input=dummy_tensors, config_file=quant_params.config_file) # make sure model got updated in-place before we use it for bc updates assert (q.model is model) if quantizer_modifications is not None: quantizer_modifications(q) # updates to skip_output_activation and layers_to_ignore for name, module in model.named_modules(): # Skip all layer's output quantization if isinstance(module, QcQuantizeWrapper): module.output_quantizers[0].enabled = False q.compute_encodings(pass_data_through_model, None) # For first conv layer, perform analytical bc if perform_only_empirical_bias_corr is set to False # and layer is not marked to be ignored during bc. if not perform_only_empirical_bias_corr: module_name, module = ordered_conv_linear_nodes[0] if module not in layers_to_ignore: aimet_torch.bias_correction.logger.info( 'Correcting layer %s using Analytical Bias Correction', module_name) quantize_layer = aimet_torch.utils.get_layer_by_name( model, module_name) aimet_torch.bias_correction.call_analytical_mo_correct_bias( quantize_layer, None, None) aimet_torch.bias_correction.logger.info( 'Corrected bias for the layer') ordered_conv_linear_nodes.pop(0) for module_name, module in ordered_conv_linear_nodes: # Ignore all layers which are skipped by user if module in layers_to_ignore or module_name in layers_to_ignore: continue else: # make sure module is in the model used by qsim. assert (module in list(q.model.modules())) # Analytical Bias Correction is only done for Conv layers reference_layer = aimet_torch.utils.get_layer_by_name( model_copy, module_name) quantize_layer = aimet_torch.utils.get_layer_by_name( model, module_name) if module in conv_bn_dict.keys(): bn_layer_info = conv_bn_dict[module] if perform_only_empirical_bias_corr or bn_layer_info is None or bn_layer_info.input_bn is None: aimet_torch.bias_correction.logger.info( 'Correcting layer %s using Empirical Bias Correction', module_name) bias_correction = libpymo.BiasCorrection() # Get output from quantized model and reference model for images_in_one_batch, _ in data_loader_n_samples_bias_corr: reference_output_batch = aimet_torch.bias_correction.get_output_data( reference_layer, model_copy, images_in_one_batch) quantized_model_output_batch = aimet_torch.bias_correction.get_output_data( quantize_layer, model, images_in_one_batch) if isinstance(reference_layer, torch.nn.Linear): extended_shape = np.concatenate( (reference_output_batch.shape, np.array([1, 1]))) reference_output_batch = reference_output_batch.reshape( extended_shape) quantized_model_output_batch = quantized_model_output_batch.reshape( extended_shape) bias_correction.storePreActivationOutput( reference_output_batch) bias_correction.storeQuantizedPreActivationOutput( quantized_model_output_batch) aimet_torch.bias_correction.call_empirical_mo_correct_bias( module, bias_correction) else: aimet_torch.bias_correction.logger.info( 'Correcting layer %s using Analytical Bias Correction', module_name) aimet_torch.bias_correction.call_analytical_mo_correct_bias( quantize_layer, bn_layer_info.input_bn, bn_layer_info.in_activation_type) aimet_torch.bias_correction.logger.info( 'Corrected bias for the layer') aimet_torch.save_utils.SaveUtils.remove_quantization_wrappers(model) aimet_torch.bias_correction.logger.info('Completed bias correction')
def _prepare_model(model: torch.nn.Module) -> torch.nn.Module: # Remove last layer layers_without_fc = list(model.children())[:-1] return torch.nn.Sequential(*layers_without_fc)
def _find_module(root: torch.nn.Module, m: torch.nn.Module): for n, p in root.named_modules(): if m is p: return n raise NameError('module is not installed as a submodule')
def add_auto_convert(module: torch.nn.Module) -> torch.nn.Module: def convert_to_dispatch_proxy(x): if isinstance(x, torch.Tensor): return x.as_subclass( QuantizationConvertTensorProxy) # type: ignore[arg-type] else: return x module_id_to_fqn: Dict[int, str] = {} # Counter for global quantizeable ops, useful for intermediate activation # logging. global_op_idx = [0] global_disable_torch_function_override = False class QuantizationConvertTensorProxy(torch.Tensor): """ An override of `torch.Tensor` to enable dynamic dispatch for quantization inference. For each function with a `__torch_fuction__` override, this proxy does the following for functions which need quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls `_auto_quant_state.op_convert_before_hook`. 3. executes the function, with target, args and kwargs possibly modified by (2) 4. calls `_auto_quant_state.inference_function_after_hook`. 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op Otherwise, calls the original function. """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): nonlocal global_disable_torch_function_override if ( # global override means disable the override here global_disable_torch_function_override or # to prevent printing things from going into an infinite loop func == torch.Tensor.__repr__ or # we don't need to override getters in this framework func.__name__ == '__get__'): return super().__torch_function__(func, types, args, kwargs) kwargs = kwargs if kwargs else {} # if we are in a function, the current module is always a parent parent_module = cur_module hook_type = get_torch_function_hook_type(parent_module, func) if enable_logging: fqn_for_logging = module_id_to_fqn.get( id(parent_module), 'unknown') if parent_module else None logger.debug( f" fqn:{fqn_for_logging} _tf_ {func} " + f"hook_type {hook_type} " + # f"arg_types {[type(arg) for arg in args]}) " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]}" ) if hook_type is HookType.OP_HOOKS: qstate: AutoQuantizationState = parent_module._auto_quant_state # type: ignore[union-attr] # before hooks qstate.validate_cur_op(func) func, args, kwargs = qstate.op_convert_before_hook( func, args, kwargs, parent_module) # type: ignore[arg-type] # forward output = super().__torch_function__(func, types, args, kwargs) # after hooks output = qstate.op_convert_after_hook(func, output, global_op_idx) qstate.mark_cur_op_complete(func) elif hook_type is HookType.ARG_DEQUANTS: # TODO(future PR): handle more dtypes new_args = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.is_quantized: new_args.append(arg.dequantize()) else: new_args.append(arg) args = tuple(new_args) output = super().__torch_function__(func, types, args, kwargs) else: # HookType.NONE output = super().__torch_function__(func, types, args, kwargs) # TODO: is this right? Don't really understand this if output is NotImplemented: with torch._C.DisableTorchFunction(): output = func( *args, **kwargs).as_subclass(QuantizationConvertTensorProxy) assert output is not NotImplemented if enable_logging: fqn_for_logging = module_id_to_fqn.get( id(parent_module), 'unknown') if parent_module else None out_dtype = None if isinstance(output, torch.Tensor): out_dtype = output.dtype logger.debug( f" fqn:{fqn_for_logging} _tf_ {func} out {out_dtype} end") return output def __repr__(self): return f'QuantizationConvertTensorProxy({super().__repr__()})' cur_module = None module_stack: List[torch.nn.Module] = [] assert len(module.__class__.__bases__) == 1 class QuantizationDispatchModule(module.__class__.__bases__[0] ): # type: ignore[name-defined] """ An override of user defined subclass of `nn.Module` to enable dynamic tracing for quantization, after model conversion to quantized domain. `cur_module` keeps track of the current module in the stack. Tensor arguments are converted to `QuantizationConvertTensorProxy`. We override the `__call__` function to do the following for each module: If the module is an op which needs quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls parent module's `._auto_quant_state.op_convert_before_hook` 3. executes the original module forward 4. calls parent module's `_auto_quant_state.op_convert_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op If the module can contain children ops that need quantization: 1. calls `_auto_quant_state.inputs_convert_hook` (not implemented yet) 2. executes the original module forward 3. calls `_auto_quant_state.outputs_convert_hook` Otherwise, calls the original module forward. """ def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_dispatch_proxy) new_kwargs = map_aggregate(kwargs, convert_to_dispatch_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): nonlocal cur_module old_module = cur_module cur_module = self nonlocal global_disable_torch_function_override try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) hook_type = get_module_hook_type(parent_module, cur_module) if enable_logging: fqn_for_logging = module_id_to_fqn.get(id(self), None) logger.debug( f" fqn: {fqn_for_logging} " + f"_cl_ {type(self)} " + f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} " + f"hook_type {hook_type}") if hook_type is HookType.OP_HOOKS: # before hooks qstate: AutoQuantizationState = \ parent_module._auto_quant_state # type: ignore[union-attr, assignment] qstate.validate_cur_op(cur_module) # If we are in this hook, `cur_module` is a leaf module. # Therefore, we do not need to override any of its # children. Disabling the overrides for performance. old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True _, args, kwargs = qstate.op_convert_before_hook( cur_module, args, kwargs, cur_module) # forward output = orig_module_call(self, *args, **kwargs) # after hooks output = qstate.op_convert_after_hook( cur_module, output, global_op_idx) # Re-enable the override. global_disable_torch_function_override = \ old_global_disable_torch_function_override qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: cur_qstate: AutoQuantizationState = cur_module._auto_quant_state cur_qstate.reset_to_new_call() # before hooks (TODO) # forward output = orig_module_call(self, *args, **kwargs) # after hooks # For the sake of performance, we assume no overrides # are needed for quantizing/dequantizing things old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True output = cur_qstate.outputs_convert_hook(output) global_disable_torch_function_override = \ old_global_disable_torch_function_override cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: # TODO(future PR): handle more dtypes new_args = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.is_quantized: dequant = arg.dequantize().as_subclass( QuantizationConvertTensorProxy ) # type: ignore[arg-type] new_args.append(dequant) else: new_args.append(arg) args = tuple(new_args) output = orig_module_call(self, *args, **kwargs) else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn_for_logging = module_id_to_fqn.get(id(self), None) logger.debug( f" fqn: {fqn_for_logging} " + f"_cl_ {type(self)} " + f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} " + "end") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] try: global_op_idx[0] = 0 output = super().__call__(*new_args, **new_kwargs) def unwrap_proxy(a): if isinstance(a, QuantizationConvertTensorProxy): a.__class__ = torch.Tensor # type: ignore[assignment] return a output = map_aggregate(output, unwrap_proxy) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] def rewrite_for_scripting(self): return auto_trace_rewriter.rewrite_for_scripting(self) pack_weights_for_functionals(module) attach_scale_zp_values_to_model(module) attach_op_convert_info_to_model(module) attach_output_convert_info_to_model(module) # Since eager mode convert could have changed the IDs of some modules, # populate the FQN map again for k, v in module.named_modules(): module_id_to_fqn[id(v)] = k module.__class__ = QuantizationDispatchModule return module
def add_auto_observation( model: torch.nn.Module, qconfig_dict: Dict[str, Any], example_inputs: Tuple[Any], input_dtypes: Any = ( torch.float, ), # must be same structure as model inputs output_dtypes: Any = ( torch.float, ), # must be same structure as model outputs prepare_custom_config_dict: Dict[str, Any] = None, ) -> torch.nn.Module: def convert_to_interception_proxy(x): if isinstance(x, torch.Tensor): return x.as_subclass( QuantizationPrepareTensorProxy) # type: ignore[arg-type] else: return x cur_module = None first_call = True module_stack: List[torch.nn.Module] = [] # Counter for tensor IDs, will be modified inplace by quant state. # This is used to track tensors from output ops to input ops. For example, # if op_n had a tensor output with id=1, and op_n+2 had a tensor input with # id=1, we know that the output of op_n is the input to op_n+2. Note, # this is a list because it needs to incremented inplace. qtensor_id = [0] module_id_to_fqn: Dict[int, str] = {} # Counter for global quantizeable ops, useful for intermediate activation # logging. global_op_idx = [0] global_disable_torch_function_override = False class QuantizationPrepareTensorProxy(torch.Tensor): """ An override of `torch.Tensor` to enable dynamic tracing for quantization. For each function with a `__torch_function__` override, this proxy does the following for functions which need quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls `_auto_quant_state.op_prepare_before_hook` 3. executes the original function 4. calls `_auto_quant_state.op_prepare_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op Otherwise, calls the original function. """ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): nonlocal global_disable_torch_function_override if ( # global override means disable the override here global_disable_torch_function_override or # to prevent printing things from going into an infinite loop func == torch.Tensor.__repr__ or # we don't need to override getters in this framework func.__name__ == '__get__'): return super().__torch_function__(func, types, args, kwargs) # if we are in a function, the current module is always a parent nonlocal cur_module parent_module = cur_module if enable_logging: if not is_activation_post_process(parent_module): # logging for insides of obs/fq is not useful for this framework # fqn map does not contain observers, which is why we # cannot always assume that FQN exists fqn_for_logging = module_id_to_fqn.get( id(parent_module), 'unknown') if parent_module else None logger.debug( f' fqn:{fqn_for_logging} _tf_ {str(func)} len_args {len(args)}' ) nonlocal qtensor_id kwargs = kwargs if kwargs else {} hook_type = get_torch_function_hook_type(parent_module, func) if first_call and hook_type is not HookType.OP_HOOKS: qstate = getattr(parent_module, '_auto_quant_state', None) if qstate: qstate.add_seen_op_type_without_op_hooks(func) if hook_type is HookType.OP_HOOKS: qstate = parent_module._auto_quant_state # type: ignore[attr-defined] fqn = module_id_to_fqn[id( parent_module)] if parent_module else None if not first_call: qstate.validate_cur_op(func) # run "before" hook args, kwargs = qstate.op_prepare_before_hook( func, args, kwargs, first_call, qtensor_id, fqn, parent_module) # forward output = super().__torch_function__(func, types, args, kwargs) # run "after" hook output = qstate.op_prepare_after_hook(func, output, args, first_call, qtensor_id, parent_module, global_op_idx) qstate.mark_cur_op_complete(func) else: output = super().__torch_function__(func, types, args, kwargs) # TODO: is this right? Don't really understand this if output is NotImplemented: with torch._C.DisableTorchFunction(): output = func( *args, **kwargs).as_subclass(QuantizationPrepareTensorProxy) assert output is not NotImplemented return output def __repr__(self): return f'QuantizationPrepareTensorProxy({super().__repr__()})' # TODO(future PR): add other math overrides class QuantizationInterceptionModule(type(model)): # type: ignore[misc] """ An override of user defined subclass of `nn.Module` to enable dynamic tracing for quantization. `cur_module` keeps track of the current module in the stack. During the fist call, an `AutoQuantizationState` object is created and attached to each non-leaf modules which we need to check for quantizeable operations. We override the `__call__` function to do the following for each module: If the module is an op which needs quantization: 1. calls `_auto_quant_state.validate_cur_op` to validate that the currently seen op is the same as what was recorded during tracing 2. calls parent module's `._auto_quant_state.op_prepare_before_hook` 3. executes the original module forward 4. calls parent module's `_auto_quant_state.op_prepare_after_hook` 5. calls `_auto_quant_state.mark_cur_op_complete` to increment the current op index in preparation for the next op If the module can contain children ops that need quantization: 1. calls `_auto_quant_state.inputs_prepare_hook` (not implemented yet) 2. executes the original module forward 3. calls `_auto_quant_state.outputs_prepare_hook` Otherwise, calls the original module forward. """ def __call__(self, *args, **kwargs): new_args = map_aggregate(args, convert_to_interception_proxy) new_kwargs = map_aggregate(kwargs, convert_to_interception_proxy) orig_module_call = torch.nn.Module.__call__ orig_nn_sequential_forward = torch.nn.Sequential.forward def _patched_module_call(self, *args, **kwargs): if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} start") nonlocal cur_module old_module = cur_module cur_module = self try: parent_module = module_stack[-1] if len( module_stack) else None module_stack.append(self) fqn = module_id_to_fqn.get(id(self), None) hook_type = get_module_hook_type(parent_module, cur_module) if first_call and hook_type is not HookType.OP_HOOKS and \ parent_module is not None: parent_qstate_fc = getattr(parent_module, '_auto_quant_state', None) if parent_qstate_fc: parent_qstate_fc.add_seen_op_type_without_op_hooks( type(cur_module)) if hook_type is HookType.OP_HOOKS: parent_qstate: AutoQuantizationState = \ parent_module._auto_quant_state # type: ignore[union-attr, assignment] # before hooks if not first_call: parent_qstate.validate_cur_op(cur_module) # If we are in this hook, `cur_module` is a leaf module. # Therefore, we do not need to override any of its # children. Disabling the overrides for performance. nonlocal global_disable_torch_function_override old_global_disable_torch_function_override = \ global_disable_torch_function_override global_disable_torch_function_override = True # mypy ignore is used instead of assert because this # runs on every forward and assert has a performance cost args, kwargs = parent_qstate.op_prepare_before_hook( cur_module, args, kwargs, first_call, qtensor_id, fqn, cur_module) # type: ignore[arg-type] # original forward output = orig_module_call(self, *args, **kwargs) # Re-enable the overrides. global_disable_torch_function_override = \ old_global_disable_torch_function_override # after hooks # TODO is it correct to call_cur_module twice here? output = parent_qstate.op_prepare_after_hook( cur_module, output, args, first_call, qtensor_id, cur_module, global_op_idx) parent_qstate.mark_cur_op_complete(cur_module) elif hook_type is HookType.MODULE_IO_HOOKS: # TODO(future PR): add inputs io hook cur_qstate = cur_module._auto_quant_state cur_qstate.reset_to_new_call() # original forward output = orig_module_call(self, *args, **kwargs) # after hooks output = cur_qstate.outputs_prepare_hook( output, first_call, qtensor_id) cur_qstate.validate_is_at_last_seen_idx() elif hook_type is HookType.ARG_DEQUANTS: output = orig_module_call(self, *args, **kwargs) # if this fp32 was inplace, make sure to set the output dtype # back to torch.float if hasattr(output, '_qtensor_info'): del output._qtensor_info else: output = orig_module_call(self, *args, **kwargs) if enable_logging: fqn = module_id_to_fqn.get(id(self), None) logger.debug(f" fqn:{fqn} _cl_: {type(self)} end") return output finally: module_stack.pop() cur_module = old_module torch.nn.Module.__call__ = _patched_module_call torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment] nonlocal first_call try: if first_call: # Create a list before iterating because we are adding new # named modules inside the loop. named_modules = list(self.named_modules()) # Record module instances which are leaves or children of leaves leaves = set() for fqn, child in named_modules: if is_leaf(child, prepare_custom_config_dict): for _, child_child in child.named_modules(): leaves.add(child_child) for fqn, v in named_modules: # fqn is the global FQN, i.e. 'foo.bar.baz' # v is the module instance # # we need to associate the global FQN with SeenOp # for modules, this is the module FQN # for functions, this is the parent module FQN module_id_to_fqn[id(v)] = fqn if v in leaves: continue if v is self: # for the top level module only, specify input # and output dtypes v._auto_quant_state = AutoQuantizationState( qconfig_dict, fqn, input_dtypes, output_dtypes) pass else: v._auto_quant_state = AutoQuantizationState( qconfig_dict, fqn) global_op_idx[0] = 0 output = super().__call__(*new_args, **new_kwargs) return output finally: torch.nn.Module.__call__ = orig_module_call torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment] first_call = False model.__class__ = QuantizationInterceptionModule # create the graph trace_with_inputs(model, example_inputs) return model
def to_device(module: t.nn.Module, use_cuda: bool = True): return module.cuda() if use_cuda and t.cuda.is_available() else module
def plot_grad_flow( model: torch.nn.Module, lines: bool = True, alpha: float = 0.5, line_width: float = 1.0, ) -> None: """ Plots the gradients flowing through different layers in the net during training. Can be used for checking for possible gradient vanishing / exploding problems. Usage: After loss.backwards(), use plot_grad_flow(model) to visualize the gradient flow of model :param model: :type model: :param lines: :type lines: :param alpha: :type alpha: :param line_width: :type line_width:""" assert 0.0 < alpha <= 1.0 ave_grads = [] max_grads = [] layers = [] for n, p in model.named_parameters(): if p.requires_grad and ("bias" not in n): layers.append(n) grad_abs = p.grad.abs() ave_grads.append(grad_abs.mean()) max_grads.append(grad_abs.max()) if lines: pyplot.plot(max_grads, alpha=alpha, linewidth=line_width, color="r") pyplot.plot(ave_grads, alpha=alpha, linewidth=line_width, color="g") else: pyplot.bar( numpy.arange(len(max_grads)), max_grads, alpha=alpha, linewidth=line_width, color="r", ) pyplot.bar( numpy.arange(len(max_grads)), ave_grads, alpha=alpha, linewidth=line_width, color="g", ) pyplot.hlines(0, 0, len(ave_grads) + 1, linewidth=1, color="k") pyplot.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") pyplot.xlim(left=0, right=len(ave_grads)) max_g = max(max_grads) margin = max_g * 1.1 pyplot.ylim( bottom=max_g - margin, top=margin ) # zoom in on the lower gradient regions pyplot.xlabel("Layers") pyplot.ylabel("Gradient Magnitude") pyplot.title("Gradient Flow") pyplot.grid(True) pyplot.legend( [ Line2D([0], [0], color="c", lw=4), Line2D([0], [0], color="b", lw=4), Line2D([0], [0], color="k", lw=4), ], ["max-gradient", "mean-gradient", "zero-gradient"], )
def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, message: str = "") -> None: for p_a, p_b in zip(model_a.parameters(), model_b.parameters()): assert torch.allclose(p_a, p_b, atol=1e-3), f"Model parameters differ\n{p_a} {p_b}\n" + message for b_a, b_b in zip(model_a.buffers(), model_b.buffers()): assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
def generate_param_groups( network: torch.nn.Module, layer_matches: Sequence[Callable], match_types: Sequence[str], lr_values: Sequence[float], include_others: bool = True, ): """ Utility function to generate parameter groups with different LR values for optimizer. The output parameter groups have the same order as `layer_match` functions. Args: network: source network to generate parameter groups from. layer_matches: a list of callable functions to select or filter out network layer groups, for "select" type, the input will be the `network`, for "filter" type, the input will be every item of `network.named_parameters()`. for "select", the parameters will be `select_func(network).parameters()`. for "filter", the parameters will be `map(lambda x: x[1], filter(filter_func, network.named_parameters()))` match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions, can be "select" or "filter". lr_values: a list of LR values corresponding to the `layer_matches` functions. include_others: whether to include the rest layers as the last group, default to True. It's mainly used to set different LR values for different network elements, for example: .. code-block:: python net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1]) print(net) # print out network components to select expected items print(net.named_parameters()) # print out all the named parameters to filter out expected items params = generate_param_groups( network=net, layer_matches=[lambda x: x.model[0], lambda x: "2.0.conv" in x[0]], match_types=["select", "filter"], lr_values=[1e-2, 1e-3], ) # the groups will be a list of dictionaries: # [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01}, # {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001}, # {'params': <filter object at 0x7f9088fd0da0>}] optimizer = torch.optim.Adam(params, 1e-4) """ layer_matches = ensure_tuple(layer_matches) match_types = ensure_tuple_rep(match_types, len(layer_matches)) lr_values = ensure_tuple_rep(lr_values, len(layer_matches)) def _get_select(f): def _select(): return f(network).parameters() return _select def _get_filter(f): def _filter(): # should eventually generate a list of network parameters return map(lambda x: x[1], filter(f, network.named_parameters())) return _filter params = [] _layers = [] for func, ty, lr in zip(layer_matches, match_types, lr_values): if ty.lower() == "select": layer_params = _get_select(func) elif ty.lower() == "filter": layer_params = _get_filter(func) else: raise ValueError(f"unsupported layer match type: {ty}.") params.append({"params": layer_params(), "lr": lr}) _layers.extend(list(map(id, layer_params()))) if include_others: params.append({ "params": filter(lambda p: id(p) not in _layers, network.parameters()) }) return params
def train_one_epoch(args, model: torch.nn.Module, criterion: torch.nn.Module, dataloader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 50 for samples, targets, support_images, support_class_ids, support_targets in metric_logger.log_every(dataloader, print_freq, header): # * Sample Support Categories; # * Filters Targets (only keep GTs within support categories); # * Samples Support Images and Targets targets, support_images, support_class_ids, support_targets = \ sample_support_categories(args, targets, support_images, support_class_ids, support_targets) samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] support_images = support_images.to(device) support_class_ids = support_class_ids.to(device) support_targets = [{k: v.to(device) for k, v in t.items()} for t in support_targets] outputs = model(samples, targets=targets, supp_samples=support_images, supp_class_ids=support_class_ids, supp_targets=support_targets) loss_dict = criterion(outputs) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is NaN - {}. \nTraining terminated unexpectedly.\n".format(loss_value)) print("loss dict:") print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() if max_norm > 0: grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) else: grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) optimizer.step() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(grad_norm=grad_total_norm) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) del support_images del support_class_ids del support_targets del samples del targets del outputs del weight_dict del grad_total_norm del loss_value del losses del loss_dict del loss_dict_reduced del loss_dict_reduced_scaled del loss_dict_reduced_unscaled return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def export_gradient_graph(model: torch.nn.Module, loss_fn: Callable[[Any, Any], Any], example_input: torch.Tensor, example_labels: torch.Tensor, gradient_graph_path: Union[Path, str], opset_version=12) -> None: r""" Build a gradient graph for `model` so that you can output gradients in an inference session when given specific input and corresponding labels. Args: model (torch.nn.Module): A gradient graph will be built for this model. loss_fn (Callable[[Any, Any], Any]): A function to compute the loss given the model's output and the `example_labels`. Predefined loss functions such as `torch.nn.CrossEntropyLoss()` will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime Web, instead, use a custom Python method. example_input (torch.Tensor): Example input that you would give your model for inference/prediction. example_labels (torch.Tensor): The expected labels for `example_input`. This could be the output of your model when given `example_input` but it might be different if your loss function expects labels to be different (e.g. when using cross entropy loss). gradient_graph_path (Union[Path, str]): The path to where you would like to save the gradient graph. opset_version (int): See `torch.onnx.export`. """ # Make sure that loss nodes that expect multiple outputs are set up. CustomOpSymbolicRegistry.register_all() if not isinstance(gradient_graph_path, str): gradient_graph_path = str(gradient_graph_path) class WrapperModule(torch.nn.Module): def forward(self, model_input, expected_labels, *model_params): for param, set_param in zip(model.parameters(), model_params): param.data = set_param.data output = model(model_input) loss = loss_fn(output, expected_labels) return output, loss wrapped_model = WrapperModule() dynamic_axes = { 'input': { 0: 'batch_size', }, 'labels': { 0: 'batch_size', }, 'output': { 0: 'batch_size', }, } args = (example_input, example_labels, *tuple(model.parameters())) model_param_names = tuple(name for name, _ in model.named_parameters()) input_names = ['input', 'labels', *model_param_names] nodes_needing_gradients = set(name for name, param in model.named_parameters() if param.requires_grad) f = io.BytesIO() torch.onnx.export(wrapped_model, args, f, export_params=True, opset_version=opset_version, do_constant_folding=False, training=TrainingMode.TRAINING, input_names=input_names, output_names=['output', 'loss'], dynamic_axes=dynamic_axes) exported_model = f.getvalue() builder = GradientGraphBuilder(exported_model, {'loss'}, nodes_needing_gradients, 'loss') builder.build() builder.save(gradient_graph_path)
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, amp: bool = True, teacher_model: torch.nn.Module = None, teach_loss: torch.nn.Module = None, distill_token: bool=False, choices=None, mode='super', retrain_config=None): model.train() criterion.train() # set random seed random.seed(epoch) metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 if mode == 'retrain': config = retrain_config model_module = unwrap_model(model) print(config) model_module.set_sample_config(config=config) print(model_module.get_sampled_params_numel(config)) for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) # sample random config if mode == 'super': config = sample_configs(choices=choices) model_module = unwrap_model(model) model_module.set_sample_config(config=config) elif mode == 'retrain': config = retrain_config model_module = unwrap_model(model) model_module.set_sample_config(config=config) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) if amp: with torch.cuda.amp.autocast(): if teacher_model: with torch.no_grad(): teach_output = teacher_model(samples) _, teacher_label = teach_output.topk(1, 1, True, True) if distill_token: output_cls, output_dis = model(samples) loss = 1/2 * criterion(output_cls, targets) + 1/2 * teach_loss(output_dis, teacher_label.squeeze()) else: outputs = model(samples) loss = 1/2 * criterion(outputs, targets) + 1/2 * teach_loss(outputs, teacher_label.squeeze()) else: outputs = model(samples) loss = criterion(outputs, targets) else: outputs = model(samples) if teacher_model: with torch.no_grad(): teach_output = teacher_model(samples) _, teacher_label = teach_output.topk(1, 1, True, True) loss = 1 / 2 * criterion(outputs, targets) + 1 / 2 * teach_loss(outputs, teacher_label.squeeze()) else: loss = criterion(outputs, targets) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) optimizer.zero_grad() # this attribute is added by timm on one optimizer (adahessian) if amp: is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order) else: loss.backward() optimizer.step() torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def generic_loop(epochs: int, device: torch.device, opt: torch.optim, loss_fn: torch.nn, model: torch.nn.Module, train_fact_iter, valid_fact_iter, weight_decay: float = 0.0, clip_grads_at: float = -1.0, lr_schedule=None, eval_fn: Callable = None) -> (list, list, list): train_loss = [] train_acc = [] val_acc = [] lrs = [] # Epoch level for e in range(epochs): per_epoch_loss = [] per_epoch_tr_acc = [] # Train with Timer() as timer: # Make data # trn_dl, val_dl = data_fn(data['train']), data_fn(data['valid']) # trn_dl = train_fact_iter for x, y in tqdm(train_fact_iter): # if batch_start_hook: batch_start_hook() opt.zero_grad() if lr_schedule: lrs.append(update_lr(opt, lr_schedule.get())) _x = torch.tensor(x, dtype=torch.long, device=device) _y = torch.tensor(y, dtype=torch.long, device=device) try: y_pred = model(_x, _y) except: print(traceback.print_exc()) return (_x, _y) try: loss = loss_fn(y_pred=y_pred, y_true=_y) except: print(y_pred.shape, _y.shape) return (_x, _y) per_epoch_tr_acc.append( eval_fn(y_pred=y_pred, y_true=_y).item()) per_epoch_loss.append(loss.item()) loss.backward() if clip_grads_at > 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grads_at) # for group in opt.param_groups: # for param in group['params']: # param.data = param.data.add(-weight_decay * group['lr'], param.data) opt.step() # Val with torch.no_grad(): per_epoch_vl_acc = [] for x, y in tqdm(valid_fact_iter): _x = torch.tensor(x, dtype=torch.long, device=device) _y = torch.tensor(y, dtype=torch.long, device=device) model.eval() y_pred = model(_x, _y) loss = loss_fn(y_pred=y_pred, y_true=_y) per_epoch_vl_acc.append( eval_fn(y_pred=y_pred, y_true=_y).item()) model.train() # per_epoch_vl_acc.append(loss.item()) # Bookkeep # per_epoch_vl_acc = [0] # @TODO:Remove this once we start calculating accuracy. train_acc.append(np.mean(per_epoch_tr_acc)) train_loss.append(np.mean(per_epoch_loss)) val_acc.append(np.mean(per_epoch_vl_acc)) print( "Epoch: %(epo)03d | Loss: %(loss).5f | Tr_c: %(tracc)0.5f | Vl_c: %(vlacc)0.5f | Time: %(time).3f min" % { 'epo': e, 'loss': float(np.mean(per_epoch_loss)), 'tracc': float(np.mean(per_epoch_tr_acc)), 'vlacc': float(np.mean(per_epoch_vl_acc)), 'time': timer.interval / 60.0 }) return train_acc, train_loss, val_acc, lrs
def save_samples(generator: torch.nn.Module, cp_name: str, cuda_mode: bool, prefix: str, save_dir='./', nc=3, im_size=64, fig_size=(5, 5), enhance=True, SNGAN=False): generator.eval() n_tests = fig_size[0] * fig_size[1] to_pil = transforms.ToPILImage() to_tensor = transforms.ToTensor() if SNGAN: noise = torch.randn(n_tests, 128).view(-1, 128, 1, 1) else: noise = torch.randn(n_tests, 100).view(-1, 100, 1, 1) if cuda_mode: noise = noise.cuda() gen_image = generator(noise).view(-1, nc, im_size, im_size) gen_image = denorm(gen_image) #n_rows = np.sqrt(noise.size()[0]).astype(np.int32) #n_cols = np.sqrt(noise.size()[0]).astype(np.int32) n_cols, n_rows = fig_size fig, axes = plt.subplots(n_cols, n_rows, figsize=(n_rows, n_cols)) for ax, img in zip(axes.flatten(), gen_image): ax.axis('off') ax.set_adjustable('box-forced') img = img.cpu().data if enhance: img_E = ImageEnhance.Sharpness(to_pil(img)).enhance(1.) img = to_tensor(img_E) # Scale to 0-255 img = (((img - img.min()) * 255) / (img.max() - img.min())).numpy().transpose(1, 2, 0).astype( np.uint8).squeeze() # ax.imshow(img.cpu().detach().view(image_size, image_size, 3).numpy(), cmap=None, aspect='equal') if nc == 1: ax.imshow(img, cmap="gray", aspect='equal') else: ax.imshow(img, cmap=None, aspect='equal') plt.subplots_adjust(wspace=0, hspace=0) #title = 'Samples' #fig.text(0.5, 0.04, title, ha='center') # save figure if not os.path.exists(save_dir): os.mkdir(save_dir) save_fn = save_dir + prefix + '_' + cp_name + '.pdf' plt.savefig(save_fn) plt.close()
def plot_predictions(model: torch.nn.Module, x: tensor, y: tensor, num_samples: int = 1, shade: bool = False, x_names: List = None, y_names: List = None, axes=None, fig=None, save_filename=None): """ Plot the predictions and the confidence interval :param x: :param y: :param: transform: Transform jointly X, Y (usually coming from a same dataset) :param: features_names: List of names for the features :return: """ # Plot the prediction of the GP mean_cond, cov_cond = model.predict(x) # Sample some predictions if num_samples <= 0: samples = None else: samples = model.sample(x, num_samples=num_samples) if len(mean_cond.shape) == 1: mean_cond = mean_cond.unsqueeze(1) num_features = x.shape[1] # if num_features > 1: # shade = False mean_add_std = mean_cond + cov_cond.mm( torch.ones((mean_cond.shape[0], 1)).to(DEVICE)) mean_sub_std = mean_cond - cov_cond.mm( torch.ones((mean_cond.shape[0], 1)).to(DEVICE)) x_train = model.x_train y_train = model.y_train if axes is None: fig, axes = plt.subplots(num_features, 1, figsize=(3 * 1, 3 * num_features), squeeze=False, sharey=False, sharex=False) else: if len(axes.shape) == 1: axes = np.expand_dims(axes, 1) if num_features == 1: axes = np.asarray(axes) for ix in range(num_features): axes[ix, 0].plot(x_train[:, ix].cpu().data.numpy(), y_train.cpu().data.numpy(), marker='.', ms=2, label="Training", linestyle='') if samples is not None: axes[ix, 0].plot(x[:, ix].cpu().data.numpy(), samples.cpu().data.numpy(), ms=4, marker=".", linestyle="", label="Sample") axes[ix, 0].plot(x[:, ix].cpu().data.numpy(), mean_cond.cpu().data.numpy(), marker='o', ms=4, label="Pred.", linestyle="") axes[ix, 0].plot(x[:, ix].cpu().data.numpy(), y.cpu().data.numpy(), '.', ms=4, label="Ground", linestyle='') if shade: ind_sort = x[:, ix].sort(descending=False)[-1] axes[ix, 0].fill_between( x[ind_sort, ix].cpu().data.numpy(), mean_sub_std.cpu().data.numpy().squeeze()[ind_sort], mean_add_std.cpu().data.numpy().squeeze()[ind_sort], color="#dddddd") # axes[ix].set_title(f"Samples from the GP posterior") if x_names is not None: axes[ix, 0].set_xlabel(f"{x_names[ix]}") if y_names is not None: axes[ix, 0].set_ylabel(f"{y_names}") # plt.show() axes[-1, 0].legend() if fig is not None: fig.tight_layout(rect=[0, 0.03, 1, 0.95]) if save_filename is not None: fig.savfig(save_filename, bbox_inches='tight', format='png', dpi=200) plt.close(fig) return fig, axes
def evaluate_cor( model: torch.nn.Module, test_file: ty.Union[str, pathlib.Path], span_digitizer: ty.Callable[[ty.Mapping[str, ty.Any]], datatools.FeaturefulSpan], pair_feats_digitizer: ty.Callable[[ty.Mapping[str, str]], ty.Iterable[int]], loss_fun: ty.Callable = libdecofre.masked_multi_cross_entropy, device=None, num_workers: int = 0, ) -> ty.Tuple[torch.FloatTensor]: model = model.to(device) with tempfile.TemporaryDirectory( prefix="decofre_antecedents_") as temp_dir: test_set = datatools.AntecedentsDataset.from_json( test_file, span_digitizer=span_digitizer, pair_feats_digitizer=pair_feats_digitizer, cache_dir=temp_dir, set_name="test", ) test_loader = torch.utils.data.DataLoader( dataset=test_set, sampler=torch.utils.data.BatchSampler( torch.utils.data.RandomSampler(test_set), batch_size=8, drop_last=False), collate_fn=lambda x: x[0], num_workers=num_workers, ) logger.info("Evaluating on the test set") model.eval() evaluator = runners.Evaluator( model, loss=loss_fun, metrics={ "antecedent_accuracy": runners.MultiLoss( loss_fn=lambda x, y: _summable_antecedent_accuracy( x.to(device), y.to(device)), output_transform=runners.extract_output, device=device, loss_names=( "total_accuracy", "mention_new_accuracy", "anaphora_accuracy", ), ), "attributions": runners.MultiLoss( loss_fn=lambda x, y: attributions(x.to(device), y.to(device )), output_transform=runners.extract_output, averaged=False, device=device, loss_names=( "true_new", "false_new", "correct_link", "false_link", "wrong_link", ), ), }, ) state = evaluator.run(test_loader) for name, value in state.metrics.items(): logger.info(f"{name}: {value}") return state.metrics["loss"]
def train_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], reporter: SubReporter, options: TrainerOptions, ) -> bool: assert check_argument_types() # Note(kamo): assumes one optimizer assert cls.num_optimizers == 1, cls.num_optimizers assert len(optimizers) == 1, len(optimizers) optimizer = optimizers[0] scheduler = schedulers[0] grad_noise = options.grad_noise accum_grad = options.accum_grad grad_clip = options.grad_clip log_interval = options.log_interval no_forward_run = options.no_forward_run ngpu = options.ngpu distributed = isinstance(model, torch.nn.parallel.DistributedDataParallel) use_apex = options.train_dtype in ("O0", "O1", "O2", "O3") if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) except TypeError: log_interval = 100 model.train() all_steps_are_invalid = True # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") start_time = time.perf_counter() for iiter, (_, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1): assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: all_steps_are_invalid = False reporter.register({}) continue with reporter.measure_time("forward_time"): loss, stats, weight = model(**batch) if ngpu > 1 or distributed: # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) # Now weight is summation over all workers loss /= weight if distributed: # NOTE(kamo): Multiply world_size because DistributedDataParallel # automatically normalizes the gradient by world_size. loss *= torch.distributed.get_world_size() reporter.register(stats, weight) loss /= accum_grad with reporter.measure_time("backward_time"): if use_apex: try: from apex import amp except ImportError: logging.error( "You need to install apex. " "See https://github.com/NVIDIA/apex#linux") with amp.scale_loss(loss, optimizers) as scaled_loss: scaled_loss.backward() else: loss.backward() if iiter % accum_grad == 0: # gradient noise injection if grad_noise: add_gradient_noise( model, reporter.get_total_count(), duration=100, eta=1.0, scale_factor=0.55, ) # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), grad_clip) # PyTorch<=1.4, clip_grad_norm_ returns float value if not isinstance(grad_norm, torch.Tensor): grad_norm = torch.tensor(grad_norm) if not torch.isfinite(grad_norm): logging.warning( f"The grad norm is {grad_norm}. Skipping updating the model." ) else: all_steps_are_invalid = False with reporter.measure_time("optim_step_time"): optimizer.step() if isinstance(scheduler, AbsBatchStepScheduler): scheduler.step() optimizer.zero_grad() # Register lr and train/load time[sec/step], # where step refers to accum_grad * mini-batch reporter.register( dict( { f"lr_{i}": pg["lr"] for i, pg in enumerate(optimizer.param_groups) if "lr" in pg }, train_time=time.perf_counter() - start_time, ), # Suppress to increment the internal counter. not_increment_count=True, ) start_time = time.perf_counter() if iiter % log_interval == 0: logging.info(reporter.log_message()) else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) return all_steps_are_invalid
def train_cor( model: torch.nn.Module, train_file: ty.Union[str, pathlib.Path], span_digitizer: ty.Callable[[ty.Mapping[str, ty.Any]], datatools.FeaturefulSpan], pair_feats_digitizer: ty.Callable[[ty.Mapping[str, str]], ty.Iterable[int]], out_dir: ty.Union[str, pathlib.Path], temp_dir: ty.Union[str, pathlib.Path], device: torch.device, epochs: int, patience: int, train_batch_size: int = 8, dev_file: ty.Optional[ty.Union[str, pathlib.Path]] = None, optimizer=None, loss_fun: ty.Callable = libdecofre.masked_multi_cross_entropy, trainer_cls=runners.SinkTrainer, num_workers: int = 0, debug: bool = False, config: ty.Optional[ty.Dict[str, ty.Any]] = None, **kwargs, ) -> ty.Tuple[ignite.engine.Engine, ty.Iterable, ty.Dict[str, ty.Any]]: logger.info("Training antecedent scoring") model = model.to(device) train_set = datatools.AntecedentsDataset.from_json( train_file, span_digitizer=span_digitizer, pair_feats_digitizer=pair_feats_digitizer, cache_dir=temp_dir, set_name="train_cor", ) train_loader = torch.utils.data.DataLoader( dataset=train_set, sampler=torch.utils.data.BatchSampler( torch.utils.data.RandomSampler(train_set), batch_size=train_batch_size, drop_last=False, ), collate_fn=lambda x: x[0], num_workers=num_workers, ) if dev_file is not None: dev_set = datatools.AntecedentsDataset.from_json( dev_file, span_digitizer=span_digitizer, pair_feats_digitizer=pair_feats_digitizer, cache_dir=temp_dir, set_name="dev_cor", ) dev_loader = torch.utils.data.DataLoader( dataset=dev_set, sampler=torch.utils.data.BatchSampler( torch.utils.data.RandomSampler(dev_set), batch_size=train_batch_size, drop_last=False, ), collate_fn=lambda x: x[0], num_workers=num_workers, ) # type: ty.Optional[torch.data.DataLoader] else: dev_loader = None cor_trainer = trainer_cls( model, checkpointed_models={"cor": model}, loss_fun=loss_fun, optimizer=optimizer, dev_loss=loss_fun, dev_metrics={ "antecedent_accuracy": runners.MultiLoss( loss_fn=lambda x, y: _summable_antecedent_accuracy( x.to(device), y.to(device)), num_loss=3, output_transform=runners.extract_output, device=device, loss_names=( "total_accuracy", "mention_new_accuracy", "anaphora_accuracy", ), ), "attributions": runners.MultiLoss( loss_fn=lambda x, y: attributions(x.to(device), y.to(device)), output_transform=runners.extract_output, averaged=False, device=device, loss_names=( "true_new", "false_new", "correct_link", "false_link", "wrong_link", ), ), }, save_path=out_dir, debug=debug, **kwargs, ) if config["lr-schedule"] == "step": logger.debug("Training with 'step' LR schedule, using γ=0.95") torch_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, len(train_loader), gamma=0.95) scheduler = ignite.contrib.handlers.create_lr_scheduler_with_warmup( torch_lr_scheduler, warmup_start_value=0.0, warmup_end_value=optimizer.defaults["lr"], warmup_duration=1000, ) cor_trainer.add_event_handler(ignite.engine.Events.ITERATION_STARTED, scheduler) return ( cor_trainer, train_loader, { "max_epochs": epochs, "patience": patience, "dev_loader": dev_loader, "run_name": "antecedent_scoring", }, )
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(samples) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def count_parameters(model: torch.nn.Module): return sum(p.numel() for p in model.parameters() if p.requires_grad)
def __init__(self, model: torch.nn.Module, trained_param_path: Path, cuda): self.device = torch.device('cuda' if cuda >= 0 else 'cpu') self.trained_param_path = trained_param_path self.model = model.to(self.device) self._load_param() self.model.eval()
def fit(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion, train_dataloader: DataLoader, val_dataloader: DataLoader, epochs: int, metrics=None, device='cpu', **kwargs): self.device = device self.model = model.to(self.device) self.optimizer = optimizer self.criterion = criterion self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.metrics = metrics or dict() self.epochs = epochs self.config = { 'model': repr(self.model), 'optimizer': repr(self.optimizer), 'criterion': repr(self.criterion), 'metrics': ', '.join(metric_name for metric_name in self.metrics.keys()), 'epochs': self.epochs, } self.train_losses = [] self.val_losses = [] for func in self.on_fit_start: func(self, self.config) try: with tqdm.tqdm(range(self.epochs), desc="Training epochs", unit="epoch") as epoch_progress_bar: for epoch in epoch_progress_bar: for func in self.on_epoch_start: func(self, epoch) # train step self.model.train() with tqdm.tqdm(self.train_dataloader, desc="Train", unit="batch", leave=False) as train_progress_bar: for batch_idx, ( batch_x, batch_y) in enumerate(train_progress_bar): for func in self.on_training_batch_start: func(self, batch_x, batch_y) self.optimizer.zero_grad() loss, out = self.criterion(model, batch_x.to(self.device), batch_y.to(self.device)) loss.backward() self.optimizer.step() self.train_losses.append(loss.item()) for func in self.on_training_batch_end: func(self, batch_x, batch_y, loss) # validation step self.model.eval() metric_results = defaultdict(list) with tqdm.tqdm(self.val_dataloader, desc='Validation', unit='batch', leave=False) as validation_progress_bar: for batch_idx, ( batch_x, batch_y) in enumerate(validation_progress_bar): for func in self.on_validation_batch_start: func(self, batch_x, batch_y) with torch.no_grad(): loss, out = self.criterion( model, batch_x.to(self.device), batch_y.to(self.device)) self.val_losses.append(loss.item()) for metric_name, metric_func in self.metrics.items( ): metric_results[metric_name].append( metric_func(batch_y, out)) for func in self.on_validation_batch_end: func(self, batch_x, batch_y, loss, metric_results) for func in self.on_epoch_end: func(self, epoch) except KeyboardInterrupt: print('Keyboard Interruption, teardown') finally: for func in self.on_fit_end: func(self) return model
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, teacher=None, set_training_mode=True): # TODO fix this for finetuning # model.train(set_training_mode) model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 100 for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) samples, targets, mix_rate, aux_targets = two_mix( samples, targets, num_patch=samples.shape[-1] // 16) with torch.cuda.amp.autocast(): # outputs, r_loss = model(samples) outputs, r_loss, s_loss, proj = model(samples, aux_targets) loss = torch.sum(-targets * (1e-8 + outputs.softmax(dim=-1)).log(), dim=-1).mean() loss_value = loss.item() loss += 1. * (r_loss + 1. * s_loss) if not math.isfinite(loss.item()): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) optimizer.zero_grad() # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr( optimizer, 'is_second_order') and optimizer.is_second_order loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order) torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.meters['r'].update(r_loss.item(), n=targets.shape[0]) # metric_logger.meters['p'].update(proj.item(), n=targets.shape[0]) metric_logger.meters['s'].update(s_loss.item(), n=targets.shape[0]) # metric_logger.meters['cos'].update(cos.item(), n=targets.shape[0]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def dump_sender_receiver(game: torch.nn.Module, dataset: 'torch.utils.data.DataLoader', gs: bool, variable_length: bool, device: Optional[torch.device] = None): """ A tool to dump the interaction between Sender and Receiver :param game: A Game instance :param dataset: Dataset of inputs to be used when analyzing the communication :param gs: whether Gumbel-Softmax relaxation was used during training :param variable_length: whether variable-length communication is used :param device: device (e.g. 'cuda') to be used :return: """ train_state = game.training # persist so we restore it back game.eval() device = device if device is not None else common_opts.device sender_inputs, messages, receiver_inputs, receiver_outputs = [], [], [], [] labels = [] with torch.no_grad(): for batch in dataset: # by agreement, each batch is (sender_input, labels) plus optional (receiver_input) sender_input = move_to(batch[0], device) print("sender_input: " + str(sender_input)) print("batch" + str(batch)) print("batch length" + str(len(batch))) receiver_input = None if len(batch) == 2 else move_to( batch[2], device) print("receiver_input: " + str(receiver_input)) message = game.sender(sender_input) # Under GS, the only output is a message; under Reinforce, two additional tensors are returned. # We don't need them. if not gs: message = message[0] output = game.receiver(message, receiver_input) if not gs: output = output[0] if batch[1] is not None: labels.extend(batch[1]) if isinstance(sender_input, list) or isinstance( sender_input, tuple): sender_inputs.extend(zip(*sender_input)) else: sender_inputs.extend(sender_input) if receiver_input is not None: receiver_inputs.extend(receiver_input) if gs: message = message.argmax( dim=-1) # actual symbols instead of one-hot encoded if not variable_length: messages.extend(message) receiver_outputs.extend(output) else: # A trickier part is to handle EOS in the messages. It also might happen that not every message has EOS. # We cut messages at EOS if it is present or return the entire message otherwise. Note, EOS id is always # set to 0. for i in range(message.size(0)): eos_positions = (message[i, :] == 0).nonzero() message_end = eos_positions[0].item( ) if eos_positions.size(0) > 0 else -1 assert message_end == -1 or message[i, message_end] == 0 if message_end < 0: messages.append(message[i, :]) else: messages.append(message[i, :message_end + 1]) if gs: receiver_outputs.append(output[i, message_end, ...]) else: receiver_outputs.append(output[i, ...]) game.train(mode=train_state) return sender_inputs, messages, receiver_inputs, receiver_outputs, labels
def train_model(model: torch.nn.Module, device: torch.device, optimizer: AdamW, file_path: Text, train_loader: DataLoader, valid_loader: DataLoader, label_to_cellid: Dict[int, int], num_epochs: int, best_valid_loss: float = float("Inf")): '''Main funcion for training model.''' # initialize running values running_loss = 0.0 global_step = 0 valid_accuracy_list, train_loss_list, valid_loss_list = [], [], [] global_steps_list, true_points_list, pred_points_list = [], [], [] # Training loop. model.train() for epoch in range(num_epochs): for batch, _ in train_loader: optimizer.zero_grad() input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.mean().backward() optimizer.step() # Update running values. running_loss += loss.mean().item() global_step += 1 # Evaluation step. valid_loss, predictions, true_vals, true_points, pred_points = evaluate( model, valid_loader, device, label_to_cellid) average_train_loss = running_loss / labels.shape[0] accuracy = accuracy_cells(true_vals, predictions) train_loss_list.append(average_train_loss) valid_loss_list.append(valid_loss) global_steps_list.append(global_step) valid_accuracy_list.append(accuracy) true_points_list.append(true_points) pred_points_list.append(pred_points) # Resetting running values. running_loss = 0.0 logging.info('Epoch [{}/{}], Step [{}/{}], \ Accuracy: {:.4f},Train Loss: {:.4f}, Valid Loss: {:.4f}'.format( epoch + 1, num_epochs, global_step, num_epochs * len(train_loader), accuracy, average_train_loss, valid_loss)) # Save model and results in checkpoint. if best_valid_loss > valid_loss: best_valid_loss = valid_loss util.save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss) util.save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list, valid_accuracy_list, true_points_list, pred_points_list) model.train() logging.info('Finished Training.')
def train(classifier: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, test_x: torch.Tensor = torch.Tensor(), test_y: torch.Tensor = torch.Tensor(), batch_size: int = 16, num_epochs: int = 2, run_device: str = "cpu", learning_rate: float = 0.001, beta_1: float = 0.9, beta_2: float = 0.999, random_state: torch.ByteTensor = torch.get_rng_state().clone(), verbose: bool = False) -> Tuple[torch.nn.Module, torch.ByteTensor]: """ Function to train classifiers and save the trained classifiers. Parameters ---------- classifier: torch.nn.Module x: torch.Tensor y: torch.Tensor test_x: torch.Tensor test_y: torch.Tensor batch_size: int num_epochs: int run_device: str learning_rate: float beta_1: float beta_2: float random_state: torch.ByteTensor verbose: bool Returns ------- Tuple[torch.nn.Module, torch.Tensor] """ assert isinstance(classifier, torch.nn.Module) assert isinstance(x, torch.Tensor) assert isinstance(y, torch.Tensor) assert isinstance(test_x, torch.Tensor) assert isinstance(test_y, torch.Tensor) assert isinstance(batch_size, int) and (batch_size > 0) assert isinstance(num_epochs, int) and (num_epochs > 0) assert isinstance(run_device, str) and (run_device.lower() in ["cpu", "cuda"]) assert isinstance(learning_rate, float) and (learning_rate > 0.0) assert isinstance(beta_1, float) and (0.0 <= beta_1 < 1.0) assert isinstance(beta_2, float) and (0.0 <= beta_2 < 1.0) assert isinstance(random_state, torch.ByteTensor) assert isinstance(verbose, bool) # Set the seed for generating random numbers. random_state_previous: torch.ByteTensor = torch.get_rng_state().clone() torch.set_rng_state(random_state) # Set the classifier. classifier_train: torch.nn.Module = copy.deepcopy(classifier.cpu()) run_device_train: str = run_device.lower() if run_device_train == "cuda": assert torch.cuda.is_available() classifier_train = classifier_train.cuda() if torch.cuda.device_count() > 1: num_gpus: int = torch.cuda.device_count() classifier_train = torch.nn.DataParallel(classifier_train, device_ids=list( range(0, num_gpus))) # Set a criterion and optimizer. criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(params=classifier_train.parameters(), lr=learning_rate, betas=(beta_1, beta_2)) # Covert PyTorch's Tensor to TensorDataset. x_train, y_train = x.clone(), y.clone() dataset_train = torch.utils.data.TensorDataset(x_train, y_train) dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, num_workers=0, shuffle=True) has_test: bool = False if (test_x.size(0) > 0) and (test_y.size(0) > 0): x_test, y_test = test_x.clone(), test_y.clone() dataset_test = torch.utils.data.TensorDataset(x_test, y_test) dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, num_workers=0, shuffle=False) has_test = True # Initialize the early_stopping object. early_stopping: _EarlyStopping = _EarlyStopping(patience=10, delta=0.0, verbose=False) log_template: str = "[{0}/{1}] Loss: {2:.4f}, Time: {3:.2f}s" log_template_test: str = "[{0}/{1}] Loss (Train): {2:.4f}, Loss (Test): {3:.4f}, Time: {4:.2f}s" list_loss: list = list() list_loss_test: list = list() # Train the classifiers. classifier_train.train() for epoch in range(1, num_epochs + 1): start_time: float = time.time() for (_, batch) in enumerate(dataloader_train, 0): batch_x, batch_y = batch if run_device_train == "cuda": batch_x, batch_y = batch_x.cuda(), batch_y.cuda() optimizer.zero_grad() output: torch.Tensor = classifier_train(batch_x) loss: torch.Tensor = criterion(output, batch_y) loss.backward() optimizer.step() list_loss.append(loss.detach().cpu().item()) end_time: float = time.time() if has_test: classifier_train.eval() for (_, batch) in enumerate(dataloader_test, 0): batch_x, batch_y = batch if run_device_train == "cuda": batch_x, batch_y = batch_x.cuda(), batch_y.cuda() output = classifier_train(batch_x) loss = criterion(output, batch_y) list_loss_test.append(loss.detach().cpu().item()) classifier_train.train() early_stopping(loss=np.mean(list_loss_test), model=classifier_train) else: early_stopping(loss=np.mean(list_loss), model=classifier_train) if verbose: if has_test: print( log_template_test.format(epoch, num_epochs, np.mean(list_loss), np.mean(list_loss_test), end_time - start_time)) else: print( log_template.format(epoch, num_epochs, np.mean(list_loss), end_time - start_time)) if early_stopping.early_stop: state_dict, rng_state = early_stopping.get_best_model() classifier_train.load_state_dict(state_dict) torch.set_rng_state(rng_state) break if isinstance(classifier_train, torch.nn.DataParallel): classifier_train = classifier_train.module random_state_after: torch.ByteTensor = torch.get_rng_state().clone() torch.set_rng_state(random_state_previous) return classifier_train.cpu(), random_state_after.clone()