def setupScreenShading(settings: inference.RenderSettings):
    shading = ScreenSpaceShading(device)
    shading.fov(settings.CAM_FOV)
    shading.ambient_light_color(np.array(settings.AMBIENT_LIGHT_COLOR))
    shading.diffuse_light_color(np.array(settings.DIFFUSE_LIGHT_COLOR))
    shading.specular_light_color(np.array(settings.SPECULAR_LIGHT_COLOR))
    shading.specular_exponent(settings.SPECULAR_EXPONENT)
    shading.light_direction(np.array(settings.LIGHT_DIRECTION))
    shading.material_color(np.array(settings.MATERIAL_COLOR))
    shading.ambient_occlusion(1.0)
    shading.inverse_ao = False
    return shading
Exemplo n.º 2
0
    renderer_path = RENDERER_CPU if DATASETS[i]['file'].endswith(
        'vdb') else RENDERER_GPU
    data_dir = DATA_DIR_CPU if DATASETS[i]['file'].endswith(
        'vdb') else DATA_DIR_GPU
    datasetfile = os.path.join(data_dir, DATASETS[i]['file'])
    print('Open', datasetfile)
    renderer = inference.Renderer(renderer_path, datasetfile, material, camera)
    time.sleep(5)
    renderer.send_command("aoradius=%5.3f\n" % float(AO_RADIUS))
    # create shading
    shading = ScreenSpaceShading(torch.device('cpu'))
    shading.fov(30)
    shading.light_direction(np.array([0.1, 0.1, 1.0]))
    shading.ambient_light_color(np.array(DATASETS[i]['ambient']) / 255.0)
    shading.diffuse_light_color(np.array(DATASETS[i]['diffuse']) / 255.0)
    shading.specular_light_color(np.array(DATASETS[i]['specular']) / 255.0)
    shading.specular_exponent(SPECULAR_EXPONENT)
    shading.material_color(np.array(DATASETS[i]['material']) / 255.0)
    shading.ambient_occlusion(AO_STRENGTH)
    shading.background(np.array(BACKGROUND))

    # render each model
    for k, m in enumerate(MODELS):
        print('Render', m['name'])
        p = m['path']
        outputName = os.path.join(
            OUTPUT_FOLDER, "%s.%s.mp4" % (DATASETS[i]['name'], m['name']))
        writer = imageio.get_writer(outputName, fps=FPS)

        camera.currentYaw = 0
        previous_image = None
import datasetVideo
from utils import ScreenSpaceShading

opt = dict({'samples': 4, 'numberOfImages': 1})
dataset_data = datasetVideo.collect_samples_clouds_video(4,
                                                         opt,
                                                         deferred_shading=True)
test_full_set = datasetVideo.DatasetFromFullImages(dataset_data, 1)
input_data, input_flow = test_full_set[0]
print(input_data.shape)

device = input_data.device
shading = ScreenSpaceShading(device)
shading.fov(30)
shading.ambient_light_color(np.array([0.1, 0.1, 0.1]))
shading.diffuse_light_color(np.array([1.0, 1.0, 1.0]))
shading.specular_light_color(np.array([0.2, 0.2, 0.2]))
shading.specular_exponent(16)
shading.light_direction(np.array([0.2, 0.2, 1.0]))
shading.material_color(np.array([1.0, 0.3, 0.3]))

output_rgb = shading(input_data)

f, axarr = plt.subplots(1, 3)
axarr[0].imshow(input_data.numpy()[0, 0, :, :] * 0.5 + 0.5)  #mask
axarr[1].imshow((input_data[0, 1:4, :, :] * 0.5 + 0.5).numpy().transpose(
    (1, 2, 0)))  #normal
