def generate_data(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"using device: {device}") imgs = np.load(IMAGES) boxes = np.load(BOXES, allow_pickle=True) masks = np.load(MASKS) data = {} thresholds = np.linspace(0.0, 1.0, num=NUM_THRESHOLDS + 1) # only care about validation dataloader _, valid_dl, _ = dataloaders(imgs, boxes, masks, BATCH_SIZE) util_model = UtilityModel.load_from_checkpoint(UTILITY_MODEL) for name, path in MODELS.items(): model = NoiseModel.load_from_checkpoint(path, util_model=util_model) data[name] = {} data[name]["thresholds"] = thresholds data[name]["num_params"] = sum( p.numel() for p in model.noise_model.parameters() ) ( data[name]["dice"], data[name]["coverage"], data[name]["dice_at_half_coverage"], ) = evaluate(valid_dl, model, thresholds, device) print(f"done: {name}") with open(RESULTS, "wb") as f: pickle.dump(data, f) return data
def main(args): pl.seed_everything(0) imgs = np.load(args.imgs) boxes = np.load(args.boxes, allow_pickle=True) masks = np.load(args.masks) train_dl, valid_dl, test_dl = dataloaders(imgs, boxes, masks, args.batch_size) util_model = UtilityModel.load_from_checkpoint(args.utility_model) if args.pretrained is not None: # use a pretrained utility model as the initialization for noise model args.pretrained = UtilityModel.load_from_checkpoint( args.pretrained ).model.state_dict() model = NoiseModel( util_model, args.depth, args.channel_factor, args.learning_rate, min_scale=args.min_scale, max_scale=args.max_scale, noise_coeff=args.noise_coeff, pretrained=args.pretrained, ) checkpoint_cb = pl.callbacks.ModelCheckpoint(monitor="val_loss", mode="min") trainer = pl.Trainer( gpus=args.gpus, max_epochs=args.epochs, checkpoint_callback=checkpoint_cb ) trainer.fit(model, train_dl, valid_dl)
def test(args): with open(args.config, "rb") as f: config = yaml.safe_load(f) trainer = pl.Trainer(gpus=config["gpus"]) _, _, test_dl = dataloaders(args.dataset, config["batch_size"]) model = ImageGPT.load_from_checkpoint(args.checkpoint) trainer.test(model, test_dataloaders=test_dl)
def train(args): with open(args.config, "rb") as f: config = yaml.safe_load(f) # experiment name name = f"{config['name']}_{args.dataset}" if args.pretrained is not None: model = ImageGPT.load_from_checkpoint(args.pretrained) # potentially modify model for finetuning model.learning_rate = config["learning_rate"] model.classify = config["classify"] else: model = ImageGPT(centroids=args.centroids, **config) train_dl, valid_dl, test_dl = dataloaders(args.dataset, config["batch_size"]) logger = pl_loggers.TensorBoardLogger("logs", name=name) if config["classify"]: # classification # stop early for best validation accuracy for finetuning early_stopping = pl.callbacks.EarlyStopping(monitor="val_acc", patience=3, mode="max") checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_acc") trainer = pl.Trainer( max_steps=config["steps"], gpus=config["gpus"], accumulate_grad_batches=config["accumulate_grad_batches"], precision=config["precision"], early_stop_callback=early_stopping, checkpoint_callback=checkpoint, logger=logger, ) else: # pretraining checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss") trainer = pl.Trainer( max_steps=config["steps"], gpus=config["gpus"], precision=config["precision"], accumulate_grad_batches=config["accumulate_grad_batches"], checkpoint_callback=checkpoint, logger=logger, ) trainer.fit(model, train_dl, valid_dl)
def main(args): imgs = np.load(args.imgs) boxes = np.load(args.boxes, allow_pickle=True) masks = np.load(args.masks) train_dl, valid_dl, test_dl = dataloaders(imgs, boxes, masks, args.batch_size) model = UtilityModel(args.depth, args.channel_factor, args.learning_rate) checkpoint_cb = pl.callbacks.ModelCheckpoint(monitor="val_dice", mode="max") trainer = pl.Trainer(gpus=args.gpus, max_epochs=args.epochs, checkpoint_callback=checkpoint_cb) trainer.fit(model, train_dl, valid_dl)
def main(args): model = ImageGPT.load_from_checkpoint(args.checkpoint).gpt.eval().cuda() centroids = np.load(args.centroids) train_dl, valid_dl, test_dl = dataloaders(args.dataset, 1) dl = iter(DataLoader(valid_dl.dataset, shuffle=True)) # rows for figure rows = [] for example in tqdm(range(args.num_examples)): img, _ = next(dl) h, w = img.shape[-2:] img = quantize(img, torch.from_numpy(centroids)).numpy()[0] seq = img.reshape(-1) # first half of image is context context = seq[: int(len(seq) / 2)] context_img = np.pad(context, (0, int(len(seq) / 2))).reshape(h, w) context = torch.from_numpy(context).cuda() # predict second half of image preds = ( sample(model, context, int(len(seq) / 2), num_samples=args.num_samples) .cpu() .numpy() .transpose() ) preds = preds.reshape(-1, h, w) # combine context, preds, and truth for figure rows.append( np.concatenate([context_img[None, ...], preds, img[None, ...]], axis=0) ) figure = make_figure(rows, centroids) figure.save("figure.png")
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"using device: {device}") imgs = np.load(IMAGES) boxes = np.load(BOXES, allow_pickle=True) masks = np.load(MASKS) # only care about validation dataloader _, valid_dl, test_dl = dataloaders(imgs, boxes, masks, BATCH_SIZE) ds = valid_dl.dataset util_model = UtilityModel.load_from_checkpoint(UTILITY_MODEL) model = util_model.model.to(device) np.random.seed(42) choices = np.random.choice(np.arange(len(ds)), NUM_IMAGES) grad_cam_time = 0.0 unoise_time = 0.0 occlusion_sensitivity_time = 0.0 for i, choice in enumerate(choices): util_model = UtilityModel.load_from_checkpoint(UTILITY_MODEL) model = util_model.model.to(device) img, mask = ds[choice] img = img.unsqueeze(0).to(device) mask = mask.unsqueeze(0).to(device) if i == 0: noise_model = NoiseModel.load_from_checkpoint( LARGE_NOISE_MODEL, util_model=util_model ).to(device) B = noise_model(img)[1] for threshold in np.linspace(0.0, 1.0, 11): thresh_img = img * (B <= threshold) thresh_img = Image.fromarray(impose(thresh_img[0])) if SAVE: thresh_img.save(f"visualizations/threshold_{threshold:.1f}.png") # for some reason these need to be loaded again # otherwise gradient hook for grad cam gets messed up util_model = UtilityModel.load_from_checkpoint(UTILITY_MODEL) model = util_model.model.to(device) # plain image plain = Image.fromarray(impose(img[0])) if SAVE: plain.save(f"visualizations/plain_{i}.png") # original mask og = Image.fromarray(impose(img[0], mask=mask[0])) if SAVE: og.save(f"visualizations/original_{i}.png") # grad cam y, x = torch.where(mask[0, 0] > 0) start = time.time() heatmap = grad_cam(model, img, x[0], y[0]) grad_cam_time += time.time() - start gc = Image.fromarray(impose(img[0], heatmap)) if SAVE: gc.save(f"visualizations/grad_cam_{i}.png") # unoise large with torch.no_grad(): noise_model = NoiseModel.load_from_checkpoint( LARGE_NOISE_MODEL, util_model=util_model ).to(device) start = time.time() heatmap = -noise_model.noise_model(img)[0, 0] heatmap = torch.relu(heatmap) / heatmap.max() unoise_time += time.time() - start large = Image.fromarray(impose(img[0], heatmap)) if SAVE: large.save(f"visualizations/unoise_large_{i}.png") # unoise large, noised image # TODO: remove duplications with torch.no_grad(): noise_model = NoiseModel.load_from_checkpoint( LARGE_NOISE_MODEL, util_model=util_model ).to(device) noised = noise_model(img)[0][0] noised = Image.fromarray(impose(noised)) if SAVE: noised.save(f"visualizations/noised_{i}.png") # unoise small with torch.no_grad(): noise_model = NoiseModel.load_from_checkpoint( SMALL_NOISE_MODEL, util_model=util_model ).to(device) heatmap = -noise_model.noise_model(img)[0, 0] heatmap = torch.relu(heatmap) / heatmap.max() small = Image.fromarray(impose(img[0], heatmap)) if SAVE: small.save(f"visualizations/unoise_small_{i}.png") # occlusion_sensitivity with torch.no_grad(): start = time.time() heatmap = occlusion_sensitivity(util_model, img, mask, patch=15, stride=2)[ 0 ] heatmap = heatmap - heatmap.min() heatmap /= heatmap.max() occlusion_sensitivity_time += time.time() - start os = Image.fromarray(impose(img[0], heatmap)) if SAVE: os.save(f"visualizations/occlusion_{i}.png") print("unoise_time:", unoise_time / NUM_IMAGES) print("grad_cam_time:", grad_cam_time / NUM_IMAGES) print("occlusion_sensitivity_time:", occlusion_sensitivity_time / NUM_IMAGES)