示例#1
0
def train():
    # Parses indices of specific observations from comma-separated list.
    if opt.specific_observation_idcs is not None:
        specific_observation_idcs = util.parse_comma_separated_integers(
            opt.specific_observation_idcs)
    else:
        specific_observation_idcs = None

    img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths)
    batch_size_per_sidelength = util.parse_comma_separated_integers(
        opt.batch_size_per_img_sidelength)
    max_steps_per_sidelength = util.parse_comma_separated_integers(
        opt.max_steps_per_img_sidelength)

    train_dataset = dataio.SceneClassDataset(
        root_dir=opt.data_root,
        max_num_instances=opt.max_num_instances_train,
        max_observations_per_instance=opt.max_num_observations_train,
        img_sidelength=img_sidelengths[0],
        specific_observation_idcs=specific_observation_idcs,
        samples_per_instance=1)

    assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \
        "Different number of image sidelengths passed than batch sizes."
    assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \
        "Different number of image sidelengths passed than max steps."

    if not opt.no_validation:
        assert (opt.val_root is not None), "No validation directory passed."

        val_dataset = dataio.SceneClassDataset(
            root_dir=opt.val_root,
            max_num_instances=opt.max_num_instances_val,
            max_observations_per_instance=opt.max_num_observations_val,
            img_sidelength=img_sidelengths[0],
            samples_per_instance=1)
        collate_fn = val_dataset.collate_fn
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=2,
                                    shuffle=False,
                                    drop_last=True,
                                    collate_fn=val_dataset.collate_fn)

    model = SRNsModel(num_instances=train_dataset.num_instances,
                      latent_dim=opt.embedding_size,
                      has_params=opt.has_params,
                      fit_single_srn=opt.fit_single_srn,
                      use_unet_renderer=opt.use_unet_renderer,
                      tracing_steps=opt.tracing_steps,
                      freeze_networks=opt.freeze_networks)
    model.train()
    model.cuda()

    if opt.checkpoint_path is not None:
        print("Loading model from %s" % opt.checkpoint_path)
        util.custom_load(model,
                         path=opt.checkpoint_path,
                         discriminator=None,
                         optimizer=None,
                         overwrite_embeddings=opt.overwrite_embeddings)

    ckpt_dir = os.path.join(opt.logging_root, 'checkpoints')
    events_dir = os.path.join(opt.logging_root, 'events')

    util.cond_mkdir(opt.logging_root)
    util.cond_mkdir(ckpt_dir)
    util.cond_mkdir(events_dir)

    # Save command-line parameters log directory.
    with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    # Save text summary of model into log directory.
    with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file:
        out_file.write(str(model))

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    writer = SummaryWriter(events_dir)
    iter = opt.start_step
    epoch = iter // len(train_dataset)
    step = 0

    print('Beginning training...')
    # This loop implements training with an increasing image sidelength.
    cum_max_steps = 0  # Tracks max_steps cumulatively over all image sidelengths.
    for img_sidelength, max_steps, batch_size in zip(
            img_sidelengths, max_steps_per_sidelength,
            batch_size_per_sidelength):
        print("\n" + "#" * 10)
        print("Training with sidelength %d for %d steps with batch size %d" %
              (img_sidelength, max_steps, batch_size))
        print("#" * 10 + "\n")
        train_dataset.set_img_sidelength(img_sidelength)

        # Need to instantiate DataLoader every time to set new batch size.
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=train_dataset.collate_fn,
                                      pin_memory=opt.preload)

        cum_max_steps += max_steps

        # Loops over epochs.
        while True:
            for model_input, ground_truth in train_dataloader:
                model_outputs = model(model_input)

                optimizer.zero_grad()

                dist_loss = model.get_image_loss(model_outputs, ground_truth)
                reg_loss = model.get_regularization_loss(
                    model_outputs, ground_truth)
                latent_loss = model.get_latent_loss()

                weighted_dist_loss = opt.l1_weight * dist_loss
                weighted_reg_loss = opt.reg_weight * reg_loss
                weighted_latent_loss = opt.kl_weight * latent_loss

                total_loss = (weighted_dist_loss + weighted_reg_loss +
                              weighted_latent_loss)

                total_loss.backward()

                optimizer.step()

                print(
                    "Iter %07d   Epoch %03d   L_img %0.4f   L_latent %0.4f   L_depth %0.4f"
                    % (iter, epoch, weighted_dist_loss, weighted_latent_loss,
                       weighted_reg_loss))

                model.write_updates(writer, model_outputs, ground_truth, iter)
                writer.add_scalar("scaled_distortion_loss", weighted_dist_loss,
                                  iter)
                writer.add_scalar("scaled_regularization_loss",
                                  weighted_reg_loss, iter)
                writer.add_scalar("scaled_latent_loss", weighted_latent_loss,
                                  iter)
                writer.add_scalar("total_loss", total_loss, iter)

                if iter % opt.steps_til_val == 0 and not opt.no_validation:
                    print("Running validation set...")

                    model.eval()
                    with torch.no_grad():
                        psnrs = []
                        ssims = []
                        dist_losses = []
                        for model_input, ground_truth in val_dataloader:
                            model_outputs = model(model_input)

                            dist_loss = model.get_image_loss(
                                model_outputs, ground_truth).cpu().numpy()
                            psnr, ssim = model.get_psnr(
                                model_outputs, ground_truth)
                            psnrs.append(psnr)
                            ssims.append(ssim)
                            dist_losses.append(dist_loss)

                            model.write_updates(writer,
                                                model_outputs,
                                                ground_truth,
                                                iter,
                                                prefix='val_')

                        writer.add_scalar("val_dist_loss",
                                          np.mean(dist_losses), iter)
                        writer.add_scalar("val_psnr", np.mean(psnrs), iter)
                        writer.add_scalar("val_ssim", np.mean(ssims), iter)
                    model.train()

                iter += 1
                step += 1

                if iter == cum_max_steps:
                    break

                if iter % opt.steps_til_ckpt == 0:
                    util.custom_save(model,
                                     os.path.join(
                                         ckpt_dir, 'epoch_%04d_iter_%06d.pth' %
                                         (epoch, iter)),
                                     discriminator=None,
                                     optimizer=optimizer)

            if iter == cum_max_steps:
                break
            epoch += 1

    util.custom_save(model,
                     os.path.join(ckpt_dir,
                                  'epoch_%04d_iter_%06d.pth' % (epoch, iter)),
                     discriminator=None,
                     optimizer=optimizer)