axarr[2].imshow(output_rgb[0, :, :, :].numpy().transpose((1, 2, 0)))  #rgb
plt.show()
class LossNetUnshaded(nn.Module):
    """
    Main Loss Module for unshaded data.

    device: cpu or cuda
    opt: command line arguments (Namespace object) with:
     - losses: list of loss terms with weighting as string
       Format: <loss>:<target>:<weighting>
       with: loss in {l1, l2, perceptual, gan}
             target in {mask, normal, color, ao, all}, 'all' is only allowed for GAN
             weighting a positive number
     - further parameters depending on the losses

    """
    def __init__(self, device, input_channels, output_channels, high_res, padding, opt):
        super().__init__()
        self.padding = padding
        self.upsample = opt.upsample
        assert input_channels == 5 # mask, normalX, normalY, normalZ, depth
        assert output_channels == 6 # mask, normalX, normalY, normalZ, depth, ambient occlusion
        #List of tuples (name, weight or None)
        self.loss_list = [s.split(':') for s in opt.losses.split(',')]
        
        # Build losses and weights
        builder = LossBuilder(device)
        self.loss_dict = {}
        self.weight_dict = {}

        self.loss_dict['mse'] = builder.mse() #always use mse for psnr
        self.weight_dict[('mse','color')] = 0.0

        content_layers = []
        style_layers = []
        self.has_discriminator = False
        self.has_style_or_content_loss = False
        self.has_temporal_l2_loss = False
        for entry in self.loss_list:
            if len(entry)<2:
                raise ValueError("illegal format for loss list: " + entry)
            name = entry[0]
            target = entry[1]
            weight = entry[2] if len(entry)>2 else None
            if target!='mask' and target!='normal' and target!='color' and target!='ao' and target!='depth' and target!='all':
                raise ValueError("Unknown target: " + target)

            if 'mse'==name or 'l2'==name or 'l2_loss'==name:
                self.weight_dict[('mse',target)] = float(weight) if weight is not None else 1.0
            elif 'l1'==name or 'l1_loss'==name:
                self.loss_dict['l1'] = builder.l1_loss()
                self.weight_dict[('l1',target)] = float(weight) if weight is not None else 1.0
            elif 'tl2'==name or 'temp-l2'==name:
                self.loss_dict['temp-l2'] = builder.mse()
                self.weight_dict[('temp-l2',target)] = float(weight) if weight is not None else 1.0
                self.has_temporal_l2_loss = True
            elif 'l2-ds'==name:
                self.loss_dict['l2-ds'] = builder.downsample_loss(
                    'l2', opt.upscale_factor, 'bilinear')
                self.weight_dict[('l2-ds',target)] = float(weight) if weight is not None else 1.0
            elif 'l1-ds'==name:
                self.loss_dict['l1-ds'] = builder.downsample_loss(
                    'l1', opt.upscale_factor, 'bilinear')
                self.weight_dict[('l1-ds',target)] = float(weight) if weight is not None else 1.0
            elif 'perceptual'==name:
                content_layers = [(s.split(':')[0],float(s.split(':')[1])) if ':' in s else (s,1) for s in opt.perceptualLossLayers.split(',')]
                self.weight_dict[('perceptual',target)] = float(weight) if weight is not None else 1.0
                self.has_style_or_content_loss = True
            elif 'texture'==name:
                style_layers = [(s.split(':')[0],float(s.split(':')[1])) if ':' in s else (s,1) for s in opt.textureLossLayers.split(',')]
                #style_layers = [('conv_1',1), ('conv_3',1), ('conv_5',1)]
                self.weight_dict[('texture',target)] = float(weight) if weight is not None else 1.0
                self.has_style_or_content_loss = True
            elif 'adv'==name or 'gan'==name: #spatial-temporal adversary
                assert target=='all'
                self.discriminator, self.adv_loss = builder.gan_loss(
                    opt.discriminator, high_res,
                    26, #5+5+8+8
                    opt)
                self.weight_dict[('adv',target)] = float(weight) if weight is not None else 1.0
                self.discriminator_use_previous_image = True
                self.discriminator_clip_weights = False
                self.has_discriminator = True
            elif 'tgan'==name: #temporal adversary, current high-res + previous high-res
                assert target=='all'
                self.temp_discriminator, self.temp_adv_loss = builder.gan_loss(
                    opt.discriminator, high_res,
                    8+8,
                    opt)
                self.weight_dict[('tgan',target)] = float(weight) if weight is not None else 1.0
                self.has_discriminator = True
            elif 'sgan'==name: #spatial adversary, current high-res + current input
                assert target=='all'
                self.spatial_discriminator, self.spatial_adv_loss = builder.gan_loss(
                    opt.discriminator, high_res,
                    5+8,
                    opt)
                self.weight_dict[('sgan',target)] = float(weight) if weight is not None else 1.0
                self.has_discriminator = True
            else:
                raise ValueError('unknown loss %s'%name)

        if self.has_style_or_content_loss:
            self.pt_loss, self.style_losses, self.content_losses = \
                    builder.get_style_and_content_loss(dict(content_layers), dict(style_layers))

        self.loss_dict = nn.ModuleDict(self.loss_dict)
        print('Loss weights:', self.weight_dict)

        self.shading = ScreenSpaceShading(device)
        self.shading.fov(30)
        self.shading.ambient_light_color(np.array([opt.lossAmbient, opt.lossAmbient, opt.lossAmbient]))
        self.shading.diffuse_light_color(np.array([opt.lossDiffuse, opt.lossDiffuse, opt.lossDiffuse]))
        self.shading.specular_light_color(np.array([opt.lossSpecular, opt.lossSpecular, opt.lossSpecular]))
        self.shading.specular_exponent(16)
        self.shading.enable_specular = False
        self.shading.light_direction(np.array([0.0, 0.0, 1.0]))
        self.shading.material_color(np.array([1.0, 1.0, 1.0]))
        self.shading.ambient_occlusion(opt.lossAO)
        print('LossNet: ambient occlusion strength:', opt.lossAO)

    def get_discr_parameters(self):
        params = []
        if hasattr(self, 'discriminator'):
            params = params + list(self.discriminator.parameters())
        if hasattr(self, 'temp_discriminator'):
            params = params + list(self.temp_discriminator.parameters())
        if hasattr(self, 'spatial_discriminator'):
            params = params + list(self.spatial_discriminator.parameters())
        return params

    def discr_eval(self):
        if hasattr(self, 'discriminator'): self.discriminator.eval()
        if hasattr(self, 'temp_discriminator'): self.temp_discriminator.eval()
        if hasattr(self, 'spatial_discriminator'): self.spatial_discriminator.eval()

    def discr_train(self):
        if hasattr(self, 'discriminator'): self.discriminator.train()
        if hasattr(self, 'temp_discriminator'): self.temp_discriminator.train()
        if hasattr(self, 'spatial_discriminator'): self.spatial_discriminator.train()

    def print_summary(self, gt_shape, pred_shape, input_shape, prev_pred_warped_shape, num_batches, device):
        #Print networks for VGG + Discriminator
        import torchsummary
        res = gt_shape[1]
        if 'perceptual' in self.weight_dict.keys() or 'texture' in self.weight_dict.keys():
            print('VGG (Perceptual + Style loss)')
            torchsummary.summary(self.pt_loss, (3, res, res), 2*num_batches, device=device.type)
        if hasattr(self, 'discriminator'):
            print('Discriminator:')
            if self.discriminator_use_previous_image:
                #2x mask+normal+color+ao
                input_images_shape = (16, res, res)
            else:
                # mask+normal+color+ao
                input_images_shape = (8, res, res)
            torchsummary.summary(
                self.discriminator,
                input_images_shape, 
                2*num_batches,
                device=device.type)


    @staticmethod
    def pad(img, border):
        """
        overwrites the border of 'img' with zeros.
        The size of the border is specified by 'border'.
        The output size is not changed.
        """
        if border==0: 
            return img
        b,c,h,w = img.shape
        img_crop = img[:,:,border:h-border,border:h-border]
        img_pad = F.pad(img_crop, (border, border, border, border), 'constant', 0)
        _,_,h2,w2 = img_pad.shape
        assert(h==h2)
        assert(w==w2)
        return img_pad

    #@profile
    def forward(self, gt, pred, input, prev_input_warped, prev_pred_warped):
        """
        gt: ground truth high resolution image (B x C=output_channels x 4W x 4H)
        pred: predicted high resolution image (B x C=output_channels x 4W x 4H)
        input: upsampled low resolution input image (B x C=input_channels x 4W x 4H)
               Only used for the discriminator, can be None if only the other losses are used
        prev_input_warped: upsampled, warped previous input image
        prev_pred_warped: predicted image from the previous frame warped by the flow
               Shape: B x Cout x 4W x 4H
               Only used for temporal losses, can be None if only the other losses are used
        """

        B, Cout, Hhigh, Whigh = gt.shape
        assert Cout == 6
        assert gt.shape == pred.shape

        # zero border padding
        gt = LossNetUnshaded.pad(gt, self.padding)
        pred = LossNetUnshaded.pad(pred, self.padding)
        if prev_pred_warped is not None:
            prev_pred_warped = LossNetUnshaded.pad(prev_pred_warped, self.padding)

        # extract mask and normal
        gt_mask = gt[:,0:1,:,:]
        gt_mask_clamp = torch.clamp(gt_mask*0.5 + 0.5, 0, 1)
        gt_normal = ScreenSpaceShading.normalize(gt[:,1:4,:,:], dim=1)
        gt_depth = gt[:,4:5,:,:]
        gt_ao = gt[:,5:6,:,:]
        pred_mask = pred[:,0:1,:,:]
        pred_mask_clamp = torch.clamp(pred_mask*0.5 + 0.5, 0, 1)
        pred_normal = ScreenSpaceShading.normalize(pred[:,1:4,:,:], dim=1)
        pred_depth = pred[:,4:5,:,:]
        pred_ao = pred[:,5:6,:,:]
        input_mask = input[:,0:1,:,:]
        input_mask_clamp = torch.clamp(input_mask*0.5 + 0.5, 0, 1)
        input_normal = ScreenSpaceShading.normalize(input[:,1:4,:,:], dim=1)
        input_depth = input[:,4:5,:,:]
        input_ao = None #not available

        # compute color output
        gt_color = self.shading(gt)
        pred_color = self.shading(pred)
        input_color = self.shading(input)

        generator_loss = 0.0
        loss_values = {}

        # normal, simple losses, uses gt+pred
        for name in ['mse','l1']:
            if (name,'mask') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_mask, pred_mask)
                loss_values[(name,'mask')] = loss.item()
                generator_loss += self.weight_dict[(name,'mask')] * loss
            if (name,'normal') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_normal*gt_mask_clamp, pred_normal*gt_mask_clamp)
                loss_values[(name,'normal')] = loss.item()
                generator_loss += self.weight_dict[(name,'normal')] * loss
            if (name,'ao') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_ao*gt_mask_clamp, pred_ao*gt_mask_clamp)
                loss_values[(name,'ao')] = loss.item()
                generator_loss += self.weight_dict[(name,'ao')] * loss
            if (name,'depth') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_depth*gt_mask_clamp, pred_depth*gt_mask_clamp)
                loss_values[(name,'depth')] = loss.item()
                generator_loss += self.weight_dict[(name,'depth')] * loss
            if (name,'color') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_color, pred_color)
                loss_values[(name,'color')] = loss.item()
                generator_loss += self.weight_dict[(name,'color')] * loss

        # downsample loss, use input+pred
        # TODO: input is passed in upsampled version, but this introduces more errors
        # Better: input low-res input and use the 'low_res_gt=True' option in downsample_loss
        # This requires the input to be upsampled here for the GAN.
        for name in ['l2-ds', 'l1-ds']:
            if (name,'mask') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_mask, pred_mask)
                loss_values[(name,'mask')] = loss.item()
                generator_loss += self.weight_dict[(name,'mask')] * loss
            if (name,'normal') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_normal*input_mask_clamp, pred_normal*input_mask_clamp)
                loss_values[(name,'normal')] = loss.item()
                generator_loss += self.weight_dict[(name,'normal')] * loss
            if (name,'ao') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_ao*input_mask_clamp, pred_ao*input_mask_clamp)
                loss_values[(name,'ao')] = loss.item()
                generator_loss += self.weight_dict[(name,'ao')] * loss
            if (name,'depth') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_depth*input_mask_clamp, pred_depth*input_mask_clamp)
                loss_values[(name,'depth')] = loss.item()
                generator_loss += self.weight_dict[(name,'depth')] * loss
            if (name,'color') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_color, pred_color)
                loss_values[(name,'color')] = loss.item()
                generator_loss += self.weight_dict[(name,'color')] * loss

        # special losses: perceptual+texture, uses gt+pred
        def compute_perceptual(target, in_pred, in_gt):
            if ('perceptual',target) in self.weight_dict.keys() \
                    or ('texture',target) in self.weight_dict.keys():
                style_weight=self.weight_dict.get(('texture',target), 0)
                content_weight=self.weight_dict.get(('perceptual',target), 0)
                style_score = 0
                content_score = 0

                input_images = torch.cat([in_gt, in_pred], dim=0)
                self.pt_loss(input_images)

                for sl in self.style_losses:
                    style_score += sl.loss
                for cl in self.content_losses:
                    content_score += cl.loss

                if ('perceptual',target) in self.weight_dict.keys():
                    loss_values[('perceptual',target)] = content_score.item()
                if ('texture',target) in self.weight_dict.keys():
                    loss_values[('texture',target)] = style_score.item()
                return style_weight * style_score + content_weight * content_score
            return 0
        generator_loss += compute_perceptual('mask', pred_mask.expand(-1,3,-1,-1)*0.5+0.5, gt_mask.expand(-1,3,-1,-1)*0.5+0.5)
        generator_loss += compute_perceptual('normal', (pred_normal*gt_mask_clamp)*0.5+0.5, (gt_normal*gt_mask_clamp)*0.5+0.5)
        generator_loss += compute_perceptual('color', pred_color, gt_color)
        generator_loss += compute_perceptual('ao', pred_ao.expand(-1,3,-1,-1), gt_ao.expand(-1,3,-1,-1))
        generator_loss += compute_perceptual('depth', pred_depth.expand(-1,3,-1,-1), gt_depth.expand(-1,3,-1,-1))

        # special: discriminator, uses input+pred+prev_pred_warped
        if self.has_discriminator:
            pred_with_color = torch.cat([
                pred_mask,
                pred_normal,
                pred_color,
                pred_ao], dim=1)

            prev_pred_normal = ScreenSpaceShading.normalize(prev_pred_warped[:,1:4,:,:], dim=1)
            prev_pred_with_color = torch.cat([
                prev_pred_warped[:,0:1,:,:],
                prev_pred_normal,
                self.shading(torch.cat([
                    prev_pred_warped[:,0:1,:,:],
                    prev_pred_normal,
                    prev_pred_warped[:,4:6,:,:]
                    ], dim=1)),
                prev_pred_warped[:,5:6,:,:]
                ], dim=1)

            input_pad = LossNetUnshaded.pad(input, self.padding)
            prev_input_warped_pad = LossNetUnshaded.pad(prev_input_warped, self.padding)
            pred_with_color_pad = LossNetUnshaded.pad(pred_with_color, self.padding)
            prev_pred_warped_pad = LossNetUnshaded.pad(prev_pred_with_color, self.padding)

            if ('adv','all') in self.weight_dict.keys(): # spatial-temporal
                discr_input = torch.cat([input_pad, prev_input_warped_pad, pred_with_color_pad, prev_pred_warped_pad], dim=1)
                gen_loss = self.adv_loss(self.discriminator(discr_input))
                loss_values['discr_pred'] = gen_loss.item()
                generator_loss += self.weight_dict[('adv','all')] * gen_loss

            if ('tgan','all') in self.weight_dict.keys(): #temporal
                discr_input = torch.cat([pred_with_color_pad, prev_pred_warped_pad], dim=1)
                gen_loss = self.temp_adv_loss(self.temp_discriminator(discr_input))
                loss_values['temp_discr_pred'] = gen_loss.item()
                generator_loss += self.weight_dict[('tgan','all')] * gen_loss

            if ('sgan','all') in self.weight_dict.keys(): #spatial
                discr_input = torch.cat([input_pad, pred_with_color_pad], dim=1)
                gen_loss = self.spatial_adv_loss(self.spatial_discriminator(discr_input))
                loss_values['spatial_discr_pred'] = gen_loss.item()
                generator_loss += self.weight_dict[('sgan','all')] * gen_loss

        # special: temporal l2 loss, uses input (for the mask) + pred + prev_warped
        if self.has_temporal_l2_loss:
            prev_pred_mask = prev_pred_warped[:,0:1,:,:]
            prev_pred_normal = ScreenSpaceShading.normalize(prev_pred_warped[:,1:4,:,:], dim=1)
            if ('temp-l2','mask') in self.weight_dict.keys():
                loss = self.loss_dict['temp-l2'](pred_mask, prev_pred_mask)
                loss_values[('temp-l2','mask')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','mask')] * loss
            if ('temp-l2','normal') in self.weight_dict.keys():
                loss = self.loss_dict['temp-l2'](
                    pred_normal*gt_mask_clamp, 
                    prev_pred_normal*gt_mask_clamp)
                loss_values[('temp-l2','normal')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','normal')] * loss
            if ('temp-l2','ao') in self.weight_dict.keys():
                prev_pred_ao = prev_pred_warped[:,5:6,:,:]
                loss = self.loss_dict['temp-l2'](
                    pred_ao*gt_mask_clamp, 
                    prev_pred_ao*gt_mask_clamp)
                loss_values[('temp-l2','ao')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','ao')] * loss
            if ('temp-l2','depth') in self.weight_dict.keys():
                prev_pred_depth = prev_pred_warped[:,4:5,:,:]
                loss = self.loss_dict['temp-l2'](
                    pred_depth*gt_mask_clamp, 
                    prev_pred_depth*gt_mask_clamp)
                loss_values[('temp-l2','depth')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','depth')] * loss
            if ('temp-l2','color') in self.weight_dict.keys():
                prev_pred_color = self.shading(prev_pred_warped)
                loss = self.loss_dict['temp-l2'](pred_color, prev_pred_color)
                loss_values[('temp-l2','color')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','color')] * loss

        return generator_loss, loss_values

    """deprecated"""
    def evaluate_discriminator(self, 
                               current_input, previous_input,
                               current_output, previous_output):
        """
        Discriminator takes the following inputs:
            - current input upsampled (B x 5 (mask+normal+depth) x H x W)
            - previous input warped upsampled (B x 5 x H x W)
            - current prediction with color (B x 8 (mask+normal+ao+color) x H x W)
            - previous prediction warped with color (B x 8 x H x W)
        All tensors are already padded
        Returns the score of the discriminator, averaged over the batch
        """
        assert current_input.shape[1] == 5
        assert previous_input.shape[1] == 5
        assert current_output.shape[1] == 8
        assert previous_output.shape[1] == 8
        B, _, H, W = current_input.shape

        input = torch.cat([current_input, previous_input, current_output, previous_output], dim=1)
        return self.adv_loss(self.discriminator(input))

    def train_discriminator(self, input, gt_high, 
                            previous_input, gt_prev_warped,
                            pred_high, pred_prev_warped):
        """
        All inputs are in high resolution.
        input: B x 5 x H x W (mask+normal+depth)
        gt_high: B x 6 x H x W
        previous_input: B x 5 x H x W
        gt_prev_warped: B x 6 x H x W
        pred_high: B x 6 x H x W
        pred_prev_warped: B x 6 x H x W
        """

        assert self.has_discriminator

        def colorize_and_pad(tensor):
            assert tensor.shape[1]==6
            mask = tensor[:,0:1,:,:]
            normal = ScreenSpaceShading.normalize(tensor[:,1:4,:,:], dim=1)
            depth_ao = tensor[:,4:6,:,:]
            ao = tensor[:,5:6,:,:]
            color = self.shading(torch.cat([mask, normal, depth_ao], dim=1))
            tensor_with_color = torch.cat([mask, normal, ao, color], dim=1)
            return LossNetUnshaded.pad(tensor_with_color, self.padding)

        # assemble input
        input = LossNetUnshaded.pad(input, self.padding)
        gt_high = colorize_and_pad(gt_high)
        pred_high = colorize_and_pad(pred_high)
        previous_input = LossNetUnshaded.pad(previous_input, self.padding)
        gt_prev_warped = colorize_and_pad(gt_prev_warped)
        pred_prev_warped = colorize_and_pad(pred_prev_warped)

        # compute losses
        discr_loss = 0
        gt_score = 0
        pred_score = 0

        if ('adv','all') in self.weight_dict.keys(): # spatial-temporal
            gt_input = torch.cat([
                input, 
                previous_input, 
                gt_high, 
                gt_prev_warped], dim=1)
            pred_input = torch.cat([
                input, 
                previous_input, 
                pred_high, 
                pred_prev_warped], dim=1)
            discr_loss0, gt_score0, pred_score0 = self.adv_loss.train_discr(
                gt_input, pred_input, self.discriminator)
            discr_loss += self.weight_dict[('adv','all')] * discr_loss0
            gt_score += self.weight_dict[('adv','all')] * gt_score0
            pred_score += self.weight_dict[('adv','all')] * pred_score0

        if ('tgan','all') in self.weight_dict.keys(): # temporal
            gt_input = torch.cat([
                gt_high, 
                gt_prev_warped], dim=1)
            pred_input = torch.cat([
                pred_high, 
                pred_prev_warped], dim=1)
            discr_loss0, gt_score0, pred_score0 = self.temp_adv_loss.train_discr(
                gt_input, pred_input, self.temp_discriminator)
            discr_loss += self.weight_dict[('tgan','all')] * discr_loss0
            gt_score += self.weight_dict[('tgan','all')] * gt_score0
            pred_score += self.weight_dict[('tgan','all')] * pred_score0

        if ('sgan','all') in self.weight_dict.keys(): # spatial-temporal
            gt_input = torch.cat([
                input, 
                gt_high], dim=1)
            pred_input = torch.cat([
                input, 
                pred_high], dim=1)
            discr_loss0, gt_score0, pred_score0 = self.spatial_adv_loss.train_discr(
                gt_input, pred_input, self.spatial_discriminator)
            discr_loss += self.weight_dict[('sgan','all')] * discr_loss0
            gt_score += self.weight_dict[('sgan','all')] * gt_score0
            pred_score += self.weight_dict[('sgan','all')] * pred_score0

        return discr_loss, gt_score, pred_score
Exemplo n.º 5
0
def run():
    torch.ops.load_library("./Renderer.dll")

    #########################
    # CONFIGURATION
    #########################

    if 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIso2/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            #("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            ("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #("adaptive011", "adaptive011_epoch500"), #title, file prefix
            ("adaptive019", "adaptive019_epoch470"),
            ("adaptive023", "adaptive023_epoch300")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['halton', 'plastic', 'random']

        HEATMAP_MIN = [0.01, 0.05, 0.2]
        HEATMAP_MEAN = [0.05, 0.1, 0.2, 0.5]

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 8

    elif 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance6/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            #("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
            #("Head", "gt-rendering-head.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #title, file prefix
            #("U-Net (5-4)", "sizes/size5-4_epoch500"),
            #("Enhance-Net (epoch 50)", "enhance2_imp050_epoch050"),
            #("Enhance-Net (epoch 400)", "enhance2_imp050_epoch400"),
            #("Enhance-Net (Thorax)", "enhance_imp050_Thorax_epoch200"),
            #("Enhance-Net (RM)", "enhance_imp050_RM_epoch200"),
            #("Imp100", "enhance4_imp100_epoch300"),
            #("Imp100res", "enhance4_imp100res_epoch230"),
            ("Imp100res+N", "enhance4_imp100res+N_epoch300"),
            ("Imp100+N", "enhance4_imp100+N_epoch300"),
            #("Imp100+N-res", "enhance4_imp100+N-res_epoch300"),
            #("Imp100+N-resInterp", "enhance4_imp100+N-resInterp_epoch300"),
            #("U-Net (5-4)", "size5-4_epoch500"),
            #("U-Net (5-3)", "size5-3_epoch500"),
            #("U-Net (4-4)", "size4-4_epoch500"),
        ]

        # Test if it is better to post-train with dense networks and PDE inpainting
        POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/"
        POSTTRAIN_NETWORKS = [
            # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'}
            #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['plastic']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [
            0.05
        ]  #[0.01, 0.02, 0.03, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 4

    elif 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance5Sampling/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #title, file prefix
            ("Enhance-Net (epoch 400)", "enhance2_imp050_epoch400"),
        ]

        # Test if it is better to post-train with dense networks and PDE inpainting
        POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/"
        POSTTRAIN_NETWORKS = [
            # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'}
            #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['halton', 'plastic', 'random', 'regular']
        #SAMPLING_PATTERNS = ['regular']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [0.05]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 4

    elif 1:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance8Sampling/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            #("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
            #("Head", "gt-rendering-head.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #title, file prefix
            ("regular", "enhance7_regular_epoch190"),
            ("random", "enhance7_random_epoch190"),
            ("halton", "enhance7_halton_epoch190"),
            ("plastic", "enhance7_plastic_epoch190"),
        ]

        # Test if it is better to post-train with dense networks and PDE inpainting
        POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/"
        POSTTRAIN_NETWORKS = [
            # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'}
            #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['regular', 'random', 'halton', 'plastic']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [
            0.05
        ]  #[0.01, 0.02, 0.03, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 4

    elif 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoImp/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            #("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/imp/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #("adaptive011", "adaptive011_epoch500"), #title, file prefix
            ("imp005", "imp005_epoch500"),
            ("imp010", "imp010_epoch500"),
            ("imp020", "imp020_epoch500"),
            ("imp050", "imp050_epoch500"),
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['halton']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [0.005, 0.01, 0.02, 0.05, 0.1]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 16

    #########################
    # LOADING
    #########################

    device = torch.device("cuda")

    # Load Networks
    IMPORTANCE_BASELINE1 = "ibase1"
    IMPORTANCE_BASELINE2 = "ibase2"
    IMPORTANCE_BASELINE3 = "ibase3"
    RECON_BASELINE = "rbase"

    # load importance model
    print("load importance networks")

    class ImportanceModel:
        def __init__(self, file):
            if file == IMPORTANCE_BASELINE1:
                self._net = importance.UniformImportanceMap(1, 0.5)
                self._upscaling = 1
                self._name = "constant"
                self.disableTemporal = True
                self._requiresPrevious = False
            elif file == IMPORTANCE_BASELINE2:
                self._net = importance.GradientImportanceMap(
                    1, (1, 1), (2, 1), (3, 1))
                self._upscaling = 1
                self._name = "curvature"
                self.disableTemporal = True
                self._requiresPrevious = False
            else:
                self._name = file[0]
                file = os.path.join(NETWORK_DIR, file[1] + "_importance.pt")
                extra_files = torch._C.ExtraFilesMap()
                extra_files['settings.json'] = ""
                self._net = torch.jit.load(file,
                                           map_location=device,
                                           _extra_files=extra_files)
                settings = json.loads(extra_files['settings.json'])
                self._upscaling = settings['networkUpscale']
                self._requiresPrevious = settings.get("requiresPrevious",
                                                      False)
                self.disableTemporal = settings.get("disableTemporal", True)

        def networkUpscaling(self):
            return self._upscaling

        def name(self):
            return self._name

        def __repr__(self):
            return self.name()

        def call(self, input, prev_warped_out):
            if self._requiresPrevious:
                input = torch.cat([
                    input,
                    models.VideoTools.flatten_high(prev_warped_out,
                                                   self._upscaling)
                ],
                                  dim=1)
            input = F.pad(input, [IMPORTANCE_BORDER] * 4, 'constant', 0)
            output = self._net(input)  # the network call
            output = F.pad(output, [-IMPORTANCE_BORDER * self._upscaling] * 4,
                           'constant', 0)
            return output

    class LuminanceImportanceModel:
        def __init__(self):
            self.disableTemporal = True

        def setTestFile(self, filename):
            importance_file = filename[:-5] + "-luminanceImportance.hdf5"
            if os.path.exists(importance_file):
                self._exist = True
                self._file = h5py.File(importance_file, 'r')
                self._dset = self._file['importance']
            else:
                self._exist = False
                self._file = None
                self._dset = None

        def isAvailable(self):
            return self._exist

        def setIndices(self, indices: torch.Tensor):
            assert len(indices.shape) == 1
            self._indices = list(indices.cpu().numpy())

        def setTime(self, time):
            self._time = time

        def networkUpscaling(self):
            return UPSCALING

        def name(self):
            return "luminance-contrast"

        def __repr__(self):
            return self.name()

        def call(self, input, prev_warped_out):
            B, C, H, W = input.shape
            if not self._exist:
                return torch.ones(B,
                                  1,
                                  H,
                                  W,
                                  dtype=input.dtype,
                                  device=input.device)
            outputs = []
            for idx in self._indices:
                outputs.append(
                    torch.from_numpy(self._dset[idx, self._time,
                                                ...]).to(device=input.device))
            return torch.stack(outputs, dim=0)

    importanceBaseline1 = ImportanceModel(IMPORTANCE_BASELINE1)
    importanceBaseline2 = ImportanceModel(IMPORTANCE_BASELINE2)
    importanceBaselineLuminance = LuminanceImportanceModel()
    importanceModels = [ImportanceModel(f) for f in NETWORKS]

    # load reconstruction networks
    print("load reconstruction networks")

    class ReconstructionModel:
        def __init__(self, file):
            if file == RECON_BASELINE:

                class Inpainting(nn.Module):
                    def forward(self, x, mask):
                        input = x[:, 0:6, :, :].contiguous(
                        )  # mask, normal xyz, depth, ao
                        mask = x[:, 6, :, :].contiguous()
                        return torch.ops.renderer.fast_inpaint(mask, input)

                self._net = Inpainting()
                self._upscaling = 1
                self._name = "inpainting"
                self.disableTemporal = True
            else:
                self._name = file[0]
                file = os.path.join(NETWORK_DIR, file[1] + "_recon.pt")
                extra_files = torch._C.ExtraFilesMap()
                extra_files['settings.json'] = ""
                self._net = torch.jit.load(file,
                                           map_location=device,
                                           _extra_files=extra_files)
                self._settings = json.loads(extra_files['settings.json'])
                self.disableTemporal = False
                requiresMask = self._settings.get('expectMask', False)
                if self._settings.get("interpolateInput", False):
                    self._originalNet = self._net

                    class Inpainting2(nn.Module):
                        def __init__(self, orignalNet, requiresMask):
                            super().__init__()
                            self._n = orignalNet
                            self._requiresMask = requiresMask

                        def forward(self, x, mask):
                            input = x[:, 0:6, :, :].contiguous(
                            )  # mask, normal xyz, depth, ao
                            mask = x[:, 6, :, :].contiguous()
                            inpainted = torch.ops.renderer.fast_inpaint(
                                mask, input)
                            x[:, 0:6, :, :] = inpainted
                            if self._requiresMask:
                                return self._n(x, mask)
                            else:
                                return self._n(x)

                    self._net = Inpainting2(self._originalNet, requiresMask)

        def name(self):
            return self._name

        def __repr__(self):
            return self.name()

        def call(self, input, mask, prev_warped_out):
            input = torch.cat([input, prev_warped_out], dim=1)
            output = self._net(input, mask)
            return output

    class ReconstructionModelPostTrain:
        """
        Reconstruction model that are trained as dense reconstruction networks
        after the adaptive training.
        They don't recive the sampling mask as input, but can start with PDE-based inpainting
        """
        def __init__(self, name: str, model_path: str, inpainting: str):
            assert inpainting == 'fast' or inpainting == 'pde', "inpainting must be either 'fast' or 'pde', but got %s" % inpainting
            self._inpainting = inpainting

            self._name = name
            file = os.path.join(POSTTRAIN_NETWORK_DIR, model_path)
            extra_files = torch._C.ExtraFilesMap()
            extra_files['settings.json'] = ""
            self._net = torch.jit.load(file,
                                       map_location=device,
                                       _extra_files=extra_files)
            self._settings = json.loads(extra_files['settings.json'])
            assert self._settings.get(
                'upscale_factor', None) == 1, "selected file is not a 1x SRNet"
            self.disableTemporal = False

        def name(self):
            return self._name

        def __repr__(self):
            return self.name()

        def call(self, input, prev_warped_out):
            # no sampling and no AO
            input_no_sampling = input[:, 0:5, :, :].contiguous(
            )  # mask, normal xyz, depth
            sampling_mask = input[:, 6, :, :].contiguous()
            # perform inpainting
            if self._inpainting == 'pde':
                inpainted = torch.ops.renderer.pde_inpaint(
                    sampling_mask,
                    input_no_sampling,
                    200,
                    1e-4,
                    5,
                    2,  # m0, epsilon, m1, m2
                    0,  # mc -> multigrid recursion count. =0 disables the multigrid hierarchy
                    9,
                    0)  # ms, m3
            else:
                inpainted = torch.ops.renderer.fast_inpaint(
                    sampling_mask, input_no_sampling)
            # run network
            input = torch.cat([inpainted, prev_warped_out], dim=1)
            output = self._net(input)
            if isinstance(output, tuple):
                output = output[0]
            return output

    reconBaseline = ReconstructionModel(RECON_BASELINE)
    reconModels = [ReconstructionModel(f) for f in NETWORKS]
    reconPostModels = [
        ReconstructionModelPostTrain(name, file, inpainting)
        for (name, file, inpainting) in POSTTRAIN_NETWORKS
    ]
    allReconModels = reconModels + reconPostModels

    NETWORK_COMBINATIONS = \
        [(importanceBaseline1, reconBaseline), (importanceBaseline2, reconBaseline)] + \
        [(importanceBaselineLuminance, reconBaseline)] + \
        [(importanceBaseline1, reconNet) for reconNet in allReconModels] + \
        [(importanceBaseline2, reconNet) for reconNet in allReconModels] + \
        [(importanceBaselineLuminance, reconNet) for reconNet in allReconModels] + \
        [(importanceNet, reconBaseline) for importanceNet in importanceModels] + \
        list(zip(importanceModels, reconModels)) + \
        [(importanceNet, reconPostModel) for importanceNet in importanceModels for reconPostModel in reconPostModels]
    #NETWORK_COMBINATIONS = list(zip(importanceModels, reconModels))
    print("Network combinations:")
    for (i, r) in NETWORK_COMBINATIONS:
        print("  %s - %s" % (i.name(), r.name()))

    # load sampling patterns
    print("load sampling patterns")
    with h5py.File(SAMPLING_FILE, 'r') as f:
        sampling_pattern = dict([(name, torch.from_numpy(f[name][...]).to(device)) \
            for name in SAMPLING_PATTERNS])

    # create shading
    shading = ScreenSpaceShading(device)
    shading.fov(30)
    shading.ambient_light_color(np.array([0.1, 0.1, 0.1]))
    shading.diffuse_light_color(np.array([1.0, 1.0, 1.0]))
    shading.specular_light_color(np.array([0.0, 0.0, 0.0]))
    shading.specular_exponent(16)
    shading.light_direction(np.array([0.1, 0.1, 1.0]))
    shading.material_color(np.array([1.0, 0.3, 0.0]))
    AMBIENT_OCCLUSION_STRENGTH = 1.0
    shading.ambient_occlusion(1.0)
    shading.inverse_ao = False

    #heatmap
    HEATMAP_CFG = [(min, mean) for min in HEATMAP_MIN for mean in HEATMAP_MEAN
                   if min < mean]
    print("heatmap configs:", HEATMAP_CFG)

    #########################
    # DEFINE STATISTICS
    #########################
    ssimLoss = SSIM(size_average=False)
    ssimLoss.to(device)
    psnrLoss = PSNR()
    psnrLoss.to(device)
    lpipsColor = lpips.PerceptualLoss(model='net-lin',
                                      net='alex',
                                      use_gpu=True)
    MIN_FILLING = 0.05
    NUM_BINS = 200

    class Statistics:
        def __init__(self):
            self.histogram_color_withAO = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_color_noAO = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_depth = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_normal = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_mask = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_ao = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_counter = 0

        def create_datasets(self, hdf5_file: h5py.File, stats_name: str,
                            histo_name: str, num_samples: int,
                            extra_info: dict):

            self.expected_num_samples = num_samples
            stats_shape = (num_samples, len(list(StatField)))
            self.stats_file = hdf5_file.require_dataset(stats_name,
                                                        stats_shape,
                                                        dtype='f',
                                                        exact=True)
            self.stats_file.attrs['NumFields'] = len(list(StatField))
            for field in list(StatField):
                self.stats_file.attrs['Field%d' % field.value] = field.name
            for key, value in extra_info.items():
                self.stats_file.attrs[key] = value
            self.stats_index = 0

            histo_shape = (NUM_BINS, len(list(HistoField)))
            self.histo_file = hdf5_file.require_dataset(histo_name,
                                                        histo_shape,
                                                        dtype='f',
                                                        exact=True)
            self.histo_file.attrs['NumFields'] = len(list(HistoField))
            for field in list(HistoField):
                self.histo_file.attrs['Field%d' % field.value] = field.name
            for key, value in extra_info.items():
                self.histo_file.attrs[key] = value

        def add_timestep_sample(self, pred_mnda, gt_mnda, sampling_mask):
            """
            adds a timestep sample:
            pred_mnda: prediction: mask, normal, depth, AO
            gt_mnda: ground truth: mask, normal, depth, AO
            """
            B = pred_mnda.shape[0]

            #shading
            shading.ambient_occlusion(AMBIENT_OCCLUSION_STRENGTH)
            pred_color_withAO = shading(pred_mnda)
            gt_color_withAO = shading(gt_mnda)
            shading.ambient_occlusion(0.0)
            pred_color_noAO = shading(pred_mnda)
            gt_color_noAO = shading(gt_mnda)

            #apply border
            pred_mnda = pred_mnda[:, :, LOSS_BORDER:-LOSS_BORDER,
                                  LOSS_BORDER:-LOSS_BORDER]
            pred_color_withAO = pred_color_withAO[:, :,
                                                  LOSS_BORDER:-LOSS_BORDER,
                                                  LOSS_BORDER:-LOSS_BORDER]
            pred_color_noAO = pred_color_noAO[:, :, LOSS_BORDER:-LOSS_BORDER,
                                              LOSS_BORDER:-LOSS_BORDER]
            gt_mnda = gt_mnda[:, :, LOSS_BORDER:-LOSS_BORDER,
                              LOSS_BORDER:-LOSS_BORDER]
            gt_color_withAO = gt_color_withAO[:, :, LOSS_BORDER:-LOSS_BORDER,
                                              LOSS_BORDER:-LOSS_BORDER]
            gt_color_noAO = gt_color_noAO[:, :, LOSS_BORDER:-LOSS_BORDER,
                                          LOSS_BORDER:-LOSS_BORDER]

            mask = gt_mnda[:, 0:1, :, :] * 0.5 + 0.5

            # PSNR
            psnr_mask = psnrLoss(pred_mnda[:, 0:1, :, :],
                                 gt_mnda[:, 0:1, :, :]).cpu().numpy()
            psnr_normal = psnrLoss(pred_mnda[:, 1:4, :, :],
                                   gt_mnda[:, 1:4, :, :],
                                   mask=mask).cpu().numpy()
            psnr_depth = psnrLoss(pred_mnda[:, 4:5, :, :],
                                  gt_mnda[:, 4:5, :, :],
                                  mask=mask).cpu().numpy()
            psnr_ao = psnrLoss(pred_mnda[:, 5:6, :, :],
                               gt_mnda[:, 5:6, :, :],
                               mask=mask).cpu().numpy()
            psnr_color_withAO = psnrLoss(pred_color_withAO,
                                         gt_color_withAO,
                                         mask=mask).cpu().numpy()
            psnr_color_noAO = psnrLoss(pred_color_noAO,
                                       gt_color_noAO,
                                       mask=mask).cpu().numpy()

            # SSIM
            ssim_mask = ssimLoss(pred_mnda[:, 0:1, :, :],
                                 gt_mnda[:, 0:1, :, :]).cpu().numpy()
            pred_mnda = gt_mnda + mask * (pred_mnda - gt_mnda)
            ssim_normal = ssimLoss(pred_mnda[:, 1:4, :, :],
                                   gt_mnda[:, 1:4, :, :]).cpu().numpy()
            ssim_depth = ssimLoss(pred_mnda[:, 4:5, :, :],
                                  gt_mnda[:, 4:5, :, :]).cpu().numpy()
            ssim_ao = ssimLoss(pred_mnda[:, 5:6, :, :],
                               gt_mnda[:, 5:6, :, :]).cpu().numpy()
            ssim_color_withAO = ssimLoss(pred_color_withAO,
                                         gt_color_withAO).cpu().numpy()
            ssim_color_noAO = ssimLoss(pred_color_noAO,
                                       gt_color_noAO).cpu().numpy()

            # Perceptual
            lpips_color_withAO = torch.cat([
                lpipsColor(
                    pred_color_withAO[b], gt_color_withAO[b], normalize=True)
                for b in range(B)
            ],
                                           dim=0).cpu().numpy()
            lpips_color_noAO = torch.cat([
                lpipsColor(
                    pred_color_noAO[b], gt_color_noAO[b], normalize=True)
                for b in range(B)
            ],
                                         dim=0).cpu().numpy()

            # Samples
            samples = torch.mean(sampling_mask, dim=(1, 2, 3)).cpu().numpy()

            # Write samples to file
            for b in range(B):
                assert self.stats_index < self.expected_num_samples, "Adding more samples than specified"
                self.stats_file[self.stats_index, :] = np.array([
                    psnr_mask[b], psnr_normal[b], psnr_depth[b], psnr_ao[b],
                    psnr_color_noAO[b], psnr_color_withAO[b], ssim_mask[b],
                    ssim_normal[b], ssim_depth[b], ssim_ao[b],
                    ssim_color_noAO[b], ssim_color_withAO[b],
                    lpips_color_noAO[b], lpips_color_withAO[b], samples[b]
                ],
                                                                dtype='f')
                self.stats_index += 1

            # Histogram
            self.histogram_counter += 1

            mask_diff = F.l1_loss(gt_mnda[:, 0, :, :],
                                  pred_mnda[:, 0, :, :],
                                  reduction='none')
            histogram, _ = np.histogram(mask_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_mask += (
                histogram /
                (NUM_BINS * B) - self.histogram_mask) / self.histogram_counter

            #normal_diff = (-F.cosine_similarity(gt_mnda[0,1:4,:,:], pred_mnda[0,1:4,:,:], dim=0)+1)/2
            normal_diff = F.l1_loss(gt_mnda[:, 1:4, :, :],
                                    pred_mnda[:, 1:4, :, :],
                                    reduction='none').sum(dim=0) / 6
            histogram, _ = np.histogram(normal_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_normal += (histogram /
                                      (NUM_BINS * B) - self.histogram_normal
                                      ) / self.histogram_counter

            depth_diff = F.l1_loss(gt_mnda[:, 4, :, :],
                                   pred_mnda[:, 4, :, :],
                                   reduction='none')
            histogram, _ = np.histogram(depth_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_depth += (
                histogram /
                (NUM_BINS * B) - self.histogram_depth) / self.histogram_counter

            ao_diff = F.l1_loss(gt_mnda[:, 5, :, :],
                                pred_mnda[:, 5, :, :],
                                reduction='none')
            histogram, _ = np.histogram(ao_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_ao += (histogram / (NUM_BINS * B) -
                                  self.histogram_ao) / self.histogram_counter

            color_diff = F.l1_loss(gt_color_withAO[:, 0, :, :],
                                   pred_color_withAO[:, 0, :, :],
                                   reduction='none')
            histogram, _ = np.histogram(color_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_color_withAO += (
                histogram / (NUM_BINS * B) -
                self.histogram_color_withAO) / self.histogram_counter

            color_diff = F.l1_loss(gt_color_noAO[:, 0, :, :],
                                   pred_color_noAO[:, 0, :, :],
                                   reduction='none')
            histogram, _ = np.histogram(color_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_color_noAO += (
                histogram / (NUM_BINS * B) -
                self.histogram_color_noAO) / self.histogram_counter

        def close_stats_file(self):
            self.stats_file.attrs['NumEntries'] = self.stats_index

        def write_histogram(self):
            """
            After every sample for the current dataset was processed, write
            a histogram of the errors in a new file
            """
            for i in range(NUM_BINS):
                self.histo_file[i, :] = np.array([
                    i / NUM_BINS, (i + 1) / NUM_BINS, self.histogram_mask[i],
                    self.histogram_normal[i], self.histogram_depth[i],
                    self.histogram_ao[i], self.histogram_color_withAO[i],
                    self.histogram_color_noAO[i]
                ])

    #########################
    # DATASET
    #########################
    class FullResDataset(torch.utils.data.Dataset):
        def __init__(self, file):
            self.hdf5_file = h5py.File(file, 'r')
            self.dset = self.hdf5_file['gt']
            print("Dataset shape:", self.dset.shape)

        def __len__(self):
            return self.dset.shape[0]

        def num_timesteps(self):
            return self.dset.shape[1]

        def __getitem__(self, idx):
            return (self.dset[idx, ...], np.array(idx))

    #########################
    # COMPUTE STATS for each dataset
    #########################
    for dataset_name, dataset_file in DATASETS:
        dataset_file = os.path.join(DATASET_PREFIX, dataset_file)
        print("Compute statistics for", dataset_name)

        # init luminance importance map
        importanceBaselineLuminance.setTestFile(dataset_file)
        if importanceBaselineLuminance.isAvailable():
            print("Luminance-contrast importance map is available")

        # create output file
        os.makedirs(OUTPUT_FOLDER, exist_ok=True)
        output_file = os.path.join(OUTPUT_FOLDER, dataset_name + '.hdf5')
        print("Save to", output_file)
        with h5py.File(output_file, 'a') as output_hdf5_file:

            # load dataset
            set = FullResDataset(dataset_file)
            data_loader = torch.utils.data.DataLoader(set,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=False)

            # define statistics
            StatsCfg = collections.namedtuple(
                "StatsCfg", "stats importance recon heatmin heatmean pattern")
            statistics = []
            for (inet, rnet) in NETWORK_COMBINATIONS:
                for (heatmin, heatmean) in HEATMAP_CFG:
                    for pattern in SAMPLING_PATTERNS:
                        stats_info = {
                            'importance': inet.name(),
                            'reconstruction': rnet.name(),
                            'heatmin': heatmin,
                            'heatmean': heatmean,
                            'pattern': pattern
                        }
                        stats_filename = "Stats_%s_%s_%03d_%03d_%s" % (
                            inet.name(), rnet.name(), heatmin * 100,
                            heatmean * 100, pattern)
                        histo_filename = "Histogram_%s_%s_%03d_%03d_%s" % (
                            inet.name(), rnet.name(), heatmin * 100,
                            heatmean * 100, pattern)
                        s = Statistics()
                        s.create_datasets(output_hdf5_file, stats_filename,
                                          histo_filename,
                                          len(set) * set.num_timesteps(),
                                          stats_info)
                        statistics.append(
                            StatsCfg(stats=s,
                                     importance=inet,
                                     recon=rnet,
                                     heatmin=heatmin,
                                     heatmean=heatmean,
                                     pattern=pattern))
            print(len(statistics),
                  " different combinations are performed per sample")

            # compute statistics
            try:
                with torch.no_grad():
                    num_minibatch = len(data_loader)
                    pg = ProgressBar(num_minibatch, 'Evaluation', length=50)
                    for iteration, (batch, batch_indices) in enumerate(
                            data_loader, 0):
                        pg.print_progress_bar(iteration)
                        batch = batch.to(device)
                        importanceBaselineLuminance.setIndices(batch_indices)
                        B, T, C, H, W = batch.shape

                        # try out each combination
                        for s in statistics:
                            #print(s)
                            # get input to evaluation
                            importanceNetUpscale = s.importance.networkUpscaling(
                            )
                            importancePostUpscale = UPSCALING // importanceNetUpscale
                            crop_low = torch.nn.functional.interpolate(
                                batch.reshape(B * T, C, H, W),
                                scale_factor=1 / UPSCALING,
                                mode='area').reshape(B, T, C, H // UPSCALING,
                                                     W // UPSCALING)
                            pattern = sampling_pattern[s.pattern][:H, :W]
                            crop_high = batch

                            # loop over timesteps
                            pattern = pattern.unsqueeze(0).unsqueeze(0)
                            previous_importance = None
                            previous_output = None
                            reconstructions = []
                            for j in range(T):
                                importanceBaselineLuminance.setTime(j)
                                # extract flow (always the last two channels of crop_high)
                                flow = crop_high[:, j, C - 2:, :, :]

                                # compute importance map
                                importance_input = crop_low[:, j, :5, :, :]
                                if j == 0 or s.importance.disableTemporal:
                                    previous_input = torch.zeros(
                                        B,
                                        1,
                                        importance_input.shape[2] *
                                        importanceNetUpscale,
                                        importance_input.shape[3] *
                                        importanceNetUpscale,
                                        dtype=crop_high.dtype,
                                        device=crop_high.device)
                                else:
                                    flow_low = F.interpolate(
                                        flow,
                                        scale_factor=1 / importancePostUpscale)
                                    previous_input = models.VideoTools.warp_upscale(
                                        previous_importance, flow_low, 1,
                                        False)
                                importance_map = s.importance.call(
                                    importance_input, previous_input)
                                if len(importance_map.shape) == 3:
                                    importance_map = importance_map.unsqueeze(
                                        1)
                                previous_importance = importance_map

                                target_mean = s.heatmean
                                if USE_BINARY_SEARCH_ON_MEAN:
                                    # For regular sampling, the normalization does not work properly,
                                    # use binary search on the heatmean instead
                                    def f(x):
                                        postprocess = importance.PostProcess(
                                            s.heatmin, x,
                                            importancePostUpscale,
                                            LOSS_BORDER //
                                            importancePostUpscale, 'basic')
                                        importance_map2 = postprocess(
                                            importance_map)[0].unsqueeze(1)
                                        sampling_mask = (
                                            importance_map2 >= pattern).to(
                                                dtype=importance_map.dtype)
                                        samples = torch.mean(
                                            sampling_mask).item()
                                        return samples

                                    target_mean = binarySearch(
                                        f, s.heatmean, s.heatmean, 10, 0, 1)
                                    #print("Binary search for #samples, mean start={}, result={} with samples={}, original={}".
                                    #      format(s.heatmean, s.heatmean, f(target_mean), f(s.heatmean)))

                                # normalize and upscale importance map
                                postprocess = importance.PostProcess(
                                    s.heatmin, target_mean,
                                    importancePostUpscale,
                                    LOSS_BORDER // importancePostUpscale,
                                    'basic')
                                importance_map = postprocess(
                                    importance_map)[0].unsqueeze(1)
                                #print("mean:", torch.mean(importance_map).item())

                                # create samples
                                sample_mask = (importance_map >= pattern).to(
                                    dtype=importance_map.dtype)

                                reconstruction_input = torch.cat(
                                    (
                                        sample_mask *
                                        crop_high[:, j, 0:
                                                  5, :, :],  # mask, normal x, normal y, normal z, depth
                                        sample_mask * torch.ones(
                                            B,
                                            1,
                                            H,
                                            W,
                                            dtype=crop_high.dtype,
                                            device=crop_high.device),  # ao
                                        sample_mask),  # sample mask
                                    dim=1)

                                # warp previous output
                                if j == 0 or s.recon.disableTemporal:
                                    previous_input = torch.zeros(
                                        B,
                                        6,
                                        H,
                                        W,
                                        dtype=crop_high.dtype,
                                        device=crop_high.device)
                                else:
                                    previous_input = models.VideoTools.warp_upscale(
                                        previous_output, flow, 1, False)

                                # run reconstruction network
                                reconstruction = s.recon.call(
                                    reconstruction_input, sample_mask,
                                    previous_input)

                                # clamp
                                reconstruction_clamped = torch.cat(
                                    [
                                        torch.clamp(
                                            reconstruction[:, 0:1, :, :], -1,
                                            +1),  # mask
                                        ScreenSpaceShading.normalize(
                                            reconstruction[:, 1:4, :, :],
                                            dim=1),
                                        torch.clamp(
                                            reconstruction[:, 4:5, :, :], 0,
                                            +1),  # depth
                                        torch.clamp(reconstruction[:,
                                                                   5:6, :, :],
                                                    0, +1)  # ao
                                    ],
                                    dim=1)
                                reconstructions.append(reconstruction_clamped)

                                # save for next frame
                                previous_output = reconstruction_clamped

                            #endfor: timesteps

                            # compute statistics
                            reconstructions = torch.cat(reconstructions, dim=0)
                            crops_high = torch.cat(
                                [crop_high[:, j, :6, :, :] for j in range(T)],
                                dim=0)
                            sample_masks = torch.cat([sample_mask] * T, dim=0)
                            s.stats.add_timestep_sample(
                                reconstructions, crops_high, sample_masks)

                        # endfor: statistic
                    # endfor: batch

                    pg.print_progress_bar(num_minibatch)
                # end no_grad()
            finally:
                # close files
                for s in statistics:
                    s.stats.write_histogram()
                    s.stats.close_stats_file()
Exemplo n.º 6
0
class LossNetSparse(nn.Module):
    """
    Main Loss Module for unshaded data.

    device: cpu or cuda
    losses: list of loss terms with weighting as string
       Format: <loss>:<target>:<weighting>
       with: loss in {l1, l2, tl2}
             target in {mask, normal, color, ao}
             weighting a positive number
    opt: command line arguments (Namespace object) with:
     - further parameters depending on the losses

    """
    def __init__(self, device, channels, losses, opt, has_flow):
        super().__init__()
        self.padding = opt.lossBorderPadding
        self.has_flow = has_flow
        if has_flow:
            assert channels == 8  # mask, normalX, normalY, normalZ, depth, ao, flowX, floatY
        else:
            assert channels == 6  # mask, normalX, normalY, normalZ, depth, ao
        #List of tuples (name, weight or None)
        self.loss_list = [s.split(':') for s in losses.split(',')]

        # Build losses and weights
        builder = LossBuilder(device)
        self.loss_dict = {}
        self.weight_dict = {}

        self.loss_dict['mse'] = builder.mse()  #always use mse for psnr
        self.weight_dict[('mse', 'color')] = 0.0

        TARGET_INFO = {
            # target name : (#channels, expected min, expected max)
            "mask": (1, -1, +1),
            "normal": (3, -1, +1),
            "color": (3, 0, 1),
            "ao": (1, 0, 1),
            "depth": (1, 0, 1)
        }

        content_layers = []
        style_layers = []
        self.has_discriminator = False
        self.has_style_or_content_loss = False
        self.has_temporal_l2_loss = False
        for entry in self.loss_list:
            if len(entry) < 2:
                raise ValueError("illegal format for loss list: " + entry)
            name = entry[0]
            target = entry[1]
            weight = entry[2] if len(entry) > 2 else None
            if target != 'mask' and target != 'normal' and target != 'color' and target != 'ao' and target != 'depth' and target != 'flow':
                raise ValueError("Unknown target: " + target)

            if 'mse' == name or 'l2' == name or 'l2_loss' == name:
                self.weight_dict[(
                    'mse',
                    target)] = float(weight) if weight is not None else 1.0
            elif 'l1' == name or 'l1_loss' == name:
                self.loss_dict['l1'] = builder.l1_loss()
                self.weight_dict[(
                    'l1',
                    target)] = float(weight) if weight is not None else 1.0
            elif 'tl2' == name or 'temp-l2' == name:
                self.loss_dict['temp-l2'] = builder.mse()
                self.weight_dict[(
                    'temp-l2',
                    target)] = float(weight) if weight is not None else 1.0
                self.has_temporal_l2_loss = True
            elif 'tl1' == name or 'temp-l1' == name:
                self.loss_dict['temp-l1'] = builder.l1_loss()
                self.weight_dict[(
                    'temp-l1',
                    target)] = float(weight) if weight is not None else 1.0
                self.has_temporal_l2_loss = True
            elif 'bounded' == name:
                if target != 'mask':
                    raise ValueError(
                        "'bounded' loss can only be applied on the mask")
                self.weight_dict[(
                    "bounded",
                    "mask")] = float(weight) if weight is not None else 1.0
            elif 'bce' == name:
                if target != 'mask':
                    raise ValueError(
                        "'bce' loss can only be applied on the mask")
                self.weight_dict[(
                    "bce",
                    "mask")] = float(weight) if weight is not None else 1.0
            elif 'ssim' == name:
                self.loss_dict['ssim'] = builder.ssim_loss(
                    TARGET_INFO[target][0])
                self.weight_dict[(
                    'ssim',
                    target)] = float(weight) if weight is not None else 1.0
            elif 'dssim' == name:
                self.loss_dict['dssim'] = builder.dssim_loss(
                    TARGET_INFO[target][0])
                self.weight_dict[(
                    'dssim',
                    target)] = float(weight) if weight is not None else 1.0
            elif 'lpips' == name:
                self.loss_dict['lpips'] = builder.lpips_loss(
                    *TARGET_INFO[target])
                self.weight_dict[(
                    'lpips',
                    target)] = float(weight) if weight is not None else 1.0
            else:
                raise ValueError('unknown loss %s' % name)

        self.loss_dict = nn.ModuleDict(self.loss_dict)
        print('Loss weights:', self.weight_dict)

        self.shading = ScreenSpaceShading(device)
        self.shading.fov(30)
        self.shading.ambient_light_color(
            np.array([opt.lossAmbient, opt.lossAmbient, opt.lossAmbient]))
        self.shading.diffuse_light_color(
            np.array([opt.lossDiffuse, opt.lossDiffuse, opt.lossDiffuse]))
        self.shading.specular_light_color(
            np.array([opt.lossSpecular, opt.lossSpecular, opt.lossSpecular]))
        self.shading.specular_exponent(16)
        self.shading.enable_specular = False
        self.shading.light_direction(np.array([0.0, 0.0, 1.0]))
        self.shading.material_color(np.array([1.0, 1.0, 1.0]))
        self.shading.ambient_occlusion(opt.lossAO)
        print('LossNet: ambient occlusion strength:', opt.lossAO)

    @staticmethod
    def pad(img, border):
        """
        overwrites the border of 'img' with zeros.
        The size of the border is specified by 'border'.
        The output size is not changed.
        """
        if border == 0:
            return img
        b, c, h, w = img.shape
        img_crop = img[:, :, border:h - border, border:h - border]
        img_pad = F.pad(img_crop, (border, border, border, border), 'constant',
                        0)
        _, _, h2, w2 = img_pad.shape
        assert (h == h2)
        assert (w == w2)
        return img_pad

    #@profile
    def forward(self,
                gt,
                pred,
                prev_pred_warped,
                no_temporal_loss: bool,
                use_checkpoints=False):
        """
        gt: ground truth high resolution image (B x C x H x W)
        pred: predicted high resolution image (B x C x H x W)
        prev_pred_warped: predicted image from the previous frame warped by the flow
               Shape: B x C x H x W
               Only used for temporal losses, can be None if only the other losses are used
        use_checkpoints: True if checkpointing should be used.
               This does not apply here since we don't have GANs or perceptual losses
        """

        # TODO: loss that penalizes deviation from the input samples

        B, C, H, W = gt.shape
        if self.has_flow:
            assert C == 8
        else:
            assert C == 6
        assert gt.shape == pred.shape

        # zero border padding
        gt = LossNetSparse.pad(gt, self.padding)
        pred = LossNetSparse.pad(pred, self.padding)
        if prev_pred_warped is not None:
            prev_pred_warped = LossNetSparse.pad(prev_pred_warped,
                                                 self.padding)

        # extract mask and normal
        gt_mask = gt[:, 0:1, :, :]
        gt_mask_clamp = torch.clamp(gt_mask * 0.5 + 0.5, 0, 1)
        gt_normal = ScreenSpaceShading.normalize(gt[:, 1:4, :, :], dim=1)
        gt_depth = gt[:, 4:5, :, :]
        gt_ao = gt[:, 5:6, :, :]
        if self.has_flow:
            gt_flow = gt[:, 6:8, :, :]
        pred_mask = pred[:, 0:1, :, :]
        pred_mask_clamp = torch.clamp(pred_mask * 0.5 + 0.5, 0, 1)
        pred_normal = ScreenSpaceShading.normalize(pred[:, 1:4, :, :], dim=1)
        pred_depth = pred[:, 4:5, :, :]
        pred_ao = pred[:, 5:6, :, :]
        if self.has_flow:
            pred_flow = pred[:, 6:8, :, :]
        if prev_pred_warped is not None and self.has_temporal_l2_loss:
            prev_pred_mask = prev_pred_warped[:, 0:1, :, :]
            prev_pred_mask_clamp = torch.clamp(prev_pred_mask * 0.5 + 0.5, 0,
                                               1)
            prev_pred_normal = ScreenSpaceShading.normalize(
                prev_pred_warped[:, 1:4, :, :], dim=1)
            prev_pred_depth = prev_pred_warped[:, 4:5, :, :]
            prev_pred_ao = prev_pred_warped[:, 5:6, :, :]
            if self.has_flow:
                prev_pred_flow = prev_pred_warped[:, 6:8, :, :]

        generator_loss = 0.0
        loss_values = {}

        # normal, simple losses, uses gt+pred
        for name in ['mse', 'l1', 'ssim', 'dssim', 'lpips']:
            if (name, 'mask') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_mask, pred_mask)
                loss_values[(name, 'mask')] = loss.item()
                generator_loss += self.weight_dict[(name, 'mask')] * loss
            if (name, 'normal') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_normal * gt_mask_clamp,
                                            pred_normal * gt_mask_clamp)
                if torch.isnan(loss).item():
                    test = self.loss_dict[name](gt_normal * gt_mask_clamp,
                                                pred_normal * gt_mask_clamp)
                loss_values[(name, 'normal')] = loss.item()
                generator_loss += self.weight_dict[(name, 'normal')] * loss
            if (name, 'ao') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_ao * gt_mask_clamp,
                                            pred_ao * gt_mask_clamp)
                loss_values[(name, 'ao')] = loss.item()
                generator_loss += self.weight_dict[(name, 'ao')] * loss
            if (name, 'depth') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_depth * gt_mask_clamp,
                                            pred_depth * gt_mask_clamp)
                loss_values[(name, 'depth')] = loss.item()
                generator_loss += self.weight_dict[(name, 'depth')] * loss
            if (name, 'flow') in self.weight_dict.keys():
                # note: flow is not restricted to inside regions
                loss = self.loss_dict[name](gt_flow, pred_flow)
                loss_values[(name, 'flow')] = loss.item()
                generator_loss += self.weight_dict[(name, 'flow')] * loss
            if (name, 'color') in self.weight_dict.keys():
                gt_color = self.shading(gt)
                pred_color = self.shading(pred)
                loss = self.loss_dict[name](gt_color, pred_color)
                loss_values[(name, 'color')] = loss.item()
                generator_loss += self.weight_dict[(name, 'color')] * loss

        if ("bounded", "mask") in self.weight_dict.keys():
            # penalizes if the mask diverges too far away from [0,1]
            zero = torch.zeros(1,
                               1,
                               1,
                               1,
                               dtype=pred_mask.dtype,
                               device=pred_mask.device)
            loss = torch.mean(torch.max(zero, pred_mask * pred_mask - 2))
            loss_values[("bounded", "mask")] = loss.item()
            generator_loss += self.weight_dict[("bounded", "mask")] * loss

        if ("bce", "mask") in self.weight_dict.keys():
            # binary cross entry loss between the unclamped masks
            loss = F.binary_cross_entropy_with_logits(pred_mask * 0.5 + 0.5,
                                                      gt_mask * 0.5 + 0.5)
            loss_values[("bce", "mask")] = loss.item()
            generator_loss += self.weight_dict[("bce", "mask")] * loss

        # temporal l2 loss, uses input (for the mask) + pred + prev_warped
        if self.has_temporal_l2_loss and not no_temporal_loss:
            assert prev_pred_warped is not None
            for name in ['temp-l2', 'temp-l1']:
                if (name, 'mask') in self.weight_dict.keys():
                    loss = self.loss_dict[name](pred_mask, prev_pred_mask)
                    loss_values[(name, 'mask')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'mask')] * loss
                if (name, 'normal') in self.weight_dict.keys():
                    loss = self.loss_dict[name](pred_normal * gt_mask_clamp,
                                                prev_pred_normal *
                                                gt_mask_clamp)
                    loss_values[(name, 'normal')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'normal')] * loss
                if (name, 'ao') in self.weight_dict.keys():
                    prev_pred_ao = prev_pred_warped[:, 5:6, :, :]
                    loss = self.loss_dict[name](pred_ao * gt_mask_clamp,
                                                prev_pred_ao * gt_mask_clamp)
                    loss_values[(name, 'ao')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'ao')] * loss
                if (name, 'depth') in self.weight_dict.keys():
                    prev_pred_depth = prev_pred_warped[:, 4:5, :, :]
                    loss = self.loss_dict[name](pred_depth * gt_mask_clamp,
                                                prev_pred_depth *
                                                gt_mask_clamp)
                    loss_values[(name, 'depth')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'depth')] * loss
                if (name, 'flow') in self.weight_dict.keys():
                    prev_pred_depth = prev_pred_warped[:, 4:5, :, :]
                    loss = self.loss_dict[name](
                        pred_flow,  #note: no restriction to inside areas
                        prev_pred_flow)
                    loss_values[(name, 'flow')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'flow')] * loss
                if (name, 'color') in self.weight_dict.keys():
                    prev_pred_color = self.shading(prev_pred_warped)
                    loss = self.loss_dict[name](pred_color, prev_pred_color)
                    loss_values[(name, 'color')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'color')] * loss

        return generator_loss, loss_values
