def test():
    if opt.specific_observation_idcs is not None:
        specific_observation_idcs = list(
            map(int, opt.specific_observation_idcs.split(',')))
    else:
        specific_observation_idcs = None

    dataset = dataio.SceneClassDataset(
        root_dir=opt.data_root,
        max_num_instances=opt.max_num_instances,
        specific_observation_idcs=specific_observation_idcs,
        max_observations_per_instance=-1,
        samples_per_instance=1,
        img_sidelength=opt.img_sidelength)
    dataset = DataLoader(dataset,
                         collate_fn=dataset.collate_fn,
                         batch_size=1,
                         shuffle=False,
                         drop_last=False)

    model = SRNsModel(num_instances=opt.num_instances,
                      latent_dim=opt.embedding_size,
                      has_params=opt.has_params,
                      fit_single_srn=opt.fit_single_srn,
                      use_unet_renderer=opt.use_unet_renderer,
                      tracing_steps=opt.tracing_steps)

    assert (opt.checkpoint_path is not None), "Have to pass checkpoint!"

    print("Loading model from %s" % opt.checkpoint_path)
    util.custom_load(model,
                     path=opt.checkpoint_path,
                     discriminator=None,
                     overwrite_embeddings=False)

    model.eval()
    model.cuda()

    # directory structure: month_day/
    renderings_dir = os.path.join(opt.logging_root, 'renderings')
    gt_comparison_dir = os.path.join(opt.logging_root, 'gt_comparisons')
    util.cond_mkdir(opt.logging_root)
    util.cond_mkdir(gt_comparison_dir)
    util.cond_mkdir(renderings_dir)

    # Save command-line parameters to log directory.
    with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    print('Beginning evaluation...')
    with torch.no_grad():
        instance_idx = 0
        idx = 0
        psnrs, ssims = list(), list()
        for model_input, ground_truth in dataset:
            model_outputs = model(model_input)
            psnr, ssim = model.get_psnr(model_outputs, ground_truth)

            psnrs.extend(psnr)
            ssims.extend(ssim)

            instance_idcs = model_input['instance_idx']
            print("Object instance %d. Running mean PSNR %0.6f SSIM %0.6f" %
                  (instance_idcs[-1], np.mean(psnrs), np.mean(ssims)))

            if instance_idx < opt.save_out_first_n:
                output_imgs = model.get_output_img(model_outputs).cpu().numpy()
                comparisons = model.get_comparisons(model_input, model_outputs,
                                                    ground_truth)
                for i in range(len(output_imgs)):
                    prev_instance_idx = instance_idx
                    instance_idx = instance_idcs[i]

                    if prev_instance_idx != instance_idx:
                        idx = 0

                    img_only_path = os.path.join(renderings_dir,
                                                 "%06d" % instance_idx)
                    comp_path = os.path.join(gt_comparison_dir,
                                             "%06d" % instance_idx)

                    util.cond_mkdir(img_only_path)
                    util.cond_mkdir(comp_path)

                    pred = util.convert_image(output_imgs[i].squeeze())
                    comp = util.convert_image(comparisons[i].squeeze())

                    util.write_img(
                        pred, os.path.join(img_only_path, "%06d.png" % idx))
                    util.write_img(comp,
                                   os.path.join(comp_path, "%06d.png" % idx))

                    idx += 1

    with open(os.path.join(opt.logging_root, "results.txt"), "w") as out_file:
        out_file.write("%0.6f, %0.6f" % (np.mean(psnrs), np.mean(ssims)))

    print("Final mean PSNR %0.6f SSIM %0.6f" %
          (np.mean(psnrs), np.mean(ssims)))
