Exemplo n.º 1
0
def compare(module1: nn.Module,
            module2: nn.Module,
            loader: torch.utils.data.DataLoader,
            print_prefix='Validate',
            indent=0,
            verbose=True,
            get_data_fn: Callable[..., tuple[torch.Tensor,
                                             torch.Tensor]] = None,
            criterion: Callable[[torch.Tensor, torch.Tensor],
                                torch.Tensor] = nn.CrossEntropyLoss(),
            **kwargs) -> float:
    module1.eval()
    module2.eval()
    get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x

    logger = MetricLogger()
    logger.create_meters(loss=None)
    loader_epoch = loader
    if verbose:
        header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(
            max(len(print_prefix), 30) + get_ansi_len(header))
        if env['tqdm']:
            loader_epoch = tqdm(loader_epoch, leave=False)
        loader_epoch = logger.log_every(loader_epoch,
                                        header=header,
                                        indent=indent)
    for data in loader_epoch:
        _input, _label = get_data_fn(data, **kwargs)
        _output1: torch.Tensor = module1(_input)
        _output2: torch.Tensor = module2(_input)
        loss = criterion(_output1, _output2.softmax(1)).item()
        batch_size = int(_label.size(0))
        logger.update(n=batch_size, loss=loss)
    return logger.meters['loss'].global_avg
Exemplo n.º 2
0
    def validate_mask_generator(self):
        loader = self.dataset.loader['valid']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, norm=None)
        idx = torch.randperm(len(dataset))
        pos = 0

        print_prefix = 'Validate'
        header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header))
        for data in logger.log_every(loader, header=header):
            _input, _label = self.model.get_data(data)
            batch_size = len(_input)
            data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
            _input2, _label2 = self.model.get_data(data2)
            pos += batch_size

            _mask = self.get_mask(_input)
            _mask2 = self.get_mask(_input2)

            input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1)
            mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5

            loss_div = input_dist.div(mask_dist).mean()
            loss_norm = _mask.sub(self.mask_density).relu().mean()

            loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div
            logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
Exemplo n.º 3
0
def validate(module: nn.Module,
             num_classes: int,
             loader: torch.utils.data.DataLoader,
             print_prefix: str = 'Validate',
             indent: int = 0,
             verbose: bool = True,
             get_data_fn: Callable[..., tuple[torch.Tensor,
                                              torch.Tensor]] = None,
             forward_fn: Callable[..., torch.Tensor] = None,
             loss_fn: Callable[..., torch.Tensor] = None,
             writer=None,
             main_tag: str = 'valid',
             tag: str = '',
             _epoch: int = None,
             accuracy_fn: Callable[..., list[float]] = None,
             **kwargs) -> tuple[float, float]:
    r"""Evaluate the model.

    Returns:
        (float, float): Accuracy and loss.
    """
    module.eval()
    get_data_fn = get_data_fn or (lambda x: x)
    forward_fn = forward_fn or module.__call__
    loss_fn = loss_fn or nn.CrossEntropyLoss()
    accuracy_fn = accuracy_fn or accuracy
    logger = MetricLogger()
    logger.create_meters(loss=None, top1=None, top5=None)
    loader_epoch = loader
    if verbose:
        header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(
            max(len(print_prefix), 30) + get_ansi_len(header))
        loader_epoch = logger.log_every(loader,
                                        header=header,
                                        tqdm_header='Batch',
                                        indent=indent)
    for data in loader_epoch:
        _input, _label = get_data_fn(data, mode='valid', **kwargs)
        with torch.no_grad():
            _output = forward_fn(_input)
            loss = float(loss_fn(_input, _label, _output=_output, **kwargs))
            acc1, acc5 = accuracy_fn(_output,
                                     _label,
                                     num_classes=num_classes,
                                     topk=(1, 5))
            batch_size = int(_label.size(0))
            logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5)
    acc, loss = (logger.meters['top1'].global_avg,
                 logger.meters['loss'].global_avg)
    if writer is not None and _epoch is not None and main_tag:
        from torch.utils.tensorboard import SummaryWriter
        assert isinstance(writer, SummaryWriter)
        writer.add_scalars(main_tag='Acc/' + main_tag,
                           tag_scalar_dict={tag: acc},
                           global_step=_epoch)
        writer.add_scalars(main_tag='Loss/' + main_tag,
                           tag_scalar_dict={tag: loss},
                           global_step=_epoch)
    return acc, loss