class GUI(tk.Tk):

    MODE_LOW = 1
    MODE_HIGH = 2
    MODE_FLOW = 3

    CHANNEL_MASK = 1
    CHANNEL_NORMAL = 2
    CHANNEL_DEPTH = 3
    CHANNEL_AO = 4
    CHANNEL_COLOR = 5

    def __init__(self):
        tk.Tk.__init__(self)
        self.title('Dataset Viewer')
        self.input_name = None

        # members to be filled later
        self.pilImage = None
        self.tikImage = None
        self.current_high = None
        self.current_low = None
        self.current_flow = None
        self.dataset_folder = None
        self.last_folder = INITIAL_DIR
        self.entries = []
        self.num_frames = 0
        self.selected_time = 0

        # root
        self.root_panel = tk.Frame(self)
        self.root_panel.pack(side="bottom", fill="both", expand="yes")
        self.panel = tk.Label(self.root_panel)
        self.black = np.zeros((512, 512, 3))
        self.setImage(self.black)
        self.panel.pack(side="left", fill="both", expand="yes")

        options1 = tk.Label(self.root_panel)

        # Input
        inputFrame = ttk.LabelFrame(options1, text="Input", relief=tk.RIDGE)
        tk.Button(inputFrame,
                  text="Open Folder",
                  command=lambda: self.openFolder()).pack()
        self.dataset_entry = 0
        self.dataset_entry_slider = tk.Scale(
            inputFrame,
            from_=0,
            to=0,
            orient=tk.HORIZONTAL,
            resolution=1,
            label="Entry",
            showvalue=0,
            command=lambda e: self.setEntry(int(e)))
        self.dataset_entry_slider.pack(anchor=tk.W, fill=tk.X)
        self.dataset_time = 0
        self.dataset_time_slider = tk.Scale(
            inputFrame,
            from_=0,
            to=0,
            orient=tk.HORIZONTAL,
            resolution=1,
            label="Time",
            showvalue=0,
            command=lambda e: self.setTime(int(e)))
        self.dataset_time_slider.pack(anchor=tk.W, fill=tk.X)
        inputFrame.pack(fill=tk.X)

        # Mode
        modeFrame = ttk.LabelFrame(options1, text="Mode", relief=tk.RIDGE)
        self.mode = tk.IntVar()
        self.mode.set(GUI.MODE_LOW)
        self.mode.trace_add('write', lambda a, b, c: self.updateImage())
        tk.Radiobutton(modeFrame,
                       text="Low",
                       variable=self.mode,
                       value=GUI.MODE_LOW).pack(anchor=tk.W)
        tk.Radiobutton(modeFrame,
                       text="High",
                       variable=self.mode,
                       value=GUI.MODE_HIGH).pack(anchor=tk.W)
        tk.Radiobutton(modeFrame,
                       text="Flow",
                       variable=self.mode,
                       value=GUI.MODE_FLOW).pack(anchor=tk.W)
        modeFrame.pack(fill=tk.X)

        # Channel
        channelsFrame = ttk.LabelFrame(options1,
                                       text="Channel",
                                       relief=tk.RIDGE)
        self.channel_mode = tk.IntVar()
        self.channel_mode.set(GUI.CHANNEL_COLOR)
        self.channel_mode.trace_add('write',
                                    lambda a, b, c: self.updateImage())
        tk.Radiobutton(channelsFrame,
                       text="Mask",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_MASK).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="Normal",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_NORMAL).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="Depth",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_DEPTH).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="AO",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_AO).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="Color",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_COLOR).pack(anchor=tk.W)
        channelsFrame.pack(fill=tk.X)

        # Shading
        self.shading = ScreenSpaceShading('cpu')
        self.shading.fov(30)
        self.ambient_light_color = np.array([0.1, 0.1, 0.1])
        self.shading.ambient_light_color(self.ambient_light_color)
        self.diffuse_light_color = np.array([0.8, 0.8, 0.8])
        self.shading.diffuse_light_color(self.diffuse_light_color)
        self.specular_light_color = np.array([0.02, 0.02, 0.02])
        self.shading.specular_light_color(self.specular_light_color)
        self.shading.specular_exponent(16)
        self.shading.light_direction(np.array([0.1, 0.1, 1.0]))
        self.material_color = np.array([1.0, 1.0, 1.0])  #[1.0, 0.3, 0.0])
        self.shading.material_color(self.material_color)
        self.shading.ambient_occlusion(0.5)
        self.shading.background(np.array([1.0, 1.0, 1.0]))

        # Save
        tk.Button(options1,
                  text="Save Image",
                  command=lambda: self.saveImage()).pack()
        self.saveFolder = "/"

        options1.pack(side="left")

    def openFolder(self):
        dataset_folder = filedialog.askdirectory(initialdir=self.last_folder)
        print(dataset_folder)
        if dataset_folder is not None:
            self.dataset_folder = str(dataset_folder)
            print("New folder selected:", self.dataset_folder)
            self.last_folder = self.dataset_folder
            # find number of entries
            self.current_low = None
            self.current_high = None
            self.current_flow = None
            entries = []
            for i in range(0, 10000):
                file_low = os.path.join(self.dataset_folder,
                                        "low_%05d.npy" % i)
                if os.path.isfile(file_low):
                    entries.append((file_low,
                                    os.path.join(self.dataset_folder,
                                                 "high_%05d.npy" % i),
                                    os.path.join(self.dataset_folder,
                                                 "flow_%05d.npy" % i)))
                else:
                    break
            print("Number of entries found:", len(entries))
            self.entries = entries
            self.dataset_entry = 0
            self.dataset_entry_slider.config(to=len(entries))
            self.setEntry(0)

        else:
            print("No folder selected")

    def setEntry(self, entry):
        entry = min(entry, len(self.entries) - 1)
        self.dataset_entry_slider.config(label='Entry: %d' % int(entry))
        if len(self.entries) == 0:
            self.current_low = None
            self.current_high = None
            self.current_flow = None
        else:
            self.current_low = np.load(self.entries[entry][0])
            self.current_high = np.load(self.entries[entry][1])
            self.current_flow = np.load(self.entries[entry][2])
            self.num_frames = self.current_low.shape[0]
            self.dataset_time_slider.config(to=self.num_frames)
            print("Entry loaded")
            self.setTime(0)

    def setTime(self, entry):
        entry = min(entry, self.num_frames - 1)
        self.dataset_time_slider.config(label='Time: %d' % int(entry))
        self.selected_time = entry
        self.updateImage()

    def setImage(self, img):
        self.pilImage = pl.Image.fromarray(
            np.clip((img * 255).transpose((1, 0, 2)), 0, 255).astype(np.uint8))
        self.tikImage = ImageTk.PhotoImage(self.pilImage)
        self.panel.configure(image=self.tikImage)

    def saveImage(self):
        filename = filedialog.asksaveasfilename(
            initialdir=self.saveFolder,
            title="Save as",
            filetypes=(("jpeg files", "*.jpg"), ("png files", "*.png"),
                       ("all files", "*.*")))
        if len(filename) == 0:
            return
        if len(os.path.splitext(filename)[1]) == 0:
            filename = filename + ".jpg"
        self.pilImage.save(filename)
        self.saveFolder = os.path.dirname(filename)

    def updateImage(self):
        if self.current_low is None:
            # no image loaded
            self.setImage(self.black)
            return

        def selectChannel(img):
            if self.channel_mode.get() == GUI.CHANNEL_MASK:
                mask = img[0:1, :, :] * 0.5 + 0.5
                return np.concatenate((mask, mask, mask))
            elif self.channel_mode.get() == GUI.CHANNEL_NORMAL:
                return img[1:4, :, :] * 0.5 + 0.5
            elif self.channel_mode.get() == GUI.CHANNEL_DEPTH:
                return np.concatenate(
                    (img[4:5, :, :], img[4:5, :, :], img[4:5, :, :]))
            elif self.channel_mode.get() == GUI.CHANNEL_AO:
                if img.shape[0] == 6:
                    return np.concatenate(
                        (img[5:6, :, :], img[5:6, :, :], img[5:6, :, :]))
                else:
                    return np.zeros((3, img.shape[1], img.shape[2]),
                                    dtype=np.float32)
            elif self.channel_mode.get() == GUI.CHANNEL_COLOR:
                shading_input = torch.unsqueeze(torch.from_numpy(img), 0)
                shading_output = self.shading(shading_input)[0]
                return torch.clamp(shading_output, 0, 1).cpu().numpy()

        if self.mode.get() == GUI.MODE_LOW:
            img = self.current_low[self.selected_time, :, :, :]
            img = selectChannel(img)
            img = img.transpose((2, 1, 0))
            img = cv.resize(img,
                            dsize=None,
                            fx=UPSCALE,
                            fy=UPSCALE,
                            interpolation=cv.INTER_NEAREST)
            self.setImage(img)
        elif self.mode.get() == GUI.MODE_HIGH:
            img = self.current_high[self.selected_time, :, :, :]
            img = selectChannel(img)
            img = img.transpose((2, 1, 0))
            self.setImage(img)
        elif self.mode.get() == GUI.MODE_FLOW:
            #img = np.stack((
            #    cv.inpaint(self.current_flow[self.selected_time,0,:,:], np.uint8(self.current_low[self.selected_time,0,:,:]==0), 3, cv.INPAINT_NS),
            #    cv.inpaint(self.current_flow[self.selected_time,1,:,:], np.uint8(self.current_low[self.selected_time,0,:,:]==0), 3, cv.INPAINT_NS),
            #    np.zeros((self.current_flow.shape[2], self.current_flow.shape[3]))), axis=0).astype(np.float32)
            img = np.concatenate(
                (self.current_flow[self.selected_time, 0:2, :, :],
                 np.zeros((1, self.current_flow.shape[2],
                           self.current_flow.shape[3]),
                          dtype=np.float32)), )
            img = (img * 10 + 0.5)
            img = img.transpose((2, 1, 0))
            img = cv.resize(img,
                            dsize=None,
                            fx=UPSCALE,
                            fy=UPSCALE,
                            interpolation=cv.INTER_NEAREST)
            self.setImage(img)
