Ejemplo n.º 1
0
def train():
    # Parses indices of specific observations from comma-separated list.
    if opt.specific_observation_idcs is not None:
        specific_observation_idcs = util.parse_comma_separated_integers(
            opt.specific_observation_idcs)
    else:
        specific_observation_idcs = None

    img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths)
    batch_size_per_sidelength = util.parse_comma_separated_integers(
        opt.batch_size_per_img_sidelength)
    max_steps_per_sidelength = util.parse_comma_separated_integers(
        opt.max_steps_per_img_sidelength)

    train_dataset = dataio.SceneClassDataset(
        root_dir=opt.data_root,
        max_num_instances=opt.max_num_instances_train,
        max_observations_per_instance=opt.max_num_observations_train,
        img_sidelength=img_sidelengths[0],
        specific_observation_idcs=specific_observation_idcs,
        samples_per_instance=1)

    assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \
        "Different number of image sidelengths passed than batch sizes."
    assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \
        "Different number of image sidelengths passed than max steps."

    if not opt.no_validation:
        assert (opt.val_root is not None), "No validation directory passed."

        val_dataset = dataio.SceneClassDataset(
            root_dir=opt.val_root,
            max_num_instances=opt.max_num_instances_val,
            max_observations_per_instance=opt.max_num_observations_val,
            img_sidelength=img_sidelengths[0],
            samples_per_instance=1)
        collate_fn = val_dataset.collate_fn
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=2,
                                    shuffle=False,
                                    drop_last=True,
                                    collate_fn=val_dataset.collate_fn)

    model = SRNsModel(num_instances=train_dataset.num_instances,
                      latent_dim=opt.embedding_size,
                      has_params=opt.has_params,
                      fit_single_srn=opt.fit_single_srn,
                      use_unet_renderer=opt.use_unet_renderer,
                      tracing_steps=opt.tracing_steps,
                      freeze_networks=opt.freeze_networks)
    model.train()
    model.cuda()

    if opt.checkpoint_path is not None:
        print("Loading model from %s" % opt.checkpoint_path)
        util.custom_load(model,
                         path=opt.checkpoint_path,
                         discriminator=None,
                         optimizer=None,
                         overwrite_embeddings=opt.overwrite_embeddings)

    ckpt_dir = os.path.join(opt.logging_root, 'checkpoints')
    events_dir = os.path.join(opt.logging_root, 'events')

    util.cond_mkdir(opt.logging_root)
    util.cond_mkdir(ckpt_dir)
    util.cond_mkdir(events_dir)

    # Save command-line parameters log directory.
    with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    # Save text summary of model into log directory.
    with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file:
        out_file.write(str(model))

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    writer = SummaryWriter(events_dir)
    iter = opt.start_step
    epoch = iter // len(train_dataset)
    step = 0

    print('Beginning training...')
    # This loop implements training with an increasing image sidelength.
    cum_max_steps = 0  # Tracks max_steps cumulatively over all image sidelengths.
    for img_sidelength, max_steps, batch_size in zip(
            img_sidelengths, max_steps_per_sidelength,
            batch_size_per_sidelength):
        print("\n" + "#" * 10)
        print("Training with sidelength %d for %d steps with batch size %d" %
              (img_sidelength, max_steps, batch_size))
        print("#" * 10 + "\n")
        train_dataset.set_img_sidelength(img_sidelength)

        # Need to instantiate DataLoader every time to set new batch size.
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=train_dataset.collate_fn,
                                      pin_memory=opt.preload)

        cum_max_steps += max_steps

        # Loops over epochs.
        while True:
            for model_input, ground_truth in train_dataloader:
                model_outputs = model(model_input)

                optimizer.zero_grad()

                dist_loss = model.get_image_loss(model_outputs, ground_truth)
                reg_loss = model.get_regularization_loss(
                    model_outputs, ground_truth)
                latent_loss = model.get_latent_loss()

                weighted_dist_loss = opt.l1_weight * dist_loss
                weighted_reg_loss = opt.reg_weight * reg_loss
                weighted_latent_loss = opt.kl_weight * latent_loss

                total_loss = (weighted_dist_loss + weighted_reg_loss +
                              weighted_latent_loss)

                total_loss.backward()

                optimizer.step()

                print(
                    "Iter %07d   Epoch %03d   L_img %0.4f   L_latent %0.4f   L_depth %0.4f"
                    % (iter, epoch, weighted_dist_loss, weighted_latent_loss,
                       weighted_reg_loss))

                model.write_updates(writer, model_outputs, ground_truth, iter)
                writer.add_scalar("scaled_distortion_loss", weighted_dist_loss,
                                  iter)
                writer.add_scalar("scaled_regularization_loss",
                                  weighted_reg_loss, iter)
                writer.add_scalar("scaled_latent_loss", weighted_latent_loss,
                                  iter)
                writer.add_scalar("total_loss", total_loss, iter)

                if iter % opt.steps_til_val == 0 and not opt.no_validation:
                    print("Running validation set...")

                    model.eval()
                    with torch.no_grad():
                        psnrs = []
                        ssims = []
                        dist_losses = []
                        for model_input, ground_truth in val_dataloader:
                            model_outputs = model(model_input)

                            dist_loss = model.get_image_loss(
                                model_outputs, ground_truth).cpu().numpy()
                            psnr, ssim = model.get_psnr(
                                model_outputs, ground_truth)
                            psnrs.append(psnr)
                            ssims.append(ssim)
                            dist_losses.append(dist_loss)

                            model.write_updates(writer,
                                                model_outputs,
                                                ground_truth,
                                                iter,
                                                prefix='val_')

                        writer.add_scalar("val_dist_loss",
                                          np.mean(dist_losses), iter)
                        writer.add_scalar("val_psnr", np.mean(psnrs), iter)
                        writer.add_scalar("val_ssim", np.mean(ssims), iter)
                    model.train()

                iter += 1
                step += 1

                if iter == cum_max_steps:
                    break

                if iter % opt.steps_til_ckpt == 0:
                    util.custom_save(model,
                                     os.path.join(
                                         ckpt_dir, 'epoch_%04d_iter_%06d.pth' %
                                         (epoch, iter)),
                                     discriminator=None,
                                     optimizer=optimizer)

            if iter == cum_max_steps:
                break
            epoch += 1

    util.custom_save(model,
                     os.path.join(ckpt_dir,
                                  'epoch_%04d_iter_%06d.pth' % (epoch, iter)),
                     discriminator=None,
                     optimizer=optimizer)
