def after_loss_fn_new(_input: torch.Tensor, _label: torch.Tensor, _output: torch.Tensor, loss: torch.Tensor, optimizer: Optimizer, loss_fn: Callable[..., torch.Tensor] = None, amp: bool = False, scaler: torch.cuda.amp.GradScaler = None, **kwargs): noise = torch.zeros_like(_input) adv_loss_fn = functools.partial(self.adv_loss, _label=_label) for m in range(self.pgd.iteration): if amp: scaler.step(optimizer) scaler.update() else: optimizer.step() self.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=adv_loss_fn, iteration=1, epsilon=adv_train_epsilon) self.train() loss = loss_fn(adv_x, _label) if callable(after_loss_fn_old): after_loss_fn_old(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, **kwargs) if amp: scaler.scale(loss).backward() else: loss.backward()
def process_batch(model: nn.Module, batch: Tuple, criterion: nn.modules.loss._Loss, optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler, train: bool, device: torch.cuda.Device) \ -> Tuple[torch.Tensor, torch.Tensor, List[str], List[torch.Tensor]]: """ :param model: model to train :param batch: batch with spectrograms, sample_rates, utterances :param criterion: criterion to calculate loss :param optimizer: optimizer to step :param scaler: GradScaler for mixed precision training :param train: perform gradient step :param device: cuda device to work on :return: (loss, logprobs, utterances) """ input_lengths, target_lengths, waveforms, spectrograms, sample_rates, utterances = batch spectrograms = spectrograms.to(device) utterances = utterances.to(device) with torch.cuda.amp.autocast(): logprobs = model(spectrograms) loss = criterion(logprobs.permute(2, 0, 1), utterances, input_lengths, target_lengths) if train: optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() loss = loss.detach() logprobs = logprobs.detach() utterances = utterances.detach() return loss, logprobs, utterances, waveforms
def after_loss_fn_new(_input: torch.Tensor, _label: torch.Tensor, _output: torch.Tensor, loss: torch.Tensor, optimizer: Optimizer, loss_fn: Callable[..., torch.Tensor] = None, amp: bool = False, scaler: torch.cuda.amp.GradScaler = None, **kwargs): optimizer.zero_grad() self.zero_grad() if pre_conditioner is not None: pre_conditioner.reset() if self.adv_train == 'free': noise = self.pgd.init_noise(_input.shape, pgd_eps=self.adv_train_eps, random_init=self.adv_train_random_init, device=_input.device) adv_x = add_noise(x=_input, noise=noise, universal=self.pgd.universal, clip_min=self.pgd.clip_min, clip_max=self.pgd.clip_max) noise.data = self.pgd.valid_noise(adv_x, _input) for m in range(self.adv_train_iter): loss = loss_fn(adv_x, _label) if amp: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() optimizer.zero_grad() self.zero_grad() # self.eval() adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, target=_label, iteration=1, pgd_alpha=self.adv_train_alpha, pgd_eps=self.adv_train_eps) # self.train() loss = loss_fn(adv_x, _label) else: loss = self.adv_loss(_input=_input, _label=_label, loss_fn=loss_fn) if amp: scaler.scale(loss).backward() else: loss.backward() if callable(after_loss_fn_old): after_loss_fn_old(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, **kwargs)
def train_one_epoch( model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler, amp, device: torch.device, epoch: int, max_norm: float = 0, ): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 for samples, targets in metric_logger.log_every(data_loader, print_freq, header): optimizer.zero_grad() with autocast(enabled=amp): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(samples) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) scaler.scale(losses).backward() if max_norm > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, scaler: torch.cuda.amp.GradScaler, epoch: int, max_norm: float = 0, fp16=False): model.train() criterion.train() tensor_type = torch.cuda.HalfTensor if fp16 else torch.cuda.FloatTensor metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter( 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter( 'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) metric_logger.add_meter( 'grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 prefetcher = data_prefetcher(data_loader, device, prefetch=True) samples, targets = prefetcher.next() # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header): samples.tensors = samples.tensors.type(tensor_type) samples.mask = samples.mask.type(tensor_type) with torch.cuda.amp.autocast(enabled=fp16): outputs, pre_outputs, pre_targets = model([samples, targets]) loss_dict = criterion(outputs, targets, pre_outputs, pre_targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = { f'{k}_unscaled': v for k, v in loss_dict_reduced.items() } loss_dict_reduced_scaled = { k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict } losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() scaler.scale(losses).backward() scaler.unscale_(optimizer) if max_norm > 0: grad_total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm) else: grad_total_norm = utils.get_total_grad_norm( model.parameters(), max_norm) scaler.step(optimizer) scaler.update() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(class_error=loss_dict_reduced['class_error']) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(grad_norm=grad_total_norm) samples, targets = prefetcher.next() # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}