示例#1
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    eval_rendering = kwargs['eval_rendering']
    eval_animation = kwargs['eval_animation']

    expname = conf.get_string('train.expname') + kwargs['expname']
    scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int(
        'dataset.scan_id', default=-1)
    if scan_id != -1:
        expname = expname + '_{0}'.format(scan_id)

    if kwargs['timestamp'] == 'latest':
        if os.path.exists(
                os.path.join('../', kwargs['exps_folder_name'], expname)):
            timestamps = os.listdir(
                os.path.join('../', kwargs['exps_folder_name'], expname))
            if (len(timestamps)) == 0:
                print('WRONG EXP FOLDER')
                exit()
            else:
                timestamp = sorted(timestamps)[-1]
        else:
            print('WRONG EXP FOLDER')
            exit()
    else:
        timestamp = kwargs['timestamp']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir = os.path.join('../', exps_folder_name, expname)
    evaldir = os.path.join('../', evals_folder_name, expname)
    utils.mkdir_ifnotexists(evaldir)

    dataset_conf = conf.get_config('dataset')
    model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'),\
                                                                  id=scan_id, datadir=dataset_conf['data_dir'])
    if torch.cuda.is_available():
        model.cuda()

    if kwargs['scan_id'] != -1:
        dataset_conf['scan_id'] = kwargs['scan_id']
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        False, **dataset_conf)

    if eval_rendering:
        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=eval_dataset.collate_fn)
        total_pixels = eval_dataset.total_pixels
        img_res = eval_dataset.img_res

    old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')

    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     str(kwargs['checkpoint']) + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])
    epoch = saved_model_state['epoch']

    ####################################################################################################################
    print("evaluating...")

    model.eval()

    detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh(
        model)
    detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(evaldir, epoch),
                       'obj')
    detail_3dmm_subdivision_full.export(
        '{0}/Subdivide_full_{1}.obj'.format(evaldir, epoch), 'obj')

    if eval_animation:
        sdf_np0, sdf_np1 = plt.get_displacement_animation(model)
        np.save('{0}/Cropped_Detailed_sdf_{1}.npy'.format(evaldir, epoch),
                sdf_np0)
        np.save('{0}/Cropped_Subdivide_full_{1}.npy'.format(evaldir, epoch),
                sdf_np1)

    if eval_rendering:
        images_dir = '{0}/rendering'.format(evaldir)
        utils.mkdir_ifnotexists(images_dir)

        psnrs = []
        for data_index, (indices, model_input,
                         ground_truth) in enumerate(eval_dataloader):
            model_input["intrinsics"] = model_input["intrinsics"].cuda()
            model_input["uv"] = model_input["uv"].cuda()
            model_input["object_mask"] = model_input["object_mask"].cuda()
            model_input['pose'] = model_input['pose'].cuda()

            split = utils.split_input(model_input, total_pixels)
            res = []
            for s in split:
                out = model(s)
                res.append({
                    'rgb_values': out['rgb_values'].detach(),
                    'diffuse_values': out['diffuse_values'].detach(),
                    'specular_values': out['specular_values'].detach(),
                    'albedo_values': out['albedo_values'].detach(),
                })

            batch_size = ground_truth['rgb'].shape[0]
            model_outputs = utils.merge_output(res, total_pixels, batch_size)
            rgb_eval = model_outputs['rgb_values']
            rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)
            rgb_eval = (rgb_eval + 1.) / 2.
            rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
            rgb_eval = rgb_eval.transpose(1, 2, 0)
            img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}.png'.format(images_dir,
                                               '%03d' % indices[0]))

            diffuse_eval = model_outputs['diffuse_values']
            diffuse_eval = diffuse_eval.reshape(batch_size, total_pixels, 3)
            diffuse_eval = (diffuse_eval + 1.) / 2.
            diffuse_eval = plt.lin2img(diffuse_eval,
                                       img_res).detach().cpu().numpy()[0]
            diffuse_eval = diffuse_eval.transpose(1, 2, 0)
            img = Image.fromarray((diffuse_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_diffuse.png'.format(images_dir,
                                                       '%03d' % indices[0]))

            specular_eval = model_outputs['specular_values']
            specular_eval = specular_eval.reshape(batch_size, total_pixels, 3)
            specular_eval = (specular_eval + 1.) / 2.
            specular_eval = plt.lin2img(specular_eval,
                                        img_res).detach().cpu().numpy()[0]
            specular_eval = specular_eval.transpose(1, 2, 0)
            img = Image.fromarray((specular_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_specular.png'.format(
                images_dir, '%03d' % indices[0]))

            albedo_eval = model_outputs['albedo_values']
            albedo_eval = albedo_eval.reshape(batch_size, total_pixels, 3)
            albedo_eval = (albedo_eval + 1.) / 2.
            albedo_eval = plt.lin2img(albedo_eval,
                                      img_res).detach().cpu().numpy()[0]
            albedo_eval = albedo_eval.transpose(1, 2, 0)
            img = Image.fromarray((albedo_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_albedo.png'.format(images_dir,
                                                      '%03d' % indices[0]))

            rgb_gt = ground_truth['rgb']
            rgb_gt = (rgb_gt + 1.) / 2.
            rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0]
            rgb_gt = rgb_gt.transpose(1, 2, 0)

            mask = model_input['object_mask']
            mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0]
            mask = mask.transpose(1, 2, 0)

            rgb_eval_masked = rgb_eval * mask
            rgb_gt_masked = rgb_gt * mask

            psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask)
            psnrs.append(psnr)

        psnrs = np.array(psnrs).astype(np.float64)
        print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".
              format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