def test():
    if opt.specific_observation_idcs is not None:
        specific_observation_idcs = list(
            map(int, opt.specific_observation_idcs.split(',')))
    else:
        specific_observation_idcs = None

    dataset = dataio.SceneClassDataset(
        root_dir=opt.data_root,
        max_num_instances=opt.max_num_instances,
        specific_observation_idcs=specific_observation_idcs,
        max_observations_per_instance=-1,
        samples_per_instance=1,
        img_sidelength=opt.img_sidelength)
    dataset = DataLoader(dataset,
                         collate_fn=dataset.collate_fn,
                         batch_size=1,
                         shuffle=False,
                         drop_last=False)

    model = SRNsModel(num_instances=opt.num_instances,
                      latent_dim=opt.embedding_size,
                      has_params=opt.has_params,
                      fit_single_srn=opt.fit_single_srn,
                      use_unet_renderer=opt.use_unet_renderer,
                      tracing_steps=opt.tracing_steps)

    assert (opt.checkpoint_path is not None), "Have to pass checkpoint!"

    print("Loading model from %s" % opt.checkpoint_path)
    util.custom_load(model,
                     path=opt.checkpoint_path,
                     discriminator=None,
                     overwrite_embeddings=False)

    model.eval()
    model.cuda()

    # directory structure: month_day/
    renderings_dir = os.path.join(opt.logging_root, 'renderings')
    gt_comparison_dir = os.path.join(opt.logging_root, 'gt_comparisons')
    util.cond_mkdir(opt.logging_root)
    util.cond_mkdir(gt_comparison_dir)
    util.cond_mkdir(renderings_dir)

    # Save command-line parameters to log directory.
    with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    print('Beginning evaluation...')
    with torch.no_grad():
        instance_idx = 0
        idx = 0
        psnrs, ssims = list(), list()
        for model_input, ground_truth in dataset:
            model_outputs = model(model_input)
            psnr, ssim = model.get_psnr(model_outputs, ground_truth)

            psnrs.extend(psnr)
            ssims.extend(ssim)

            instance_idcs = model_input['instance_idx']
            print("Object instance %d. Running mean PSNR %0.6f SSIM %0.6f" %
                  (instance_idcs[-1], np.mean(psnrs), np.mean(ssims)))

            if instance_idx < opt.save_out_first_n:
                output_imgs = model.get_output_img(model_outputs).cpu().numpy()
                comparisons = model.get_comparisons(model_input, model_outputs,
                                                    ground_truth)
                for i in range(len(output_imgs)):
                    prev_instance_idx = instance_idx
                    instance_idx = instance_idcs[i]

                    if prev_instance_idx != instance_idx:
                        idx = 0

                    img_only_path = os.path.join(renderings_dir,
                                                 "%06d" % instance_idx)
                    comp_path = os.path.join(gt_comparison_dir,
                                             "%06d" % instance_idx)

                    util.cond_mkdir(img_only_path)
                    util.cond_mkdir(comp_path)

                    pred = util.convert_image(output_imgs[i].squeeze())
                    comp = util.convert_image(comparisons[i].squeeze())

                    util.write_img(
                        pred, os.path.join(img_only_path, "%06d.png" % idx))
                    util.write_img(comp,
                                   os.path.join(comp_path, "%06d.png" % idx))

                    idx += 1

    with open(os.path.join(opt.logging_root, "results.txt"), "w") as out_file:
        out_file.write("%0.6f, %0.6f" % (np.mean(psnrs), np.mean(ssims)))

    print("Final mean PSNR %0.6f SSIM %0.6f" %
          (np.mean(psnrs), np.mean(ssims)))
