コード例 #1
0
    def forward(self, sample_low, previous_warped_flattened):
        inputs = sample_low[:, 0:self.input_channels, :, :]
        resized_inputs = F.interpolate(
            inputs,
            size=[
                inputs.shape[2] * self.upscale_factor,
                inputs.shape[3] * self.upscale_factor
            ],
            mode=self.upsample)

        prediction = torch.cat([
            resized_inputs,
            torch.ones(resized_inputs.shape[0],
                       self.output_channels - self.input_channels,
                       resized_inputs.shape[2],
                       resized_inputs.shape[3],
                       dtype=resized_inputs.dtype,
                       device=resized_inputs.device)
        ],
                               dim=1)
        prediction[:, 0:1, :, :] = torch.clamp(prediction[:, 0:1, :, :], -1,
                                               +1)  #mask
        prediction[:, 1:4, :, :] = ScreenSpaceShading.normalize(
            prediction[:, 1:4, :, :], dim=1)  # normal
        prediction[:, 4:6, :, :] = torch.clamp(prediction[:, 4:6, :, :], 0,
                                               +1)  # depth+ao

        color = self.shading(prediction)
        return color, prediction
コード例 #2
0
 def colorize_and_pad(tensor):
     assert tensor.shape[1]==6
     mask = tensor[:,0:1,:,:]
     normal = ScreenSpaceShading.normalize(tensor[:,1:4,:,:], dim=1)
     depth_ao = tensor[:,4:6,:,:]
     ao = tensor[:,5:6,:,:]
     color = self.shading(torch.cat([mask, normal, depth_ao], dim=1))
     tensor_with_color = torch.cat([mask, normal, ao, color], dim=1)
     return LossNetUnshaded.pad(tensor_with_color, self.padding)
コード例 #3
0
    def forward(self, sample_low, previous_warped_flattened):
        single_input = torch.cat((sample_low, previous_warped_flattened),
                                 dim=1)

        prediction, _ = self.mdl(single_input)

        prediction[:, 0:1, :, :] = torch.clamp(prediction[:, 0:1, :, :], -1,
                                               +1)  #mask
        prediction[:, 1:4, :, :] = ScreenSpaceShading.normalize(
            prediction[:, 1:4, :, :], dim=1)  # normal
        prediction[:, 4:6, :, :] = torch.clamp(prediction[:, 4:6, :, :], 0,
                                               +1)  # depth+ao

        color = self.shading(prediction)
        return color, prediction
コード例 #4
0
                pass  # nothing to do
            else:  # network
                if not m['temporal']: previous_image = None
                imageInput = np.copy(image)
                #image = cv.resize(image.transpose((2, 1, 0)),
                #                    dsize=None,
                #                    fx=UPSCALING,
                #                    fy=UPSCALING,
                #                    interpolation=cv.INTER_LINEAR).transpose((2, 1, 0))
                if models[k].unshaded:
                    # unshaded input
                    imageRaw = models[k].inference(imageInput, previous_image)

                    imageRaw = torch.cat([
                        torch.clamp(imageRaw[:, 0:1, :, :], -1, +1),
                        ScreenSpaceShading.normalize(imageRaw[:, 1:4, :, :],
                                                     dim=1),
                        torch.clamp(imageRaw[:, 4:, :, :], 0, 1)
                    ],
                                         dim=1)
                    previous_image = imageRaw

                    image = image.transpose((2, 1, 0))
                    image = cv.resize(image,
                                      dsize=None,
                                      fx=UPSCALING,
                                      fy=UPSCALING,
                                      interpolation=cv.INTER_LINEAR)
                    image = image.transpose((2, 1, 0))
                    base_mask = np.copy(image[3, :, :])
                    imageRawCpu = imageRaw.cpu()
                    image[0:3, :, :] = shading(imageRawCpu)[0].numpy()
コード例 #5
0
    def add_timestep_sample(self, pred_mnda, gt_mnda, input_mnda):
        """
        adds a timestep sample:
        pred_mnda: prediction: mask, normal, depth, AO
        pred_color: shaded color
        gt_mnda: ground truth: mask, normal, depth, AO
        gt_color: shaded ground truth
        """

        #shading
        shading.ambient_occlusion(AMBIENT_OCCLUSION_STRENGTH)
        pred_color_withAO = shading(pred_mnda)
        gt_color_withAO = shading(gt_mnda)
        shading.ambient_occlusion(0.0)
        pred_color_noAO = shading(pred_mnda)
        gt_color_noAO = shading(gt_mnda)
        input_color_noAO = shading(input_mnda)

        #apply border
        BORDER2 = BORDER * UPSCALING
        pred_mnda = pred_mnda[:, :, BORDER2:-BORDER2, BORDER2:-BORDER2]
        pred_color_withAO = pred_color_withAO[:, :, BORDER2:-BORDER2,
                                              BORDER2:-BORDER2]
        pred_color_noAO = pred_color_noAO[:, :, BORDER2:-BORDER2,
                                          BORDER2:-BORDER2]
        gt_mnda = gt_mnda[:, :, BORDER2:-BORDER2, BORDER2:-BORDER2]
        gt_color_withAO = gt_color_withAO[:, :, BORDER2:-BORDER2,
                                          BORDER2:-BORDER2]
        gt_color_noAO = gt_color_noAO[:, :, BORDER2:-BORDER2, BORDER2:-BORDER2]
        input_mnda = input_mnda[:, :, BORDER:-BORDER, BORDER:-BORDER]
        input_color_noAO = input_color_noAO[:, :, BORDER:-BORDER,
                                            BORDER:-BORDER]

        #self.psnr_normal += 10 * math.log10(1 / torch.mean((pred_mnda[0,1:4,:,:]-gt_mnda[0,1:4,:,:])**2).item())
        #self.psnr_ao += 10 * math.log10(1 / torch.mean((pred_mnda[0,5:6,:,:]-gt_mnda[0,5:6,:,:])**2).item())
        #self.psnr_color += 10 * math.log10(1 / torch.mean((pred_color-gt_color)**2).item())

        # check fill rate
        mask = gt_mnda[:, 0:1, :, :] * 0.5 + 0.5
        B, C, H, W = mask.shape
        factor = torch.sum(mask).item() / (H * W)
        if factor < MIN_FILLING:
            #print("Ignore, too few filled points")
            return

        self.n += 1

        # PSNR
        self.psnr_normal += psnrLoss(pred_mnda[:, 1:4, :, :],
                                     gt_mnda[:, 1:4, :, :],
                                     mask=mask).item()
        self.psnr_depth += psnrLoss(pred_mnda[:, 4:5, :, :],
                                    gt_mnda[:, 4:5, :, :],
                                    mask=mask).item()
        self.psnr_ao += psnrLoss(pred_mnda[:, 5:6, :, :],
                                 gt_mnda[:, 5:6, :, :],
                                 mask=mask).item()
        self.psnr_color_withAO += psnrLoss(pred_color_withAO,
                                           gt_color_withAO,
                                           mask=mask).item()
        self.psnr_color_noAO += psnrLoss(pred_color_noAO,
                                         gt_color_noAO,
                                         mask=mask).item()

        # SSIM
        pred_mnda = gt_mnda + mask * (pred_mnda - gt_mnda)
        self.ssim_normal += ssimLoss(pred_mnda[:, 1:4, :, :],
                                     gt_mnda[:, 1:4, :, :]).item()
        self.ssim_depth += ssimLoss(pred_mnda[:, 4:5, :, :],
                                    gt_mnda[:, 4:5, :, :]).item()
        self.ssim_ao += ssimLoss(pred_mnda[:, 5:6, :, :],
                                 gt_mnda[:, 5:6, :, :]).item()
        self.ssim_color_withAO += ssimLoss(pred_color_withAO,
                                           gt_color_withAO).item()
        self.ssim_color_noAO += ssimLoss(pred_color_noAO, gt_color_noAO).item()

        # Downsample Loss
        ds_normal = self.downsample_loss(
            input_mnda[:, 1:4, :, :],
            ScreenSpaceShading.normalize(self.downsample(pred_mnda[:,
                                                                   1:4, :, :]),
                                         dim=1))
        ds_color = self.downsample_loss(input_color_noAO,
                                        self.downsample(pred_color_noAO))
        self.l2ds_normal_mean += torch.mean(ds_normal).item()
        self.l2ds_normal_max = max(self.l2ds_normal_max,
                                   torch.max(ds_normal).item())
        self.l2ds_colorNoAO_mean += torch.mean(ds_color).item()
        self.l2ds_colorNoAO_max = max(self.l2ds_colorNoAO_max,
                                      torch.max(ds_color).item())

        # Histogram
        self.histogram_counter += 1

        mask_diff = F.l1_loss(gt_mnda[0, 0, :, :],
                              pred_mnda[0, 0, :, :],
                              reduction='none')
        histogram, _ = np.histogram(mask_diff.cpu().numpy(),
                                    bins=NUM_BINS,
                                    range=(0, 1),
                                    density=True)
        self.histogram_mask += (histogram / NUM_BINS -
                                self.histogram_mask) / self.histogram_counter

        #normal_diff = (-F.cosine_similarity(gt_mnda[0,1:4,:,:], pred_mnda[0,1:4,:,:], dim=0)+1)/2
        normal_diff = F.l1_loss(gt_mnda[0, 1:4, :, :],
                                pred_mnda[0, 1:4, :, :],
                                reduction='none').sum(dim=0) / 6
        histogram, _ = np.histogram(normal_diff.cpu().numpy(),
                                    bins=NUM_BINS,
                                    range=(0, 1),
                                    density=True)
        self.histogram_normal += (histogram / NUM_BINS - self.histogram_normal
                                  ) / self.histogram_counter

        depth_diff = F.l1_loss(gt_mnda[0, 4, :, :],
                               pred_mnda[0, 4, :, :],
                               reduction='none')
        histogram, _ = np.histogram(depth_diff.cpu().numpy(),
                                    bins=NUM_BINS,
                                    range=(0, 1),
                                    density=True)
        self.histogram_depth += (histogram / NUM_BINS -
                                 self.histogram_depth) / self.histogram_counter

        ao_diff = F.l1_loss(gt_mnda[0, 5, :, :],
                            pred_mnda[0, 5, :, :],
                            reduction='none')
        histogram, _ = np.histogram(ao_diff.cpu().numpy(),
                                    bins=NUM_BINS,
                                    range=(0, 1),
                                    density=True)
        self.histogram_ao += (histogram / NUM_BINS -
                              self.histogram_ao) / self.histogram_counter

        color_diff = F.l1_loss(gt_color_withAO[0, 0, :, :],
                               pred_color_withAO[0, 0, :, :],
                               reduction='none')
        histogram, _ = np.histogram(color_diff.cpu().numpy(),
                                    bins=NUM_BINS,
                                    range=(0, 1),
                                    density=True)
        self.histogram_color_withAO += (
            histogram / NUM_BINS -
            self.histogram_color_withAO) / self.histogram_counter

        color_diff = F.l1_loss(gt_color_noAO[0, 0, :, :],
                               pred_color_noAO[0, 0, :, :],
                               reduction='none')
        histogram, _ = np.histogram(color_diff.cpu().numpy(),
                                    bins=NUM_BINS,
                                    range=(0, 1),
                                    density=True)
        self.histogram_color_noAO += (
            histogram / NUM_BINS -
            self.histogram_color_noAO) / self.histogram_counter
コード例 #6
0
                                sample_flow[j - 1:j, :, :, :],
                                UPSCALING,
                                special_mask=True)
                            #previous_warped = previous_output
                        previous_warped_flattened = models.VideoTools.flatten_high(
                            previous_warped, UPSCALING)
                        single_input = torch.cat((sample_low[j:j + 1, :, :, :],
                                                  previous_warped_flattened),
                                                 dim=1)
                        # run model
                        prediction, _ = model_list[model_index](single_input)
                        prediction[:, 0:1, :, :] = torch.clamp(
                            prediction[:, 0:1, :, :], -1, +1)  #mask
                        prediction[:,
                                   1:4, :, :] = ScreenSpaceShading.normalize(
                                       prediction[:,
                                                  1:4, :, :], dim=1)  # normal
                        prediction[:, 4:6, :, :] = torch.clamp(
                            prediction[:, 4:6, :, :], 0, +1)  # depth+ao

                        ##DEBUG: save the images
                        #img_rgb = (prediction[:,1:4,:,:] * 0.5 + 0.5)[0].cpu().numpy()
                        ##img_rgb = img_rgb.transpose((2,1,0))
                        #scipy.misc.toimage(img_rgb, cmin=0.0, cmax=1.0).save(
                        #    os.path.join(OUTPUT_FOLDER, "Img_%s_%s_%d.png"%(dataset_name, MODELS[model_index]['name'], j)))
                        #if model_index==0:
                        #    img_rgb = (sample_high[j,1:4,:,:] * 0.5 + 0.5).cpu().numpy()
                        #    scipy.misc.toimage(img_rgb, cmin=0.0, cmax=1.0).save(
                        #        os.path.join(OUTPUT_FOLDER, "Img_%s_GT_%d.png"%(dataset_name,j)))

                        # STATISTICS