示例#2
0
def train():
    discriminator.train()
    model.train()

    if opt.checkpoint:
        util.custom_load(model, opt.checkpoint, discriminator)

    # Create the training dataset loader
    train_dataset = NovelViewTriplets(root_dir=opt.data_root,
                                      img_size=input_image_dims,
                                      sampling_pattern=opt.sampling_pattern)
    dataloader = DataLoader(train_dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=8)

    # directory name contains some info about hyperparameters.
    dir_name = os.path.join(
        datetime.datetime.now().strftime('%m_%d'),
        datetime.datetime.now().strftime('%H-%M-%S_') +
        (opt.sampling_pattern + '_') + ('%0.2f_l1_weight_' % opt.l1_weight) +
        ('%d_trgt_' % opt.num_trgt) + '_' +
        opt.data_root.strip('/').split('/')[-1] + opt.experiment_name)

    log_dir = os.path.join(opt.logging_root, 'logs', dir_name)
    run_dir = os.path.join(opt.logging_root, 'runs', dir_name)

    data_util.cond_mkdir(log_dir)
    data_util.cond_mkdir(run_dir)

    # Save all command line arguments into a txt file in the logging directory for later referene.
    with open(os.path.join(log_dir, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    writer = SummaryWriter(run_dir)

    iter = opt.start_epoch * len(train_dataset)

    print('Begin training...')
    for epoch in range(opt.start_epoch, opt.max_epoch):
        for trgt_views, nearest_view in dataloader:
            backproj_mapping = projection.comp_lifting_idcs(
                camera_to_world=nearest_view['pose'].squeeze().to(device),
                grid2world=grid_origin)

            proj_mappings = list()
            for i in range(len(trgt_views)):
                proj_mappings.append(
                    projection.compute_proj_idcs(
                        trgt_views[i]['pose'].squeeze().to(device),
                        grid2world=grid_origin))

            if backproj_mapping is None:
                print("Lifting invalid")
                continue
            else:
                lift_volume_idcs, lift_img_coords = backproj_mapping

            if None in proj_mappings:
                print('Projection invalid')
                continue

            proj_frustrum_idcs, proj_grid_coords = list(zip(*proj_mappings))

            outputs, depth_maps = model(nearest_view['gt_rgb'].to(device),
                                        proj_frustrum_idcs,
                                        proj_grid_coords,
                                        lift_volume_idcs,
                                        lift_img_coords,
                                        writer=writer)

            # Convert the depth maps to metric
            for i in range(len(depth_maps)):
                depth_maps[i] = (
                    (depth_maps[i] + 0.5) *
                    int(np.ceil(np.sqrt(3) * grid_dims[-1])) * voxel_size +
                    near_plane)

            # We don't enforce a loss on the outermost 5 pixels to alleviate boundary errors
            for i in range(len(trgt_views)):
                outputs[i] = outputs[i][:, :, 5:-5, 5:-5]
                trgt_views[i]['gt_rgb'] = trgt_views[i]['gt_rgb'][:, :, 5:-5,
                                                                  5:-5]

            l1_losses = list()
            for idx in range(len(trgt_views)):
                l1_losses.append(
                    criterionL1(
                        outputs[idx].contiguous().view(-1).float(),
                        trgt_views[idx]['gt_rgb'].to(device).view(-1).float()))

            losses_d = []
            losses_g = []

            optimizerD.zero_grad()
            optimizerG.zero_grad()

            for idx in range(len(trgt_views)):
                #######
                ## Train Discriminator
                #######
                out_perm = outputs[idx]  # batch, ndf, height, width

                # Fake forward step
                pred_fake = discriminator.forward(out_perm.detach(
                ))  # Detach to make sure no gradients go into generator
                loss_d_fake = criterionGAN(pred_fake, False)

                # Real forward step
                real_input = trgt_views[idx]['gt_rgb'].float().to(device)
                pred_real = discriminator.forward(real_input)
                loss_d_real = criterionGAN(pred_real, True)

                # Combined Loss
                losses_d.append((loss_d_fake + loss_d_real) * 0.5)

                #######
                ## Train generator
                #######
                # Try to fake discriminator
                pred_fake = discriminator.forward(out_perm)
                loss_g_gan = criterionGAN(pred_fake, True)

                loss_g_l1 = l1_losses[idx] * opt.l1_weight
                losses_g.append(loss_g_gan + loss_g_l1)

            loss_d = torch.stack(losses_d, dim=0).mean()
            loss_g = torch.stack(losses_g, dim=0).mean()

            loss_d.backward()
            optimizerD.step()
            loss_g.backward()
            optimizerG.step()

            print(
                "Iter %07d   Epoch %03d   loss_gen %0.4f   loss_discrim %0.4f"
                % (iter, epoch, loss_g, loss_d))

            if not iter % 100:
                # Write tensorboard logs.
                writer.add_image(
                    "Depth",
                    torchvision.utils.make_grid(
                        [
                            depth_map.squeeze(dim=0).repeat(3, 1, 1)
                            for depth_map in depth_maps
                        ],
                        scale_each=True,
                        normalize=True).cpu().detach().numpy(), iter)
                writer.add_image(
                    "Nearest_neighbors_rgb",
                    torchvision.utils.make_grid(
                        nearest_view['gt_rgb'],
                        scale_each=True,
                        normalize=True).detach().numpy(), iter)
                output_vs_gt = torch.cat(
                    (torch.cat(outputs, dim=0),
                     torch.cat([i['gt_rgb'].to(device) for i in trgt_views],
                               dim=0)),
                    dim=0)
                writer.add_image(
                    "Output_vs_gt",
                    torchvision.utils.make_grid(
                        output_vs_gt, scale_each=True,
                        normalize=True).cpu().detach().numpy(), iter)

            writer.add_scalar("out_min", outputs[0].min(), iter)
            writer.add_scalar("out_max", outputs[0].max(), iter)

            writer.add_scalar("trgt_min", trgt_views[0]['gt_rgb'].min(), iter)
            writer.add_scalar("trgt_max", trgt_views[0]['gt_rgb'].max(), iter)

            writer.add_scalar("discrim_loss", loss_d, iter)
            writer.add_scalar("gen_loss_total", loss_g, iter)
            writer.add_scalar("gen_loss_l1", loss_g_l1, iter)
            writer.add_scalar("gen_loss_g", loss_g_gan, iter)

            iter += 1

            if iter % 10000 == 0:
                util.custom_save(
                    model,
                    os.path.join(log_dir,
                                 'model-epoch_%d_iter_%s.pth' % (epoch, iter)),
                    discriminator)

    util.custom_save(
        model,
        os.path.join(log_dir, 'model-epoch_%d_iter_%s.pth' % (epoch, iter)),
        discriminator)
def train():
    # Parses indices of specific observations from comma-separated list.
    if opt.specific_observation_idcs is not None:
        specific_observation_idcs = util.parse_comma_separated_integers(
            opt.specific_observation_idcs)
    else:
        specific_observation_idcs = None

    img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths)
    batch_size_per_sidelength = util.parse_comma_separated_integers(
        opt.batch_size_per_img_sidelength)
    max_steps_per_sidelength = util.parse_comma_separated_integers(
        opt.max_steps_per_img_sidelength)

    train_dataset = dataio.PBWDataset(train=True)

    assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \
        "Different number of image sidelengths passed than batch sizes."
    assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \
        "Different number of image sidelengths passed than max steps."

    if not opt.no_validation:
        assert (opt.val_root is not None), "No validation directory passed."

        val_dataset = dataio.PBWDataset(train=False)
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=16,
                                    shuffle=False,
                                    drop_last=True,
                                    collate_fn=val_dataset.collate_fn)

    model = SRNsModel3(latent_dim=opt.embedding_size,
                       has_params=opt.has_params,
                       fit_single_srn=True,
                       tracing_steps=opt.tracing_steps,
                       freeze_networks=opt.freeze_networks)
    model.train()
    model.cuda()
    if opt.checkpoint_path is not None:
        print("Loading model from %s" % opt.checkpoint_path)
        util.custom_load(model,
                         path=opt.checkpoint_path,
                         discriminator=None,
                         optimizer=None,
                         overwrite_embeddings=opt.overwrite_embeddings)

    ckpt_dir = os.path.join(opt.logging_root, 'checkpoints')
    events_dir = os.path.join(opt.logging_root, 'events')

    util.cond_mkdir(opt.logging_root)
    util.cond_mkdir(ckpt_dir)
    util.cond_mkdir(events_dir)

    # Save command-line parameters log directory.
    with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    # Save text summary of model into log directory.
    with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file:
        out_file.write(str(model))

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    writer = SummaryWriter(events_dir)
    iter = opt.start_step
    epoch = iter // len(train_dataset)
    step = 0

    print('Beginning training...')
    # This loop implements training with an increasing image sidelength.
    cum_max_steps = 0  # Tracks max_steps cumulatively over all image sidelengths.
    for img_sidelength, max_steps, batch_size in zip(
            img_sidelengths, max_steps_per_sidelength,
            batch_size_per_sidelength):
        print("\n" + "#" * 10)
        print("Training with sidelength %d for %d steps with batch size %d" %
              (img_sidelength, max_steps, batch_size))
        print("#" * 10 + "\n")

        # Need to instantiate DataLoader every time to set new batch size.
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=train_dataset.collate_fn,
        )

        cum_max_steps += max_steps

        # Loops over epochs.
        while True:
            for batch in train_dataloader:
                rgb, ext_mat, info, rgb_mat = batch
                ground_truth = {"rgb": rgb}
                model_input = (ext_mat, rgb_mat, info
                               )  # color, pix coord, location, box
                model_outputs = model(model_input)
                optimizer.zero_grad()

                total_loss = model.get_image_loss(model_outputs, ground_truth)
                total_loss.backward()

                optimizer.step()
                if iter % 100 == 0:
                    print("Iter %07d   Epoch %03d   L_img %0.4f" %
                          (iter, epoch, total_loss))

                if iter % opt.steps_til_val == 0 and not opt.no_validation:
                    print("Running validation set...")
                    acc = test(model, val_dataloader, str(iter))
                    print("Accuracy:", acc)

                iter += 1
                step += 1

                if iter == cum_max_steps:
                    break

            if iter == cum_max_steps:
                break
            epoch += 1

    util.custom_save(model,
                     os.path.join(ckpt_dir,
                                  'epoch_%04d_iter_%06d.pth' % (epoch, iter)),
                     discriminator=None,
                     optimizer=optimizer)