Exemplo n.º 4
0
 def _validate(self,
               full=True,
               print_prefix='Validate',
               indent=0,
               verbose=True,
               loader: torch.utils.data.DataLoader = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               writer=None,
               main_tag: str = 'valid',
               tag: str = '',
               _epoch: int = None,
               **kwargs) -> tuple[float, float]:
     self.eval()
     if loader is None:
         loader = self.dataset.loader[
             'valid'] if full else self.dataset.loader['valid2']
     get_data_fn = get_data_fn if get_data_fn is not None else self.get_data
     loss_fn = loss_fn if loss_fn is not None else self.loss
     logger = MetricLogger()
     logger.meters['loss'] = SmoothedValue()
     logger.meters['top1'] = SmoothedValue()
     logger.meters['top5'] = SmoothedValue()
     loader_epoch = loader
     if verbose:
         header = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
         header = header.ljust(
             max(len(print_prefix), 30) + get_ansi_len(header))
         if env['tqdm']:
             header = '{upline}{clear_line}'.format(**ansi) + header
             loader_epoch = tqdm(loader_epoch)
         loader_epoch = logger.log_every(loader_epoch,
                                         header=header,
                                         indent=indent)
     for data in loader_epoch:
         _input, _label = get_data_fn(data, mode='valid', **kwargs)
         with torch.no_grad():
             _output = self(_input)
             loss = float(loss_fn(_input, _label, _output=_output,
                                  **kwargs))
             acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
             batch_size = int(_label.size(0))
             logger.meters['loss'].update(loss, batch_size)
             logger.meters['top1'].update(acc1, batch_size)
             logger.meters['top5'].update(acc5, batch_size)
     loss, acc = logger.meters['loss'].global_avg, logger.meters[
         'top1'].global_avg
     if writer is not None and _epoch is not None and main_tag:
         from torch.utils.tensorboard import SummaryWriter
         assert isinstance(writer, SummaryWriter)
         writer.add_scalars(main_tag='Loss/' + main_tag,
                            tag_scalar_dict={tag: loss},
                            global_step=_epoch)
         writer.add_scalars(main_tag='Acc/' + main_tag,
                            tag_scalar_dict={tag: acc},
                            global_step=_epoch)
     return loss, acc
