예제 #1
0
    def __init__(self, config, train_dl, val_dl, device):
        self.config = config
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.device = device

        self.model = get_network(config["model"])
        state_dict = torch.load(config["model"]["weights_path"])
        self.model.load_state_dict(state_dict["model"])
        self.model = self.model.to(device)
        self.model.eval()

        self.maps_weights = (
            self.model.linear.weight
        )  # shape: [num_classes, K] K - number of kernel filters
        self.interpolation_mode = config["interpolation"]
예제 #2
0
    def __init__(self, config, device, return_one_always=True):
        self.model = get_network(config["model"])
        self.model.load_state_dict(
            torch.load(config["model"]["weights_path"],
                       map_location=device)["model"])
        self.config = config
        self.device = device
        self.model.to(self.device)
        self.model.eval()
        self.prediction_threshold = config.get("prediction_threshold", None)
        self.set_target_layer_hook(config["model"]["target_layer"])
        self.return_one_always = return_one_always

        self.activation_maps = None
        self.grad = None

        self.func_on_target = lambda x: x
예제 #3
0
    def __init__(self, config, train_dl, val_dl):
        self.config = config
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.device = config["devices"][0]

        model_path = os.path.join(config["experiment_path"],
                                  config["model"]["weights_path"])
        self.model = get_network(config["model"])
        state_dict = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(state_dict["model"])
        self.model = self.model.to(self.device)
        self.model.eval()

        self.metrics = {
            metric_name: get_metric(metric_name, config["model"]["classes"],
                                    self.device)
            for metric_name in config["model"]["metrics"]
        }
예제 #4
0
    def __init__(self, config, train_dl, val_dl, device):
        self.config = config
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.device = device

        self.model = get_network(config["model"])
        state_dict = torch.load(config["model"]["weights_path"])
        self.model.load_state_dict(state_dict["model"])
        self.model = self.model.to(device)
        self.model.eval()

        self.maps_weights = getattr(
            self.model, config["weights_layer"]
        ).weight  # shape: [num_classes, K] K - number of kernel filters
        self.interpolation_mode = config["interpolation"]
        self.use_predicted_labels = config["use_predicted_labels"]
        self.denorm = denormalization["default"]

        self.maps = None
        getattr(self.model, config["maps_layer"]).register_forward_hook(
            self.save_maps_forward)
예제 #5
0

def get_cam_grad_extractor(config, device):
    name = config["extraction_method"]
    if name == "grad-cam":
        return CamGrad(config, device)
    elif name == "grad-cam++":
        return CamGradPlusPlus(config, device)
    else:
        raise ValueError(f"Unrecognized mask generation method {name}.")


if __name__ == "__main__":
    model_config = {
        "arch": "resnet50",
        "pretrained": False,
        "classes": 201,
    }
    config = {
        "model": model_config,
        "target_layer": "layer4.1.conv3",
        "prediction_threshold": 0.5,
    }

    model = get_network(model_config).cuda()
    for name, module in model.named_children():
        print("Direct child name: ", name)
        for n, m in module.named_children():
            print("\t", n)
    print(list(model.named_modules()))
def _work(process_id, config):
    device = config["devices"][process_id]

    model = get_network(config["model"])
    model.load_state_dict(
        torch.load(config["model"]["weights_path"],
                   map_location=device)["model"])
    model.to(device)
    model.eval()

    transform = get_transforms(config["data"]["transform"])
    dataset = ImageNetMLC(config["data"]["path"], transform, return_size=True)
    subsets = split_dataset(dataset, len(config["devices"]))
    dataloader = DataLoader(subsets[process_id],
                            shuffle=False,
                            pin_memory=False)

    with torch.no_grad():
        for iteration, (X, y, name, orig_size) in enumerate(dataloader):
            X, y, name, orig_size = (
                X.to(device, non_blocking=True),
                y.to(device, non_blocking=True),
                name[0],
                orig_size[0],
            )

            if os.path.exists(
                    os.path.join(config["data"]["output_path"],
                                 name + ".npy")):
                continue

            X_tta = torch.cat([X, X.flip(-1)], dim=0)
            edge, _ = model(X_tta)
            edge = torch.sigmoid(edge[0] / 2 + edge[1].flip(-1) / 2)

            cam_dict = np.load(config["data"]["cam_path"] + "/" + name +
                               ".npy",
                               allow_pickle=True).item()
            cams = torch.from_numpy(cam_dict["cam"]).to(device)
            keys = np.pad(cam_dict["keys"] + 1, (1, 0), mode="constant")

            cams = F.interpolate(
                cams.unsqueeze(1),
                size=edge.shape[1:],
                mode="bilinear",
                align_corners=False,
            ).squeeze(1)
            rw = indexing.propagate_to_edge(
                cams,
                edge,
                beta=config["beta"],
                exp_times=config["exp_times"],
                radius=5,
                device=device,
            )

            rw_up = F.interpolate(
                rw,
                size=(orig_size[0], orig_size[1]),
                mode="bilinear",
                align_corners=False,
            )
            rw_up = rw_up.relu_() / torch.max(rw_up)
            np.save(
                os.path.join(config["data"]["output_path"], name + ".npy"),
                {
                    "keys": keys,
                    "map": rw_up.squeeze(0).cpu().numpy()
                },
            )

            # rw_up_bg = F.pad(rw_up, (0, 0, 0, 0, 1, 0), value=config['sem_seg_bg_thres'])[0]
            # rw_pred = torch.argmax(rw_up_bg, dim=0).cpu().numpy()
            # rw_pred = keys[rw_pred]

            if iteration % 100 == 0:
                print(
                    f"Device: {process_id}, Iteration: {iteration}/{len(subsets[process_id])}"
                )