Ejemplo n.º 3
0
def test():
    # Create the training dataset loader
    dataset = TestDataset(pose_dir=os.path.join(opt.data_root, 'pose'))

    util.custom_load(model, opt.checkpoint)
    model.eval()

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4)

    dir_name = os.path.join(
        datetime.datetime.now().strftime('%m_%d'),
        datetime.datetime.now().strftime('%H-%M-%S_') +
        '_'.join(opt.checkpoint.strip('/').split('/')[-2:]) + '_' +
        opt.data_root.strip('/').split('/')[-1])

    traj_dir = os.path.join(opt.logging_root, 'test_traj', dir_name)
    depth_dir = os.path.join(traj_dir, 'depth')

    data_util.cond_mkdir(traj_dir)
    data_util.cond_mkdir(depth_dir)

    forward_time = 0.

    print('starting testing...')
    with torch.no_grad():
        iter = 0
        depth_imgs = []
        for trgt_pose in dataloader:
            trgt_pose = trgt_pose.squeeze().to(device)

            start = time.time()
            # compute projection mapping
            proj_mapping = projection.compute_proj_idcs(
                trgt_pose.squeeze(), grid_origin)
            if proj_mapping is None:  # invalid sample
                print('(invalid sample)')
                continue

            proj_ind_3d, proj_ind_2d = proj_mapping

            # Run through model
            output, depth_maps, = model(None, [proj_ind_3d], [proj_ind_2d],
                                        None, None, None)
            end = time.time()
            forward_time += end - start

            output[0] = output[0][:, :, 5:-5, 5:-5]
            print("Iter %d" % iter)

            output_img = np.array(output[0].squeeze().cpu().detach().numpy())
            output_img = output_img.transpose(1, 2, 0)
            output_img += 0.5
            output_img *= 2**16 - 1
            output_img = output_img.round().clip(0, 2**16 - 1)

            depth_img = depth_maps[0].squeeze(0).cpu().detach().numpy()
            depth_img = depth_img.transpose(1, 2, 0)
            depth_imgs.append(depth_img)

            cv2.imwrite(os.path.join(traj_dir, "img_%05d.png" % iter),
                        output_img.astype(np.uint16)[:, :, ::-1])

            iter += 1

        depth_imgs = np.stack(depth_imgs, axis=0)
        depth_imgs = (depth_imgs - np.amin(depth_imgs)) / (
            np.amax(depth_imgs) - np.amin(depth_imgs))
        depth_imgs *= 2**16 - 1
        depth_imgs = depth_imgs.round()

        for i in range(len(depth_imgs)):
            cv2.imwrite(os.path.join(depth_dir, "img_%05d.png" % i),
                        depth_imgs[i].astype(np.uint16))

    print("Average forward pass time over %d examples is %f" %
          (iter, forward_time / iter))
