def infer_batch(self, images: torch.Tensor,
                    labels: torch.Tensor) -> Tuple[int, int]:
        images = images.to(model_device(self.model))
        labels = labels.to(model_device(self.model))

        log_probs = self.model(images)
        predictions = torch.argmax(log_probs, dim=1)

        n_correct = torch.sum(labels == predictions).item()
        n_incorrect = labels.size(0) - n_correct

        Tracker.progress()
        return n_correct, n_incorrect
Ejemplo n.º 2
0
    def load_best_epoch(self) -> None:
        with open(self.from_out("checkpoints/best_epoch.txt"), "r") as file:
            epoch = int(file.read())

        checkpoint_path = self.from_out(f"checkpoints/checkpoint_{epoch}.pth")
        checkpoint = torch.load(checkpoint_path,
                                map_location=model_device(self.model))
        self.model.load_state_dict(checkpoint["model_state_dict"])
Ejemplo n.º 3
0
    def run_batch(self, images: torch.Tensor,
                  labels: torch.Tensor) -> List[Dict[str, Any]]:
        images = images.to(model_device(self.model))
        labels = labels.to(model_device(self.model))

        if self.adversary is None:
            output = None
        else:
            output = self.adversary(self.model, images, labels)
            if Tracker.batch == 1:
                for i in range(min(images.size(0), self.visualize_adversary)):
                    debug_img = torch.stack([
                        images[i], output.noises[i] + 0.5,
                        output.perturbed_images[i]
                    ],
                                            dim=0)
                    filepath = self.from_out(
                        f"inference_debug/{self.out_name}/image_{i + 1}.png")
                    torchvision.utils.save_image(debug_img, filepath, nrow=3)
            images = output.perturbed_images

        result = [{"true_label": label.item()} for label in labels]
        if isinstance(self.model, ResNet_Gaussian):
            gaussian_output = self.model(images, GaussianMode.BOTH)
            for im_result, im_log_probs, im_log_likelihoods in zip(
                    result, gaussian_output.log_posteriors,
                    gaussian_output.log_likelihoods):
                im_result["log_probs"] = im_log_probs.tolist()
                im_result["log_likelihoods"] = im_log_likelihoods.tolist()
        else:
            assert isinstance(self.model, ResNet_Softmax)
            log_probs = self.model(images, SoftmaxMode.LOG_SOFTMAX)
            for im_result, im_log_probs in zip(result, log_probs):
                im_result["log_probs"] = im_log_probs.tolist()

        if isinstance(output, CarliniAndWagnerOutput):
            for im_result, noise_l2_norm in zip(result, output.l2_norm):
                im_result["noise_l2_norm"] = noise_l2_norm.item()

        Tracker.progress()
        return result
    def load_checkpoint(self, epoch: int) -> None:
        checkpoint_path = self.from_out(f"checkpoints/checkpoint_{epoch}.pth")
        checkpoint = torch.load(checkpoint_path,
                                map_location=model_device(self.model))

        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        Tracker.epoch = checkpoint["epoch"] + 1
        Tracker.global_batch = checkpoint["global_batch"]
        Tracker.best_val_accuracy = checkpoint["best_val_accuracy"]

        shutil.rmtree(self.from_out("logs"))
        shutil.copytree(self.from_out(f"checkpoints/logs_{epoch}"),
                        self.from_out("logs"))
    def train_batch(self, images: torch.Tensor, labels: torch.Tensor) -> None:
        images = images.to(model_device(self.model))
        labels = labels.to(model_device(self.model))

        self.optimizer.zero_grad()
        if Tracker.global_batch <= self.cfg.debug.visualize_inputs:
            torchvision.utils.save_image(
                images,
                self.from_out(
                    f"debug/inputs_batch_{Tracker.global_batch}.png"))
            lines = [str(label) for label in labels.tolist()]
            write_lines(
                self.from_out(
                    f"debug/inputs_batch_{Tracker.global_batch}.txt"), lines)

        outputs = self.model(images)
        loss = self.criterion(outputs, labels)
        loss.backward()
        self.optimizer.step()

        self.logger.add_scalar("loss",
                               loss.item(),
                               global_step=Tracker.global_batch)
        Tracker.progress()
    def __init__(self, cfg_filepath: str):
        set_deterministic_seed()
        with open(from_root(cfg_filepath), "r") as file:
            self.cfg = DictObject(json.load(file))

        for sub_dirname in ("logs", "checkpoints", "debug"):
            os.makedirs(self.from_out(sub_dirname), exist_ok=True)

        self.model = RefineNet_4Cascaded()
        self.model = self.model.cuda()
        state_dict = torch.load("TODO", map_location=model_device(self.model))
        self.model.backbone.load_state_dict(state_dict)

        self.train_loader = load_pascal_voc_train(16)
        self.infer_train_loader = load_pascal_voc_infer("train", 16)
        self.infer_val_loader = load_pascal_voc_infer("val", 16)
    def init_surrogate_model(self, surrogate_cfg_filepath: str) -> None:
        with open(from_root(surrogate_cfg_filepath), "r") as file:
            cfg = DictObject(json.load(file))

        self.surrogate_model = create_resnet(cfg)
        self.surrogate_model = self.surrogate_model.to(cfg.model.device)

        best_epoch_filepath = os.path.join(from_root(cfg.out_dirpath),
                                           "checkpoints/best_epoch.txt")
        with open(best_epoch_filepath, "r") as file:
            epoch = int(file.read())

        checkpoint_filepath = os.path.join(
            from_root(cfg.out_dirpath), f"checkpoints/checkpoint_{epoch}.pth")
        checkpoint = torch.load(checkpoint_filepath,
                                map_location=model_device(
                                    self.surrogate_model))
        self.surrogate_model.load_state_dict(checkpoint["model_state_dict"])

        self.surrogate_model.eval()
        for param in self.surrogate_model.parameters():
            param.requires_grad = False