示例#2
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    timestamp = '2020'
    checkpoint = '2000'

    expname = conf.get_string('train.expname')

    geometry_id = kwargs['geometry_id']
    appearance_id = kwargs['appearance_id']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir_geometry = os.path.join('../', exps_folder_name, expname + '_{0}'.format(geometry_id))
    expdir_appearance = os.path.join('../', exps_folder_name, expname + '_{0}'.format(appearance_id))
    evaldir = os.path.join('../', evals_folder_name, expname + '_{0}_{1}'.format(geometry_id, appearance_id))
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    # Load geometry network model
    old_checkpnts_dir = os.path.join(expdir_geometry, timestamp, 'checkpoints')
    saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])

    # Load rendering network model
    model_fake = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model_fake.cuda()
    old_checkpnts_dir = os.path.join(expdir_appearance, timestamp, 'checkpoints')
    saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth"))
    model_fake.load_state_dict(saved_model_state["model_state_dict"])

    model.rendering_network = model_fake.rendering_network

    dataset_conf = conf.get_config('dataset')
    dataset_conf['scan_id'] = geometry_id
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(False, **dataset_conf)

    eval_dataloader = torch.utils.data.DataLoader(eval_dataset,
                                                  batch_size=1,
                                                  shuffle=True,
                                                  collate_fn=eval_dataset.collate_fn
                                                  )
    total_pixels = eval_dataset.total_pixels
    img_res = eval_dataset.img_res

    ####################################################################################################################
    print("evaluating...")

    model.eval()

    gt_pose = eval_dataset.get_gt_pose(scaled=True).cuda()
    gt_quat = rend_util.rot_to_quat(gt_pose[:, :3, :3])
    gt_pose_vec = torch.cat([gt_quat, gt_pose[:, :3, 3]], 1)

    indices_all = [11, 16, 34, 28, 11]
    pose = gt_pose_vec[indices_all, :]
    t_in = np.array([0, 2, 3, 5, 6]).astype(np.float32)

    n_inter = 5
    t_out = np.linspace(t_in[0], t_in[-1], n_inter * t_in[-1]).astype(np.float32)

    scales = np.array([4.2, 4.2, 3.8, 3.8, 4.2]).astype(np.float32)

    s_new = CubicSpline(t_in, scales, bc_type='periodic')
    s_new = s_new(t_out)

    q_new = CubicSpline(t_in, pose[:, :4].detach().cpu().numpy(), bc_type='periodic')
    q_new = q_new(t_out)
    q_new = q_new / np.linalg.norm(q_new, 2, 1)[:, None]
    q_new = torch.from_numpy(q_new).cuda().float()

    images_dir = '{0}/novel_views_rendering'.format(evaldir)
    utils.mkdir_ifnotexists(images_dir)

    indices, model_input, ground_truth = next(iter(eval_dataloader))

    for i, (new_q, scale) in enumerate(zip(q_new, s_new)):
        torch.cuda.empty_cache()

        new_q = new_q.unsqueeze(0)
        new_t = -rend_util.quat_to_rot(new_q)[:, :, 2] * scale

        new_p = torch.eye(4).float().cuda().unsqueeze(0)
        new_p[:, :3, :3] = rend_util.quat_to_rot(new_q)
        new_p[:, :3, 3] = new_t

        sample = {
            "object_mask": torch.zeros_like(model_input['object_mask']).cuda().bool(),
            "uv": model_input['uv'].cuda(),
            "intrinsics": model_input['intrinsics'].cuda(),
            "pose": new_p
        }

        split = utils.split_input(sample, total_pixels)
        res = []
        for s in split:
            out = model(s)
            res.append({
                'rgb_values': out['rgb_values'].detach(),
            })

        batch_size = 1
        model_outputs = utils.merge_output(res, total_pixels, batch_size)
        rgb_eval = model_outputs['rgb_values']
        rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

        rgb_eval = (rgb_eval + 1.) / 2.
        rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
        rgb_eval = rgb_eval.transpose(1, 2, 0)
        img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
        img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % i))