def train():
    # Parses indices of specific observations from comma-separated list.
    if opt.specific_observation_idcs is not None:
        specific_observation_idcs = util.parse_comma_separated_integers(
            opt.specific_observation_idcs)
    else:
        specific_observation_idcs = None

    img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths)
    batch_size_per_sidelength = util.parse_comma_separated_integers(
        opt.batch_size_per_img_sidelength)
    max_steps_per_sidelength = util.parse_comma_separated_integers(
        opt.max_steps_per_img_sidelength)

    train_dataset = dataio.PBWDataset(train=True)

    assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \
        "Different number of image sidelengths passed than batch sizes."
    assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \
        "Different number of image sidelengths passed than max steps."

    if not opt.no_validation:
        assert (opt.val_root is not None), "No validation directory passed."

        val_dataset = dataio.PBWDataset(train=False)
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=16,
                                    shuffle=False,
                                    drop_last=True,
                                    collate_fn=val_dataset.collate_fn)

    model = SRNsModel3(latent_dim=opt.embedding_size,
                       has_params=opt.has_params,
                       fit_single_srn=True,
                       tracing_steps=opt.tracing_steps,
                       freeze_networks=opt.freeze_networks)
    model.train()
    model.cuda()
    if opt.checkpoint_path is not None:
        print("Loading model from %s" % opt.checkpoint_path)
        util.custom_load(model,
                         path=opt.checkpoint_path,
                         discriminator=None,
                         optimizer=None,
                         overwrite_embeddings=opt.overwrite_embeddings)

    ckpt_dir = os.path.join(opt.logging_root, 'checkpoints')
    events_dir = os.path.join(opt.logging_root, 'events')

    util.cond_mkdir(opt.logging_root)
    util.cond_mkdir(ckpt_dir)
    util.cond_mkdir(events_dir)

    # Save command-line parameters log directory.
    with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    # Save text summary of model into log directory.
    with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file:
        out_file.write(str(model))

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    writer = SummaryWriter(events_dir)
    iter = opt.start_step
    epoch = iter // len(train_dataset)
    step = 0

    print('Beginning training...')
    # This loop implements training with an increasing image sidelength.
    cum_max_steps = 0  # Tracks max_steps cumulatively over all image sidelengths.
    for img_sidelength, max_steps, batch_size in zip(
            img_sidelengths, max_steps_per_sidelength,
            batch_size_per_sidelength):
        print("\n" + "#" * 10)
        print("Training with sidelength %d for %d steps with batch size %d" %
              (img_sidelength, max_steps, batch_size))
        print("#" * 10 + "\n")

        # Need to instantiate DataLoader every time to set new batch size.
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=train_dataset.collate_fn,
        )

        cum_max_steps += max_steps

        # Loops over epochs.
        while True:
            for batch in train_dataloader:
                rgb, ext_mat, info, rgb_mat = batch
                ground_truth = {"rgb": rgb}
                model_input = (ext_mat, rgb_mat, info
                               )  # color, pix coord, location, box
                model_outputs = model(model_input)
                optimizer.zero_grad()

                total_loss = model.get_image_loss(model_outputs, ground_truth)
                total_loss.backward()

                optimizer.step()
                if iter % 100 == 0:
                    print("Iter %07d   Epoch %03d   L_img %0.4f" %
                          (iter, epoch, total_loss))

                if iter % opt.steps_til_val == 0 and not opt.no_validation:
                    print("Running validation set...")
                    acc = test(model, val_dataloader, str(iter))
                    print("Accuracy:", acc)

                iter += 1
                step += 1

                if iter == cum_max_steps:
                    break

            if iter == cum_max_steps:
                break
            epoch += 1

    util.custom_save(model,
                     os.path.join(ckpt_dir,
                                  'epoch_%04d_iter_%06d.pth' % (epoch, iter)),
                     discriminator=None,
                     optimizer=optimizer)
