def compare(module1: nn.Module, module2: nn.Module, loader: torch.utils.data.DataLoader, print_prefix='Validate', indent=0, verbose=True, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = nn.CrossEntropyLoss(), **kwargs) -> float: module1.eval() module2.eval() get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x logger = MetricLogger() logger.create_meters(loss=None) loader_epoch = loader if verbose: header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: loader_epoch = tqdm(loader_epoch, leave=False) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) for data in loader_epoch: _input, _label = get_data_fn(data, **kwargs) _output1: torch.Tensor = module1(_input) _output2: torch.Tensor = module2(_input) loss = criterion(_output1, _output2.softmax(1)).item() batch_size = int(_label.size(0)) logger.update(n=batch_size, loss=loss) return logger.meters['loss'].global_avg
def validate_mask_generator(self): loader = self.dataset.loader['valid'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, norm=None) idx = torch.randperm(len(dataset)) pos = 0 print_prefix = 'Validate' header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header)) for data in logger.log_every(loader, header=header): _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size _mask = self.get_mask(_input) _mask2 = self.get_mask(_input2) input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1) mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5 loss_div = input_dist.div(mask_dist).mean() loss_norm = _mask.sub(self.mask_density).relu().mean() loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
def validate(module: nn.Module, num_classes: int, loader: torch.utils.data.DataLoader, print_prefix: str = 'Validate', indent: int = 0, verbose: bool = True, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, forward_fn: Callable[..., torch.Tensor] = None, loss_fn: Callable[..., torch.Tensor] = None, writer=None, main_tag: str = 'valid', tag: str = '', _epoch: int = None, accuracy_fn: Callable[..., list[float]] = None, **kwargs) -> tuple[float, float]: r"""Evaluate the model. Returns: (float, float): Accuracy and loss. """ module.eval() get_data_fn = get_data_fn or (lambda x: x) forward_fn = forward_fn or module.__call__ loss_fn = loss_fn or nn.CrossEntropyLoss() accuracy_fn = accuracy_fn or accuracy logger = MetricLogger() logger.create_meters(loss=None, top1=None, top5=None) loader_epoch = loader if verbose: header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) loader_epoch = logger.log_every(loader, header=header, tqdm_header='Batch', indent=indent) for data in loader_epoch: _input, _label = get_data_fn(data, mode='valid', **kwargs) with torch.no_grad(): _output = forward_fn(_input) loss = float(loss_fn(_input, _label, _output=_output, **kwargs)) acc1, acc5 = accuracy_fn(_output, _label, num_classes=num_classes, topk=(1, 5)) batch_size = int(_label.size(0)) logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5) acc, loss = (logger.meters['top1'].global_avg, logger.meters['loss'].global_avg) if writer is not None and _epoch is not None and main_tag: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch) return acc, loss
def _validate(self, full=True, print_prefix='Validate', indent=0, verbose=True, loader: torch.utils.data.DataLoader = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, writer=None, main_tag: str = 'valid', tag: str = '', _epoch: int = None, **kwargs) -> tuple[float, float]: self.eval() if loader is None: loader = self.dataset.loader[ 'valid'] if full else self.dataset.loader['valid2'] get_data_fn = get_data_fn if get_data_fn is not None else self.get_data loss_fn = loss_fn if loss_fn is not None else self.loss logger = MetricLogger() logger.meters['loss'] = SmoothedValue() logger.meters['top1'] = SmoothedValue() logger.meters['top5'] = SmoothedValue() loader_epoch = loader if verbose: header = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) for data in loader_epoch: _input, _label = get_data_fn(data, mode='valid', **kwargs) with torch.no_grad(): _output = self(_input) loss = float(loss_fn(_input, _label, _output=_output, **kwargs)) acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) logger.meters['loss'].update(loss, batch_size) logger.meters['top1'].update(acc1, batch_size) logger.meters['top5'].update(acc5, batch_size) loss, acc = logger.meters['loss'].global_avg, logger.meters[ 'top1'].global_avg if writer is not None and _epoch is not None and main_tag: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch) return loss, acc
def train_mask_generator(self, verbose: bool = True): r"""Train :attr:`self.mask_generator`.""" optimizer = torch.optim.Adam(self.mask_generator.parameters(), lr=1e-2, betas=(0.5, 0.9)) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.train_mask_epochs) loader = self.dataset.loader['train'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, norm=None) print_prefix = 'Mask Epoch' for _epoch in range(self.train_mask_epochs): _epoch += 1 idx = torch.randperm(len(dataset)) pos = 0 logger.reset() header: str = '{blue_light}{0}: {1}{reset}'.format( print_prefix, output_iter(_epoch, self.train_mask_epochs), **ansi) header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header)) self.mask_generator.train() for data in logger.log_every(loader, header=header) if verbose else loader: optimizer.zero_grad() _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size _mask = self.get_mask(_input) _mask2 = self.get_mask(_input2) input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1) mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5 loss_div = input_dist.div(mask_dist).mean() loss_norm = _mask.sub(self.mask_density).relu().mean() loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div loss.backward() optimizer.step() logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item()) lr_scheduler.step() self.mask_generator.eval() if verbose and (_epoch % (max(self.train_mask_epochs // 5, 1)) == 0 or _epoch == self.train_mask_epochs): self.validate_mask_generator() optimizer.zero_grad()
def compare(module1: nn.Module, module2: nn.Module, loader: torch.utils.data.DataLoader, print_prefix='Validate', indent=0, verbose=True, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> float: logsoftmax = nn.LogSoftmax(dim=1) softmax = nn.Softmax(dim=1) module1.eval() module2.eval() get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x def cross_entropy(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: result: torch.Tensor = -softmax(p) * logsoftmax(q) return result.sum(1).mean() logger = MetricLogger() logger.meters['loss'] = SmoothedValue() loader_epoch = loader if verbose: header = '{yellow}{0}{reset}'.format(print_prefix, **ansi) header = header.ljust( max(len(print_prefix), 30) + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) with torch.no_grad(): for data in loader_epoch: _input, _label = get_data_fn(data, **kwargs) _output1, _output2 = module1(_input), module2(_input) loss = float(cross_entropy(_output1, _output2)) batch_size = int(_label.size(0)) logger.meters['loss'].update(loss, batch_size) return logger.meters['loss'].global_avg
def attack(self, epochs: int, save: bool = False, **kwargs): # train generator resnet_model = trojanvision.models.create('resnet18_comp', dataset=self.dataset, pretrained=True) model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4]) match self.generator_mode: case 'resnet': resnet_model = torchvision.models.resnet18(pretrained=True).to(device=env['device']) model_extractor = nn.Sequential(*list(resnet_model.children())[:5]) case _: resnet_model = trojanvision.models.create('resnet18_comp', dataset=self.dataset, pretrained=True) model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4]) model_extractor.requires_grad_(False) model_extractor.train() self.generator.train() if self.generator_mode == 'default': self.generator.requires_grad_() parameters = self.generator.parameters() else: self.generator.bottleneck.requires_grad_() self.generator.decoder.requires_grad_() parameters = self.generator[1:].parameters() optimizer = torch.optim.Adam(parameters, lr=1e-3) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.train_generator_epochs, eta_min=1e-5) logger = MetricLogger() logger.create_meters(loss=None, acc=None) for _epoch in range(self.train_generator_epochs): _epoch += 1 logger.reset() header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, self.train_generator_epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) loader = logger.log_every(self.dataset.loader['train'], header=header, tqdm_header='Batch') for data in loader: optimizer.zero_grad() _input, _label = self.model.get_data(data) adv_input = (self.generator(_input) + 1) / 2 _feats = model_extractor(_input) adv_feats = model_extractor(adv_input) loss = F.l1_loss(_feats, self.noise_coeff * adv_feats) loss.backward() optimizer.step() batch_size = len(_label) with torch.no_grad(): org_class = self.model.get_class(_input) adv_class = self.model.get_class(adv_input) acc = (org_class == adv_class).float().sum().item() * 100.0 / batch_size logger.update(n=batch_size, loss=loss.item(), acc=acc) lr_scheduler.step() self.save_generator() optimizer.zero_grad() self.generator.eval() self.generator.requires_grad_(False) self.mark.mark[:-1] = (self.generator(self.mark.mark[:-1].unsqueeze(0))[0] + 1) / 2 self.poison_set = self.get_poison_dataset(load_mark=False) return super().attack(epochs, save=save, **kwargs)
def attack(self, epochs: int, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, validate_interval: int = 1, save: bool = False, verbose: bool = True, **kwargs): if verbose: print('train mask generator') self.mark_generator.requires_grad_(False) self.mask_generator.requires_grad_() self.model.requires_grad_(False) self.train_mask_generator(verbose=verbose) if verbose: print() print('train mark generator and model') self.mark_generator.requires_grad_() self.mask_generator.requires_grad_(False) if not self.natural: params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) self.model.activate_params(params) mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9)) mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( mark_optimizer, T_max=epochs) loader = self.dataset.loader['train'] dataset = loader.dataset logger = MetricLogger() logger.create_meters(loss=None, div=None, ce=None) if validate_interval != 0: best_validate_result = self.validate_fn(verbose=verbose) best_asr = best_validate_result[0] for _epoch in range(epochs): _epoch += 1 idx = torch.randperm(len(dataset)) pos = 0 logger.reset() if not self.natural: self.model.train() self.mark_generator.train() header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) for data in logger.log_every(loader, header=header) if verbose else loader: if not self.natural: optimizer.zero_grad() mark_optimizer.zero_grad() _input, _label = self.model.get_data(data) batch_size = len(_input) data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size]) _input2, _label2 = self.model.get_data(data2) pos += batch_size final_input, final_label = _input.clone(), _label.clone() # generate trigger input trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent) trigger_int = int(trigger_int) if random.uniform(0, 1) < trigger_dec: trigger_int += 1 x = _input[:trigger_int] trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x) trigger_input = x + trigger_mask * (trigger_mark - x) final_input[:trigger_int] = trigger_input final_label[:trigger_int] = self.target_class # generate cross input cross_dec, cross_int = math.modf(len(_label) * self.cross_percent) cross_int = int(cross_int) if random.uniform(0, 1) < cross_dec: cross_int += 1 x = _input[trigger_int:trigger_int + cross_int] x2 = _input2[trigger_int:trigger_int + cross_int] cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2) cross_input = x + cross_mask * (cross_mark - x) final_input[trigger_int:trigger_int + cross_int] = cross_input # div loss if len(trigger_input) <= len(cross_input): length = len(trigger_input) cross_input = cross_input[:length] cross_mark = cross_mark[:length] cross_mask = cross_mask[:length] else: length = len(cross_input) trigger_input = trigger_input[:length] trigger_mark = trigger_mark[:length] trigger_mask = trigger_mask[:length] input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1) mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5 loss_ce = self.model.loss(final_input, final_label) loss_div = input_dist.div(mark_dist).mean() loss = loss_ce + self.lambda_div * loss_div loss.backward() if not self.natural: optimizer.step() mark_optimizer.step() logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item()) if not self.natural and lr_scheduler: lr_scheduler.step() mark_scheduler.step() if not self.natural: self.model.eval() self.mark_generator.eval() if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs): validate_result = self.validate_fn(verbose=verbose) cur_asr = validate_result[0] if cur_asr >= best_asr: best_validate_result = validate_result best_asr = cur_asr if save: self.save() if not self.natural: optimizer.zero_grad() mark_optimizer.zero_grad() self.mark_generator.requires_grad_(False) self.mask_generator.requires_grad_(False) self.model.requires_grad_(False) return best_validate_result
def train(module: nn.Module, num_classes: int, epochs: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, lr_warmup_epochs: int = 0, model_ema: ExponentialMovingAverage = None, model_ema_steps: int = 32, grad_clip: float = None, pre_conditioner: None | KFAC | EKFAC = None, print_prefix: str = 'Train', start_epoch: int = 0, resume: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, forward_fn: Callable[..., torch.Tensor] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', accuracy_fn: Callable[..., list[float]] = None, verbose: bool = True, output_freq: str = 'iter', indent: int = 0, change_train_eval: bool = True, lr_scheduler_freq: str = 'epoch', backward_and_step: bool = True, **kwargs): r"""Train the model""" if epochs <= 0: return get_data_fn = get_data_fn or (lambda x: x) forward_fn = forward_fn or module.__call__ loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy( _output or forward_fn(_input), _label)) validate_fn = validate_fn or validate accuracy_fn = accuracy_fn or accuracy scaler: torch.cuda.amp.GradScaler = None if not env['num_gpus']: amp = False if amp: scaler = torch.cuda.amp.GradScaler() best_validate_result = (0.0, float('inf')) if validate_interval != 0: best_validate_result = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, forward_fn=forward_fn, loss_fn=loss_fn, writer=None, tag=tag, _epoch=start_epoch, verbose=verbose, indent=indent, **kwargs) best_acc = best_validate_result[0] params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) len_loader_train = len(loader_train) total_iter = (epochs - resume) * len_loader_train logger = MetricLogger() logger.create_meters(loss=None, top1=None, top5=None) if resume and lr_scheduler: for _ in range(resume): lr_scheduler.step() iterator = range(resume, epochs) if verbose and output_freq == 'epoch': header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi) header = header.ljust(max(len(header), 30) + get_ansi_len(header)) iterator = logger.log_every(range(resume, epochs), header=print_prefix, tqdm_header='Epoch', indent=indent) for _epoch in iterator: _epoch += 1 logger.reset() if callable(epoch_fn): activate_params(module, []) epoch_fn(optimizer=optimizer, lr_scheduler=lr_scheduler, _epoch=_epoch, epochs=epochs, start_epoch=start_epoch) loader_epoch = loader_train if verbose and output_freq == 'iter': header: str = '{blue_light}{0}: {1}{reset}'.format( 'Epoch', output_iter(_epoch, epochs), **ansi) header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header)) loader_epoch = logger.log_every(loader_train, header=header, tqdm_header='Batch', indent=indent) if change_train_eval: module.train() activate_params(module, params) for i, data in enumerate(loader_epoch): _iter = _epoch * len_loader_train + i # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') if pre_conditioner is not None and not amp: pre_conditioner.track.enable() _output = forward_fn(_input, amp=amp, parallel=True) loss = loss_fn(_input, _label, _output=_output, amp=amp) if backward_and_step: optimizer.zero_grad() if amp: scaler.scale(loss).backward() if callable(after_loss_fn) or grad_clip is not None: scaler.unscale_(optimizer) if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) if grad_clip is not None: nn.utils.clip_grad_norm_(params, grad_clip) scaler.step(optimizer) scaler.update() else: loss.backward() if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs) if pre_conditioner is not None: pre_conditioner.track.disable() pre_conditioner.step() if grad_clip is not None: nn.utils.clip_grad_norm_(params, grad_clip) optimizer.step() if model_ema and i % model_ema_steps == 0: model_ema.update_parameters(module) if _epoch <= lr_warmup_epochs: # Reset ema buffer to keep copying weights # during warmup period model_ema.n_averaged.fill_(0) if lr_scheduler and lr_scheduler_freq == 'iter': lr_scheduler.step() acc1, acc5 = accuracy_fn(_output, _label, num_classes=num_classes, topk=(1, 5)) batch_size = int(_label.size(0)) logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5) empty_cache() optimizer.zero_grad() if lr_scheduler and lr_scheduler_freq == 'epoch': lr_scheduler.step() if change_train_eval: module.eval() activate_params(module, []) loss, acc = (logger.meters['loss'].global_avg, logger.meters['top1'].global_avg) if writer is not None: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch + start_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch + start_epoch) if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs): validate_result = validate_fn(module=module, num_classes=num_classes, loader=loader_valid, get_data_fn=get_data_fn, forward_fn=forward_fn, loss_fn=loss_fn, writer=writer, tag=tag, _epoch=_epoch + start_epoch, verbose=verbose, indent=indent, **kwargs) cur_acc = validate_result[0] if cur_acc >= best_acc: best_validate_result = validate_result if verbose: prints('{purple}best result update!{reset}'.format(**ansi), indent=indent) prints( f'Current Acc: {cur_acc:.3f} ' f'Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: prints('-' * 50, indent=indent) module.zero_grad() return best_validate_result
def _train(self, epoch: int, optimizer: Optimizer, lr_scheduler: _LRScheduler = None, grad_clip: float = None, print_prefix: str = 'Epoch', start_epoch: int = 0, validate_interval: int = 10, save: bool = False, amp: bool = False, loader_train: torch.utils.data.DataLoader = None, loader_valid: torch.utils.data.DataLoader = None, epoch_fn: Callable[..., None] = None, get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None, loss_fn: Callable[..., torch.Tensor] = None, after_loss_fn: Callable[..., None] = None, validate_fn: Callable[..., tuple[float, float]] = None, save_fn: Callable[..., None] = None, file_path: str = None, folder_path: str = None, suffix: str = None, writer=None, main_tag: str = 'train', tag: str = '', verbose: bool = True, indent: int = 0, **kwargs): loader_train = loader_train if loader_train is not None else self.dataset.loader[ 'train'] get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data loss_fn = loss_fn if callable(loss_fn) else self.loss validate_fn = validate_fn if callable(validate_fn) else self._validate save_fn = save_fn if callable(save_fn) else self.save # if not callable(iter_fn) and hasattr(self, 'iter_fn'): # iter_fn = getattr(self, 'iter_fn') if not callable(epoch_fn) and hasattr(self, 'epoch_fn'): epoch_fn = getattr(self, 'epoch_fn') if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'): after_loss_fn = getattr(self, 'after_loss_fn') scaler: torch.cuda.amp.GradScaler = None if not env['num_gpus']: amp = False if amp: scaler = torch.cuda.amp.GradScaler() _, best_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=None, tag=tag, _epoch=start_epoch, verbose=verbose, indent=indent, **kwargs) params: list[nn.Parameter] = [] for param_group in optimizer.param_groups: params.extend(param_group['params']) total_iter = epoch * len(loader_train) for _epoch in range(epoch): _epoch += 1 if callable(epoch_fn): self.activate_params([]) epoch_fn(optimizer=optimizer, lr_scheduler=lr_scheduler, _epoch=_epoch, epoch=epoch, start_epoch=start_epoch) self.activate_params(params) logger = MetricLogger() logger.meters['loss'] = SmoothedValue() logger.meters['top1'] = SmoothedValue() logger.meters['top5'] = SmoothedValue() loader_epoch = loader_train if verbose: header = '{blue_light}{0}: {1}{reset}'.format( print_prefix, output_iter(_epoch, epoch), **ansi) header = header.ljust(30 + get_ansi_len(header)) if env['tqdm']: header = '{upline}{clear_line}'.format(**ansi) + header loader_epoch = tqdm(loader_epoch) loader_epoch = logger.log_every(loader_epoch, header=header, indent=indent) self.train() self.activate_params(params) optimizer.zero_grad() for i, data in enumerate(loader_epoch): _iter = _epoch * len(loader_train) + i # data_time.update(time.perf_counter() - end) _input, _label = get_data_fn(data, mode='train') _output = self(_input, amp=amp) loss = loss_fn(_input, _label, _output=_output, amp=amp) if amp: scaler.scale(loss).backward() if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) scaler.step(optimizer) scaler.update() else: loss.backward() if grad_clip is not None: nn.utils.clip_grad_norm_(params) if callable(after_loss_fn): after_loss_fn(_input=_input, _label=_label, _output=_output, loss=loss, optimizer=optimizer, loss_fn=loss_fn, amp=amp, scaler=scaler, _iter=_iter, total_iter=total_iter) # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch) optimizer.step() optimizer.zero_grad() acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5)) batch_size = int(_label.size(0)) logger.meters['loss'].update(float(loss), batch_size) logger.meters['top1'].update(acc1, batch_size) logger.meters['top5'].update(acc5, batch_size) empty_cache( ) # TODO: should it be outside of the dataloader loop? self.eval() self.activate_params([]) loss, acc = logger.meters['loss'].global_avg, logger.meters[ 'top1'].global_avg if writer is not None: from torch.utils.tensorboard import SummaryWriter assert isinstance(writer, SummaryWriter) writer.add_scalars(main_tag='Loss/' + main_tag, tag_scalar_dict={tag: loss}, global_step=_epoch + start_epoch) writer.add_scalars(main_tag='Acc/' + main_tag, tag_scalar_dict={tag: acc}, global_step=_epoch + start_epoch) if lr_scheduler: lr_scheduler.step() if validate_interval != 0: if _epoch % validate_interval == 0 or _epoch == epoch: _, cur_acc = validate_fn(loader=loader_valid, get_data_fn=get_data_fn, loss_fn=loss_fn, writer=writer, tag=tag, _epoch=_epoch + start_epoch, verbose=verbose, indent=indent, **kwargs) if cur_acc >= best_acc: if verbose: prints('{green}best result update!{reset}'.format( **ansi), indent=indent) prints( f'Current Acc: {cur_acc:.3f} Previous Best Acc: {best_acc:.3f}', indent=indent) best_acc = cur_acc if save: save_fn(file_path=file_path, folder_path=folder_path, suffix=suffix, verbose=verbose) if verbose: prints('-' * 50, indent=indent) self.zero_grad()