コード例 #7
0
    def forward(self, gt, pred, input, prev_input_warped, prev_pred_warped):
        """
        gt: ground truth high resolution image (B x C=output_channels x 4W x 4H)
        pred: predicted high resolution image (B x C=output_channels x 4W x 4H)
        input: upsampled low resolution input image (B x C=input_channels x 4W x 4H)
               Only used for the discriminator, can be None if only the other losses are used
        prev_input_warped: upsampled, warped previous input image
        prev_pred_warped: predicted image from the previous frame warped by the flow
               Shape: B x Cout x 4W x 4H
               Only used for temporal losses, can be None if only the other losses are used
        """

        B, Cout, Hhigh, Whigh = gt.shape
        assert Cout == 6
        assert gt.shape == pred.shape

        # zero border padding
        gt = LossNetUnshaded.pad(gt, self.padding)
        pred = LossNetUnshaded.pad(pred, self.padding)
        if prev_pred_warped is not None:
            prev_pred_warped = LossNetUnshaded.pad(prev_pred_warped, self.padding)

        # extract mask and normal
        gt_mask = gt[:,0:1,:,:]
        gt_mask_clamp = torch.clamp(gt_mask*0.5 + 0.5, 0, 1)
        gt_normal = ScreenSpaceShading.normalize(gt[:,1:4,:,:], dim=1)
        gt_depth = gt[:,4:5,:,:]
        gt_ao = gt[:,5:6,:,:]
        pred_mask = pred[:,0:1,:,:]
        pred_mask_clamp = torch.clamp(pred_mask*0.5 + 0.5, 0, 1)
        pred_normal = ScreenSpaceShading.normalize(pred[:,1:4,:,:], dim=1)
        pred_depth = pred[:,4:5,:,:]
        pred_ao = pred[:,5:6,:,:]
        input_mask = input[:,0:1,:,:]
        input_mask_clamp = torch.clamp(input_mask*0.5 + 0.5, 0, 1)
        input_normal = ScreenSpaceShading.normalize(input[:,1:4,:,:], dim=1)
        input_depth = input[:,4:5,:,:]
        input_ao = None #not available

        # compute color output
        gt_color = self.shading(gt)
        pred_color = self.shading(pred)
        input_color = self.shading(input)

        generator_loss = 0.0
        loss_values = {}

        # normal, simple losses, uses gt+pred
        for name in ['mse','l1']:
            if (name,'mask') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_mask, pred_mask)
                loss_values[(name,'mask')] = loss.item()
                generator_loss += self.weight_dict[(name,'mask')] * loss
            if (name,'normal') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_normal*gt_mask_clamp, pred_normal*gt_mask_clamp)
                loss_values[(name,'normal')] = loss.item()
                generator_loss += self.weight_dict[(name,'normal')] * loss
            if (name,'ao') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_ao*gt_mask_clamp, pred_ao*gt_mask_clamp)
                loss_values[(name,'ao')] = loss.item()
                generator_loss += self.weight_dict[(name,'ao')] * loss
            if (name,'depth') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_depth*gt_mask_clamp, pred_depth*gt_mask_clamp)
                loss_values[(name,'depth')] = loss.item()
                generator_loss += self.weight_dict[(name,'depth')] * loss
            if (name,'color') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_color, pred_color)
                loss_values[(name,'color')] = loss.item()
                generator_loss += self.weight_dict[(name,'color')] * loss

        # downsample loss, use input+pred
        # TODO: input is passed in upsampled version, but this introduces more errors
        # Better: input low-res input and use the 'low_res_gt=True' option in downsample_loss
        # This requires the input to be upsampled here for the GAN.
        for name in ['l2-ds', 'l1-ds']:
            if (name,'mask') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_mask, pred_mask)
                loss_values[(name,'mask')] = loss.item()
                generator_loss += self.weight_dict[(name,'mask')] * loss
            if (name,'normal') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_normal*input_mask_clamp, pred_normal*input_mask_clamp)
                loss_values[(name,'normal')] = loss.item()
                generator_loss += self.weight_dict[(name,'normal')] * loss
            if (name,'ao') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_ao*input_mask_clamp, pred_ao*input_mask_clamp)
                loss_values[(name,'ao')] = loss.item()
                generator_loss += self.weight_dict[(name,'ao')] * loss
            if (name,'depth') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_depth*input_mask_clamp, pred_depth*input_mask_clamp)
                loss_values[(name,'depth')] = loss.item()
                generator_loss += self.weight_dict[(name,'depth')] * loss
            if (name,'color') in self.weight_dict.keys():
                loss = self.loss_dict[name](input_color, pred_color)
                loss_values[(name,'color')] = loss.item()
                generator_loss += self.weight_dict[(name,'color')] * loss

        # special losses: perceptual+texture, uses gt+pred
        def compute_perceptual(target, in_pred, in_gt):
            if ('perceptual',target) in self.weight_dict.keys() \
                    or ('texture',target) in self.weight_dict.keys():
                style_weight=self.weight_dict.get(('texture',target), 0)
                content_weight=self.weight_dict.get(('perceptual',target), 0)
                style_score = 0
                content_score = 0

                input_images = torch.cat([in_gt, in_pred], dim=0)
                self.pt_loss(input_images)

                for sl in self.style_losses:
                    style_score += sl.loss
                for cl in self.content_losses:
                    content_score += cl.loss

                if ('perceptual',target) in self.weight_dict.keys():
                    loss_values[('perceptual',target)] = content_score.item()
                if ('texture',target) in self.weight_dict.keys():
                    loss_values[('texture',target)] = style_score.item()
                return style_weight * style_score + content_weight * content_score
            return 0
        generator_loss += compute_perceptual('mask', pred_mask.expand(-1,3,-1,-1)*0.5+0.5, gt_mask.expand(-1,3,-1,-1)*0.5+0.5)
        generator_loss += compute_perceptual('normal', (pred_normal*gt_mask_clamp)*0.5+0.5, (gt_normal*gt_mask_clamp)*0.5+0.5)
        generator_loss += compute_perceptual('color', pred_color, gt_color)
        generator_loss += compute_perceptual('ao', pred_ao.expand(-1,3,-1,-1), gt_ao.expand(-1,3,-1,-1))
        generator_loss += compute_perceptual('depth', pred_depth.expand(-1,3,-1,-1), gt_depth.expand(-1,3,-1,-1))

        # special: discriminator, uses input+pred+prev_pred_warped
        if self.has_discriminator:
            pred_with_color = torch.cat([
                pred_mask,
                pred_normal,
                pred_color,
                pred_ao], dim=1)

            prev_pred_normal = ScreenSpaceShading.normalize(prev_pred_warped[:,1:4,:,:], dim=1)
            prev_pred_with_color = torch.cat([
                prev_pred_warped[:,0:1,:,:],
                prev_pred_normal,
                self.shading(torch.cat([
                    prev_pred_warped[:,0:1,:,:],
                    prev_pred_normal,
                    prev_pred_warped[:,4:6,:,:]
                    ], dim=1)),
                prev_pred_warped[:,5:6,:,:]
                ], dim=1)

            input_pad = LossNetUnshaded.pad(input, self.padding)
            prev_input_warped_pad = LossNetUnshaded.pad(prev_input_warped, self.padding)
            pred_with_color_pad = LossNetUnshaded.pad(pred_with_color, self.padding)
            prev_pred_warped_pad = LossNetUnshaded.pad(prev_pred_with_color, self.padding)

            if ('adv','all') in self.weight_dict.keys(): # spatial-temporal
                discr_input = torch.cat([input_pad, prev_input_warped_pad, pred_with_color_pad, prev_pred_warped_pad], dim=1)
                gen_loss = self.adv_loss(self.discriminator(discr_input))
                loss_values['discr_pred'] = gen_loss.item()
                generator_loss += self.weight_dict[('adv','all')] * gen_loss

            if ('tgan','all') in self.weight_dict.keys(): #temporal
                discr_input = torch.cat([pred_with_color_pad, prev_pred_warped_pad], dim=1)
                gen_loss = self.temp_adv_loss(self.temp_discriminator(discr_input))
                loss_values['temp_discr_pred'] = gen_loss.item()
                generator_loss += self.weight_dict[('tgan','all')] * gen_loss

            if ('sgan','all') in self.weight_dict.keys(): #spatial
                discr_input = torch.cat([input_pad, pred_with_color_pad], dim=1)
                gen_loss = self.spatial_adv_loss(self.spatial_discriminator(discr_input))
                loss_values['spatial_discr_pred'] = gen_loss.item()
                generator_loss += self.weight_dict[('sgan','all')] * gen_loss

        # special: temporal l2 loss, uses input (for the mask) + pred + prev_warped
        if self.has_temporal_l2_loss:
            prev_pred_mask = prev_pred_warped[:,0:1,:,:]
            prev_pred_normal = ScreenSpaceShading.normalize(prev_pred_warped[:,1:4,:,:], dim=1)
            if ('temp-l2','mask') in self.weight_dict.keys():
                loss = self.loss_dict['temp-l2'](pred_mask, prev_pred_mask)
                loss_values[('temp-l2','mask')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','mask')] * loss
            if ('temp-l2','normal') in self.weight_dict.keys():
                loss = self.loss_dict['temp-l2'](
                    pred_normal*gt_mask_clamp, 
                    prev_pred_normal*gt_mask_clamp)
                loss_values[('temp-l2','normal')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','normal')] * loss
            if ('temp-l2','ao') in self.weight_dict.keys():
                prev_pred_ao = prev_pred_warped[:,5:6,:,:]
                loss = self.loss_dict['temp-l2'](
                    pred_ao*gt_mask_clamp, 
                    prev_pred_ao*gt_mask_clamp)
                loss_values[('temp-l2','ao')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','ao')] * loss
            if ('temp-l2','depth') in self.weight_dict.keys():
                prev_pred_depth = prev_pred_warped[:,4:5,:,:]
                loss = self.loss_dict['temp-l2'](
                    pred_depth*gt_mask_clamp, 
                    prev_pred_depth*gt_mask_clamp)
                loss_values[('temp-l2','depth')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','depth')] * loss
            if ('temp-l2','color') in self.weight_dict.keys():
                prev_pred_color = self.shading(prev_pred_warped)
                loss = self.loss_dict['temp-l2'](pred_color, prev_pred_color)
                loss_values[('temp-l2','color')] = loss.item()
                generator_loss += self.weight_dict[('temp-l2','color')] * loss

        return generator_loss, loss_values
コード例 #8
0
def trainAdv_v2(epoch):
    """
    Second version of adverserial training, 
    for each batch, train both discriminator and generator.
    Not full epoch for each seperately
    """
    print("===> Epoch %d Training" % epoch)
    discr_scheduler.step()
    writer.add_scalar('train/lr_discr', discr_scheduler.get_lr()[0], epoch)
    gen_scheduler.step()
    writer.add_scalar('train/lr_gen', gen_scheduler.get_lr()[0], epoch)

    disc_steps = opt.advDiscrInitialSteps if opt.advDiscrInitialSteps is not None and epoch == 1 else opt.advDiscrMaxSteps
    gen_steps = opt.advGenMaxSteps

    num_minibatch = len(training_data_loader)
    model.train()
    criterion.discr_train()

    total_discr_loss = 0
    total_gen_loss = 0
    total_gt_score = 0
    total_pred_score = 0

    pg = ProgressBar(num_minibatch, 'Train', length=50)
    for iteration, batch in enumerate(training_data_loader):
        pg.print_progress_bar(iteration)
        input, flow, target = batch[0].to(device), batch[1].to(
            device), batch[2].to(device)
        B, _, Cout, Hhigh, Whigh = target.shape
        _, _, Cin, H, W = input.shape

        # DISCRIMINATOR
        for _ in range(disc_steps):
            discr_optimizer.zero_grad()
            gen_optimizer.zero_grad()
            loss = 0
            #iterate over all timesteps
            for j in range(dataset_data.num_frames):
                # prepare input for the generator
                if j == 0 or opt.disableTemporal:
                    previous_warped = initialImage(input[:, 0, :, :, :], Cout,
                                                   opt.initialImage, False,
                                                   opt.upscale_factor)
                    # loss takes the ground truth current image as warped previous image,
                    # to not introduce a bias and big loss for the first image
                    previous_warped_loss = target[:, 0, :, :, :]
                    previous_input = F.interpolate(input[:, 0, :, :, :],
                                                   size=(Hhigh, Whigh),
                                                   mode=opt.upsample)
                else:
                    previous_warped = models.VideoTools.warp_upscale(
                        previous_output,
                        flow[:, j - 1, :, :, :],
                        opt.upscale_factor,
                        special_mask=True)
                    previous_warped_loss = previous_warped
                    previous_input = F.interpolate(input[:, j - 1, :, :, :],
                                                   size=(Hhigh, Whigh),
                                                   mode=opt.upsample)
                    previous_input = models.VideoTools.warp_upscale(
                        previous_input,
                        flow[:, j - 1, :, :, :],
                        opt.upscale_factor,
                        special_mask=True)
                previous_warped_flattened = models.VideoTools.flatten_high(
                    previous_warped, opt.upscale_factor)
                single_input = torch.cat(
                    (input[:, j, :, :, :], previous_warped_flattened), dim=1)
                #evaluate generator
                with torch.no_grad():
                    prediction, _ = model(single_input)
                #prepare input for the discriminator
                gt_prev_warped = models.VideoTools.warp_upscale(
                    target[:, j - 1, :, :, :],
                    flow[:, j - 1, :, :, :],
                    opt.upscale_factor,
                    special_mask=True)
                #evaluate discriminator
                input_high = F.interpolate(input[:, j, :, :, :],
                                           size=(Hhigh, Whigh),
                                           mode=opt.upsample)
                disc_loss, gt_score, pred_score = criterion.train_discriminator(
                    input_high, target[:, j, :, :, :], previous_input,
                    gt_prev_warped, prediction, previous_warped_loss)
                loss += disc_loss
                total_gt_score += float(gt_score)
                total_pred_score += float(pred_score)
                # save output
                previous_output = torch.cat(
                    [
                        torch.clamp(prediction[:, 0:1, :, :], -1, +1),  # mask
                        ScreenSpaceShading.normalize(prediction[:, 1:4, :, :],
                                                     dim=1),
                        torch.clamp(prediction[:, 4:5, :, :], 0, +1),  # depth
                        torch.clamp(prediction[:, 5:6, :, :], 0, +1)  # ao
                    ],
                    dim=1)
            loss.backward()
            discr_optimizer.step()
        total_discr_loss += loss.item()

        # GENERATOR
        for _ in range(disc_steps):
            discr_optimizer.zero_grad()
            gen_optimizer.zero_grad()
            loss = 0
            #iterate over all timesteps
            for j in range(dataset_data.num_frames):
                # prepare input for the generator
                if j == 0 or opt.disableTemporal:
                    previous_warped = initialImage(input[:, 0, :, :, :], Cout,
                                                   opt.initialImage, False,
                                                   opt.upscale_factor)
                    # loss takes the ground truth current image as warped previous image,
                    # to not introduce a bias and big loss for the first image
                    previous_warped_loss = target[:, 0, :, :, :]
                    previous_input = F.interpolate(input[:, 0, :, :, :],
                                                   size=(Hhigh, Whigh),
                                                   mode=opt.upsample)
                else:
                    previous_warped = models.VideoTools.warp_upscale(
                        previous_output,
                        flow[:, j - 1, :, :, :],
                        opt.upscale_factor,
                        special_mask=True)
                    previous_warped_loss = previous_warped
                    previous_input = F.interpolate(input[:, j - 1, :, :, :],
                                                   size=(Hhigh, Whigh),
                                                   mode=opt.upsample)
                    previous_input = models.VideoTools.warp_upscale(
                        previous_input,
                        flow[:, j - 1, :, :, :],
                        opt.upscale_factor,
                        special_mask=True)
                previous_warped_flattened = models.VideoTools.flatten_high(
                    previous_warped, opt.upscale_factor)
                single_input = torch.cat(
                    (input[:, j, :, :, :], previous_warped_flattened), dim=1)
                #evaluate generator
                prediction, _ = model(single_input)
                #evaluate loss
                input_high = F.interpolate(input[:, j, :, :, :],
                                           size=(Hhigh, Whigh),
                                           mode=opt.upsample)
                loss0, map = criterion(target[:, j, :, :, :], prediction,
                                       input_high, previous_input,
                                       previous_warped_loss)
                loss += loss0
                # save output
                previous_output = torch.cat(
                    [
                        torch.clamp(prediction[:, 0:1, :, :], -1, +1),  # mask
                        ScreenSpaceShading.normalize(prediction[:, 1:4, :, :],
                                                     dim=1),
                        torch.clamp(prediction[:, 4:5, :, :], 0, +1),  # depth
                        torch.clamp(prediction[:, 5:6, :, :], 0, +1)  # ao
                    ],
                    dim=1)
            loss.backward()
            gen_optimizer.step()
        total_gen_loss += loss.item()
    pg.print_progress_bar(num_minibatch)

    total_discr_loss /= num_minibatch * dataset_data.num_frames
    total_gen_loss /= num_minibatch * dataset_data.num_frames
    total_gt_score /= num_minibatch * dataset_data.num_frames
    total_pred_score /= num_minibatch * dataset_data.num_frames

    writer.add_scalar('train/discr_loss', total_discr_loss, epoch)
    writer.add_scalar('train/gen_loss', total_gen_loss, epoch)
    writer.add_scalar('train/gt_score', total_gt_score, epoch)
    writer.add_scalar('train/pred_score', total_pred_score, epoch)
    print("===> Epoch {} Complete".format(epoch))
コード例 #9
0
        def render(rerender=True, resuperres=True):
            """
            Main render function
            rerender: if True, the volumes are retraced. If false, the previous images are kept
            resuperres: if True, the superresolution is performed again. If False, the previous result is used
            """
            global oldFile, frameIndex, global_depth_max, global_depth_min, previous_frames
            # check if file was changed
            if oldFile != scene.file:
                oldFile = scene.file
                renderer.load(os.path.join(DATA_DIR_GPU, scene.file))
            # send render parameters
            currentOrigin = camera.getOrigin()
            currentLookAt = camera.getLookAt()
            currentUp = camera.getUp()
            renderer.send_command(
                "cameraOrigin", "%5.3f,%5.3f,%5.3f" %
                (currentOrigin[0], currentOrigin[1], currentOrigin[2]))
            renderer.send_command(
                "cameraLookAt", "%5.3f,%5.3f,%5.3f" %
                (currentLookAt[0], currentLookAt[1], currentLookAt[2]))
            renderer.send_command(
                "cameraUp", "%5.3f,%5.3f,%5.3f" %
                (currentUp[0], currentUp[1], currentUp[2]))
            renderer.send_command("cameraFoV", "%.3f" % shading.get_fov())
            renderer.send_command("isovalue", "%5.3f" % float(scene.isovalue))
            renderer.send_command("aoradius", "%5.3f" % float(scene.aoRadius))
            if PREVIEW:
                renderer.send_command("aosamples", "0")
            else:
                renderer.send_command("aosamples", "%d" % scene.aoSamples)

            if rerender:

                # render low resolution
                renderer.send_command(
                    "resolution",
                    "%d,%d" % (RESOLUTION_LOW[0], RESOLUTION_LOW[1]))
                renderer.send_command(
                    "viewport", "%d,%d,%d,%d" %
                    (0, 0, RESOLUTION_LOW[0], RESOLUTION_LOW[1]))
                renderer.render_direct(rendered_low)

                # render high resolution
                if not PREVIEW:
                    renderer.send_command(
                        "resolution", "%d,%d" % (RESOLUTION[0], RESOLUTION[1]))
                    renderer.send_command(
                        "viewport",
                        "%d,%d,%d,%d" % (0, 0, RESOLUTION[0], RESOLUTION[1]))
                    renderer.render_direct(rendered_high)

            # preprocessing
            def preprocess(input):
                input = input.to(
                    cpuDevice if CPU_SUPERRES else cudaDevice).permute(
                        2, 0, 1)
                output = torch.unsqueeze(input, 0)
                output = torch.cat(
                    (
                        output[:, 0:3, :, :],
                        output[:, 3:4, :, :] * 2 -
                        1,  #transform mask into -1,+1
                        output[:, 4:, :, :]),
                    dim=1)
                #image_shaded_input = torch.cat((output[:,3:4,:,:], output[:,4:8,:,:], output[:,10:11,:,:]), dim=1)
                #image_shaded = torch.clamp(shading(image_shaded_input), 0, 1)
                #output[:,0:3,:,:] = image_shaded
                return output

            processed_low = preprocess(rendered_low)
            processed_high = preprocess(rendered_high)
            # image now contains all channels:
            # 0:3 - color (shaded)
            # 3:4 - mask in -1,+1
            # 4:7 - normal
            # 7:8 - depth
            # 8:10 - flow
            # 10:11 - AO

            # prepare bounds for depth
            depthForBounds = processed_low[:, 7:8, :, :]
            maxDepth = torch.max(depthForBounds)
            minDepth = torch.min(
                depthForBounds +
                torch.le(depthForBounds, 1e-5).type_as(depthForBounds))
            global_depth_max = max(global_depth_max, maxDepth.item())
            global_depth_min = min(global_depth_min, minDepth.item())
            if scene.depthMin is not None:
                minDepth = scene.depthMin
            if scene.depthMax is not None:
                maxDepth = scene.depthMax

            # mask
            if PREVIEW:
                base_mask = F.interpolate(processed_low,
                                          scale_factor=UPSCALING,
                                          mode='bilinear')[:, 3:4, :, :]
            else:
                base_mask = processed_high[:, 3:4, :, :]
            base_mask = (base_mask * 0.5 + 0.5)

            # loop through the models
            for model_idx, model in enumerate(MODELS):
                # perform super-resolution
                if model['path'] == MODEL_NEAREST:
                    image = F.interpolate(processed_low,
                                          scale_factor=UPSCALING,
                                          mode='nearest')
                elif model['path'] == MODEL_BILINEAR:
                    image = F.interpolate(processed_low,
                                          scale_factor=UPSCALING,
                                          mode='bilinear')
                elif model['path'] == MODEL_BICUBIC:
                    image = F.interpolate(processed_low,
                                          scale_factor=UPSCALING,
                                          mode='bicubic')
                elif model['path'] == MODEL_GROUND_TRUTH:
                    image = processed_high
                else:
                    # NETWROK
                    if resuperres:
                        # previous frame
                        if scene.temporalConsistency:
                            previous_frame = previous_frames[model_idx]
                        else:
                            previous_frame = None
                        # apply network
                        imageRaw = models[model_idx].inference(
                            processed_low, previous_frame)
                        # post-process
                        imageRaw = torch.cat([
                            torch.clamp(imageRaw[:, 0:1, :, :], -1, +1),
                            ScreenSpaceShading.normalize(
                                imageRaw[:, 1:4, :, :], dim=1),
                            torch.clamp(imageRaw[:, 4:, :, :], 0, 1)
                        ],
                                             dim=1)
                        previous_frames[model_idx] = imageRaw
                    else:
                        imageRaw = previous_frames[model_idx]
                    image = F.interpolate(processed_low,
                                          scale_factor=UPSCALING,
                                          mode='bilinear')
                    #image[:,0:3,:,:] = shading(imageRaw)
                    image[:, 3:8, :, :] = imageRaw[:, 0:-1, :, :]
                    image[:, 10, :, :] = imageRaw[:, -1, :, :]
                    #masking
                    if model['masking']:
                        image[:, 3:4, :, :] = base_mask * 2 - 1
                        #image[:,7:8,:,:] = 0 + base_mask * (image[:,7:8,:,:] - 0)
                        image[:, 10:11, :, :] = 1 + base_mask * (
                            image[:, 10:11, :, :] - 1)

                # shading
                image_shaded_input = torch.cat(
                    (image[:, 3:4, :, :], image[:, 4:8, :, :],
                     image[:, 10:11, :, :]),
                    dim=1)
                image_shaded_withAO = torch.clamp(shading(image_shaded_input),
                                                  0, 1)
                ao = shading._ao
                shading.ambient_occlusion(0.0)
                image_shaded_noAO = torch.clamp(shading(image_shaded_input), 0,
                                                1)
                shading.ambient_occlusion(ao)

                # perform channel selection
                for channel_idx in range(len(CHANNEL_NAMES)):
                    if channel_idx == CHANNEL_AO:
                        if SHOW_DIFFERENCE and model[
                                'path'] != MODEL_GROUND_TRUTH:
                            image[:, 10:11, :, :] = 1 - torch.abs(
                                image[:, 10:11, :, :])
                        imageRGB = torch.cat(
                            (image[:, 10:11, :, :], image[:, 10:11, :, :],
                             image[:, 10:11, :, :]),
                            dim=1)
                    elif channel_idx == CHANNEL_COLOR_NOAO:
                        imageRGB = image_shaded_noAO
                    elif channel_idx == CHANNEL_COLOR_WITHAO:
                        imageRGB = image_shaded_withAO
                    elif channel_idx == CHANNEL_DEPTH:
                        if SHOW_DIFFERENCE and model[
                                'path'] != MODEL_GROUND_TRUTH:
                            depthVal = torch.abs(
                                image[:, 7:8, :, :] -
                                processed_high[:, 7:8, :, :]
                            )  # / (2*(maxDepth - minDepth))
                        else:
                            depthVal = (image[:, 7:8, :, :] -
                                        minDepth) / (maxDepth - minDepth)
                        imageRGB = torch.cat((depthVal, depthVal, depthVal),
                                             dim=1)
                        imageRGB = 1 - imageRGB
                        imageRGB[imageRGB < 0.05] = 1.0
                        #imageRGB = BACKGROUND[0] + base_mask * (imageRGB - BACKGROUND[0])
                    elif channel_idx == CHANNEL_NORMAL:
                        if SHOW_DIFFERENCE and model[
                                'path'] != MODEL_GROUND_TRUTH:
                            diffVal = F.cosine_similarity(
                                image[:, 4:7, :, :],
                                processed_high[:, 4:7, :, :],
                                dim=1) * 0.5 + 0.5
                            imageRGB = torch.stack((diffVal, diffVal, diffVal),
                                                   dim=1)
                            #imageRGB = 1 - torch.abs(image[:,4:7,:,:])
                        else:
                            imageRGB = image[:, 4:7, :, :] * 0.5 + 0.5
                        imageRGB = BACKGROUND[0] + base_mask * (imageRGB -
                                                                BACKGROUND[0])
                    imageRGB = torch.clamp(imageRGB, 0, 1)

                    # copy to numpy and write to video
                    imageRGB_cpu = imageRGB.cpu().numpy()[0].transpose(
                        (1, 2, 0))
                    imageRGB_cpu = np.clip(imageRGB_cpu * 255, 0,
                                           255).astype(np.uint8)
                    if OUTPUT_FIRST_IMAGE:
                        scipy.misc.imsave(writers[model_idx][channel_idx],
                                          imageRGB_cpu)
                    else:
                        writers[model_idx][channel_idx].append_data(
                            imageRGB_cpu)
            # done with this frame
            frameIndex += 1
            if frameIndex % 10 == 0:
                print(" %d" % frameIndex)

            if OUTPUT_FIRST_IMAGE:
                raise BreakRenderingException()
コード例 #10
0
    def test_images(epoch):
        def write_image(img, filename):
            out_img = img.cpu().detach().numpy()
            out_img *= 255.0
            out_img = out_img.clip(0, 255)
            out_img = np.uint8(out_img)
            writer.add_image(filename, out_img, epoch)

        with torch.no_grad():
            num_minibatch = len(testing_full_data_loader)
            pg = ProgressBar(num_minibatch,
                             'Test %d Images' % num_minibatch,
                             length=50)
            model.eval()
            if criterion.has_discriminator:
                criterion.discr_eval()
            for i, batch in enumerate(testing_full_data_loader):
                pg.print_progress_bar(i)
                input, flow, target = batch[0].to(device), batch[1].to(
                    device), batch[2].to(device)
                B, _, Cin, H, W = input.shape
                Hhigh = H * upscale_factor
                Whigh = W * upscale_factor
                Cout = output_channels

                channel_mask = [1, 2, 3]  #normal

                previous_output = None
                for j in range(dataset_data.num_frames):
                    # prepare input
                    if j == 0 or opt.disableTemporal:
                        previous_warped = initialImage(input[:, 0, :, :, :],
                                                       Cout, opt.initialImage,
                                                       False, upscale_factor)
                    else:
                        previous_warped = models.VideoTools.warp_upscale(
                            previous_output,
                            flow[:, j - 1, :, :, :],
                            upscale_factor,
                            special_mask=True)
                    # TODO: enable temporal component again
                    #previous_warped_flattened = models.VideoTools.flatten_high(previous_warped, opt.upscale_factor)
                    #single_input = torch.cat((
                    #        input[:,j,:,:,:],
                    #        previous_warped_flattened),
                    #    dim=1)
                    single_input = input[:, j, :, :, :]
                    # run generator
                    heatMap = model(single_input)
                    heatMap = postprocess(heatMap)
                    prediction = importance.adaptiveSmoothing(
                        target[:, j, :, :, :].contiguous(),
                        1 / heatMap.unsqueeze(1),
                        opt.distanceToStandardDeviation)
                    # write heatmap
                    write_image(heatMap[0].unsqueeze(0),
                                'image%03d/frame%03d_heatmap' % (i, j))
                    ## write warped previous frame
                    #write_image(previous_warped[0, channel_mask], 'image%03d/frame%03d_warped' % (i, j))
                    # write predicted normals
                    prediction[:, 1:4, :, :] = ScreenSpaceShading.normalize(
                        prediction[:, 1:4, :, :], dim=1)
                    write_image(prediction[0, channel_mask],
                                'image%03d/frame%03d_prediction' % (i, j))
                    # write shaded image if network runs in deferredShading mode
                    shaded_image = shading(prediction)
                    write_image(shaded_image[0],
                                'image%03d/frame%03d_shaded' % (i, j))
                    # write mask
                    write_image(prediction[0, 0:1, :, :] * 0.5 + 0.5,
                                'image%03d/frame%03d_mask' % (i, j))
                    # write ambient occlusion
                    write_image(prediction[0, 5:6, :, :],
                                'image%03d/frame%03d_ao' % (i, j))
                    # save output for next frame
                    previous_output = prediction
            pg.print_progress_bar(num_minibatch)

        print("Test images sent to Tensorboard for visualization")
コード例 #11
0
def run():
    torch.ops.load_library("./Renderer.dll")

    #########################
    # CONFIGURATION
    #########################

    if 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIso2/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            #("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            ("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #("adaptive011", "adaptive011_epoch500"), #title, file prefix
            ("adaptive019", "adaptive019_epoch470"),
            ("adaptive023", "adaptive023_epoch300")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['halton', 'plastic', 'random']

        HEATMAP_MIN = [0.01, 0.05, 0.2]
        HEATMAP_MEAN = [0.05, 0.1, 0.2, 0.5]

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 8

    elif 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance6/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            #("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
            #("Head", "gt-rendering-head.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #title, file prefix
            #("U-Net (5-4)", "sizes/size5-4_epoch500"),
            #("Enhance-Net (epoch 50)", "enhance2_imp050_epoch050"),
            #("Enhance-Net (epoch 400)", "enhance2_imp050_epoch400"),
            #("Enhance-Net (Thorax)", "enhance_imp050_Thorax_epoch200"),
            #("Enhance-Net (RM)", "enhance_imp050_RM_epoch200"),
            #("Imp100", "enhance4_imp100_epoch300"),
            #("Imp100res", "enhance4_imp100res_epoch230"),
            ("Imp100res+N", "enhance4_imp100res+N_epoch300"),
            ("Imp100+N", "enhance4_imp100+N_epoch300"),
            #("Imp100+N-res", "enhance4_imp100+N-res_epoch300"),
            #("Imp100+N-resInterp", "enhance4_imp100+N-resInterp_epoch300"),
            #("U-Net (5-4)", "size5-4_epoch500"),
            #("U-Net (5-3)", "size5-3_epoch500"),
            #("U-Net (4-4)", "size4-4_epoch500"),
        ]

        # Test if it is better to post-train with dense networks and PDE inpainting
        POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/"
        POSTTRAIN_NETWORKS = [
            # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'}
            #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['plastic']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [
            0.05
        ]  #[0.01, 0.02, 0.03, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 4

    elif 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance5Sampling/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #title, file prefix
            ("Enhance-Net (epoch 400)", "enhance2_imp050_epoch400"),
        ]

        # Test if it is better to post-train with dense networks and PDE inpainting
        POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/"
        POSTTRAIN_NETWORKS = [
            # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'}
            #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['halton', 'plastic', 'random', 'regular']
        #SAMPLING_PATTERNS = ['regular']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [0.05]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 4

    elif 1:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoEnhance8Sampling/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            #("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
            #("Head", "gt-rendering-head.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #title, file prefix
            ("regular", "enhance7_regular_epoch190"),
            ("random", "enhance7_random_epoch190"),
            ("halton", "enhance7_halton_epoch190"),
            ("plastic", "enhance7_plastic_epoch190"),
        ]

        # Test if it is better to post-train with dense networks and PDE inpainting
        POSTTRAIN_NETWORK_DIR = "D:/VolumeSuperResolution/dense-modeldir/"
        POSTTRAIN_NETWORKS = [
            # title, file suffix to POSTTRAIN_NETWORK_DIR, inpainting {'fast', 'pde'}
            #("Enhance PDE (post)", "inpHv2-pde05-epoch200.pt", "pde")
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['regular', 'random', 'halton', 'plastic']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [
            0.05
        ]  #[0.01, 0.02, 0.03, 0.04, 0.06, 0.08, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 4

    elif 0:
        OUTPUT_FOLDER = "../result-stats/adaptiveIsoImp/"
        DATASET_PREFIX = "D:/VolumeSuperResolution-InputData/"
        DATASETS = [
            ("Ejecta", "gt-rendering-ejecta-v2-test.hdf5"),
            #("RM", "gt-rendering-rm-v1.hdf5"),
            #("Human", "gt-rendering-human-v1.hdf5"),
            #("Thorax", "gt-rendering-thorax-v1.hdf5"),
        ]

        NETWORK_DIR = "D:/VolumeSuperResolution/adaptive-modeldir/imp/"
        NETWORKS = [  #suffixed with _importance.pt and _recon.pt
            #("adaptive011", "adaptive011_epoch500"), #title, file prefix
            ("imp005", "imp005_epoch500"),
            ("imp010", "imp010_epoch500"),
            ("imp020", "imp020_epoch500"),
            ("imp050", "imp050_epoch500"),
        ]

        SAMPLING_FILE = "D:/VolumeSuperResolution-InputData/samplingPattern.hdf5"
        SAMPLING_PATTERNS = ['halton']

        HEATMAP_MIN = [0.002]
        HEATMAP_MEAN = [0.005, 0.01, 0.02, 0.05, 0.1]
        USE_BINARY_SEARCH_ON_MEAN = True

        UPSCALING = 8  # = networkUp * postUp

        IMPORTANCE_BORDER = 8
        LOSS_BORDER = 32
        BATCH_SIZE = 16

    #########################
    # LOADING
    #########################

    device = torch.device("cuda")

    # Load Networks
    IMPORTANCE_BASELINE1 = "ibase1"
    IMPORTANCE_BASELINE2 = "ibase2"
    IMPORTANCE_BASELINE3 = "ibase3"
    RECON_BASELINE = "rbase"

    # load importance model
    print("load importance networks")

    class ImportanceModel:
        def __init__(self, file):
            if file == IMPORTANCE_BASELINE1:
                self._net = importance.UniformImportanceMap(1, 0.5)
                self._upscaling = 1
                self._name = "constant"
                self.disableTemporal = True
                self._requiresPrevious = False
            elif file == IMPORTANCE_BASELINE2:
                self._net = importance.GradientImportanceMap(
                    1, (1, 1), (2, 1), (3, 1))
                self._upscaling = 1
                self._name = "curvature"
                self.disableTemporal = True
                self._requiresPrevious = False
            else:
                self._name = file[0]
                file = os.path.join(NETWORK_DIR, file[1] + "_importance.pt")
                extra_files = torch._C.ExtraFilesMap()
                extra_files['settings.json'] = ""
                self._net = torch.jit.load(file,
                                           map_location=device,
                                           _extra_files=extra_files)
                settings = json.loads(extra_files['settings.json'])
                self._upscaling = settings['networkUpscale']
                self._requiresPrevious = settings.get("requiresPrevious",
                                                      False)
                self.disableTemporal = settings.get("disableTemporal", True)

        def networkUpscaling(self):
            return self._upscaling

        def name(self):
            return self._name

        def __repr__(self):
            return self.name()

        def call(self, input, prev_warped_out):
            if self._requiresPrevious:
                input = torch.cat([
                    input,
                    models.VideoTools.flatten_high(prev_warped_out,
                                                   self._upscaling)
                ],
                                  dim=1)
            input = F.pad(input, [IMPORTANCE_BORDER] * 4, 'constant', 0)
            output = self._net(input)  # the network call
            output = F.pad(output, [-IMPORTANCE_BORDER * self._upscaling] * 4,
                           'constant', 0)
            return output

    class LuminanceImportanceModel:
        def __init__(self):
            self.disableTemporal = True

        def setTestFile(self, filename):
            importance_file = filename[:-5] + "-luminanceImportance.hdf5"
            if os.path.exists(importance_file):
                self._exist = True
                self._file = h5py.File(importance_file, 'r')
                self._dset = self._file['importance']
            else:
                self._exist = False
                self._file = None
                self._dset = None

        def isAvailable(self):
            return self._exist

        def setIndices(self, indices: torch.Tensor):
            assert len(indices.shape) == 1
            self._indices = list(indices.cpu().numpy())

        def setTime(self, time):
            self._time = time

        def networkUpscaling(self):
            return UPSCALING

        def name(self):
            return "luminance-contrast"

        def __repr__(self):
            return self.name()

        def call(self, input, prev_warped_out):
            B, C, H, W = input.shape
            if not self._exist:
                return torch.ones(B,
                                  1,
                                  H,
                                  W,
                                  dtype=input.dtype,
                                  device=input.device)
            outputs = []
            for idx in self._indices:
                outputs.append(
                    torch.from_numpy(self._dset[idx, self._time,
                                                ...]).to(device=input.device))
            return torch.stack(outputs, dim=0)

    importanceBaseline1 = ImportanceModel(IMPORTANCE_BASELINE1)
    importanceBaseline2 = ImportanceModel(IMPORTANCE_BASELINE2)
    importanceBaselineLuminance = LuminanceImportanceModel()
    importanceModels = [ImportanceModel(f) for f in NETWORKS]

    # load reconstruction networks
    print("load reconstruction networks")

    class ReconstructionModel:
        def __init__(self, file):
            if file == RECON_BASELINE:

                class Inpainting(nn.Module):
                    def forward(self, x, mask):
                        input = x[:, 0:6, :, :].contiguous(
                        )  # mask, normal xyz, depth, ao
                        mask = x[:, 6, :, :].contiguous()
                        return torch.ops.renderer.fast_inpaint(mask, input)

                self._net = Inpainting()
                self._upscaling = 1
                self._name = "inpainting"
                self.disableTemporal = True
            else:
                self._name = file[0]
                file = os.path.join(NETWORK_DIR, file[1] + "_recon.pt")
                extra_files = torch._C.ExtraFilesMap()
                extra_files['settings.json'] = ""
                self._net = torch.jit.load(file,
                                           map_location=device,
                                           _extra_files=extra_files)
                self._settings = json.loads(extra_files['settings.json'])
                self.disableTemporal = False
                requiresMask = self._settings.get('expectMask', False)
                if self._settings.get("interpolateInput", False):
                    self._originalNet = self._net

                    class Inpainting2(nn.Module):
                        def __init__(self, orignalNet, requiresMask):
                            super().__init__()
                            self._n = orignalNet
                            self._requiresMask = requiresMask

                        def forward(self, x, mask):
                            input = x[:, 0:6, :, :].contiguous(
                            )  # mask, normal xyz, depth, ao
                            mask = x[:, 6, :, :].contiguous()
                            inpainted = torch.ops.renderer.fast_inpaint(
                                mask, input)
                            x[:, 0:6, :, :] = inpainted
                            if self._requiresMask:
                                return self._n(x, mask)
                            else:
                                return self._n(x)

                    self._net = Inpainting2(self._originalNet, requiresMask)

        def name(self):
            return self._name

        def __repr__(self):
            return self.name()

        def call(self, input, mask, prev_warped_out):
            input = torch.cat([input, prev_warped_out], dim=1)
            output = self._net(input, mask)
            return output

    class ReconstructionModelPostTrain:
        """
        Reconstruction model that are trained as dense reconstruction networks
        after the adaptive training.
        They don't recive the sampling mask as input, but can start with PDE-based inpainting
        """
        def __init__(self, name: str, model_path: str, inpainting: str):
            assert inpainting == 'fast' or inpainting == 'pde', "inpainting must be either 'fast' or 'pde', but got %s" % inpainting
            self._inpainting = inpainting

            self._name = name
            file = os.path.join(POSTTRAIN_NETWORK_DIR, model_path)
            extra_files = torch._C.ExtraFilesMap()
            extra_files['settings.json'] = ""
            self._net = torch.jit.load(file,
                                       map_location=device,
                                       _extra_files=extra_files)
            self._settings = json.loads(extra_files['settings.json'])
            assert self._settings.get(
                'upscale_factor', None) == 1, "selected file is not a 1x SRNet"
            self.disableTemporal = False

        def name(self):
            return self._name

        def __repr__(self):
            return self.name()

        def call(self, input, prev_warped_out):
            # no sampling and no AO
            input_no_sampling = input[:, 0:5, :, :].contiguous(
            )  # mask, normal xyz, depth
            sampling_mask = input[:, 6, :, :].contiguous()
            # perform inpainting
            if self._inpainting == 'pde':
                inpainted = torch.ops.renderer.pde_inpaint(
                    sampling_mask,
                    input_no_sampling,
                    200,
                    1e-4,
                    5,
                    2,  # m0, epsilon, m1, m2
                    0,  # mc -> multigrid recursion count. =0 disables the multigrid hierarchy
                    9,
                    0)  # ms, m3
            else:
                inpainted = torch.ops.renderer.fast_inpaint(
                    sampling_mask, input_no_sampling)
            # run network
            input = torch.cat([inpainted, prev_warped_out], dim=1)
            output = self._net(input)
            if isinstance(output, tuple):
                output = output[0]
            return output

    reconBaseline = ReconstructionModel(RECON_BASELINE)
    reconModels = [ReconstructionModel(f) for f in NETWORKS]
    reconPostModels = [
        ReconstructionModelPostTrain(name, file, inpainting)
        for (name, file, inpainting) in POSTTRAIN_NETWORKS
    ]
    allReconModels = reconModels + reconPostModels

    NETWORK_COMBINATIONS = \
        [(importanceBaseline1, reconBaseline), (importanceBaseline2, reconBaseline)] + \
        [(importanceBaselineLuminance, reconBaseline)] + \
        [(importanceBaseline1, reconNet) for reconNet in allReconModels] + \
        [(importanceBaseline2, reconNet) for reconNet in allReconModels] + \
        [(importanceBaselineLuminance, reconNet) for reconNet in allReconModels] + \
        [(importanceNet, reconBaseline) for importanceNet in importanceModels] + \
        list(zip(importanceModels, reconModels)) + \
        [(importanceNet, reconPostModel) for importanceNet in importanceModels for reconPostModel in reconPostModels]
    #NETWORK_COMBINATIONS = list(zip(importanceModels, reconModels))
    print("Network combinations:")
    for (i, r) in NETWORK_COMBINATIONS:
        print("  %s - %s" % (i.name(), r.name()))

    # load sampling patterns
    print("load sampling patterns")
    with h5py.File(SAMPLING_FILE, 'r') as f:
        sampling_pattern = dict([(name, torch.from_numpy(f[name][...]).to(device)) \
            for name in SAMPLING_PATTERNS])

    # create shading
    shading = ScreenSpaceShading(device)
    shading.fov(30)
    shading.ambient_light_color(np.array([0.1, 0.1, 0.1]))
    shading.diffuse_light_color(np.array([1.0, 1.0, 1.0]))
    shading.specular_light_color(np.array([0.0, 0.0, 0.0]))
    shading.specular_exponent(16)
    shading.light_direction(np.array([0.1, 0.1, 1.0]))
    shading.material_color(np.array([1.0, 0.3, 0.0]))
    AMBIENT_OCCLUSION_STRENGTH = 1.0
    shading.ambient_occlusion(1.0)
    shading.inverse_ao = False

    #heatmap
    HEATMAP_CFG = [(min, mean) for min in HEATMAP_MIN for mean in HEATMAP_MEAN
                   if min < mean]
    print("heatmap configs:", HEATMAP_CFG)

    #########################
    # DEFINE STATISTICS
    #########################
    ssimLoss = SSIM(size_average=False)
    ssimLoss.to(device)
    psnrLoss = PSNR()
    psnrLoss.to(device)
    lpipsColor = lpips.PerceptualLoss(model='net-lin',
                                      net='alex',
                                      use_gpu=True)
    MIN_FILLING = 0.05
    NUM_BINS = 200

    class Statistics:
        def __init__(self):
            self.histogram_color_withAO = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_color_noAO = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_depth = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_normal = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_mask = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_ao = np.zeros(NUM_BINS, dtype=np.float64)
            self.histogram_counter = 0

        def create_datasets(self, hdf5_file: h5py.File, stats_name: str,
                            histo_name: str, num_samples: int,
                            extra_info: dict):

            self.expected_num_samples = num_samples
            stats_shape = (num_samples, len(list(StatField)))
            self.stats_file = hdf5_file.require_dataset(stats_name,
                                                        stats_shape,
                                                        dtype='f',
                                                        exact=True)
            self.stats_file.attrs['NumFields'] = len(list(StatField))
            for field in list(StatField):
                self.stats_file.attrs['Field%d' % field.value] = field.name
            for key, value in extra_info.items():
                self.stats_file.attrs[key] = value
            self.stats_index = 0

            histo_shape = (NUM_BINS, len(list(HistoField)))
            self.histo_file = hdf5_file.require_dataset(histo_name,
                                                        histo_shape,
                                                        dtype='f',
                                                        exact=True)
            self.histo_file.attrs['NumFields'] = len(list(HistoField))
            for field in list(HistoField):
                self.histo_file.attrs['Field%d' % field.value] = field.name
            for key, value in extra_info.items():
                self.histo_file.attrs[key] = value

        def add_timestep_sample(self, pred_mnda, gt_mnda, sampling_mask):
            """
            adds a timestep sample:
            pred_mnda: prediction: mask, normal, depth, AO
            gt_mnda: ground truth: mask, normal, depth, AO
            """
            B = pred_mnda.shape[0]

            #shading
            shading.ambient_occlusion(AMBIENT_OCCLUSION_STRENGTH)
            pred_color_withAO = shading(pred_mnda)
            gt_color_withAO = shading(gt_mnda)
            shading.ambient_occlusion(0.0)
            pred_color_noAO = shading(pred_mnda)
            gt_color_noAO = shading(gt_mnda)

            #apply border
            pred_mnda = pred_mnda[:, :, LOSS_BORDER:-LOSS_BORDER,
                                  LOSS_BORDER:-LOSS_BORDER]
            pred_color_withAO = pred_color_withAO[:, :,
                                                  LOSS_BORDER:-LOSS_BORDER,
                                                  LOSS_BORDER:-LOSS_BORDER]
            pred_color_noAO = pred_color_noAO[:, :, LOSS_BORDER:-LOSS_BORDER,
                                              LOSS_BORDER:-LOSS_BORDER]
            gt_mnda = gt_mnda[:, :, LOSS_BORDER:-LOSS_BORDER,
                              LOSS_BORDER:-LOSS_BORDER]
            gt_color_withAO = gt_color_withAO[:, :, LOSS_BORDER:-LOSS_BORDER,
                                              LOSS_BORDER:-LOSS_BORDER]
            gt_color_noAO = gt_color_noAO[:, :, LOSS_BORDER:-LOSS_BORDER,
                                          LOSS_BORDER:-LOSS_BORDER]

            mask = gt_mnda[:, 0:1, :, :] * 0.5 + 0.5

            # PSNR
            psnr_mask = psnrLoss(pred_mnda[:, 0:1, :, :],
                                 gt_mnda[:, 0:1, :, :]).cpu().numpy()
            psnr_normal = psnrLoss(pred_mnda[:, 1:4, :, :],
                                   gt_mnda[:, 1:4, :, :],
                                   mask=mask).cpu().numpy()
            psnr_depth = psnrLoss(pred_mnda[:, 4:5, :, :],
                                  gt_mnda[:, 4:5, :, :],
                                  mask=mask).cpu().numpy()
            psnr_ao = psnrLoss(pred_mnda[:, 5:6, :, :],
                               gt_mnda[:, 5:6, :, :],
                               mask=mask).cpu().numpy()
            psnr_color_withAO = psnrLoss(pred_color_withAO,
                                         gt_color_withAO,
                                         mask=mask).cpu().numpy()
            psnr_color_noAO = psnrLoss(pred_color_noAO,
                                       gt_color_noAO,
                                       mask=mask).cpu().numpy()

            # SSIM
            ssim_mask = ssimLoss(pred_mnda[:, 0:1, :, :],
                                 gt_mnda[:, 0:1, :, :]).cpu().numpy()
            pred_mnda = gt_mnda + mask * (pred_mnda - gt_mnda)
            ssim_normal = ssimLoss(pred_mnda[:, 1:4, :, :],
                                   gt_mnda[:, 1:4, :, :]).cpu().numpy()
            ssim_depth = ssimLoss(pred_mnda[:, 4:5, :, :],
                                  gt_mnda[:, 4:5, :, :]).cpu().numpy()
            ssim_ao = ssimLoss(pred_mnda[:, 5:6, :, :],
                               gt_mnda[:, 5:6, :, :]).cpu().numpy()
            ssim_color_withAO = ssimLoss(pred_color_withAO,
                                         gt_color_withAO).cpu().numpy()
            ssim_color_noAO = ssimLoss(pred_color_noAO,
                                       gt_color_noAO).cpu().numpy()

            # Perceptual
            lpips_color_withAO = torch.cat([
                lpipsColor(
                    pred_color_withAO[b], gt_color_withAO[b], normalize=True)
                for b in range(B)
            ],
                                           dim=0).cpu().numpy()
            lpips_color_noAO = torch.cat([
                lpipsColor(
                    pred_color_noAO[b], gt_color_noAO[b], normalize=True)
                for b in range(B)
            ],
                                         dim=0).cpu().numpy()

            # Samples
            samples = torch.mean(sampling_mask, dim=(1, 2, 3)).cpu().numpy()

            # Write samples to file
            for b in range(B):
                assert self.stats_index < self.expected_num_samples, "Adding more samples than specified"
                self.stats_file[self.stats_index, :] = np.array([
                    psnr_mask[b], psnr_normal[b], psnr_depth[b], psnr_ao[b],
                    psnr_color_noAO[b], psnr_color_withAO[b], ssim_mask[b],
                    ssim_normal[b], ssim_depth[b], ssim_ao[b],
                    ssim_color_noAO[b], ssim_color_withAO[b],
                    lpips_color_noAO[b], lpips_color_withAO[b], samples[b]
                ],
                                                                dtype='f')
                self.stats_index += 1

            # Histogram
            self.histogram_counter += 1

            mask_diff = F.l1_loss(gt_mnda[:, 0, :, :],
                                  pred_mnda[:, 0, :, :],
                                  reduction='none')
            histogram, _ = np.histogram(mask_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_mask += (
                histogram /
                (NUM_BINS * B) - self.histogram_mask) / self.histogram_counter

            #normal_diff = (-F.cosine_similarity(gt_mnda[0,1:4,:,:], pred_mnda[0,1:4,:,:], dim=0)+1)/2
            normal_diff = F.l1_loss(gt_mnda[:, 1:4, :, :],
                                    pred_mnda[:, 1:4, :, :],
                                    reduction='none').sum(dim=0) / 6
            histogram, _ = np.histogram(normal_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_normal += (histogram /
                                      (NUM_BINS * B) - self.histogram_normal
                                      ) / self.histogram_counter

            depth_diff = F.l1_loss(gt_mnda[:, 4, :, :],
                                   pred_mnda[:, 4, :, :],
                                   reduction='none')
            histogram, _ = np.histogram(depth_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_depth += (
                histogram /
                (NUM_BINS * B) - self.histogram_depth) / self.histogram_counter

            ao_diff = F.l1_loss(gt_mnda[:, 5, :, :],
                                pred_mnda[:, 5, :, :],
                                reduction='none')
            histogram, _ = np.histogram(ao_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_ao += (histogram / (NUM_BINS * B) -
                                  self.histogram_ao) / self.histogram_counter

            color_diff = F.l1_loss(gt_color_withAO[:, 0, :, :],
                                   pred_color_withAO[:, 0, :, :],
                                   reduction='none')
            histogram, _ = np.histogram(color_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_color_withAO += (
                histogram / (NUM_BINS * B) -
                self.histogram_color_withAO) / self.histogram_counter

            color_diff = F.l1_loss(gt_color_noAO[:, 0, :, :],
                                   pred_color_noAO[:, 0, :, :],
                                   reduction='none')
            histogram, _ = np.histogram(color_diff.cpu().numpy(),
                                        bins=NUM_BINS,
                                        range=(0, 1),
                                        density=True)
            self.histogram_color_noAO += (
                histogram / (NUM_BINS * B) -
                self.histogram_color_noAO) / self.histogram_counter

        def close_stats_file(self):
            self.stats_file.attrs['NumEntries'] = self.stats_index

        def write_histogram(self):
            """
            After every sample for the current dataset was processed, write
            a histogram of the errors in a new file
            """
            for i in range(NUM_BINS):
                self.histo_file[i, :] = np.array([
                    i / NUM_BINS, (i + 1) / NUM_BINS, self.histogram_mask[i],
                    self.histogram_normal[i], self.histogram_depth[i],
                    self.histogram_ao[i], self.histogram_color_withAO[i],
                    self.histogram_color_noAO[i]
                ])

    #########################
    # DATASET
    #########################
    class FullResDataset(torch.utils.data.Dataset):
        def __init__(self, file):
            self.hdf5_file = h5py.File(file, 'r')
            self.dset = self.hdf5_file['gt']
            print("Dataset shape:", self.dset.shape)

        def __len__(self):
            return self.dset.shape[0]

        def num_timesteps(self):
            return self.dset.shape[1]

        def __getitem__(self, idx):
            return (self.dset[idx, ...], np.array(idx))

    #########################
    # COMPUTE STATS for each dataset
    #########################
    for dataset_name, dataset_file in DATASETS:
        dataset_file = os.path.join(DATASET_PREFIX, dataset_file)
        print("Compute statistics for", dataset_name)

        # init luminance importance map
        importanceBaselineLuminance.setTestFile(dataset_file)
        if importanceBaselineLuminance.isAvailable():
            print("Luminance-contrast importance map is available")

        # create output file
        os.makedirs(OUTPUT_FOLDER, exist_ok=True)
        output_file = os.path.join(OUTPUT_FOLDER, dataset_name + '.hdf5')
        print("Save to", output_file)
        with h5py.File(output_file, 'a') as output_hdf5_file:

            # load dataset
            set = FullResDataset(dataset_file)
            data_loader = torch.utils.data.DataLoader(set,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=False)

            # define statistics
            StatsCfg = collections.namedtuple(
                "StatsCfg", "stats importance recon heatmin heatmean pattern")
            statistics = []
            for (inet, rnet) in NETWORK_COMBINATIONS:
                for (heatmin, heatmean) in HEATMAP_CFG:
                    for pattern in SAMPLING_PATTERNS:
                        stats_info = {
                            'importance': inet.name(),
                            'reconstruction': rnet.name(),
                            'heatmin': heatmin,
                            'heatmean': heatmean,
                            'pattern': pattern
                        }
                        stats_filename = "Stats_%s_%s_%03d_%03d_%s" % (
                            inet.name(), rnet.name(), heatmin * 100,
                            heatmean * 100, pattern)
                        histo_filename = "Histogram_%s_%s_%03d_%03d_%s" % (
                            inet.name(), rnet.name(), heatmin * 100,
                            heatmean * 100, pattern)
                        s = Statistics()
                        s.create_datasets(output_hdf5_file, stats_filename,
                                          histo_filename,
                                          len(set) * set.num_timesteps(),
                                          stats_info)
                        statistics.append(
                            StatsCfg(stats=s,
                                     importance=inet,
                                     recon=rnet,
                                     heatmin=heatmin,
                                     heatmean=heatmean,
                                     pattern=pattern))
            print(len(statistics),
                  " different combinations are performed per sample")

            # compute statistics
            try:
                with torch.no_grad():
                    num_minibatch = len(data_loader)
                    pg = ProgressBar(num_minibatch, 'Evaluation', length=50)
                    for iteration, (batch, batch_indices) in enumerate(
                            data_loader, 0):
                        pg.print_progress_bar(iteration)
                        batch = batch.to(device)
                        importanceBaselineLuminance.setIndices(batch_indices)
                        B, T, C, H, W = batch.shape

                        # try out each combination
                        for s in statistics:
                            #print(s)
                            # get input to evaluation
                            importanceNetUpscale = s.importance.networkUpscaling(
                            )
                            importancePostUpscale = UPSCALING // importanceNetUpscale
                            crop_low = torch.nn.functional.interpolate(
                                batch.reshape(B * T, C, H, W),
                                scale_factor=1 / UPSCALING,
                                mode='area').reshape(B, T, C, H // UPSCALING,
                                                     W // UPSCALING)
                            pattern = sampling_pattern[s.pattern][:H, :W]
                            crop_high = batch

                            # loop over timesteps
                            pattern = pattern.unsqueeze(0).unsqueeze(0)
                            previous_importance = None
                            previous_output = None
                            reconstructions = []
                            for j in range(T):
                                importanceBaselineLuminance.setTime(j)
                                # extract flow (always the last two channels of crop_high)
                                flow = crop_high[:, j, C - 2:, :, :]

                                # compute importance map
                                importance_input = crop_low[:, j, :5, :, :]
                                if j == 0 or s.importance.disableTemporal:
                                    previous_input = torch.zeros(
                                        B,
                                        1,
                                        importance_input.shape[2] *
                                        importanceNetUpscale,
                                        importance_input.shape[3] *
                                        importanceNetUpscale,
                                        dtype=crop_high.dtype,
                                        device=crop_high.device)
                                else:
                                    flow_low = F.interpolate(
                                        flow,
                                        scale_factor=1 / importancePostUpscale)
                                    previous_input = models.VideoTools.warp_upscale(
                                        previous_importance, flow_low, 1,
                                        False)
                                importance_map = s.importance.call(
                                    importance_input, previous_input)
                                if len(importance_map.shape) == 3:
                                    importance_map = importance_map.unsqueeze(
                                        1)
                                previous_importance = importance_map

                                target_mean = s.heatmean
                                if USE_BINARY_SEARCH_ON_MEAN:
                                    # For regular sampling, the normalization does not work properly,
                                    # use binary search on the heatmean instead
                                    def f(x):
                                        postprocess = importance.PostProcess(
                                            s.heatmin, x,
                                            importancePostUpscale,
                                            LOSS_BORDER //
                                            importancePostUpscale, 'basic')
                                        importance_map2 = postprocess(
                                            importance_map)[0].unsqueeze(1)
                                        sampling_mask = (
                                            importance_map2 >= pattern).to(
                                                dtype=importance_map.dtype)
                                        samples = torch.mean(
                                            sampling_mask).item()
                                        return samples

                                    target_mean = binarySearch(
                                        f, s.heatmean, s.heatmean, 10, 0, 1)
                                    #print("Binary search for #samples, mean start={}, result={} with samples={}, original={}".
                                    #      format(s.heatmean, s.heatmean, f(target_mean), f(s.heatmean)))

                                # normalize and upscale importance map
                                postprocess = importance.PostProcess(
                                    s.heatmin, target_mean,
                                    importancePostUpscale,
                                    LOSS_BORDER // importancePostUpscale,
                                    'basic')
                                importance_map = postprocess(
                                    importance_map)[0].unsqueeze(1)
                                #print("mean:", torch.mean(importance_map).item())

                                # create samples
                                sample_mask = (importance_map >= pattern).to(
                                    dtype=importance_map.dtype)

                                reconstruction_input = torch.cat(
                                    (
                                        sample_mask *
                                        crop_high[:, j, 0:
                                                  5, :, :],  # mask, normal x, normal y, normal z, depth
                                        sample_mask * torch.ones(
                                            B,
                                            1,
                                            H,
                                            W,
                                            dtype=crop_high.dtype,
                                            device=crop_high.device),  # ao
                                        sample_mask),  # sample mask
                                    dim=1)

                                # warp previous output
                                if j == 0 or s.recon.disableTemporal:
                                    previous_input = torch.zeros(
                                        B,
                                        6,
                                        H,
                                        W,
                                        dtype=crop_high.dtype,
                                        device=crop_high.device)
                                else:
                                    previous_input = models.VideoTools.warp_upscale(
                                        previous_output, flow, 1, False)

                                # run reconstruction network
                                reconstruction = s.recon.call(
                                    reconstruction_input, sample_mask,
                                    previous_input)

                                # clamp
                                reconstruction_clamped = torch.cat(
                                    [
                                        torch.clamp(
                                            reconstruction[:, 0:1, :, :], -1,
                                            +1),  # mask
                                        ScreenSpaceShading.normalize(
                                            reconstruction[:, 1:4, :, :],
                                            dim=1),
                                        torch.clamp(
                                            reconstruction[:, 4:5, :, :], 0,
                                            +1),  # depth
                                        torch.clamp(reconstruction[:,
                                                                   5:6, :, :],
                                                    0, +1)  # ao
                                    ],
                                    dim=1)
                                reconstructions.append(reconstruction_clamped)

                                # save for next frame
                                previous_output = reconstruction_clamped

                            #endfor: timesteps

                            # compute statistics
                            reconstructions = torch.cat(reconstructions, dim=0)
                            crops_high = torch.cat(
                                [crop_high[:, j, :6, :, :] for j in range(T)],
                                dim=0)
                            sample_masks = torch.cat([sample_mask] * T, dim=0)
                            s.stats.add_timestep_sample(
                                reconstructions, crops_high, sample_masks)

                        # endfor: statistic
                    # endfor: batch

                    pg.print_progress_bar(num_minibatch)
                # end no_grad()
            finally:
                # close files
                for s in statistics:
                    s.stats.write_histogram()
                    s.stats.close_stats_file()
コード例 #12
0
    def test(epoch):
        avg_psnr = 0
        avg_losses = defaultdict(float)
        heatmap_min = 1e10
        heatmap_max = -1e10
        heatmap_avg = heatmap_count = 0
        with torch.no_grad():
            num_minibatch = len(testing_data_loader)
            pg = ProgressBar(num_minibatch, 'Testing', length=50)
            model.eval()
            if criterion.has_discriminator:
                criterion.discr_eval()
            for iteration, batch in enumerate(testing_data_loader, 0):
                pg.print_progress_bar(iteration)
                input, flow, target = batch[0].to(device), batch[1].to(
                    device), batch[2].to(device)
                B, _, Cout, Hhigh, Whigh = target.shape
                _, _, Cin, H, W = input.shape

                previous_output = None
                for j in range(dataset_data.num_frames):
                    # prepare input
                    if j == 0 or opt.disableTemporal:
                        previous_warped = initialImage(input[:, 0, :, :, :],
                                                       Cout, opt.initialImage,
                                                       False, upscale_factor)
                        # loss takes the ground truth current image as warped previous image,
                        # to not introduce a bias and big loss for the first image
                        previous_warped_loss = target[:, 0, :, :, :]
                        previous_input = F.interpolate(input[:, 0, :, :, :],
                                                       size=(Hhigh, Whigh),
                                                       mode='bilinear')
                    else:
                        previous_warped = models.VideoTools.warp_upscale(
                            previous_output,
                            flow[:, j - 1, :, :, :],
                            upscale_factor,
                            special_mask=True)
                        previous_warped_loss = previous_warped
                        previous_input = F.interpolate(input[:,
                                                             j - 1, :, :, :],
                                                       size=(Hhigh, Whigh),
                                                       mode='bilinear')
                        previous_input = models.VideoTools.warp_upscale(
                            previous_input,
                            flow[:, j - 1, :, :, :],
                            upscale_factor,
                            special_mask=True)
                    # TODO: enable temporal component again
                    #previous_warped_flattened = models.VideoTools.flatten_high(previous_warped, opt.upscale_factor)
                    #single_input = torch.cat((
                    #        input[:,j,:,:,:],
                    #        previous_warped_flattened),
                    #    dim=1)
                    single_input = input[:, j, :, :, :]
                    # run generator
                    heatMap = model(single_input)
                    heatMapCrop = heatMap[:, opt.lossBorderPadding:-opt.
                                          lossBorderPadding,
                                          opt.lossBorderPadding:-opt.
                                          lossBorderPadding]
                    heatmap_min = min(heatmap_min,
                                      torch.min(heatMapCrop).item())
                    heatmap_max = max(heatmap_max,
                                      torch.max(heatMapCrop).item())
                    heatmap_avg += torch.mean(heatMapCrop).item()
                    heatmap_count += 1
                    heatMap = postprocess(heatMap)
                    prediction = importance.adaptiveSmoothing(
                        target[:, j, :, :, :].contiguous(),
                        1 / heatMap.unsqueeze(1),
                        opt.distanceToStandardDeviation)
                    # evaluate cost
                    input_high = F.interpolate(input[:, j, :, :, :],
                                               size=(Hhigh, Whigh),
                                               mode='bilinear')
                    loss0, loss_values = criterion(target[:, j, :, :, :],
                                                   prediction, input_high,
                                                   previous_input,
                                                   previous_warped_loss)
                    avg_losses['total_loss'] += loss0.item()
                    psnr = 10 * log10(
                        1 / max(1e-10, loss_values[('mse', 'color')]))
                    avg_losses['psnr'] += psnr
                    for key, value in loss_values.items():
                        avg_losses[str(key)] += value

                    # save output for next frame
                    previous_output = torch.cat(
                        [
                            torch.clamp(prediction[:, 0:1, :, :], -1,
                                        +1),  # mask
                            ScreenSpaceShading.normalize(
                                prediction[:, 1:4, :, :], dim=1),
                            torch.clamp(prediction[:, 4:5, :, :], 0,
                                        +1),  # depth
                            torch.clamp(prediction[:, 5:6, :, :], 0, +1)  # ao
                        ],
                        dim=1)
            pg.print_progress_bar(num_minibatch)
        for key in avg_losses.keys():
            avg_losses[key] /= num_minibatch * dataset_data.num_frames
        print("===> Avg. PSNR: {:.4f} dB".format(avg_losses['psnr']))
        print("  losses:", avg_losses)
        for key, value in avg_losses.items():
            writer.add_scalar('test/%s' % key, value, epoch)
        print("  heatmap: min=%f, max=%f, avg=%f" %
              (heatmap_min, heatmap_max, heatmap_avg / heatmap_count))
        writer.add_scalar('test/heatmap_min', heatmap_min, epoch)
        writer.add_scalar('test/heatmap_max', heatmap_max, epoch)
        writer.add_scalar('test/heatmap_avg', heatmap_avg / heatmap_count,
                          epoch)
コード例 #13
0
def test_images(epoch):
    def write_image(img, filename):
        out_img = img.cpu().detach().numpy()
        out_img *= 255.0
        out_img = out_img.clip(0, 255)
        out_img = np.uint8(out_img)
        writer.add_image(filename, out_img, epoch)

    with torch.no_grad():
        num_minibatch = len(testing_full_data_loader)
        pg = ProgressBar(num_minibatch, 'Test %d Images'%num_minibatch, length=50)
        model.eval()
        if criterion.has_discriminator:
            criterion.discr_eval()
        for i,batch in enumerate(testing_full_data_loader):
            pg.print_progress_bar(i)
            input, flow = batch[0].to(device), batch[1].to(device)
            B, _, Cin, H, W = input.shape
            Hhigh = H * opt.upscale_factor
            Whigh = W * opt.upscale_factor
            Cout = output_channels

            channel_mask = [1, 2, 3] #normal

            previous_output = None
            for j in range(dataset_data.num_frames):
                # prepare input
                if j == 0 or opt.disableTemporal:
                    previous_warped = initialImage(input[:,0,:,:,:], Cout, 
                                               opt.initialImage, False, opt.upscale_factor)
                else:
                    previous_warped = models.VideoTools.warp_upscale(
                        previous_output, 
                        flow[:, j-1, :, :, :], 
                        opt.upscale_factor,
                        special_mask = True)
                previous_warped_flattened = models.VideoTools.flatten_high(previous_warped, opt.upscale_factor)
                single_input = torch.cat((
                        input[:,j,:,:,:],
                        previous_warped_flattened),
                    dim=1)
                # write warped previous frame
                write_image(previous_warped[0, channel_mask], 'image%03d/frame%03d_warped' % (i, j))
                # run generator and cost
                prediction, residual = model(single_input)
                # normalize normal
                prediction[:,1:4,:,:] = ScreenSpaceShading.normalize(prediction[:,1:4,:,:], dim=1)
                # write prediction image
                write_image(prediction[0, channel_mask], 'image%03d/frame%03d_prediction' % (i, j))
                # write residual image
                if residual is not None:
                    write_image(residual[0, channel_mask], 'image%03d/frame%03d_residual' % (i, j))
                # write shaded image if network runs in deferredShading mode
                shaded_image = shading(prediction)
                write_image(shaded_image[0], 'image%03d/frame%03d_shaded' % (i, j))
                # write mask
                write_image(prediction[0, 0:1, :, :]*0.5+0.5, 'image%03d/frame%03d_mask' % (i, j))
                # write ambient occlusion
                # write mask
                write_image(prediction[0, 5:6, :, :], 'image%03d/frame%03d_ao' % (i, j))
                # save output for next frame
                previous_output = torch.cat([
                    torch.clamp(prediction[:,0:1,:,:], -1, +1), # mask
                    prediction[:,1:4,:,:], #already normalized
                    torch.clamp(prediction[:,4:5,:,:], 0, +1), # depth
                    torch.clamp(prediction[:,5:6,:,:], 0, +1) # ao
                    ], dim=1)
        pg.print_progress_bar(num_minibatch)

    print("Test images sent to Tensorboard for visualization")
コード例 #14
0
def test(epoch):
    avg_psnr = 0
    avg_losses = defaultdict(float)
    with torch.no_grad():
        num_minibatch = len(testing_data_loader)
        pg = ProgressBar(num_minibatch, 'Testing', length=50)
        model.eval()
        if criterion.has_discriminator:
            criterion.discr_eval()
        for iteration, batch in enumerate(testing_data_loader, 0):
            pg.print_progress_bar(iteration)
            input, flow, target = batch[0].to(device), batch[1].to(device), batch[2].to(device)
            B, _, Cout, Hhigh, Whigh = target.shape
            _, _, Cin, H, W = input.shape

            previous_output = None
            for j in range(dataset_data.num_frames):
                # prepare input
                if j == 0 or opt.disableTemporal:
                    previous_warped = initialImage(input[:,0,:,:,:], Cout, 
                                               opt.initialImage, False, opt.upscale_factor)
                    # loss takes the ground truth current image as warped previous image,
                    # to not introduce a bias and big loss for the first image
                    previous_warped_loss = target[:,0,:,:,:]
                    previous_input = F.interpolate(input[:,0,:,:,:], size=(Hhigh, Whigh), mode=opt.upsample)
                else:
                    previous_warped = models.VideoTools.warp_upscale(
                        previous_output, 
                        flow[:, j-1, :, :, :], 
                        opt.upscale_factor,
                        special_mask = True)
                    previous_warped_loss = previous_warped
                    previous_input = F.interpolate(input[:,j-1,:,:,:], size=(Hhigh, Whigh), mode=opt.upsample)
                    previous_input = models.VideoTools.warp_upscale(
                        previous_input, 
                        flow[:, j-1, :, :, :], 
                        opt.upscale_factor,
                        special_mask = True)
                previous_warped_flattened = models.VideoTools.flatten_high(previous_warped, opt.upscale_factor)
                single_input = torch.cat((
                        input[:,j,:,:,:],
                        previous_warped_flattened),
                    dim=1)
                # run generator
                prediction, _ = model(single_input)
                # evaluate cost
                input_high = F.interpolate(input[:,j,:,:,:], size=(Hhigh, Whigh), mode=opt.upsample)
                loss0, loss_values = criterion(
                    target[:,j,:,:,:], 
                    prediction, 
                    input_high,
                    previous_input,
                    previous_warped_loss)
                avg_losses['total_loss'] += loss0.item()
                psnr = 10 * log10(1 / max(1e-10, loss_values[('mse','color')]))
                avg_losses['psnr'] += psnr
                for key, value in loss_values.items():
                    avg_losses[str(key)] += value
                ## extra: evaluate discriminator on ground truth data
                #if criterion.discriminator is not None:
                #    input_high = F.interpolate(input[:,j,:,:,:], size=(Hhigh, Whigh), mode=opt.upsample)
                #    if criterion.discriminator_use_previous_image:
                #        gt_prev_warped = models.VideoTools.warp_upscale(
                #            target[:,j-1,:,:,:],
                #            flow[:, j-1, :, :, :], 
                #            opt.upscale_factor,
                #            special_mask = True)
                #        input_images = torch.cat([input_high, target[:,j,:,:,:], gt_prev_warped], dim=1)
                #    else:
                #        input_images = torch.cat([input_high, target[:,j,:,:,:]], dim=1)
                #    input_images = losses.LossNet.pad(input_images, criterion.padding)
                #    discr_gt = criterion.adv_loss(criterion.discriminator(input_images))
                #    avg_losses['discr_gt'] += discr_gt.item()

                # save output for next frame
                previous_output = torch.cat([
                    torch.clamp(prediction[:,0:1,:,:], -1, +1), # mask
                    ScreenSpaceShading.normalize(prediction[:,1:4,:,:], dim=1),
                    torch.clamp(prediction[:,4:5,:,:], 0, +1), # depth
                    torch.clamp(prediction[:,5:6,:,:], 0, +1) # ao
                    ], dim=1)
        pg.print_progress_bar(num_minibatch)
    for key in avg_losses.keys():
        avg_losses[key] /= num_minibatch * dataset_data.num_frames
    print("===> Avg. PSNR: {:.4f} dB".format(avg_losses['psnr']))
    print("  losses:",avg_losses)
    for key, value in avg_losses.items():
        writer.add_scalar('test/%s'%key, value, epoch)
コード例 #15
0
def trainNormal(epoch):
    epoch_loss = 0
    scheduler.step()
    num_minibatch = len(training_data_loader)
    pg = ProgressBar(num_minibatch, 'Training', length=50)
    model.train()
    for iteration, batch in enumerate(training_data_loader, 0):
        pg.print_progress_bar(iteration)
        input, flow, target = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        B, _, Cout, Hhigh, Whigh = target.shape
        _, _, Cin, H, W = input.shape
        assert(Cout == output_channels)
        assert(Cin == input_channels)
        assert(H == dataset_data.crop_size)
        assert(W == dataset_data.crop_size)
        assert(Hhigh == dataset_data.crop_size * opt.upscale_factor)
        assert(Whigh == dataset_data.crop_size * opt.upscale_factor)

        optimizer.zero_grad()

        previous_output = None
        loss = 0
        for j in range(dataset_data.num_frames):
            # prepare input
            if j == 0 or opt.disableTemporal:
                previous_warped = initialImage(input[:,0,:,:,:], Cout, 
                                               opt.initialImage, False, opt.upscale_factor)
                # loss takes the ground truth current image as warped previous image,
                # to not introduce a bias and big loss for the first image
                previous_warped_loss = target[:,0,:,:,:]
                previous_input = F.interpolate(input[:,0,:,:,:], size=(Hhigh, Whigh), mode=opt.upsample)
            else:
                previous_warped = models.VideoTools.warp_upscale(
                    previous_output, 
                    flow[:, j-1, :, :, :], 
                    opt.upscale_factor,
                    special_mask = True)
                previous_warped_loss = previous_warped
                previous_input = F.interpolate(input[:,j-1,:,:,:], size=(Hhigh, Whigh), mode=opt.upsample)
                previous_input = models.VideoTools.warp_upscale(
                    previous_input, 
                    flow[:, j-1, :, :, :], 
                    opt.upscale_factor,
                    special_mask = True)
            previous_warped_flattened = models.VideoTools.flatten_high(previous_warped, opt.upscale_factor)
            single_input = torch.cat((
                    input[:,j,:,:,:],
                    previous_warped_flattened),
                dim=1)
            # run generator
            prediction, _ = model(single_input)
            # evaluate cost
            input_high = F.interpolate(input[:,j,:,:,:], size=(Hhigh, Whigh), mode=opt.upsample)
            loss0,_ = criterion(
                target[:,j,:,:,:], 
                prediction, 
                input_high,
                previous_input,
                previous_warped_loss)
            del _
            loss += loss0
            epoch_loss += loss0.item()
            # save output
            previous_output = torch.cat([
                torch.clamp(prediction[:,0:1,:,:], -1, +1), # mask
                ScreenSpaceShading.normalize(prediction[:,1:4,:,:], dim=1),
                torch.clamp(prediction[:,4:5,:,:], 0, +1), # depth
                torch.clamp(prediction[:,5:6,:,:], 0, +1) # ao
                ], dim=1)

        loss.backward()
        optimizer.step()
    pg.print_progress_bar(num_minibatch)
    epoch_loss /= num_minibatch * dataset_data.num_frames
    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss))
    writer.add_scalar('train/total_loss', epoch_loss, epoch)
    writer.add_scalar('train/lr', scheduler.get_lr()[0], epoch)
コード例 #16
0
    def forward(self,
                gt,
                pred,
                prev_pred_warped,
                no_temporal_loss: bool,
                use_checkpoints=False):
        """
        gt: ground truth high resolution image (B x C x H x W)
        pred: predicted high resolution image (B x C x H x W)
        prev_pred_warped: predicted image from the previous frame warped by the flow
               Shape: B x C x H x W
               Only used for temporal losses, can be None if only the other losses are used
        use_checkpoints: True if checkpointing should be used.
               This does not apply here since we don't have GANs or perceptual losses
        """

        # TODO: loss that penalizes deviation from the input samples

        B, C, H, W = gt.shape
        if self.has_flow:
            assert C == 8
        else:
            assert C == 6
        assert gt.shape == pred.shape

        # zero border padding
        gt = LossNetSparse.pad(gt, self.padding)
        pred = LossNetSparse.pad(pred, self.padding)
        if prev_pred_warped is not None:
            prev_pred_warped = LossNetSparse.pad(prev_pred_warped,
                                                 self.padding)

        # extract mask and normal
        gt_mask = gt[:, 0:1, :, :]
        gt_mask_clamp = torch.clamp(gt_mask * 0.5 + 0.5, 0, 1)
        gt_normal = ScreenSpaceShading.normalize(gt[:, 1:4, :, :], dim=1)
        gt_depth = gt[:, 4:5, :, :]
        gt_ao = gt[:, 5:6, :, :]
        if self.has_flow:
            gt_flow = gt[:, 6:8, :, :]
        pred_mask = pred[:, 0:1, :, :]
        pred_mask_clamp = torch.clamp(pred_mask * 0.5 + 0.5, 0, 1)
        pred_normal = ScreenSpaceShading.normalize(pred[:, 1:4, :, :], dim=1)
        pred_depth = pred[:, 4:5, :, :]
        pred_ao = pred[:, 5:6, :, :]
        if self.has_flow:
            pred_flow = pred[:, 6:8, :, :]
        if prev_pred_warped is not None and self.has_temporal_l2_loss:
            prev_pred_mask = prev_pred_warped[:, 0:1, :, :]
            prev_pred_mask_clamp = torch.clamp(prev_pred_mask * 0.5 + 0.5, 0,
                                               1)
            prev_pred_normal = ScreenSpaceShading.normalize(
                prev_pred_warped[:, 1:4, :, :], dim=1)
            prev_pred_depth = prev_pred_warped[:, 4:5, :, :]
            prev_pred_ao = prev_pred_warped[:, 5:6, :, :]
            if self.has_flow:
                prev_pred_flow = prev_pred_warped[:, 6:8, :, :]

        generator_loss = 0.0
        loss_values = {}

        # normal, simple losses, uses gt+pred
        for name in ['mse', 'l1', 'ssim', 'dssim', 'lpips']:
            if (name, 'mask') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_mask, pred_mask)
                loss_values[(name, 'mask')] = loss.item()
                generator_loss += self.weight_dict[(name, 'mask')] * loss
            if (name, 'normal') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_normal * gt_mask_clamp,
                                            pred_normal * gt_mask_clamp)
                if torch.isnan(loss).item():
                    test = self.loss_dict[name](gt_normal * gt_mask_clamp,
                                                pred_normal * gt_mask_clamp)
                loss_values[(name, 'normal')] = loss.item()
                generator_loss += self.weight_dict[(name, 'normal')] * loss
            if (name, 'ao') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_ao * gt_mask_clamp,
                                            pred_ao * gt_mask_clamp)
                loss_values[(name, 'ao')] = loss.item()
                generator_loss += self.weight_dict[(name, 'ao')] * loss
            if (name, 'depth') in self.weight_dict.keys():
                loss = self.loss_dict[name](gt_depth * gt_mask_clamp,
                                            pred_depth * gt_mask_clamp)
                loss_values[(name, 'depth')] = loss.item()
                generator_loss += self.weight_dict[(name, 'depth')] * loss
            if (name, 'flow') in self.weight_dict.keys():
                # note: flow is not restricted to inside regions
                loss = self.loss_dict[name](gt_flow, pred_flow)
                loss_values[(name, 'flow')] = loss.item()
                generator_loss += self.weight_dict[(name, 'flow')] * loss
            if (name, 'color') in self.weight_dict.keys():
                gt_color = self.shading(gt)
                pred_color = self.shading(pred)
                loss = self.loss_dict[name](gt_color, pred_color)
                loss_values[(name, 'color')] = loss.item()
                generator_loss += self.weight_dict[(name, 'color')] * loss

        if ("bounded", "mask") in self.weight_dict.keys():
            # penalizes if the mask diverges too far away from [0,1]
            zero = torch.zeros(1,
                               1,
                               1,
                               1,
                               dtype=pred_mask.dtype,
                               device=pred_mask.device)
            loss = torch.mean(torch.max(zero, pred_mask * pred_mask - 2))
            loss_values[("bounded", "mask")] = loss.item()
            generator_loss += self.weight_dict[("bounded", "mask")] * loss

        if ("bce", "mask") in self.weight_dict.keys():
            # binary cross entry loss between the unclamped masks
            loss = F.binary_cross_entropy_with_logits(pred_mask * 0.5 + 0.5,
                                                      gt_mask * 0.5 + 0.5)
            loss_values[("bce", "mask")] = loss.item()
            generator_loss += self.weight_dict[("bce", "mask")] * loss

        # temporal l2 loss, uses input (for the mask) + pred + prev_warped
        if self.has_temporal_l2_loss and not no_temporal_loss:
            assert prev_pred_warped is not None
            for name in ['temp-l2', 'temp-l1']:
                if (name, 'mask') in self.weight_dict.keys():
                    loss = self.loss_dict[name](pred_mask, prev_pred_mask)
                    loss_values[(name, 'mask')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'mask')] * loss
                if (name, 'normal') in self.weight_dict.keys():
                    loss = self.loss_dict[name](pred_normal * gt_mask_clamp,
                                                prev_pred_normal *
                                                gt_mask_clamp)
                    loss_values[(name, 'normal')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'normal')] * loss
                if (name, 'ao') in self.weight_dict.keys():
                    prev_pred_ao = prev_pred_warped[:, 5:6, :, :]
                    loss = self.loss_dict[name](pred_ao * gt_mask_clamp,
                                                prev_pred_ao * gt_mask_clamp)
                    loss_values[(name, 'ao')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'ao')] * loss
                if (name, 'depth') in self.weight_dict.keys():
                    prev_pred_depth = prev_pred_warped[:, 4:5, :, :]
                    loss = self.loss_dict[name](pred_depth * gt_mask_clamp,
                                                prev_pred_depth *
                                                gt_mask_clamp)
                    loss_values[(name, 'depth')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'depth')] * loss
                if (name, 'flow') in self.weight_dict.keys():
                    prev_pred_depth = prev_pred_warped[:, 4:5, :, :]
                    loss = self.loss_dict[name](
                        pred_flow,  #note: no restriction to inside areas
                        prev_pred_flow)
                    loss_values[(name, 'flow')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'flow')] * loss
                if (name, 'color') in self.weight_dict.keys():
                    prev_pred_color = self.shading(prev_pred_warped)
                    loss = self.loss_dict[name](pred_color, prev_pred_color)
                    loss_values[(name, 'color')] = loss.item()
                    generator_loss += self.weight_dict[(name, 'color')] * loss

        return generator_loss, loss_values