Ejemplo n.º 5
0
def train():
    discriminator.train()
    model.train()

    if opt.checkpoint:
        util.custom_load(model, opt.checkpoint, discriminator)

    # Create the training dataset loader
    train_dataset = NovelViewTriplets(root_dir=opt.data_root,
                                      img_size=input_image_dims,
                                      sampling_pattern=opt.sampling_pattern)
    dataloader = DataLoader(train_dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=8)

    # directory name contains some info about hyperparameters.
    dir_name = os.path.join(
        datetime.datetime.now().strftime('%m_%d'),
        datetime.datetime.now().strftime('%H-%M-%S_') +
        (opt.sampling_pattern + '_') + ('%0.2f_l1_weight_' % opt.l1_weight) +
        ('%d_trgt_' % opt.num_trgt) + '_' +
        opt.data_root.strip('/').split('/')[-1] + opt.experiment_name)

    log_dir = os.path.join(opt.logging_root, 'logs', dir_name)
    run_dir = os.path.join(opt.logging_root, 'runs', dir_name)

    data_util.cond_mkdir(log_dir)
    data_util.cond_mkdir(run_dir)

    # Save all command line arguments into a txt file in the logging directory for later referene.
    with open(os.path.join(log_dir, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    writer = SummaryWriter(run_dir)

    iter = opt.start_epoch * len(train_dataset)

    print('Begin training...')
    for epoch in range(opt.start_epoch, opt.max_epoch):
        for trgt_views, nearest_view in dataloader:
            backproj_mapping = projection.comp_lifting_idcs(
                camera_to_world=nearest_view['pose'].squeeze().to(device),
                grid2world=grid_origin)

            proj_mappings = list()
            for i in range(len(trgt_views)):
                proj_mappings.append(
                    projection.compute_proj_idcs(
                        trgt_views[i]['pose'].squeeze().to(device),
                        grid2world=grid_origin))

            if backproj_mapping is None:
                print("Lifting invalid")
                continue
            else:
                lift_volume_idcs, lift_img_coords = backproj_mapping

            if None in proj_mappings:
                print('Projection invalid')
                continue

            proj_frustrum_idcs, proj_grid_coords = list(zip(*proj_mappings))

            outputs, depth_maps = model(nearest_view['gt_rgb'].to(device),
                                        proj_frustrum_idcs,
                                        proj_grid_coords,
                                        lift_volume_idcs,
                                        lift_img_coords,
                                        writer=writer)

            # Convert the depth maps to metric
            for i in range(len(depth_maps)):
                depth_maps[i] = (
                    (depth_maps[i] + 0.5) *
                    int(np.ceil(np.sqrt(3) * grid_dims[-1])) * voxel_size +
                    near_plane)

            # We don't enforce a loss on the outermost 5 pixels to alleviate boundary errors
            for i in range(len(trgt_views)):
                outputs[i] = outputs[i][:, :, 5:-5, 5:-5]
                trgt_views[i]['gt_rgb'] = trgt_views[i]['gt_rgb'][:, :, 5:-5,
                                                                  5:-5]

            l1_losses = list()
            for idx in range(len(trgt_views)):
                l1_losses.append(
                    criterionL1(
                        outputs[idx].contiguous().view(-1).float(),
                        trgt_views[idx]['gt_rgb'].to(device).view(-1).float()))

            losses_d = []
            losses_g = []

            optimizerD.zero_grad()
            optimizerG.zero_grad()

            for idx in range(len(trgt_views)):
                #######
                ## Train Discriminator
                #######
                out_perm = outputs[idx]  # batch, ndf, height, width

                # Fake forward step
                pred_fake = discriminator.forward(out_perm.detach(
                ))  # Detach to make sure no gradients go into generator
                loss_d_fake = criterionGAN(pred_fake, False)

                # Real forward step
                real_input = trgt_views[idx]['gt_rgb'].float().to(device)
                pred_real = discriminator.forward(real_input)
                loss_d_real = criterionGAN(pred_real, True)

                # Combined Loss
                losses_d.append((loss_d_fake + loss_d_real) * 0.5)

                #######
                ## Train generator
                #######
                # Try to fake discriminator
                pred_fake = discriminator.forward(out_perm)
                loss_g_gan = criterionGAN(pred_fake, True)

                loss_g_l1 = l1_losses[idx] * opt.l1_weight
                losses_g.append(loss_g_gan + loss_g_l1)

            loss_d = torch.stack(losses_d, dim=0).mean()
            loss_g = torch.stack(losses_g, dim=0).mean()

            loss_d.backward()
            optimizerD.step()
            loss_g.backward()
            optimizerG.step()

            print(
                "Iter %07d   Epoch %03d   loss_gen %0.4f   loss_discrim %0.4f"
                % (iter, epoch, loss_g, loss_d))

            if not iter % 100:
                # Write tensorboard logs.
                writer.add_image(
                    "Depth",
                    torchvision.utils.make_grid(
                        [
                            depth_map.squeeze(dim=0).repeat(3, 1, 1)
                            for depth_map in depth_maps
                        ],
                        scale_each=True,
                        normalize=True).cpu().detach().numpy(), iter)
                writer.add_image(
                    "Nearest_neighbors_rgb",
                    torchvision.utils.make_grid(
                        nearest_view['gt_rgb'],
                        scale_each=True,
                        normalize=True).detach().numpy(), iter)
                output_vs_gt = torch.cat(
                    (torch.cat(outputs, dim=0),
                     torch.cat([i['gt_rgb'].to(device) for i in trgt_views],
                               dim=0)),
                    dim=0)
                writer.add_image(
                    "Output_vs_gt",
                    torchvision.utils.make_grid(
                        output_vs_gt, scale_each=True,
                        normalize=True).cpu().detach().numpy(), iter)

            writer.add_scalar("out_min", outputs[0].min(), iter)
            writer.add_scalar("out_max", outputs[0].max(), iter)

            writer.add_scalar("trgt_min", trgt_views[0]['gt_rgb'].min(), iter)
            writer.add_scalar("trgt_max", trgt_views[0]['gt_rgb'].max(), iter)

            writer.add_scalar("discrim_loss", loss_d, iter)
            writer.add_scalar("gen_loss_total", loss_g, iter)
            writer.add_scalar("gen_loss_l1", loss_g_l1, iter)
            writer.add_scalar("gen_loss_g", loss_g_gan, iter)

            iter += 1

            if iter % 10000 == 0:
                util.custom_save(
                    model,
                    os.path.join(log_dir,
                                 'model-epoch_%d_iter_%s.pth' % (epoch, iter)),
                    discriminator)

    util.custom_save(
        model,
        os.path.join(log_dir, 'model-epoch_%d_iter_%s.pth' % (epoch, iter)),
        discriminator)
Ejemplo n.º 6
0
                                  use_gcn=False)

# interpolater
interpolater = network.Interpolater()

# L1 loss
criterionL1 = nn.L1Loss(reduction='mean').to(device)

# Optimizer
optimizerG = torch.optim.Adam(list(texture_mapper.parameters()) +
                              list(render_net.parameters()),
                              lr=opt.lr)

# load checkpoint
if opt.checkpoint:
    util.custom_load([texture_mapper, render_net],
                     ['texture_mapper', 'render_net'], opt.checkpoint)

# move to device
texture_mapper.to(device)
render_net.to(device)
interpolater.to(device)

# get module
texture_mapper_module = texture_mapper
render_net_module = render_net

# use multi-GPU
if opt.gpu_id != '':
    texture_mapper = nn.DataParallel(texture_mapper)
    render_net = nn.DataParallel(render_net)
    interpolater = nn.DataParallel(interpolater)
Ejemplo n.º 7
0
    input_dims=input_shape,
    hidden_dim=args.hidden_dim,
    num_slots=args.num_slots,
    encoder=args.encoder,
    cnn_size=args.cnn_size,
    trans_model=args.trans_model,
    decoder=args.decoder,
    identity_action=args.identity_action,
    residual=args.residual,
    canonical=args.canonical_rep)
