def training_step(self, batch, batch_idx):
        x, t = batch

        if self.attack_class:  # apply adversarial training
            # NOTE: pytorch lightning automatically decide device. so we have to instantiation attack here
            if self.attack is None:
                self.attack = self.attack_class(device=self.device)
            x = self.attack(self.model, x, t)
            self.model.zero_grad()  # NOTE without zero_grad leads nan gradient

        if self.mixaugment:  # apply mix augmentation
            loss, retdict = self.mixaugment(self.model, self.criterion, x, t)
        else:
            output = self.model(x)
            loss = self.criterion(output, t)
            retdict = dict(x=x.detach(),
                           output=output.detach(),
                           loss=loss.detach())

        # save sample input
        if batch_idx == 1:
            torchvision.utils.save_image(retdict["x"][:32],
                                         "train_img_sample.png")

        # calculate error and create log dict.
        err1, err5 = calc_error(retdict["output"], t.detach(), topk=(1, 5))
        log = dict(train_loss=retdict["loss"],
                   train_err1=err1,
                   train_err5=err5)
        return dict(loss=loss, log=log)  # need to return loss for backward.
    def validation_step(self, batch, batch_idx):
        x, t = batch

        output = self.model(x)
        loss = self.criterion(output, t)
        retdict = dict(x=x.detach(),
                       output=output.detach(),
                       loss=loss.detach())

        # calculate error and create log dict.
        err1, err5 = calc_error(retdict["output"], t.detach(), topk=(1, 5))
        log = dict(val_loss=retdict["loss"], val_err1=err1, val_err5=err5)
        return log  # no need to return loss.
def calc_mean_error(model: torch.nn.Module,
                    loader: torch.utils.data.DataLoader,
                    device: str) -> Tuple[float]:
    """
    Calcurate top1 and top5 error for given model and dataset.
    """
    err1_list, err5_list = list(), list()
    with torch.no_grad():
        for x, t in loader:
            x, t = x.to(device), t.to(device)
            output = model(x)
            err1, err5 = shared.calc_error(output, t, topk=(1, 5))

            err1_list.append(err1.item())
            err5_list.append(err5.item())

    mean_err1 = sum(err1_list) / len(err1_list)
    mean_err5 = sum(err5_list) / len(err5_list)
    return mean_err1, mean_err5
Esempio n. 4
0
    def test_valid_input(self):
        # top-1 (all correct)
        output = torch.zeros(16, 10)
        output[:, 0] = 1.0
        target = torch.zeros(16)
        assert calc_error(output, target, topk=(1,))[0].equal(
            torch.Tensor([0.0]).float()
        )

        # top-1 (all wrong)
        output = torch.zeros(16, 10)
        output[:, 0] = 1.0
        target = torch.ones(16)
        assert calc_error(output, target, topk=(1,))[0].equal(
            torch.Tensor([100.0]).float()
        )

        # top-1 (all correct) and top-5 (all correct)
        output = torch.zeros(16, 10)
        output[:, 0] = 1.0
        output[:, 1] = 0.1
        target = torch.zeros(16)
        assert calc_error(output, target, topk=(1, 5))[0].equal(
            torch.Tensor([0.0]).float()
        )
        assert calc_error(output, target, topk=(1, 5))[1].equal(
            torch.Tensor([0.0]).float()
        )

        # top-1 (all wrong) and top-5 (all correct)
        output = torch.zeros(16, 10)
        output[:, 0] = 1.0
        output[:, 1] = 0.1
        target = torch.ones(16)
        assert calc_error(output, target, topk=(1, 5))[0].equal(
            torch.Tensor([100.0]).float()
        )
        assert calc_error(output, target, topk=(1, 5))[1].equal(
            torch.Tensor([0.0]).float()
        )