示例#3
0
    def run(self):
        print("training...")

        pbar = tqdm(range(self.start_epoch, self.nepochs + 1))
        pbar.set_description(f'Training IDR',)
        for epoch in pbar:

            if epoch in self.alpha_milestones:
                self.loss.alpha = self.loss.alpha * self.alpha_factor

            if epoch % 100 == 0:
                self.save_checkpoints(epoch)

            if epoch % self.plot_freq == 0:
                self.model.eval()
                if self.train_cameras:
                    self.pose_vecs.eval()
                self.train_dataset.change_sampling_idx(-1)
                indices, model_input, ground_truth = next(iter(self.plot_dataloader))

                model_input["intrinsics"] = utils.to_cuda(model_input["intrinsics"])
                model_input["uv"] = utils.to_cuda(model_input["uv"])
                model_input["object_mask"] = utils.to_cuda(model_input["object_mask"])

                if self.train_cameras:
                    pose_input = self.pose_vecs(utils.to_cuda(indices))
                    model_input['pose'] = pose_input
                else:
                    model_input['pose'] = utils.to_cuda(model_input['pose'])

                split = utils.split_input(model_input, self.total_pixels)
                res = []
                for s in split:
                    out = self.model(s)
                    res.append({
                        'points': out['points'].detach(),
                        'rgb_values': out['rgb_values'].detach(),
                        'network_object_mask': out['network_object_mask'].detach(),
                        'object_mask': out['object_mask'].detach()
                    })

                batch_size = ground_truth['rgb'].shape[0]
                model_outputs = utils.merge_output(res, self.total_pixels, batch_size)

                plt.plot(self.model,
                         indices,
                         model_outputs,
                         model_input['pose'],
                         ground_truth['rgb'],
                         self.plots_dir,
                         epoch,
                         self.img_res,
                         **self.plot_conf
                         )

                self.model.train()
                if self.train_cameras:
                    self.pose_vecs.train()

            self.train_dataset.change_sampling_idx(self.num_pixels)

            for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader):

                model_input["intrinsics"] = utils.to_cuda(model_input["intrinsics"])
                model_input["uv"] = utils.to_cuda(model_input["uv"])
                model_input["object_mask"] = utils.to_cuda(model_input["object_mask"])

                if self.train_cameras:
                    pose_input = self.pose_vecs(utils.to_cuda(indices))
                    model_input['pose'] = pose_input
                else:
                    model_input['pose'] = utils.to_cuda(model_input['pose'])

                model_outputs = self.model(model_input)
                loss_output = self.loss(model_outputs, ground_truth)

                loss = loss_output['loss']

                self.optimizer.zero_grad()
                if self.train_cameras:
                    self.optimizer_cam.zero_grad()

                loss.backward()

                self.optimizer.step()
                if self.train_cameras:
                    self.optimizer_cam.step()

            pbar.set_postfix({
                'loss':  loss.item(),
                'rgb_loss': loss_output['rgb_loss'].item(),
                'eikonal_loss': loss_output['eikonal_loss'].item(),
                'mask_loss': loss_output['mask_loss'].item(),
                'alpha': self.loss.alpha,
                'lr': self.scheduler.get_lr()[0]
                })

            self.scheduler.step()
