def save_results_to_disk(p, val_loader, model, crf_postprocess=False): print('Save results to disk ...') model.eval() # CRF if crf_postprocess: from utils.crf import dense_crf counter = 0 for i, batch in enumerate(val_loader): output = model(batch['image'].cuda(non_blocking=True)) meta = batch['meta'] for jj in range(output.shape[0]): counter += 1 image_file = meta['image_file'][jj] # CRF post-process if crf_postprocess: probs = dense_crf(meta['image_file'][jj], output[jj]) pred = np.argmax(probs, axis=0).astype(np.uint8) # Regular else: pred = torch.argmax(output[jj], dim=0).cpu().numpy().astype(np.uint8) result = cv2.resize(pred, dsize=(meta['im_size'][1][jj], meta['im_size'][0][jj]), interpolation=cv2.INTER_NEAREST) imageio.imwrite(os.path.join(p['save_dir'], meta['image'][jj] + '.png'), result) if counter % 250 == 0: print('Saving results: {} of {} objects'.format(counter, len(val_loader.dataset)))
def predict_img(net, full_img, gpu=False): img = resize(full_img) img = np.array(img) img = torch.FloatTensor(img) x = img.permute(2, 0, 1).contiguous() # transform to (C x H x W) x = x.view(1, 3, 256, 255) # image (N x C x H x W) if gpu: with torch.no_grad(): x = Variable(x).cuda() else: with torch.no_grad(): x = Variable(x) x = normalize(x) # normalize values to [0, 1] x = net(x) # feed into the net x = F.sigmoid(x) x = F.upsample_bilinear(x, scale_factor=2).data[0][0].cpu().numpy( ) # rescale the image to full size yy = dense_crf(np.array(full_img).astype(np.uint8), x) return yy > 0.5
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5, use_dense_crf=False): net.eval() ds = BasicDataset('data/training/img/', 'data/training/full_mask/', scale=scale_factor) img = ds.preprocess(full_img) img = torch.from_numpy(img) img = torch.unsqueeze(img, 0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() if use_dense_crf: full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) return full_mask > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5, use_dense_crf=False): net.eval() img = full_img.resize((320, 240)) img = np.array(img) img = (img < 1200) * img img = img / img.max() img = np.expand_dims(img, axis=2) img = img.transpose((2, 0, 1)) img = torch.from_numpy(img.astype(np.float32)) # img = preprocess(full_img, scale_factor) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((240, 320)), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() if use_dense_crf: full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) return full_mask > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5, use_dense_crf=False): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(full_img.size[1]), transforms.ToTensor() ]) probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() if use_dense_crf: full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) return full_mask > out_threshold
def main(): args = parser.parse_args() model = WNet.WNet(args.squeeze) model.load_state_dict( torch.load(args.model, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose( [transforms.Resize((64, 64)), transforms.ToTensor()]) img = Image.open("data2/images/train/1head.png").convert('RGB') x = transform(img)[None, :, :, :] enc, dec = model(x) show_image(x[0]) # TODO: torch sum/ stack? show_image(enc[0, :3, :, :].detach()) # show_image(torch.argmax(enc[:,:,:,:], dim=1)) # show_image(dec[0, :, :, :].detach()) # now put enc in crf segment = enc[0, :, :, :].detach() # put in tensor here? orimg = imread("data2/images/train/1head.png") img = resize(orimg, (64, 64)) Q = dense_crf(img, segment.numpy()) print(type(Q)) Q = np.argmax(Q, axis=0) print(len(Q)) print(np.unique(Q)) plt.imshow(Q) plt.show()
(best_iou, (best_iou * 2) / (best_iou + 1))) else: raise ValueError("can't find model") crf = True with torch.no_grad(): img, mask = img.to(device), mask.to(device) output = model(img) #[1, 9, 256, 256] probs = F.softmax(output, dim=1) if crf: pred_crf = probs.cpu().data[0].numpy() # crf img = img.cpu().data[0].numpy() pred_crf = dense_crf(img * 255, pred_crf) pred_crf = np.asarray(pred_crf, dtype=np.int) # 合并特征 pred_crf = merge_classes(pred_crf) _, pred = torch.max(probs, dim=1) pred = pred.cpu().data[0].numpy() label = mask.cpu().data[0].numpy() pred = np.asarray(pred, dtype=np.int) label = np.asarray(label, dtype=np.int) pred = merge_classes(pred) label = merge_classes(label) cv2.namedWindow("image", 0) cv2.imshow("image", image) cv2.namedWindow("mask", 0) cv2.imshow("mask", encode(label, color_test)) cv2.namedWindow("pred", 0)
(best_iou, (best_iou * 2) / (best_iou + 1))) else: raise ValueError("can't find model") print(">>>Test After Dense CRF: ") model.eval() running_metrics.reset() with torch.no_grad(): for i, (img, mask) in tqdm(enumerate(val_loader)): img = img.to(device) output = model(img) #[-1, 9, 256, 256] probs = F.softmax(output, dim=1) pred = probs.cpu().data[0].numpy() label = mask.cpu().data[0].numpy() # crf img = img.cpu().data[0].numpy() pred = dense_crf(img * 255, pred) # print(pred.shape) # _, pred = torch.max(torch.tensor(pred), dim=-1) pred = np.asarray(pred, dtype=np.int) label = np.asarray(label, dtype=np.int) # 合并特征 pred = merge_classes(pred) label = merge_classes(label) # print(pred.shape,label.shape) running_metrics.update(label, pred) score, class_iou = running_metrics.get_scores() for k, v in score.items(): print(k, ':', v) print(i, class_iou)
def save_inference_results_on_disk(loader, network, name): config = loader.config pack_volume = config['pack_volume'] path = os.path.join(config['temp_folder'], name, '') print('path ', path) network.eval() network = network.cuda() all_outputs = torch.cuda.FloatTensor() i = 1 print('Inference is in progress') print('loader ', loader.batch_sampler.sampler) for data in tqdm(loader): images, true_masks = data images = images.cuda() images_themselves = images[:, :3] if config['with_depth']: depths = images[:, 3] else: depths = None size_101 = config['101'] if config['resize_128']: outputs = network(images_themselves, depths).detach() else: size_patch = config['patch_size'] size_37 = size_101 - size_patch outputs_1 = network(images_themselves[:, :, :size_patch, :size_patch], depths).detach() outputs_2 = network(images_themselves[:, :, size_37:, :size_patch], depths).detach() outputs_3 = network(images_themselves[:, :, :size_patch, size_37:], depths).detach() outputs_4 = network(images_themselves[:, :, size_37:, size_37:], depths).detach() outputs = torch.from_numpy(np.zeros((outputs_1.shape[0], outputs_1.shape[1], size_101, size_101))).float().cuda() outputs[:, :, :size_patch,:size_patch] += outputs_1 outputs[:, :, size_37:,:size_patch] += outputs_2 outputs[:, :, :size_patch,size_37:] += outputs_3 outputs[:, :, size_37:,size_37:] += outputs_4 outputs[:, :, size_37:size_patch, :size_37] /= 2.0 outputs[:, :, size_37:size_patch, size_patch:] /= 2.0 outputs[:, :, :size_37, size_37:size_patch] /= 2.0 outputs[:, :, size_patch:, size_37:size_patch] /= 2.0 outputs[:, :, size_37:size_patch, size_37:size_patch] /= 4.0 outputs = F.sigmoid(outputs) # something like smoothing with conditional random fields if config['crf']: for j, (output, image) in enumerate(zip(outputs, images_themselves)): output = output.squeeze(dim=0) # print('output before ', output.shape) image = torch.transpose(image, dim0=0, dim1=2) # print('image ', image.shape) output = dense_crf(image.data.cpu().numpy().astype(np.uint8), output.data.cpu().numpy()) # print('output after', output) outputs[j] = torch.from_numpy(output).float() if config['resize_128']: resized_outputs = np.zeros((outputs.shape[0], outputs.shape[1], size_101, size_101)) for j, output in enumerate(outputs): output_as_array = output.data.cpu().numpy() resized_outputs[j] = output_as_array[:, 27:, 14:-13] # resize_image(output_as_array, (size_101, size_101)) outputs = torch.from_numpy(resized_outputs).cuda().float() # outputs_for_plot = outputs.cpu().numpy()[0][0] # print('outputs_for_plot ', outputs_for_plot, outputs_for_plot.shape) # print('true_masks ', true_masks.shape) # import matplotlib.pyplot as plt # plt.imshow(outputs_for_plot) # plt.show() # plt.imshow(true_masks[0], cmap='gray') # plt.show() # input() all_outputs = torch.cat((all_outputs, outputs.data), dim=0) if i % pack_volume == 0: torch.save(all_outputs, '%sall_outputs_%d' % (path, i)) all_outputs = torch.cuda.FloatTensor() torch.cuda.empty_cache() i += 1 batches_number = len(loader) // pack_volume print('batches_number = ', batches_number) all_outputs = None torch.cuda.empty_cache() return batches_number
segment2 = torch.load('segment2.pt') segment3 = torch.load('segment3.pt') segment4 = torch.load('segment4.pt') # segment = segment2 print(segment1) segment = torch.stack([segment1, segment2, segment3, segment4]) # segment = torch.load('segment1.pt') # sns.heatmap(segment, cmap="binary") # plt.show() # segment = torch.squeeze(segment) # print(type(segment)) # segment = -torch.log(segment) # segment_normalize = torch.round(torch.sigmoid(segment)) # segment_normalize = torch.nn.functional.softmax(segment3).data orimg = imread("data2/images/train/8049.jpg") img = resize(orimg, (224, 224)) Q = dense_crf(img, segment.numpy()) print(Q) sns.heatmap(Q[0], cmap="cubehelix") plt.show() sns.heatmap(Q[1], cmap="cubehelix") plt.show() sns.heatmap(Q[2], cmap="cubehelix") plt.show() sns.heatmap(Q[3], cmap="cubehelix") plt.show()