示例#4
0
def main():
    print('Start buffering data for training views...')
    view_dataset.buffer_all()
    view_dataloader = DataLoader(view_dataset,
                                 batch_size=opt.batch_size,
                                 shuffle=True,
                                 num_workers=8)

    print('Start buffering data for validation views...')
    view_val_dataset.buffer_all()
    view_val_dataloader = DataLoader(view_val_dataset,
                                     batch_size=opt.batch_size,
                                     shuffle=False,
                                     num_workers=8)

    # directory name contains some info about hyperparameters.
    dir_name = os.path.join(datetime.datetime.now().strftime('%m-%d') + '_' +
                            datetime.datetime.now().strftime('%H-%M-%S') +
                            '_' + opt.sampling_pattern + '_' +
                            opt.data_root.strip('/').split('/')[-1])
    if opt.exp_name is not '':
        dir_name += '_' + opt.exp_name

    # directory for logging
    log_dir = os.path.join(opt.logging_root, dir_name)
    data_util.cond_mkdir(log_dir)

    # directory for saving validation data on view synthesis
    val_out_dir = os.path.join(log_dir, 'val_out')
    val_gt_dir = os.path.join(log_dir, 'val_gt')
    val_err_dir = os.path.join(log_dir, 'val_err')
    data_util.cond_mkdir(val_out_dir)
    data_util.cond_mkdir(val_gt_dir)
    data_util.cond_mkdir(val_err_dir)

    # Save all command line arguments into a txt file in the logging directory for later reference.
    with open(os.path.join(log_dir, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    writer = SummaryWriter(log_dir)

    iter = opt.start_epoch * len(view_dataset)

    print('Begin training...')
    val_log_batch_id = 0
    first_val = True
    for epoch in range(opt.start_epoch, opt.max_epoch):
        for view_trgt in view_dataloader:
            start = time.time()
            # get view data
            uv_map = view_trgt[0]['uv_map'].to(device)  # [N, H, W, 2]
            sh_basis_map = view_trgt[0]['sh_basis_map'].to(
                device)  # [N, H, W, 9]
            alpha_map = view_trgt[0]['alpha_map'][:, None, :, :].to(
                device)  # [N, 1, H, W]
            img_gt = []
            for i in range(len(view_trgt)):
                img_gt.append(view_trgt[i]['img_gt'].to(device))

            # sample texture
            neural_img = texture_mapper(uv_map, sh_basis_map)

            # rendering net
            outputs = render_net(neural_img, None)
            img_max_val = 2.0
            outputs = (outputs * 0.5 +
                       0.5) * img_max_val  # map to [0, img_max_val]
            if type(outputs) is not list:
                outputs = [outputs]

            # We don't enforce a loss on the outermost 5 pixels to alleviate boundary errors, also weight loss by alpha
            alpha_map_central = alpha_map[:, :, 5:-5, 5:-5]
            for i in range(len(view_trgt)):
                outputs[i] = outputs[i][:, :, 5:-5, 5:-5] * alpha_map_central
                img_gt[i] = img_gt[i][:, :, 5:-5, 5:-5] * alpha_map_central

            # loss on final image
            loss_rn = list()
            for idx in range(len(view_trgt)):
                loss_rn.append(
                    criterionL1(outputs[idx].contiguous().view(-1).float(),
                                img_gt[idx].contiguous().view(-1).float()))
            loss_rn = torch.stack(loss_rn, dim=0).mean()

            # total loss
            loss_g = loss_rn

            optimizerG.zero_grad()
            loss_g.backward()
            optimizerG.step()

            # error metrics
            with torch.no_grad():
                err_metrics_batch_i = metric.compute_err_metrics_batch(
                    outputs[0] * 255.0,
                    img_gt[0] * 255.0,
                    alpha_map_central,
                    compute_ssim=False)

            # tensorboard scalar logs of training data
            writer.add_scalar("loss_g", loss_g, iter)
            writer.add_scalar("loss_rn", loss_rn, iter)
            writer.add_scalar("final_mae_valid",
                              err_metrics_batch_i['mae_valid_mean'], iter)
            writer.add_scalar("final_psnr_valid",
                              err_metrics_batch_i['psnr_valid_mean'], iter)

            end = time.time()
            print(
                "Iter %07d   Epoch %03d   loss_g %0.4f   mae_valid %0.4f   psnr_valid %0.4f   t_total %0.4f"
                % (iter, epoch, loss_g, err_metrics_batch_i['mae_valid_mean'],
                   err_metrics_batch_i['psnr_valid_mean'], end - start))

            # tensorboard figure logs of training data
            if not iter % opt.log_freq:
                output_final_vs_gt = []
                for i in range(len(view_trgt)):
                    output_final_vs_gt.append(outputs[i].clamp(min=0., max=1.))
                    output_final_vs_gt.append(img_gt[i].clamp(min=0., max=1.))
                    output_final_vs_gt.append(
                        (outputs[i] - img_gt[i]).abs().clamp(min=0., max=1.))
                output_final_vs_gt = torch.cat(output_final_vs_gt, dim=0)
                writer.add_image(
                    "output_final_vs_gt",
                    torchvision.utils.make_grid(
                        output_final_vs_gt,
                        nrow=outputs[0].shape[0],
                        range=(0, 1),
                        scale_each=False,
                        normalize=False).cpu().detach().numpy(), iter)

            # validation
            if not iter % opt.val_freq:
                start_val = time.time()
                with torch.no_grad():
                    # error metrics
                    err_metrics_val = {}
                    err_metrics_val['mae_valid'] = []
                    err_metrics_val['mse_valid'] = []
                    err_metrics_val['psnr_valid'] = []
                    err_metrics_val['ssim_valid'] = []
                    # loop over batches
                    batch_id = 0
                    for view_val_trgt in view_val_dataloader:
                        start_val_i = time.time()

                        # get view data
                        uv_map = view_val_trgt[0]['uv_map'].to(
                            device)  # [N, H, W, 2]
                        sh_basis_map = view_val_trgt[0]['sh_basis_map'].to(
                            device)  # [N, H, W, 9]
                        alpha_map = view_val_trgt[0][
                            'alpha_map'][:,
                                         None, :, :].to(device)  # [N, 1, H, W]
                        view_idx = view_val_trgt[0]['idx']

                        batch_size = alpha_map.shape[0]
                        img_h = alpha_map.shape[2]
                        img_w = alpha_map.shape[3]
                        num_view = len(view_val_trgt)
                        img_gt = []
                        for i in range(num_view):
                            img_gt.append(
                                view_val_trgt[i]['img_gt'].to(device))

                        # sample texture
                        neural_img = texture_mapper(uv_map, sh_basis_map)

                        # rendering net
                        outputs = render_net(neural_img, None)
                        img_max_val = 2.0
                        outputs = (outputs * 0.5 + 0.5
                                   ) * img_max_val  # map to [0, img_max_val]
                        if type(outputs) is not list:
                            outputs = [outputs]

                        # apply alpha
                        for i in range(num_view):
                            outputs[i] = outputs[i] * alpha_map
                            img_gt[i] = img_gt[i] * alpha_map

                        # tensorboard figure logs of validation data
                        if batch_id == val_log_batch_id:
                            output_final_vs_gt = []
                            for i in range(num_view):
                                output_final_vs_gt.append(outputs[i].clamp(
                                    min=0., max=1.))
                                output_final_vs_gt.append(img_gt[i].clamp(
                                    min=0., max=1.))
                                output_final_vs_gt.append(
                                    (outputs[i] - img_gt[i]).abs().clamp(
                                        min=0., max=1.))
                            output_final_vs_gt = torch.cat(output_final_vs_gt,
                                                           dim=0)
                            writer.add_image(
                                "output_final_vs_gt_val",
                                torchvision.utils.make_grid(
                                    output_final_vs_gt,
                                    nrow=batch_size,
                                    range=(0, 1),
                                    scale_each=False,
                                    normalize=False).cpu().detach().numpy(),
                                iter)

                        # error metrics
                        err_metrics_batch_i_final = metric.compute_err_metrics_batch(
                            outputs[0] * 255.0,
                            img_gt[0] * 255.0,
                            alpha_map,
                            compute_ssim=True)

                        for i in range(batch_size):
                            for key in list(err_metrics_val.keys()):
                                if key in err_metrics_batch_i_final.keys():
                                    err_metrics_val[key].append(
                                        err_metrics_batch_i_final[key][i])

                        # save images
                        for i in range(batch_size):
                            cv2.imwrite(
                                os.path.join(
                                    val_out_dir,
                                    str(iter).zfill(8) +
                                    '_' + str(view_idx[i].cpu().detach().numpy(
                                    )).zfill(5) + '.png'),
                                outputs[0][i, :].permute(
                                    (1, 2,
                                     0)).cpu().detach().numpy()[:, :, ::-1] *
                                255.)
                            cv2.imwrite(
                                os.path.join(
                                    val_err_dir,
                                    str(iter).zfill(8) +
                                    '_' + str(view_idx[i].cpu().detach().numpy(
                                    )).zfill(5) + '.png'),
                                (outputs[0] - img_gt[0]).abs().clamp(
                                    min=0., max=1.)[i, :].permute(
                                        (1, 2,
                                         0)).cpu().detach().numpy()[:, :, ::-1]
                                * 255.)
                            if first_val:
                                cv2.imwrite(
                                    os.path.join(
                                        val_gt_dir,
                                        str(view_idx[i].cpu().detach().numpy()
                                            ).zfill(5) + '.png'),
                                    img_gt[0][i, :].permute(
                                        (1, 2,
                                         0)).cpu().detach().numpy()[:, :, ::-1]
                                    * 255.)

                        end_val_i = time.time()
                        print(
                            "Val   batch %03d   mae_valid %0.4f   psnr_valid %0.4f   ssim_valid %0.4f   t_total %0.4f"
                            % (batch_id,
                               err_metrics_batch_i_final['mae_valid_mean'],
                               err_metrics_batch_i_final['psnr_valid_mean'],
                               err_metrics_batch_i_final['ssim_valid_mean'],
                               end_val_i - start_val_i))

                        batch_id += 1

                    for key in list(err_metrics_val.keys()):
                        if err_metrics_val[key]:
                            err_metrics_val[key] = np.vstack(
                                err_metrics_val[key])
                            err_metrics_val[
                                key + '_mean'] = err_metrics_val[key].mean()
                        else:
                            err_metrics_val[key + '_mean'] = np.nan

                    # tensorboard scalar logs of validation data
                    writer.add_scalar("final_mae_valid_val",
                                      err_metrics_val['mae_valid_mean'], iter)
                    writer.add_scalar("final_psnr_valid_val",
                                      err_metrics_val['psnr_valid_mean'], iter)
                    writer.add_scalar("final_ssim_valid_val",
                                      err_metrics_val['ssim_valid_mean'], iter)

                    first_val = False
                    val_log_batch_id = (val_log_batch_id + 1) % batch_id

                    end_val = time.time()
                    print(
                        "Val   mae_valid %0.4f   psnr_valid %0.4f   ssim_valid %0.4f   t_total %0.4f"
                        % (err_metrics_val['mae_valid_mean'],
                           err_metrics_val['psnr_valid_mean'],
                           err_metrics_val['ssim_valid_mean'],
                           end_val - start_val))

            iter += 1

            if iter % opt.ckp_freq == 0:
                util.custom_save(
                    os.path.join(log_dir,
                                 'model_epoch-%d_iter-%s.pth' % (epoch, iter)),
                    part_list, part_name_list)

    util.custom_save(
        os.path.join(log_dir, 'model_epoch-%d_iter-%s.pth' % (epoch, iter)),
        part_list, part_name_list)