Exemplo n.º 8
0
class GUI(tk.Tk):

    CHANNEL_MASK = 1
    CHANNEL_NORMAL = 2
    CHANNEL_DEPTH = 3
    CHANNEL_AO = 4
    CHANNEL_COLOR = 5
    CHANNEL_FLOW = 6

    def __init__(self):
        tk.Tk.__init__(self)
        self.title('Dataset Viewer')
        self.input_name = None

        # members to be filled later
        self.pilImage = None
        self.tikImage = None
        self.num_frames = 0
        self.selected_entry = 0
        self.selected_time = 0
        self.last_folder = None
        self.dataset_file = None
        self.hdf5_file = None
        self.dset_keys = []
        self.dset = None
        self.mode = None

        # root
        self.root_panel = tk.Frame(self)
        self.root_panel.pack(side="bottom", fill="both", expand="yes")
        self.panel = tk.Label(self.root_panel)
        self.black = np.zeros((512, 512, 3))
        self.setImage(self.black)
        self.panel.pack(side="left", fill="both", expand="yes")

        options1 = tk.Label(self.root_panel)

        # Input
        inputFrame = ttk.LabelFrame(options1, text="Input", relief=tk.RIDGE)
        tk.Button(inputFrame,
                  text="Open HDF5",
                  command=lambda: self.openHDF5()).pack()
        listbox_frame = tk.Frame(inputFrame)
        self.dset_listbox_scrollbar = tk.Scrollbar(listbox_frame,
                                                   orient=tk.VERTICAL)
        self.dset_listbox = tk.Listbox(
            listbox_frame,
            selectmode=tk.SINGLE,
            yscrollcommand=self.dset_listbox_scrollbar.set)
        self.dset_listbox_scrollbar.config(command=self.dset_listbox.yview)
        self.dset_listbox_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        self.dset_listbox.pack(side=tk.LEFT, anchor=tk.W, fill=tk.X, expand=1)
        self.dset_listbox.bind("<Double-Button-1>", self.setDsetCallback)
        listbox_frame.pack(anchor=tk.W, fill=tk.X)
        self.selected_dset = None
        self.dataset_entry_slider = tk.Scale(
            inputFrame,
            from_=0,
            to=0,
            orient=tk.HORIZONTAL,
            resolution=1,
            label="Entry",
            showvalue=0,
            command=lambda e: self.setEntry(int(e)))
        self.dataset_entry_slider.pack(anchor=tk.W, fill=tk.X)
        self.dataset_time = 0
        self.dataset_time_slider = tk.Scale(
            inputFrame,
            from_=0,
            to=0,
            orient=tk.HORIZONTAL,
            resolution=1,
            label="Time",
            showvalue=0,
            command=lambda e: self.setTime(int(e)))
        self.dataset_time_slider.pack(anchor=tk.W, fill=tk.X)
        inputFrame.pack(fill=tk.X)

        # Channel
        channelsFrame = ttk.LabelFrame(options1,
                                       text="Channel",
                                       relief=tk.RIDGE)
        self.channel_mode = tk.IntVar()
        self.channel_mode.set(GUI.CHANNEL_COLOR)
        self.channel_mode.trace_add('write',
                                    lambda a, b, c: self.updateImage())
        tk.Radiobutton(channelsFrame,
                       text="Mask",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_MASK).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="Normal",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_NORMAL).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="Depth",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_DEPTH).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="AO",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_AO).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="Color",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_COLOR).pack(anchor=tk.W)
        tk.Radiobutton(channelsFrame,
                       text="Flow",
                       variable=self.channel_mode,
                       value=GUI.CHANNEL_FLOW).pack(anchor=tk.W)
        channelsFrame.pack(fill=tk.X)

        # Shading
        self.shading = ScreenSpaceShading('cpu')
        self.shading.fov(30)
        self.ambient_light_color = np.array([0.1, 0.1, 0.1])
        self.shading.ambient_light_color(self.ambient_light_color)
        self.diffuse_light_color = np.array([0.8, 0.8, 0.8])
        self.shading.diffuse_light_color(self.diffuse_light_color)
        self.specular_light_color = np.array([0.02, 0.02, 0.02])
        self.shading.specular_light_color(self.specular_light_color)
        self.shading.specular_exponent(16)
        self.shading.light_direction(np.array([0.1, 0.1, 1.0]))
        self.material_color = np.array([1.0, 1.0, 1.0])  #[1.0, 0.3, 0.0])
        self.shading.material_color(self.material_color)
        self.shading.ambient_occlusion(0.5)
        self.shading.background(np.array([1.0, 1.0, 1.0]))

        # Save
        tk.Button(options1,
                  text="Save Image",
                  command=lambda: self.saveImage()).pack()
        self.saveFolder = "/"

        options1.pack(side="left")

    def openHDF5(self):
        dataset_file = filedialog.askopenfilename(initialdir=self.last_folder,
                                                  title="Select HDF5 file",
                                                  filetypes=(("HDF5 files",
                                                              "*.hdf5"), ))
        print(dataset_file)
        if dataset_file is not None:
            self.dataset_file = str(dataset_file)
            print("New hdf5 selected:", self.dataset_file)
            self.title(self.dataset_file)
            self.last_folder = os.path.dirname(self.dataset_file)
            # load hdf5 file
            self.hdf5_file = h5py.File(dataset_file, "r")
            # list datasets
            self.dset_listbox.delete(0, tk.END)
            self.dset_keys = list(self.hdf5_file.keys())
            for key in self.dset_keys:
                self.dset_listbox.insert(tk.END, key)
            self.dset_listbox.selection_set(first=0)
            self.setDset(self.dset_keys[0])

        else:
            print("No folder selected")

    def setDsetCallback(self, *args):
        items = self.dset_listbox.curselection()
        items = [self.dset_keys[int(item)] for item in items]
        self.setDset(items[0])

    def setDset(self, name):
        self.selected_dset = name
        print("Select dataset '%s'" % (self.selected_dset))
        self.dset = self.hdf5_file[self.selected_dset]
        self.mode = self.dset.attrs.get("Mode", "IsoUnshaded")
        print("Mode:", self.mode)
        # find number of entries
        entries = self.dset.shape[0]
        num_frames = self.dset.shape[1]
        print("Number of entries found:", entries, "with", num_frames,
              "timesteps")
        print("Image size:", self.dset.shape[3], "x", self.dset.shape[4])
        self.entries = entries
        self.num_frames = num_frames
        self.dataset_entry_slider.config(to=entries - 1)
        self.dataset_time_slider.config(to=num_frames - 1)
        self.setEntry(
            self.selected_entry if self.selected_entry < entries else 0)

    def setEntry(self, entry):
        entry = min(entry, self.entries - 1)
        self.dataset_entry_slider.config(label='Entry: %d' % int(entry))
        self.selected_entry = entry
        self.updateImage()

    def setTime(self, entry):
        entry = min(entry, self.num_frames - 1)
        self.dataset_time_slider.config(label='Time: %d' % int(entry))
        self.selected_time = entry
        self.updateImage()

    def setImage(self, img):
        if img.dtype == np.uint8:
            self.pilImage = pl.Image.fromarray(img.transpose((1, 0, 2)))
        else:
            self.pilImage = pl.Image.fromarray(
                np.clip((img * 255).transpose((1, 0, 2)), 0,
                        255).astype(np.uint8))
        if self.pilImage.size[0] <= 256:
            self.pilImage = self.pilImage.resize(
                (self.pilImage.size[0] * 2, self.pilImage.size[1] * 2),
                pl.Image.NEAREST)
        self.tikImage = ImageTk.PhotoImage(self.pilImage)
        self.panel.configure(image=self.tikImage)

    def saveImage(self):
        filename = filedialog.asksaveasfilename(
            initialdir=self.saveFolder,
            title="Save as",
            filetypes=(("png files", "*.png"), ("jpeg files", "*.jpg"),
                       ("all files", "*.*")))
        if len(filename) == 0:
            return
        if len(os.path.splitext(filename)[1]) == 0:
            filename = filename + ".png"
        self.pilImage.save(filename)
        self.saveFolder = os.path.dirname(filename)

    def updateImage(self):
        if self.dset is None:
            # no image loaded
            self.setImage(self.black)
            return

        def selectChannel(img):
            if self.mode == "DVR":
                mask = img[3:4, :, :]
                if self.channel_mode.get() == GUI.CHANNEL_MASK:
                    return np.concatenate((mask, mask, mask))
                elif self.channel_mode.get() == GUI.CHANNEL_COLOR:
                    return img[0:3, :, :]
                elif self.channel_mode.get() == GUI.CHANNEL_NORMAL:
                    return (img[4:7, :, :] * 0.5 + 0.5) * mask
                elif self.channel_mode.get() == GUI.CHANNEL_DEPTH:
                    return np.concatenate([img[7:8, :, :]] * 3, axis=0)
                elif self.channel_mode.get() == GUI.CHANNEL_FLOW:
                    return np.concatenate(
                        (img[8:9, :, :] * 10, img[9:10, :, :] * 10,
                         np.zeros_like(img[6:7, :, :]))) + 0.5
                else:
                    return self.black

            else:  # IsoUnshaded
                if self.channel_mode.get() == GUI.CHANNEL_MASK:
                    if img.dtype == np.uint8:
                        mask = img[0:1, :, :]
                    else:
                        mask = img[0:1, :, :] * 0.5 + 0.5
                    return np.concatenate((mask, mask, mask))
                elif self.channel_mode.get() == GUI.CHANNEL_NORMAL:
                    if img.dtype == np.uint8:
                        return img[1:4, :, :]
                    else:
                        return img[1:4, :, :] * 0.5 + 0.5
                elif self.channel_mode.get() == GUI.CHANNEL_DEPTH:
                    return np.concatenate(
                        (img[4:5, :, :], img[4:5, :, :], img[4:5, :, :]))
                elif self.channel_mode.get() == GUI.CHANNEL_AO:
                    return np.concatenate(
                        (img[5:6, :, :], img[5:6, :, :], img[5:6, :, :]))
                elif self.channel_mode.get() == GUI.CHANNEL_FLOW:
                    return np.concatenate(
                        (img[6:7, :, :] * 10, img[7:8, :, :] * 10,
                         np.zeros_like(img[6:7, :, :]))) + 0.5
                elif self.channel_mode.get() == GUI.CHANNEL_COLOR:
                    if img.dtype == np.uint8:
                        shading_input = torch.unsqueeze(
                            torch.from_numpy(img.astype(np.float32) / 255.0),
                            0)
                        shading_input[:, 1:
                                      4, :, :] = shading_input[:, 1:
                                                               4, :, :] * 2 - 1
                    else:
                        shading_input = torch.unsqueeze(
                            torch.from_numpy(img), 0)
                    shading_output = self.shading(shading_input)[0]
                    return torch.clamp(shading_output, 0, 1).cpu().numpy()

        img = self.dset[self.selected_entry, self.selected_time, :, :, :]
        img = selectChannel(img)
        img = img.transpose((2, 1, 0))
        self.setImage(img)