Esempio n. 5
0
    loader = torch.utils.data.DataLoader(dataset,
                                         32,
                                         shuffle=False,
                                         num_workers=8)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    weightpath = "../testdata/weight_cifar10_wideresnet40_100ep.pth"
    model = shared.get_model(name="wideresnet40", num_classes=10)
    shared.load_model(model, weightpath)
    model = model.to(device)
    model.eval()

    for x, t in loader:
        x, t = x.to(device), t.to(device)
        output = model(x)
        err1 = shared.calc_error(output.detach(), t, topk=(1, ))

        attack = PgdAttack(
            input_size,
            mean,
            std,
            num_iteration,
            eps_max,
            step_size,
            norm,
            rand_init,
            scale_eps,
            scale_each,
            avoid_target,
            criterion,
            device,
Esempio n. 6
0
def calc_mean_error(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    normalizer,
    denormalizer,
    device: str,
    bandwidth: int,
    filter_mode: str,
    eps: float,
) -> Tuple[float]:
    """
    Calcurate top1 and top5 error for given model and dataset.
    """
    err1_list, err5_list = list(), list()
    with torch.no_grad():
        for i, (x, t) in enumerate(loader):
            x, t = x.to(device), t.to(device)

            scale_r = random.uniform(0, 1)
            scale_g = random.uniform(0, 1)
            scale_b = random.uniform(0, 1)
            gaussian_r = torch.normal(mean=0.0,
                                      std=scale_r,
                                      size=[x.size(0),
                                            x.size(-2),
                                            x.size(-1)]).to(device)
            gaussian_g = torch.normal(mean=0.0,
                                      std=scale_g,
                                      size=[x.size(0),
                                            x.size(-2),
                                            x.size(-1)]).to(device)
            gaussian_b = torch.normal(mean=0.0,
                                      std=scale_b,
                                      size=[x.size(0),
                                            x.size(-2),
                                            x.size(-1)]).to(device)
            gaussian = torch.stack([gaussian_r, gaussian_g, gaussian_b], dim=1)

            if 0.0 < bandwidth < x.size(-1):
                # scale = random.uniform(0, 1)
                # gaussian = torch.normal(mean=0.0, std=scale, size=list(x.size())).to(device)
                bandpassed_gaussian, w = fourier.bandpass_filter(
                    gaussian, bandwidth, filter_mode, eps)
                x = torch.clamp(
                    denormalizer(x) + bandpassed_gaussian, 0.0, 1.0)
                x = normalizer(x)
            elif bandwidth == x.size(-1):
                # scale = random.uniform(0, 1)
                # gaussian = torch.normal(mean=0.0, std=scale, size=list(x.size())).to(device)

                norms_r = gaussian[:, 0, :, :].view(gaussian.size(0),
                                                    -1).norm(dim=-1)  # (B)
                norms_g = gaussian[:, 1, :, :].view(gaussian.size(0),
                                                    -1).norm(dim=-1)  # (B)
                norms_b = gaussian[:, 2, :, :].view(gaussian.size(0),
                                                    -1).norm(dim=-1)  # (B)

                gaussian[:, 0, :, :] /= norms_r[:, None, None]
                gaussian[:, 1, :, :] /= norms_g[:, None, None]
                gaussian[:, 2, :, :] /= norms_b[:, None, None]
                gaussian *= eps

                x = torch.clamp(denormalizer(x) + gaussian, 0.0, 1.0)
                x = normalizer(x)

            output = model(x)
            err1, err5 = shared.calc_error(output, t, topk=(1, 5))

            err1_list.append(err1.item())
            err5_list.append(err5.item())

            if (bandwidth == 0) or (bandwidth == x.size(-1)):
                x_sample = None
            elif i == 0:
                x_sample = torch.cat(
                    [denormalizer(x), gaussian, bandpassed_gaussian, w],
                    dim=-2)[0:16]

    mean_err1 = sum(err1_list) / len(err1_list)
    mean_err5 = sum(err5_list) / len(err5_list)
    return mean_err1, mean_err5, x_sample