Exemplo n.º 5
0
    def train_mask_generator(self, verbose: bool = True):
        r"""Train :attr:`self.mask_generator`."""
        optimizer = torch.optim.Adam(self.mask_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.train_mask_epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, norm=None)
        print_prefix = 'Mask Epoch'
        for _epoch in range(self.train_mask_epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                print_prefix, output_iter(_epoch, self.train_mask_epochs), **ansi)
            header = header.ljust(max(len(print_prefix), 30) + get_ansi_len(header))
            self.mask_generator.train()
            for data in logger.log_every(loader, header=header) if verbose else loader:
                optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size

                _mask = self.get_mask(_input)
                _mask2 = self.get_mask(_input2)

                input_dist: torch.Tensor = (_input - _input2).flatten(1).norm(p=2, dim=1)
                mask_dist: torch.Tensor = (_mask - _mask2).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_div = input_dist.div(mask_dist).mean()
                loss_norm = _mask.sub(self.mask_density).relu().mean()

                loss = self.lambda_norm * loss_norm + self.lambda_div * loss_div
                loss.backward()
                optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), norm=loss_norm.item())
            lr_scheduler.step()
            self.mask_generator.eval()
            if verbose and (_epoch % (max(self.train_mask_epochs // 5, 1)) == 0 or _epoch == self.train_mask_epochs):
                self.validate_mask_generator()
        optimizer.zero_grad()
Exemplo n.º 6
0
def compare(module1: nn.Module,
            module2: nn.Module,
            loader: torch.utils.data.DataLoader,
            print_prefix='Validate',
            indent=0,
            verbose=True,
            get_data_fn: Callable[..., tuple[torch.Tensor,
                                             torch.Tensor]] = None,
            **kwargs) -> float:
    logsoftmax = nn.LogSoftmax(dim=1)
    softmax = nn.Softmax(dim=1)
    module1.eval()
    module2.eval()
    get_data_fn = get_data_fn if get_data_fn is not None else lambda x: x

    def cross_entropy(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        result: torch.Tensor = -softmax(p) * logsoftmax(q)
        return result.sum(1).mean()

    logger = MetricLogger()
    logger.meters['loss'] = SmoothedValue()
    loader_epoch = loader
    if verbose:
        header = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
        header = header.ljust(
            max(len(print_prefix), 30) + get_ansi_len(header))
        if env['tqdm']:
            header = '{upline}{clear_line}'.format(**ansi) + header
            loader_epoch = tqdm(loader_epoch)
        loader_epoch = logger.log_every(loader_epoch,
                                        header=header,
                                        indent=indent)
    with torch.no_grad():
        for data in loader_epoch:
            _input, _label = get_data_fn(data, **kwargs)
            _output1, _output2 = module1(_input), module2(_input)
            loss = float(cross_entropy(_output1, _output2))
            batch_size = int(_label.size(0))
            logger.meters['loss'].update(loss, batch_size)
    return logger.meters['loss'].global_avg
Exemplo n.º 7
0
    def attack(self, epochs: int, save: bool = False, **kwargs):
        # train generator
        resnet_model = trojanvision.models.create('resnet18_comp',
                                                  dataset=self.dataset,
                                                  pretrained=True)
        model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4])
        match self.generator_mode:
            case 'resnet':
                resnet_model = torchvision.models.resnet18(pretrained=True).to(device=env['device'])
                model_extractor = nn.Sequential(*list(resnet_model.children())[:5])
            case _:
                resnet_model = trojanvision.models.create('resnet18_comp',
                                                          dataset=self.dataset,
                                                          pretrained=True)
                model_extractor = nn.Sequential(*list(resnet_model._model.features.children())[:4])
        model_extractor.requires_grad_(False)
        model_extractor.train()

        self.generator.train()
        if self.generator_mode == 'default':
            self.generator.requires_grad_()
            parameters = self.generator.parameters()
        else:
            self.generator.bottleneck.requires_grad_()
            self.generator.decoder.requires_grad_()
            parameters = self.generator[1:].parameters()
        optimizer = torch.optim.Adam(parameters, lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.train_generator_epochs,
            eta_min=1e-5)

        logger = MetricLogger()
        logger.create_meters(loss=None, acc=None)
        for _epoch in range(self.train_generator_epochs):
            _epoch += 1
            logger.reset()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, self.train_generator_epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader = logger.log_every(self.dataset.loader['train'], header=header,
                                      tqdm_header='Batch')
            for data in loader:
                optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                adv_input = (self.generator(_input) + 1) / 2

                _feats = model_extractor(_input)
                adv_feats = model_extractor(adv_input)

                loss = F.l1_loss(_feats, self.noise_coeff * adv_feats)
                loss.backward()
                optimizer.step()
                batch_size = len(_label)
                with torch.no_grad():
                    org_class = self.model.get_class(_input)
                    adv_class = self.model.get_class(adv_input)
                    acc = (org_class == adv_class).float().sum().item() * 100.0 / batch_size
                logger.update(n=batch_size, loss=loss.item(), acc=acc)
            lr_scheduler.step()
        self.save_generator()
        optimizer.zero_grad()
        self.generator.eval()
        self.generator.requires_grad_(False)

        self.mark.mark[:-1] = (self.generator(self.mark.mark[:-1].unsqueeze(0))[0] + 1) / 2
        self.poison_set = self.get_poison_dataset(load_mark=False)
        return super().attack(epochs, save=save, **kwargs)
Exemplo n.º 8
0
    def attack(self, epochs: int, optimizer: torch.optim.Optimizer,
               lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
               validate_interval: int = 1, save: bool = False,
               verbose: bool = True, **kwargs):
        if verbose:
            print('train mask generator')
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_()
        self.model.requires_grad_(False)
        self.train_mask_generator(verbose=verbose)
        if verbose:
            print()
            print('train mark generator and model')

        self.mark_generator.requires_grad_()
        self.mask_generator.requires_grad_(False)
        if not self.natural:
            params: list[nn.Parameter] = []
            for param_group in optimizer.param_groups:
                params.extend(param_group['params'])
            self.model.activate_params(params)

        mark_optimizer = torch.optim.Adam(self.mark_generator.parameters(), lr=1e-2, betas=(0.5, 0.9))
        mark_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            mark_optimizer, T_max=epochs)
        loader = self.dataset.loader['train']
        dataset = loader.dataset
        logger = MetricLogger()
        logger.create_meters(loss=None, div=None, ce=None)

        if validate_interval != 0:
            best_validate_result = self.validate_fn(verbose=verbose)
            best_asr = best_validate_result[0]
        for _epoch in range(epochs):
            _epoch += 1
            idx = torch.randperm(len(dataset))
            pos = 0
            logger.reset()
            if not self.natural:
                self.model.train()
            self.mark_generator.train()
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            for data in logger.log_every(loader, header=header) if verbose else loader:
                if not self.natural:
                    optimizer.zero_grad()
                mark_optimizer.zero_grad()
                _input, _label = self.model.get_data(data)
                batch_size = len(_input)
                data2 = sample_batch(dataset, idx=idx[pos:pos + batch_size])
                _input2, _label2 = self.model.get_data(data2)
                pos += batch_size
                final_input, final_label = _input.clone(), _label.clone()

                # generate trigger input
                trigger_dec, trigger_int = math.modf(len(_label) * self.poison_percent)
                trigger_int = int(trigger_int)
                if random.uniform(0, 1) < trigger_dec:
                    trigger_int += 1
                x = _input[:trigger_int]
                trigger_mark, trigger_mask = self.get_mark(x), self.get_mask(x)
                trigger_input = x + trigger_mask * (trigger_mark - x)
                final_input[:trigger_int] = trigger_input
                final_label[:trigger_int] = self.target_class

                # generate cross input
                cross_dec, cross_int = math.modf(len(_label) * self.cross_percent)
                cross_int = int(cross_int)
                if random.uniform(0, 1) < cross_dec:
                    cross_int += 1
                x = _input[trigger_int:trigger_int + cross_int]
                x2 = _input2[trigger_int:trigger_int + cross_int]
                cross_mark, cross_mask = self.get_mark(x2), self.get_mask(x2)
                cross_input = x + cross_mask * (cross_mark - x)
                final_input[trigger_int:trigger_int + cross_int] = cross_input

                # div loss
                if len(trigger_input) <= len(cross_input):
                    length = len(trigger_input)
                    cross_input = cross_input[:length]
                    cross_mark = cross_mark[:length]
                    cross_mask = cross_mask[:length]
                else:
                    length = len(cross_input)
                    trigger_input = trigger_input[:length]
                    trigger_mark = trigger_mark[:length]
                    trigger_mask = trigger_mask[:length]
                input_dist: torch.Tensor = (trigger_input - cross_input).flatten(1).norm(p=2, dim=1)
                mark_dist: torch.Tensor = (trigger_mark - cross_mark).flatten(1).norm(p=2, dim=1) + 1e-5

                loss_ce = self.model.loss(final_input, final_label)
                loss_div = input_dist.div(mark_dist).mean()

                loss = loss_ce + self.lambda_div * loss_div
                loss.backward()
                if not self.natural:
                    optimizer.step()
                mark_optimizer.step()
                logger.update(n=batch_size, loss=loss.item(), div=loss_div.item(), ce=loss_ce.item())
            if not self.natural and lr_scheduler:
                lr_scheduler.step()
            mark_scheduler.step()
            if not self.natural:
                self.model.eval()
            self.mark_generator.eval()
            if validate_interval != 0 and (_epoch % validate_interval == 0 or _epoch == epochs):
                validate_result = self.validate_fn(verbose=verbose)
                cur_asr = validate_result[0]
                if cur_asr >= best_asr:
                    best_validate_result = validate_result
                    best_asr = cur_asr
                    if save:
                        self.save()
        if not self.natural:
            optimizer.zero_grad()
        mark_optimizer.zero_grad()
        self.mark_generator.requires_grad_(False)
        self.mask_generator.requires_grad_(False)
        self.model.requires_grad_(False)
        return best_validate_result
Exemplo n.º 9
0
def train(module: nn.Module,
          num_classes: int,
          epochs: int,
          optimizer: Optimizer,
          lr_scheduler: _LRScheduler = None,
          lr_warmup_epochs: int = 0,
          model_ema: ExponentialMovingAverage = None,
          model_ema_steps: int = 32,
          grad_clip: float = None,
          pre_conditioner: None | KFAC | EKFAC = None,
          print_prefix: str = 'Train',
          start_epoch: int = 0,
          resume: int = 0,
          validate_interval: int = 10,
          save: bool = False,
          amp: bool = False,
          loader_train: torch.utils.data.DataLoader = None,
          loader_valid: torch.utils.data.DataLoader = None,
          epoch_fn: Callable[..., None] = None,
          get_data_fn: Callable[..., tuple[torch.Tensor, torch.Tensor]] = None,
          forward_fn: Callable[..., torch.Tensor] = None,
          loss_fn: Callable[..., torch.Tensor] = None,
          after_loss_fn: Callable[..., None] = None,
          validate_fn: Callable[..., tuple[float, float]] = None,
          save_fn: Callable[..., None] = None,
          file_path: str = None,
          folder_path: str = None,
          suffix: str = None,
          writer=None,
          main_tag: str = 'train',
          tag: str = '',
          accuracy_fn: Callable[..., list[float]] = None,
          verbose: bool = True,
          output_freq: str = 'iter',
          indent: int = 0,
          change_train_eval: bool = True,
          lr_scheduler_freq: str = 'epoch',
          backward_and_step: bool = True,
          **kwargs):
    r"""Train the model"""
    if epochs <= 0:
        return
    get_data_fn = get_data_fn or (lambda x: x)
    forward_fn = forward_fn or module.__call__
    loss_fn = loss_fn or (lambda _input, _label, _output=None: F.cross_entropy(
        _output or forward_fn(_input), _label))
    validate_fn = validate_fn or validate
    accuracy_fn = accuracy_fn or accuracy

    scaler: torch.cuda.amp.GradScaler = None
    if not env['num_gpus']:
        amp = False
    if amp:
        scaler = torch.cuda.amp.GradScaler()
    best_validate_result = (0.0, float('inf'))
    if validate_interval != 0:
        best_validate_result = validate_fn(loader=loader_valid,
                                           get_data_fn=get_data_fn,
                                           forward_fn=forward_fn,
                                           loss_fn=loss_fn,
                                           writer=None,
                                           tag=tag,
                                           _epoch=start_epoch,
                                           verbose=verbose,
                                           indent=indent,
                                           **kwargs)
        best_acc = best_validate_result[0]

    params: list[nn.Parameter] = []
    for param_group in optimizer.param_groups:
        params.extend(param_group['params'])
    len_loader_train = len(loader_train)
    total_iter = (epochs - resume) * len_loader_train

    logger = MetricLogger()
    logger.create_meters(loss=None, top1=None, top5=None)

    if resume and lr_scheduler:
        for _ in range(resume):
            lr_scheduler.step()
    iterator = range(resume, epochs)
    if verbose and output_freq == 'epoch':
        header: str = '{blue_light}{0}: {reset}'.format(print_prefix, **ansi)
        header = header.ljust(max(len(header), 30) + get_ansi_len(header))
        iterator = logger.log_every(range(resume, epochs),
                                    header=print_prefix,
                                    tqdm_header='Epoch',
                                    indent=indent)
    for _epoch in iterator:
        _epoch += 1
        logger.reset()
        if callable(epoch_fn):
            activate_params(module, [])
            epoch_fn(optimizer=optimizer,
                     lr_scheduler=lr_scheduler,
                     _epoch=_epoch,
                     epochs=epochs,
                     start_epoch=start_epoch)
        loader_epoch = loader_train
        if verbose and output_freq == 'iter':
            header: str = '{blue_light}{0}: {1}{reset}'.format(
                'Epoch', output_iter(_epoch, epochs), **ansi)
            header = header.ljust(max(len('Epoch'), 30) + get_ansi_len(header))
            loader_epoch = logger.log_every(loader_train,
                                            header=header,
                                            tqdm_header='Batch',
                                            indent=indent)
        if change_train_eval:
            module.train()
        activate_params(module, params)
        for i, data in enumerate(loader_epoch):
            _iter = _epoch * len_loader_train + i
            # data_time.update(time.perf_counter() - end)
            _input, _label = get_data_fn(data, mode='train')
            if pre_conditioner is not None and not amp:
                pre_conditioner.track.enable()
            _output = forward_fn(_input, amp=amp, parallel=True)
            loss = loss_fn(_input, _label, _output=_output, amp=amp)
            if backward_and_step:
                optimizer.zero_grad()
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn) or grad_clip is not None:
                        scaler.unscale_(optimizer)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epochs=epochs)
                    if pre_conditioner is not None:
                        pre_conditioner.track.disable()
                        pre_conditioner.step()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params, grad_clip)
                    optimizer.step()

            if model_ema and i % model_ema_steps == 0:
                model_ema.update_parameters(module)
                if _epoch <= lr_warmup_epochs:
                    # Reset ema buffer to keep copying weights
                    # during warmup period
                    model_ema.n_averaged.fill_(0)

            if lr_scheduler and lr_scheduler_freq == 'iter':
                lr_scheduler.step()
            acc1, acc5 = accuracy_fn(_output,
                                     _label,
                                     num_classes=num_classes,
                                     topk=(1, 5))
            batch_size = int(_label.size(0))
            logger.update(n=batch_size, loss=float(loss), top1=acc1, top5=acc5)
            empty_cache()
        optimizer.zero_grad()
        if lr_scheduler and lr_scheduler_freq == 'epoch':
            lr_scheduler.step()
        if change_train_eval:
            module.eval()
        activate_params(module, [])
        loss, acc = (logger.meters['loss'].global_avg,
                     logger.meters['top1'].global_avg)
        if writer is not None:
            from torch.utils.tensorboard import SummaryWriter
            assert isinstance(writer, SummaryWriter)
            writer.add_scalars(main_tag='Loss/' + main_tag,
                               tag_scalar_dict={tag: loss},
                               global_step=_epoch + start_epoch)
            writer.add_scalars(main_tag='Acc/' + main_tag,
                               tag_scalar_dict={tag: acc},
                               global_step=_epoch + start_epoch)
        if validate_interval != 0 and (_epoch % validate_interval == 0
                                       or _epoch == epochs):
            validate_result = validate_fn(module=module,
                                          num_classes=num_classes,
                                          loader=loader_valid,
                                          get_data_fn=get_data_fn,
                                          forward_fn=forward_fn,
                                          loss_fn=loss_fn,
                                          writer=writer,
                                          tag=tag,
                                          _epoch=_epoch + start_epoch,
                                          verbose=verbose,
                                          indent=indent,
                                          **kwargs)
            cur_acc = validate_result[0]
            if cur_acc >= best_acc:
                best_validate_result = validate_result
                if verbose:
                    prints('{purple}best result update!{reset}'.format(**ansi),
                           indent=indent)
                    prints(
                        f'Current Acc: {cur_acc:.3f}    '
                        f'Previous Best Acc: {best_acc:.3f}',
                        indent=indent)
                best_acc = cur_acc
                if save:
                    save_fn(file_path=file_path,
                            folder_path=folder_path,
                            suffix=suffix,
                            verbose=verbose)
            if verbose:
                prints('-' * 50, indent=indent)
    module.zero_grad()
    return best_validate_result
