def test_model(model,
               device,
               criterion_mask,
               criterion_depth,
               test_loader,
               depthweight=0.5):
    model.eval()
    datalength = len(test_loader)
    dice_mask = 0
    dice_depth = 0
    final_dice_mask = 0
    final_dice_depth = 0
    with torch.no_grad():
        pbar = notebook.tqdm(test_loader)
        for batch_idx, (bg, image, mask, depthmap) in enumerate(pbar):
            bg, image, mask, depthmap = bg.to(device), image.to(
                device), mask.to(device), depthmap.to(device)
            predmask, preddepth = model(bg, image)

            loss_mask = criterion_mask(predmask, mask)
            loss_depth = criterion_depth(preddepth, depthmap)

            total_loss = (
                (1 - depthweight) * loss_mask) + (depthweight * loss_depth)
            test_losses.append(total_loss)

            #Calculate Dice Coeff for Mask
            pred_m = torch.sigmoid(predmask)
            pred_m = (pred_m > 0.5).float()
            dice_mask += dice.dice_coeff(pred_m, mask).item()

            #Calculate Dice Coeff for Depthmap
            dice_depth += dice.dice_coeff(preddepth, depthmap).item()

            pbar.set_description(desc=f'Loss={total_loss}')

        final_dice_mask = dice_mask / datalength
        final_dice_depth = dice_depth / datalength
        print('*********************** TEST ***********************')
        print('Mask Dice Coeff: ', final_dice_mask, 'Depthmap Dice Coeff: ',
              final_dice_depth)
        print('======================= IMAGE ======================')
        print('image:', image.shape)
        visualize.show_img(
            torchvision.utils.make_grid(image.detach().cpu()[1:5]), 8)
        print('======================= MASK =======================')
        print('actual:', mask.shape)
        visualize.show_img(
            torchvision.utils.make_grid(mask.detach().cpu()[1:5]), 8)
        print('predicted:', predmask.shape)
        visualize.show_img(
            torchvision.utils.make_grid(predmask.detach().cpu()[1:5]), 8)
        print('======================= DEPTHMAP ===================')
        print('actual:', depthmap.shape)
        visualize.show_img(
            torchvision.utils.make_grid(depthmap.detach().cpu()[1:5]), 8)
        print('predicted:', preddepth.shape)
        visualize.show_img(
            torchvision.utils.make_grid(preddepth.detach().cpu()[1:5]), 8)
        return test_losses, final_dice_mask, final_dice_depth
