def infer_batch(self, images: torch.Tensor, labels: torch.Tensor) -> Tuple[int, int]: images = images.to(model_device(self.model)) labels = labels.to(model_device(self.model)) log_probs = self.model(images) predictions = torch.argmax(log_probs, dim=1) n_correct = torch.sum(labels == predictions).item() n_incorrect = labels.size(0) - n_correct Tracker.progress() return n_correct, n_incorrect
def load_best_epoch(self) -> None: with open(self.from_out("checkpoints/best_epoch.txt"), "r") as file: epoch = int(file.read()) checkpoint_path = self.from_out(f"checkpoints/checkpoint_{epoch}.pth") checkpoint = torch.load(checkpoint_path, map_location=model_device(self.model)) self.model.load_state_dict(checkpoint["model_state_dict"])
def run_batch(self, images: torch.Tensor, labels: torch.Tensor) -> List[Dict[str, Any]]: images = images.to(model_device(self.model)) labels = labels.to(model_device(self.model)) if self.adversary is None: output = None else: output = self.adversary(self.model, images, labels) if Tracker.batch == 1: for i in range(min(images.size(0), self.visualize_adversary)): debug_img = torch.stack([ images[i], output.noises[i] + 0.5, output.perturbed_images[i] ], dim=0) filepath = self.from_out( f"inference_debug/{self.out_name}/image_{i + 1}.png") torchvision.utils.save_image(debug_img, filepath, nrow=3) images = output.perturbed_images result = [{"true_label": label.item()} for label in labels] if isinstance(self.model, ResNet_Gaussian): gaussian_output = self.model(images, GaussianMode.BOTH) for im_result, im_log_probs, im_log_likelihoods in zip( result, gaussian_output.log_posteriors, gaussian_output.log_likelihoods): im_result["log_probs"] = im_log_probs.tolist() im_result["log_likelihoods"] = im_log_likelihoods.tolist() else: assert isinstance(self.model, ResNet_Softmax) log_probs = self.model(images, SoftmaxMode.LOG_SOFTMAX) for im_result, im_log_probs in zip(result, log_probs): im_result["log_probs"] = im_log_probs.tolist() if isinstance(output, CarliniAndWagnerOutput): for im_result, noise_l2_norm in zip(result, output.l2_norm): im_result["noise_l2_norm"] = noise_l2_norm.item() Tracker.progress() return result
def load_checkpoint(self, epoch: int) -> None: checkpoint_path = self.from_out(f"checkpoints/checkpoint_{epoch}.pth") checkpoint = torch.load(checkpoint_path, map_location=model_device(self.model)) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) Tracker.epoch = checkpoint["epoch"] + 1 Tracker.global_batch = checkpoint["global_batch"] Tracker.best_val_accuracy = checkpoint["best_val_accuracy"] shutil.rmtree(self.from_out("logs")) shutil.copytree(self.from_out(f"checkpoints/logs_{epoch}"), self.from_out("logs"))
def train_batch(self, images: torch.Tensor, labels: torch.Tensor) -> None: images = images.to(model_device(self.model)) labels = labels.to(model_device(self.model)) self.optimizer.zero_grad() if Tracker.global_batch <= self.cfg.debug.visualize_inputs: torchvision.utils.save_image( images, self.from_out( f"debug/inputs_batch_{Tracker.global_batch}.png")) lines = [str(label) for label in labels.tolist()] write_lines( self.from_out( f"debug/inputs_batch_{Tracker.global_batch}.txt"), lines) outputs = self.model(images) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() self.logger.add_scalar("loss", loss.item(), global_step=Tracker.global_batch) Tracker.progress()
def __init__(self, cfg_filepath: str): set_deterministic_seed() with open(from_root(cfg_filepath), "r") as file: self.cfg = DictObject(json.load(file)) for sub_dirname in ("logs", "checkpoints", "debug"): os.makedirs(self.from_out(sub_dirname), exist_ok=True) self.model = RefineNet_4Cascaded() self.model = self.model.cuda() state_dict = torch.load("TODO", map_location=model_device(self.model)) self.model.backbone.load_state_dict(state_dict) self.train_loader = load_pascal_voc_train(16) self.infer_train_loader = load_pascal_voc_infer("train", 16) self.infer_val_loader = load_pascal_voc_infer("val", 16)
def init_surrogate_model(self, surrogate_cfg_filepath: str) -> None: with open(from_root(surrogate_cfg_filepath), "r") as file: cfg = DictObject(json.load(file)) self.surrogate_model = create_resnet(cfg) self.surrogate_model = self.surrogate_model.to(cfg.model.device) best_epoch_filepath = os.path.join(from_root(cfg.out_dirpath), "checkpoints/best_epoch.txt") with open(best_epoch_filepath, "r") as file: epoch = int(file.read()) checkpoint_filepath = os.path.join( from_root(cfg.out_dirpath), f"checkpoints/checkpoint_{epoch}.pth") checkpoint = torch.load(checkpoint_filepath, map_location=model_device( self.surrogate_model)) self.surrogate_model.load_state_dict(checkpoint["model_state_dict"]) self.surrogate_model.eval() for param in self.surrogate_model.parameters(): param.requires_grad = False