model.to(device)
print('Number of parameters in model', util.count_params(model))

if args.checkpoint_path is not None:
    print("Loading model from %s" % args.checkpoint_path)
    util.custom_load(model, path=args.checkpoint_path)
else:
    print("Initialising random weights")
    model.apply(util.weights_init)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=args.learning_rate)

# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.4)

now = datetime.datetime.now()
timestamp = now.isoformat()

if args.name == 'none':
    exp_name = timestamp
Ejemplo n.º 8
0
num_vertex = mesh.num_vertex

# interpolater
interpolater = network.Interpolater()

# texture mapper
texture_mapper = network.TextureMapper(texture_size=opt.texture_size,
                                       texture_num_ch=opt.texture_num_ch,
                                       mipmap_level=opt.mipmap_level,
                                       texture_init=None,
                                       fix_texture=True,
                                       apply_sh=opt.apply_sh)

# load checkpoint
checkpoint_dict = util.custom_load([texture_mapper], ['texture_mapper'],
                                   checkpoint_fp,
                                   strict=True)

# trained lighting model
new_state_dict = checkpoint_dict['lighting_model']
lighting_model_train = network.LightingSH(l_dir,
                                          lmax=int(params['sh_lmax']),
                                          num_lighting=2,
                                          num_channel=num_channel,
                                          fix_params=True)
