def __init__(self, config, train_dl, val_dl, device): self.config = config self.train_dl = train_dl self.val_dl = val_dl self.device = device self.model = get_network(config["model"]) state_dict = torch.load(config["model"]["weights_path"]) self.model.load_state_dict(state_dict["model"]) self.model = self.model.to(device) self.model.eval() self.maps_weights = ( self.model.linear.weight ) # shape: [num_classes, K] K - number of kernel filters self.interpolation_mode = config["interpolation"]
def __init__(self, config, device, return_one_always=True): self.model = get_network(config["model"]) self.model.load_state_dict( torch.load(config["model"]["weights_path"], map_location=device)["model"]) self.config = config self.device = device self.model.to(self.device) self.model.eval() self.prediction_threshold = config.get("prediction_threshold", None) self.set_target_layer_hook(config["model"]["target_layer"]) self.return_one_always = return_one_always self.activation_maps = None self.grad = None self.func_on_target = lambda x: x
def __init__(self, config, train_dl, val_dl): self.config = config self.train_dl = train_dl self.val_dl = val_dl self.device = config["devices"][0] model_path = os.path.join(config["experiment_path"], config["model"]["weights_path"]) self.model = get_network(config["model"]) state_dict = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state_dict["model"]) self.model = self.model.to(self.device) self.model.eval() self.metrics = { metric_name: get_metric(metric_name, config["model"]["classes"], self.device) for metric_name in config["model"]["metrics"] }
def __init__(self, config, train_dl, val_dl, device): self.config = config self.train_dl = train_dl self.val_dl = val_dl self.device = device self.model = get_network(config["model"]) state_dict = torch.load(config["model"]["weights_path"]) self.model.load_state_dict(state_dict["model"]) self.model = self.model.to(device) self.model.eval() self.maps_weights = getattr( self.model, config["weights_layer"] ).weight # shape: [num_classes, K] K - number of kernel filters self.interpolation_mode = config["interpolation"] self.use_predicted_labels = config["use_predicted_labels"] self.denorm = denormalization["default"] self.maps = None getattr(self.model, config["maps_layer"]).register_forward_hook( self.save_maps_forward)
def get_cam_grad_extractor(config, device): name = config["extraction_method"] if name == "grad-cam": return CamGrad(config, device) elif name == "grad-cam++": return CamGradPlusPlus(config, device) else: raise ValueError(f"Unrecognized mask generation method {name}.") if __name__ == "__main__": model_config = { "arch": "resnet50", "pretrained": False, "classes": 201, } config = { "model": model_config, "target_layer": "layer4.1.conv3", "prediction_threshold": 0.5, } model = get_network(model_config).cuda() for name, module in model.named_children(): print("Direct child name: ", name) for n, m in module.named_children(): print("\t", n) print(list(model.named_modules()))
def _work(process_id, config): device = config["devices"][process_id] model = get_network(config["model"]) model.load_state_dict( torch.load(config["model"]["weights_path"], map_location=device)["model"]) model.to(device) model.eval() transform = get_transforms(config["data"]["transform"]) dataset = ImageNetMLC(config["data"]["path"], transform, return_size=True) subsets = split_dataset(dataset, len(config["devices"])) dataloader = DataLoader(subsets[process_id], shuffle=False, pin_memory=False) with torch.no_grad(): for iteration, (X, y, name, orig_size) in enumerate(dataloader): X, y, name, orig_size = ( X.to(device, non_blocking=True), y.to(device, non_blocking=True), name[0], orig_size[0], ) if os.path.exists( os.path.join(config["data"]["output_path"], name + ".npy")): continue X_tta = torch.cat([X, X.flip(-1)], dim=0) edge, _ = model(X_tta) edge = torch.sigmoid(edge[0] / 2 + edge[1].flip(-1) / 2) cam_dict = np.load(config["data"]["cam_path"] + "/" + name + ".npy", allow_pickle=True).item() cams = torch.from_numpy(cam_dict["cam"]).to(device) keys = np.pad(cam_dict["keys"] + 1, (1, 0), mode="constant") cams = F.interpolate( cams.unsqueeze(1), size=edge.shape[1:], mode="bilinear", align_corners=False, ).squeeze(1) rw = indexing.propagate_to_edge( cams, edge, beta=config["beta"], exp_times=config["exp_times"], radius=5, device=device, ) rw_up = F.interpolate( rw, size=(orig_size[0], orig_size[1]), mode="bilinear", align_corners=False, ) rw_up = rw_up.relu_() / torch.max(rw_up) np.save( os.path.join(config["data"]["output_path"], name + ".npy"), { "keys": keys, "map": rw_up.squeeze(0).cpu().numpy() }, ) # rw_up_bg = F.pad(rw_up, (0, 0, 0, 0, 1, 0), value=config['sem_seg_bg_thres'])[0] # rw_pred = torch.argmax(rw_up_bg, dim=0).cpu().numpy() # rw_pred = keys[rw_pred] if iteration % 100 == 0: print( f"Device: {process_id}, Iteration: {iteration}/{len(subsets[process_id])}" )