Example #2
0
    def func_per_iteration(self, data, device, iter=None):
        if self.config is not None: config = self.config
        img = data['data']
        label = data['label']
        name = data['fn']

        if len(config.eval_scale_array) == 1:
            pred = self.whole_eval(img, None, device)
        else:
            pred = self.sliding_eval(img, config.eval_crop_size,
                                     config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        # tensorboard logger does not fit multiprocess
        if self.logger is not None and iter is not None:
            colors = self.dataset.get_class_colors()
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            self.logger.add_image(
                'vis', np.swapaxes(np.swapaxes(comp_img, 0, 2), 1, 2), iter)

        print("self.show_prediction = ", self.show_prediction)
        if self.show_image or self.show_prediction:
            colors = self.dataset.get_class_colors()
            image = img
            clean = np.zeros(label.shape)
            if self.show_image:
                comp_img = show_img(colors, config.background, image, clean,
                                    label, pred)
            else:
                comp_img = show_prediction(colors, config.background, image,
                                           pred)
            cv2.imwrite(
                os.path.join(os.path.realpath('.'), self.config.save, "eval",
                             name + ".vis.png"), comp_img[:, :, ::-1])
            # cv2.imwrite(name + ".png", comp_img[:,:,::-1])

        return results_dict
Example #3
0
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']

        pred = self.sliding_eval(img,
                                 config.eval_crop_size,
                                 config.eval_stride_rate,
                                 device=device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        hha = data['hha_img']
        name = data['fn']
        pred = self.sliding_eval_rgbdepth(img, hha, config.eval_crop_size,
                                          config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            ensure_dir(self.save_path)
            ensure_dir(self.save_path + '_color')

            fn = name + '.png'

            'save colored result'
            result_img = Image.fromarray(pred.astype(np.uint8), mode='P')
            class_colors = get_class_colors()
            palette_list = list(np.array(class_colors).flat)
            if len(palette_list) < 768:
                palette_list += [0] * (768 - len(palette_list))
            result_img.putpalette(palette_list)
            result_img.save(os.path.join(self.save_path + '_color', fn))

            'save raw result'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict
Example #5
0
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']

        img = cv2.resize(img, (config.image_width, config.image_height),
                         interpolation=cv2.INTER_LINEAR)
        label = cv2.resize(label,
                           (config.image_width // config.gt_down_sampling,
                            config.image_height // config.gt_down_sampling),
                           interpolation=cv2.INTER_NEAREST)

        pred = self.whole_eval(img,
                               (config.image_height // config.gt_down_sampling,
                                config.image_width // config.gt_down_sampling),
                               config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            fn = name + '.png'
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict
Example #6
0
def predict_images(modelpath, filepath, bgfile, imgfile):
  dim = 64
  dispnorm = False
  bg, image = load_image(filepath, bgfile, imgfile, dim)
  bg_disp, image_disp = load_image_nonorm(filepath, bgfile, imgfile, dim)
  
  if modelpath:
    inchannels = 3
    model = dnn.CustomNet15(inchannels)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(modelpath)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device=device)
  #  optimizer.load_state_dict(checkpoint['optimizer'])
  #  epoch = checkpoint['epoch']

  model.eval()
  bg, image = bg.to(device), image.to(device)
  predmask, preddepth = model(bg, image)
  visualize.show_img(torchvision.utils.make_grid(bg_disp), 3)
  visualize.show_img(torchvision.utils.make_grid(image_disp), 3)
  if dispnorm:
    visualize.show_img(torchvision.utils.make_grid(bg.detach().cpu()), 3)
    visualize.show_img(torchvision.utils.make_grid(image.detach().cpu()), 3)
  visualize.show_img(torchvision.utils.make_grid(predmask.detach().cpu()), 3)
  visualize.show_img(torchvision.utils.make_grid(preddepth.detach().cpu()), 3)

#if __name__ == "__main__":
#    args = get_args()
#    bg_file = args.background
#    img_file = args.image
#    modelpath = args.model

#    inchannels = 3
#    net = dnn.CustomNet15(inchannels)

#    logging.info("Loading model...")

#    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#    logging.info(f'Using device {device}')
#    net.to(device=device)
#    checkpoint = torch.load(modelpath)
#    net.load_state_dict(checkpoint['state_dict'])

#    logging.info("Model loaded !")
#    logging.info("\nPredicting image...")
#    predict_images(bg_file, img_file, net, device, dim=64, dispnorm=False)
def train_model(status,
                epoch,
                model,
                device,
                train_loader,
                criterion_mask,
                criterion_depth,
                optimizer,
                depthweight=0.5,
                printtestimg=False,
                printinterval=2000,
                scheduler=False):
    model.train()
    pbar = notebook.tqdm(train_loader)
    for batch_idx, (bg, image, mask, depthmap) in enumerate(pbar):
        bg, image, mask, depthmap = bg.to(device), image.to(device), mask.to(
            device), depthmap.to(device)
        # Init
        optimizer.zero_grad()
        # Predict
        predmask, preddepth = model(bg, image)

        loss_mask = criterion_mask(predmask, mask)
        #loss_depth = criterion_depth(preddepth, depthmap)
        loss_depth = customloss.depth_loss(preddepth, depthmap,
                                           criterion_depth)
        loss = ((1 - depthweight) * loss_mask) + (depthweight * loss_depth)
        train_losses.append(loss)

        # Backpropagation
        loss.backward()
        optimizer.step()
        if (scheduler):
            scheduler.step(loss)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx}')
        status.value = f'epoch={epoch}, Batch_id={batch_idx}, Loss={loss}, Mask={loss_mask}, Depth={loss_depth}'

        if batch_idx % 500 == 0:
            torch.cuda.empty_cache()

        if printtestimg:
            if batch_idx % printinterval == 0:
                print('*********************** TRAINING *******************')
                print('======================= IMAGE ======================')
                print('image:', image.shape)
                visualize.show_img(
                    torchvision.utils.make_grid(image.detach().cpu()[1:5]), 8)
                print('======================= MASK =======================')
                print('actual:', mask.shape)
                visualize.show_img(
                    torchvision.utils.make_grid(mask.detach().cpu()[1:5]), 8)
                print('predicted:', predmask.shape)
                visualize.show_img(
                    torchvision.utils.make_grid(predmask.detach().cpu()[1:5]),
                    8)
                print('======================= DEPTHMAP ===================')
                print('actual:', depthmap.shape)
                visualize.show_img(
                    torchvision.utils.make_grid(depthmap.detach().cpu()[1:5]),
                    8)
                print('predicted:', preddepth.shape)
                visualize.show_img(
                    torchvision.utils.make_grid(preddepth.detach().cpu()[1:5]),
                    8)
    return train_losses