lighting_model_train.coeff.data = new_state_dict['coeff']
lighting_model_train.l_samples.data = new_state_dict['l_samples']

# lighting model lp
lighting_model_lp = network.LightingLP(l_dir,
                                       num_channel=num_channel,
Ejemplo n.º 9
0
def test():

    test_dataset = dataio.TwoViewsDataset(
        data_dir=args.test_dir,
        num_pairs_per_scene=args.test_pairs_per_scene,
        num_scenes=args.num_test_scenes,
        sidelength=args.sidelength)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)

    print(f'Size of test dataset {len(test_dataset)}')

    obs = test_loader.__iter__().next()
    data_util.show_batch_pairs(obs)
    input_shape = obs['image1'].size()[1:]

    # Load training params
    with open(args.train_log_dir + '/params.txt', 'r') as f:
        train_params = yaml.safe_load(f)

    model = nod.NodModel(embedding_dim=train_params['embedding_dim'],
                         input_dims=input_shape,
                         hidden_dim=train_params['hidden_dim'],
                         num_slots=train_params['num_slots'],
                         encoder=train_params['encoder'],
                         decoder=train_params['decoder'])

    print("Loading model from %s" % args.checkpoint_path)
    util.custom_load(model, path=args.checkpoint_path)
    print("Evaluation to be saved to %s" % args.results_dir)

    model.to(device)
    model.eval()

    gt_comparison_dir = os.path.join(args.results_dir, 'gt_comparisons')
    sv_comps_dir = os.path.join(args.results_dir, 'components_same_view')
    dv_comps_dir = os.path.join(args.results_dir, 'components_diff_view')
    util.cond_mkdir(args.results_dir)
    util.cond_mkdir(gt_comparison_dir)
    util.cond_mkdir(sv_comps_dir)
    util.cond_mkdir(dv_comps_dir)

    # Save command-line parameters to log directory.
    with open(os.path.join(args.results_dir, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(args).items()]))

    l2_loss = nn.MSELoss(reduction="mean")

    print('Beginning evaluation...')
    with torch.no_grad():
        same_view_losses = []
        diff_view_losses = []
        total_losses = []
        for batch_idx, data_batch in enumerate(test_loader):
            img1, img2 = data_batch['image1'].to(
                device), data_batch['image2'].to(device)
            batch_size = img1.shape[0]
            imgs = torch.cat((img1, img2), dim=0)
            w, h = imgs.size(-2), imgs.size(-1)
            images_gt = torch.cat((img1.unsqueeze(1), img2.unsqueeze(1)),
                                  dim=1)

            action1, action2 = data_batch['transf21'].to(
                device), data_batch['transf12'].to(device)
            actions = torch.cat((action1, action2), dim=0)

            out = model(imgs, actions)
            masks, masked_comps, recs = model.compose_image(out)

            rec_views = recs[:batch_size * 2]
            novel_views = recs[batch_size * 2:]

            same_view_loss = l2_loss(rec_views, imgs)
            novel_view_loss = l2_loss(novel_views, imgs)
            total_loss = same_view_loss + novel_view_loss
            same_view_losses.append(same_view_loss.item())
            diff_view_losses.append(novel_view_loss.item())
            total_losses.append(total_loss.item())
            print(
                f"Number input images {batch_idx * args.batch_size}  |  Running l2 loss: {np.mean(total_losses)}"
            )

            break

            if batch_idx * args.batch_size < args.save_out_first_n:

                rec_views = rec_views.reshape(2, args.batch_size, 3, w,
                                              h).transpose(0, 1)
                novel_views = novel_views.reshape(2, args.batch_size, 3, w,
                                                  h).transpose(0, 1)
                same_view_masked_comps = masked_comps[:args.batch_size *
                                                      2].reshape(
                                                          2, args.batch_size,
                                                          model.num_slots, 3,
                                                          w,
                                                          h).transpose(0, 1)
                diff_view_masked_comps = masked_comps[args.batch_size *
                                                      2:].reshape(
                                                          2, args.batch_size,
                                                          model.num_slots, 3,
                                                          w,
                                                          h).transpose(0, 1)
                same_view_masks = masks[args.batch_size * 2:].reshape(
                    2, args.batch_size, model.num_slots, w, h).transpose(0, 1)
                diff_view_masks = masks[args.batch_size * 2:].reshape(
                    2, args.batch_size, model.num_slots, w, h).transpose(0, 1)
                # Expand to have 3 channels so can concat with rgb images
                same_view_masks = same_view_masks.unsqueeze(3).repeat(
                    1, 1, 1, 3, 1, 1)
                diff_view_masks = diff_view_masks.unsqueeze(3).repeat(
                    1, 1, 1, 3, 1, 1)
                # Shift to be in range [-1, 1] like rgb
                same_view_masks = same_view_masks * 2 - 1
                diff_view_masks = diff_view_masks * 2 - 1

                for i in range(args.batch_size):
                    gt = images_gt[i]
                    same_view_rec = rec_views[i]
                    diff_view_rec = novel_views[i]

                    # Save ground truth reconstruction comparison
                    gt_vs_rec_vs_nv = torch.cat(
                        (gt, same_view_rec, diff_view_rec), dim=0)
                    gt_comparison_imgs = torchvision.utils.make_grid(
                        gt_vs_rec_vs_nv,
                        nrow=2,
                        scale_each=False,
                        normalize=True,
                        range=(-1, 1)).cpu().detach().numpy()
                    plt.imsave(
                        os.path.join(
                            gt_comparison_dir,
                            f'{i + batch_idx * args.batch_size:04d}.png'),
                        np.transpose(gt_comparison_imgs, (1, 2, 0)))

                    # Save components
                    sv_images = torch.cat(
                        (images_gt[i].unsqueeze(1), same_view_rec.unsqueeze(1),
                         same_view_masked_comps[i], same_view_masks[i]),
                        dim=1)
                    dv_images = torch.cat(
                        (images_gt[i].unsqueeze(1), diff_view_rec.unsqueeze(1),
                         diff_view_masked_comps[i], diff_view_masks[i]),
                        dim=1)

                    comps_same_view_images = torchvision.utils.make_grid(
                        sv_images.reshape(-1, 3, h, w),
                        nrow=2 * model.num_slots + 2,
                        scale_each=False,
                        normalize=True,
                        range=(-1, 1)).cpu().detach().numpy()
                    comps_diff_view_images = torchvision.utils.make_grid(
                        dv_images.reshape(-1, 3, h, w),
                        nrow=2 * model.num_slots + 2,
                        scale_each=False,
                        normalize=True,
                        range=(-1, 1)).cpu().detach().numpy()
                    plt.imsave(
                        os.path.join(
                            sv_comps_dir,
                            f'{i + batch_idx * args.batch_size:04d}.png'),
                        np.transpose(comps_same_view_images, (1, 2, 0)))
                    plt.imsave(
                        os.path.join(
                            dv_comps_dir,
                            f'{i + batch_idx * args.batch_size:04d}.png'),
                        np.transpose(comps_diff_view_images, (1, 2, 0)))

        save_circles(model, args.results_dir,
                     args.circle_source_img_path.split())

    with open(os.path.join(args.results_dir, "results.txt"), "w") as out_file:
        out_file.write("Evaluation Metric: score \n\n")
        out_file.write(
            f"Same view rec l2 loss: {np.mean(same_view_losses):10f} \n")
        out_file.write(
            f"Diff view rec l2 loss: {np.mean(diff_view_losses):10f} \n")
        out_file.write(f"Rec l2 loss: {np.mean(total_losses):10f} \n")

    print("\nFinal score: ")