Exemplo n.º 10
0
    def _train(self,
               epoch: int,
               optimizer: Optimizer,
               lr_scheduler: _LRScheduler = None,
               grad_clip: float = None,
               print_prefix: str = 'Epoch',
               start_epoch: int = 0,
               validate_interval: int = 10,
               save: bool = False,
               amp: bool = False,
               loader_train: torch.utils.data.DataLoader = None,
               loader_valid: torch.utils.data.DataLoader = None,
               epoch_fn: Callable[..., None] = None,
               get_data_fn: Callable[..., tuple[torch.Tensor,
                                                torch.Tensor]] = None,
               loss_fn: Callable[..., torch.Tensor] = None,
               after_loss_fn: Callable[..., None] = None,
               validate_fn: Callable[..., tuple[float, float]] = None,
               save_fn: Callable[..., None] = None,
               file_path: str = None,
               folder_path: str = None,
               suffix: str = None,
               writer=None,
               main_tag: str = 'train',
               tag: str = '',
               verbose: bool = True,
               indent: int = 0,
               **kwargs):
        loader_train = loader_train if loader_train is not None else self.dataset.loader[
            'train']
        get_data_fn = get_data_fn if callable(get_data_fn) else self.get_data
        loss_fn = loss_fn if callable(loss_fn) else self.loss
        validate_fn = validate_fn if callable(validate_fn) else self._validate
        save_fn = save_fn if callable(save_fn) else self.save
        # if not callable(iter_fn) and hasattr(self, 'iter_fn'):
        #     iter_fn = getattr(self, 'iter_fn')
        if not callable(epoch_fn) and hasattr(self, 'epoch_fn'):
            epoch_fn = getattr(self, 'epoch_fn')
        if not callable(after_loss_fn) and hasattr(self, 'after_loss_fn'):
            after_loss_fn = getattr(self, 'after_loss_fn')

        scaler: torch.cuda.amp.GradScaler = None
        if not env['num_gpus']:
            amp = False
        if amp:
            scaler = torch.cuda.amp.GradScaler()
        _, best_acc = validate_fn(loader=loader_valid,
                                  get_data_fn=get_data_fn,
                                  loss_fn=loss_fn,
                                  writer=None,
                                  tag=tag,
                                  _epoch=start_epoch,
                                  verbose=verbose,
                                  indent=indent,
                                  **kwargs)

        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        total_iter = epoch * len(loader_train)
        for _epoch in range(epoch):
            _epoch += 1
            if callable(epoch_fn):
                self.activate_params([])
                epoch_fn(optimizer=optimizer,
                         lr_scheduler=lr_scheduler,
                         _epoch=_epoch,
                         epoch=epoch,
                         start_epoch=start_epoch)
                self.activate_params(params)
            logger = MetricLogger()
            logger.meters['loss'] = SmoothedValue()
            logger.meters['top1'] = SmoothedValue()
            logger.meters['top5'] = SmoothedValue()
            loader_epoch = loader_train
            if verbose:
                header = '{blue_light}{0}: {1}{reset}'.format(
                    print_prefix, output_iter(_epoch, epoch), **ansi)
                header = header.ljust(30 + get_ansi_len(header))
                if env['tqdm']:
                    header = '{upline}{clear_line}'.format(**ansi) + header
                    loader_epoch = tqdm(loader_epoch)
                loader_epoch = logger.log_every(loader_epoch,
                                                header=header,
                                                indent=indent)
            self.train()
            self.activate_params(params)
            optimizer.zero_grad()
            for i, data in enumerate(loader_epoch):
                _iter = _epoch * len(loader_train) + i
                # data_time.update(time.perf_counter() - end)
                _input, _label = get_data_fn(data, mode='train')
                _output = self(_input, amp=amp)
                loss = loss_fn(_input, _label, _output=_output, amp=amp)
                if amp:
                    scaler.scale(loss).backward()
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if grad_clip is not None:
                        nn.utils.clip_grad_norm_(params)
                    if callable(after_loss_fn):
                        after_loss_fn(_input=_input,
                                      _label=_label,
                                      _output=_output,
                                      loss=loss,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      amp=amp,
                                      scaler=scaler,
                                      _iter=_iter,
                                      total_iter=total_iter)
                        # start_epoch=start_epoch, _epoch=_epoch, epoch=epoch)
                    optimizer.step()
                optimizer.zero_grad()
                acc1, acc5 = self.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                logger.meters['loss'].update(float(loss), batch_size)
                logger.meters['top1'].update(acc1, batch_size)
                logger.meters['top5'].update(acc5, batch_size)
                empty_cache(
                )  # TODO: should it be outside of the dataloader loop?
            self.eval()
            self.activate_params([])
            loss, acc = logger.meters['loss'].global_avg, logger.meters[
                'top1'].global_avg
            if writer is not None:
                from torch.utils.tensorboard import SummaryWriter
                assert isinstance(writer, SummaryWriter)
                writer.add_scalars(main_tag='Loss/' + main_tag,
                                   tag_scalar_dict={tag: loss},
                                   global_step=_epoch + start_epoch)
                writer.add_scalars(main_tag='Acc/' + main_tag,
                                   tag_scalar_dict={tag: acc},
                                   global_step=_epoch + start_epoch)
            if lr_scheduler:
                lr_scheduler.step()
            if validate_interval != 0:
                if _epoch % validate_interval == 0 or _epoch == epoch:
                    _, cur_acc = validate_fn(loader=loader_valid,
                                             get_data_fn=get_data_fn,
                                             loss_fn=loss_fn,
                                             writer=writer,
                                             tag=tag,
                                             _epoch=_epoch + start_epoch,
                                             verbose=verbose,
                                             indent=indent,
                                             **kwargs)
                    if cur_acc >= best_acc:
                        if verbose:
                            prints('{green}best result update!{reset}'.format(
                                **ansi),
                                   indent=indent)
                            prints(
                                f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}',
                                indent=indent)
                        best_acc = cur_acc
                        if save:
                            save_fn(file_path=file_path,
                                    folder_path=folder_path,
                                    suffix=suffix,
                                    verbose=verbose)
                    if verbose:
                        prints('-' * 50, indent=indent)
        self.zero_grad()