示例#4
0
    def run(self):
        print("training...")

        for epoch in range(self.start_epoch, self.nepochs + 1):

            if epoch in self.alpha_milestones:
                self.loss.alpha = self.loss.alpha * self.alpha_factor

            if epoch % 100 == 0:
                self.save_checkpoints(epoch)

            if epoch % self.plot_freq == 0:
                self.model.eval()
                if self.train_cameras:
                    self.pose_vecs.eval()
                self.train_dataset.change_sampling_idx(-1)
                indices, model_input, ground_truth = next(
                    iter(self.plot_dataloader))

                model_input["intrinsics"] = model_input["intrinsics"].cuda()
                model_input["uv"] = model_input["uv"].cuda()
                model_input["object_mask"] = model_input["object_mask"].cuda()

                if self.train_cameras:
                    pose_input = self.pose_vecs(indices.cuda())
                    model_input['pose'] = pose_input
                else:
                    model_input['pose'] = model_input['pose'].cuda()

                split = utils.split_input(model_input, self.total_pixels)
                res = []
                for s in split:
                    out = self.model(s)
                    res.append({
                        'points':
                        out['points'].detach(),
                        'rgb_values':
                        out['rgb_values'].detach(),
                        'network_object_mask':
                        out['network_object_mask'].detach(),
                        'object_mask':
                        out['object_mask'].detach()
                    })

                batch_size = ground_truth['rgb'].shape[0]
                model_outputs = utils.merge_output(res, self.total_pixels,
                                                   batch_size)

                plt.plot(self.model, indices, model_outputs,
                         model_input['pose'], ground_truth['rgb'],
                         self.plots_dir, epoch, self.img_res, **self.plot_conf)

                self.model.train()
                if self.train_cameras:
                    self.pose_vecs.train()

            self.train_dataset.change_sampling_idx(self.num_pixels)

            for data_index, (indices, model_input,
                             ground_truth) in enumerate(self.train_dataloader):

                model_input["intrinsics"] = model_input["intrinsics"].cuda()
                model_input["uv"] = model_input["uv"].cuda()
                model_input["object_mask"] = model_input["object_mask"].cuda()

                if self.train_cameras:
                    pose_input = self.pose_vecs(indices.cuda())
                    model_input['pose'] = pose_input
                else:
                    model_input['pose'] = model_input['pose'].cuda()

                model_outputs = self.model(model_input)
                loss_output = self.loss(model_outputs, ground_truth)

                loss = loss_output['loss']

                self.optimizer.zero_grad()
                if self.train_cameras:
                    self.optimizer_cam.zero_grad()

                loss.backward()

                self.optimizer.step()
                if self.train_cameras:
                    self.optimizer_cam.step()

                print(
                    '{0} [{1}] ({2}/{3}): loss = {4}, rgb_loss = {5}, eikonal_loss = {6}, mask_loss = {7}, alpha = {8}, lr = {9}'
                    .format(self.expname, epoch, data_index, self.n_batches,
                            loss.item(), loss_output['rgb_loss'].item(),
                            loss_output['eikonal_loss'].item(),
                            loss_output['mask_loss'].item(), self.loss.alpha,
                            self.scheduler.get_lr()[0]))

            self.scheduler.step()