Ejemplo n.º 2
0
def train():
    # Parses indices of specific observations from comma-separated list.
    if opt.specific_observation_idcs is not None:
        specific_observation_idcs = util.parse_comma_separated_integers(
            opt.specific_observation_idcs)
    else:
        specific_observation_idcs = None

    img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths)
    batch_size_per_sidelength = util.parse_comma_separated_integers(
        opt.batch_size_per_img_sidelength)
    max_steps_per_sidelength = util.parse_comma_separated_integers(
        opt.max_steps_per_img_sidelength)

    train_dataset = dataio.SceneClassDataset(
        root_dir=opt.data_root,
        max_num_instances=opt.max_num_instances_train,
        max_observations_per_instance=opt.max_num_observations_train,
        img_sidelength=img_sidelengths[0],
        specific_observation_idcs=specific_observation_idcs,
        samples_per_instance=1)

    assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \
        "Different number of image sidelengths passed than batch sizes."
    assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \
        "Different number of image sidelengths passed than max steps."

    if not opt.no_validation:
        assert (opt.val_root is not None), "No validation directory passed."

        val_dataset = dataio.SceneClassDataset(
            root_dir=opt.val_root,
            max_num_instances=opt.max_num_instances_val,
            max_observations_per_instance=opt.max_num_observations_val,
            img_sidelength=img_sidelengths[0],
            samples_per_instance=1)
        collate_fn = val_dataset.collate_fn
        val_dataloader = DataLoader(val_dataset,
                                    batch_size=2,
                                    shuffle=False,
                                    drop_last=True,
                                    collate_fn=val_dataset.collate_fn)

    model = SRNsModel(num_instances=train_dataset.num_instances,
                      latent_dim=opt.embedding_size,
                      has_params=opt.has_params,
                      fit_single_srn=opt.fit_single_srn,
                      use_unet_renderer=opt.use_unet_renderer,
                      tracing_steps=opt.tracing_steps,
                      freeze_networks=opt.freeze_networks)
    model.train()
    model.cuda()

    if opt.checkpoint_path is not None:
        print("Loading model from %s" % opt.checkpoint_path)
        util.custom_load(model,
                         path=opt.checkpoint_path,
                         discriminator=None,
                         optimizer=None,
                         overwrite_embeddings=opt.overwrite_embeddings)

    ckpt_dir = os.path.join(opt.logging_root, 'checkpoints')
    events_dir = os.path.join(opt.logging_root, 'events')

    util.cond_mkdir(opt.logging_root)
    util.cond_mkdir(ckpt_dir)
    util.cond_mkdir(events_dir)

    # Save command-line parameters log directory.
    with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file:
        out_file.write('\n'.join(
            ["%s: %s" % (key, value) for key, value in vars(opt).items()]))

    # Save text summary of model into log directory.
    with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file:
        out_file.write(str(model))

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    writer = SummaryWriter(events_dir)
    iter = opt.start_step
    epoch = iter // len(train_dataset)
    step = 0

    print('Beginning training...')
    # This loop implements training with an increasing image sidelength.
    cum_max_steps = 0  # Tracks max_steps cumulatively over all image sidelengths.
    for img_sidelength, max_steps, batch_size in zip(
            img_sidelengths, max_steps_per_sidelength,
            batch_size_per_sidelength):
        print("\n" + "#" * 10)
        print("Training with sidelength %d for %d steps with batch size %d" %
              (img_sidelength, max_steps, batch_size))
        print("#" * 10 + "\n")
        train_dataset.set_img_sidelength(img_sidelength)

        # Need to instantiate DataLoader every time to set new batch size.
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      collate_fn=train_dataset.collate_fn,
                                      pin_memory=opt.preload)

        cum_max_steps += max_steps

        # Loops over epochs.
        while True:
            for model_input, ground_truth in train_dataloader:
                model_outputs = model(model_input)

                optimizer.zero_grad()

                dist_loss = model.get_image_loss(model_outputs, ground_truth)
                reg_loss = model.get_regularization_loss(
                    model_outputs, ground_truth)
                latent_loss = model.get_latent_loss()

                weighted_dist_loss = opt.l1_weight * dist_loss
                weighted_reg_loss = opt.reg_weight * reg_loss
                weighted_latent_loss = opt.kl_weight * latent_loss

                total_loss = (weighted_dist_loss + weighted_reg_loss +
                              weighted_latent_loss)

                total_loss.backward()

                optimizer.step()

                print(
                    "Iter %07d   Epoch %03d   L_img %0.4f   L_latent %0.4f   L_depth %0.4f"
                    % (iter, epoch, weighted_dist_loss, weighted_latent_loss,
                       weighted_reg_loss))

                model.write_updates(writer, model_outputs, ground_truth, iter)
                writer.add_scalar("scaled_distortion_loss", weighted_dist_loss,
                                  iter)
                writer.add_scalar("scaled_regularization_loss",
                                  weighted_reg_loss, iter)
                writer.add_scalar("scaled_latent_loss", weighted_latent_loss,
                                  iter)
                writer.add_scalar("total_loss", total_loss, iter)

                if iter % opt.steps_til_val == 0 and not opt.no_validation:
                    print("Running validation set...")

                    model.eval()
                    with torch.no_grad():
                        psnrs = []
                        ssims = []
                        dist_losses = []
                        for model_input, ground_truth in val_dataloader:
                            model_outputs = model(model_input)

                            dist_loss = model.get_image_loss(
                                model_outputs, ground_truth).cpu().numpy()
                            psnr, ssim = model.get_psnr(
                                model_outputs, ground_truth)
                            psnrs.append(psnr)
                            ssims.append(ssim)
                            dist_losses.append(dist_loss)

                            model.write_updates(writer,
                                                model_outputs,
                                                ground_truth,
                                                iter,
                                                prefix='val_')

                        writer.add_scalar("val_dist_loss",
                                          np.mean(dist_losses), iter)
                        writer.add_scalar("val_psnr", np.mean(psnrs), iter)
                        writer.add_scalar("val_ssim", np.mean(ssims), iter)
                    model.train()

                iter += 1
                step += 1

                if iter == cum_max_steps:
                    break

                if iter % opt.steps_til_ckpt == 0:
                    util.custom_save(model,
                                     os.path.join(
                                         ckpt_dir, 'epoch_%04d_iter_%06d.pth' %
                                         (epoch, iter)),
                                     discriminator=None,
                                     optimizer=optimizer)

            if iter == cum_max_steps:
                break
            epoch += 1

    util.custom_save(model,
                     os.path.join(ckpt_dir,
                                  'epoch_%04d_iter_%06d.pth' % (epoch, iter)),
                     discriminator=None,
                     optimizer=optimizer)