示例#5
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    eval_cameras = kwargs['eval_cameras']
    eval_rendering = kwargs['eval_rendering']

    expname = conf.get_string('train.expname') + kwargs['expname']
    scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int(
        'dataset.scan_id', default=-1)
    if scan_id != -1:
        expname = expname + '_{0}'.format(scan_id)

    if kwargs['timestamp'] == 'latest':
        if os.path.exists(
                os.path.join('../', kwargs['exps_folder_name'], expname)):
            timestamps = os.listdir(
                os.path.join('../', kwargs['exps_folder_name'], expname))
            if (len(timestamps)) == 0:
                print('WRONG EXP FOLDER')
                exit()
            else:
                timestamp = sorted(timestamps)[-1]
        else:
            print('WRONG EXP FOLDER')
            exit()
    else:
        timestamp = kwargs['timestamp']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir = os.path.join('../', exps_folder_name, expname)
    evaldir = os.path.join('../', evals_folder_name, expname)
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(
        conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    dataset_conf = conf.get_config('dataset')
    if kwargs['scan_id'] != -1:
        dataset_conf['scan_id'] = kwargs['scan_id']
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        eval_cameras, **dataset_conf)

    # settings for camera optimization
    scale_mat = eval_dataset.get_scale_mat()
    if eval_cameras:
        num_images = len(eval_dataset)
        pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda()
        pose_vecs.weight.data.copy_(eval_dataset.get_pose_init())

        gt_pose = eval_dataset.get_gt_pose()

    if eval_rendering:
        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=eval_dataset.collate_fn)
        total_pixels = eval_dataset.total_pixels
        img_res = eval_dataset.img_res

    old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')

    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     str(kwargs['checkpoint']) + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])
    epoch = saved_model_state['epoch']

    if eval_cameras:
        data = torch.load(
            os.path.join(old_checkpnts_dir, 'CamParameters',
                         str(kwargs['checkpoint']) + ".pth"))
        pose_vecs.load_state_dict(data["pose_vecs_state_dict"])

    ####################################################################################################################
    print("evaluating...")

    model.eval()
    if eval_cameras:
        pose_vecs.eval()

    with torch.no_grad():
        if eval_cameras:
            gt_Rs = gt_pose[:, :3, :3].double()
            gt_ts = gt_pose[:, :3, 3].double()

            pred_Rs = rend_util.quat_to_rot(
                pose_vecs.weight.data[:, :4]).cpu().double()
            pred_ts = pose_vecs.weight.data[:, 4:].cpu().double()

            R_opt, t_opt, c_opt, R_fixed, t_fixed = get_cameras_accuracy(
                pred_Rs, gt_Rs, pred_ts, gt_ts)

            cams_transformation = np.eye(4, dtype=np.double)
            cams_transformation[:3, :3] = c_opt * R_opt
            cams_transformation[:3, 3] = t_opt

        mesh = plt.get_surface_high_res_mesh(
            sdf=lambda x: model.implicit_network(x)[:, 0],
            resolution=kwargs['resolution'])

        # Transform to world coordinates
        if eval_cameras:
            mesh.apply_transform(cams_transformation)
        else:
            mesh.apply_transform(scale_mat)

        # Taking the biggest connected component
        components = mesh.split(only_watertight=False)
        areas = np.array([c.area for c in components], dtype=np.float)
        mesh_clean = components[areas.argmax()]
        mesh_clean.export(
            '{0}/surface_world_coordinates_{1}.ply'.format(evaldir, epoch),
            'ply')

    if eval_rendering:
        images_dir = '{0}/rendering'.format(evaldir)
        utils.mkdir_ifnotexists(images_dir)

        psnrs = []
        for data_index, (indices, model_input,
                         ground_truth) in enumerate(eval_dataloader):
            model_input["intrinsics"] = model_input["intrinsics"].cuda()
            model_input["uv"] = model_input["uv"].cuda()
            model_input["object_mask"] = model_input["object_mask"].cuda()

            if eval_cameras:
                pose_input = pose_vecs(indices.cuda())
                model_input['pose'] = pose_input
            else:
                model_input['pose'] = model_input['pose'].cuda()

            split = utils.split_input(model_input, total_pixels)
            res = []
            for s in split:
                out = model(s)
                res.append({
                    'rgb_values': out['rgb_values'].detach(),
                })

            batch_size = ground_truth['rgb'].shape[0]
            model_outputs = utils.merge_output(res, total_pixels, batch_size)
            rgb_eval = model_outputs['rgb_values']
            rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

            rgb_eval = (rgb_eval + 1.) / 2.
            rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
            rgb_eval = rgb_eval.transpose(1, 2, 0)
            img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}.png'.format(images_dir,
                                               '%03d' % indices[0]))

            rgb_gt = ground_truth['rgb']
            rgb_gt = (rgb_gt + 1.) / 2.
            rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0]
            rgb_gt = rgb_gt.transpose(1, 2, 0)

            mask = model_input['object_mask']
            mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0]
            mask = mask.transpose(1, 2, 0)

            rgb_eval_masked = rgb_eval * mask
            rgb_gt_masked = rgb_gt * mask

            psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask)
            psnrs.append(psnr)

        psnrs = np.array(psnrs).astype(np.float64)
        print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".
              format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))