def test(): if opt.specific_observation_idcs is not None: specific_observation_idcs = list( map(int, opt.specific_observation_idcs.split(','))) else: specific_observation_idcs = None dataset = dataio.SceneClassDataset( root_dir=opt.data_root, max_num_instances=opt.max_num_instances, specific_observation_idcs=specific_observation_idcs, max_observations_per_instance=-1, samples_per_instance=1, img_sidelength=opt.img_sidelength) dataset = DataLoader(dataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=False, drop_last=False) model = SRNsModel(num_instances=opt.num_instances, latent_dim=opt.embedding_size, has_params=opt.has_params, fit_single_srn=opt.fit_single_srn, use_unet_renderer=opt.use_unet_renderer, tracing_steps=opt.tracing_steps) assert (opt.checkpoint_path is not None), "Have to pass checkpoint!" print("Loading model from %s" % opt.checkpoint_path) util.custom_load(model, path=opt.checkpoint_path, discriminator=None, overwrite_embeddings=False) model.eval() model.cuda() # directory structure: month_day/ renderings_dir = os.path.join(opt.logging_root, 'renderings') gt_comparison_dir = os.path.join(opt.logging_root, 'gt_comparisons') util.cond_mkdir(opt.logging_root) util.cond_mkdir(gt_comparison_dir) util.cond_mkdir(renderings_dir) # Save command-line parameters to log directory. with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) print('Beginning evaluation...') with torch.no_grad(): instance_idx = 0 idx = 0 psnrs, ssims = list(), list() for model_input, ground_truth in dataset: model_outputs = model(model_input) psnr, ssim = model.get_psnr(model_outputs, ground_truth) psnrs.extend(psnr) ssims.extend(ssim) instance_idcs = model_input['instance_idx'] print("Object instance %d. Running mean PSNR %0.6f SSIM %0.6f" % (instance_idcs[-1], np.mean(psnrs), np.mean(ssims))) if instance_idx < opt.save_out_first_n: output_imgs = model.get_output_img(model_outputs).cpu().numpy() comparisons = model.get_comparisons(model_input, model_outputs, ground_truth) for i in range(len(output_imgs)): prev_instance_idx = instance_idx instance_idx = instance_idcs[i] if prev_instance_idx != instance_idx: idx = 0 img_only_path = os.path.join(renderings_dir, "%06d" % instance_idx) comp_path = os.path.join(gt_comparison_dir, "%06d" % instance_idx) util.cond_mkdir(img_only_path) util.cond_mkdir(comp_path) pred = util.convert_image(output_imgs[i].squeeze()) comp = util.convert_image(comparisons[i].squeeze()) util.write_img( pred, os.path.join(img_only_path, "%06d.png" % idx)) util.write_img(comp, os.path.join(comp_path, "%06d.png" % idx)) idx += 1 with open(os.path.join(opt.logging_root, "results.txt"), "w") as out_file: out_file.write("%0.6f, %0.6f" % (np.mean(psnrs), np.mean(ssims))) print("Final mean PSNR %0.6f SSIM %0.6f" % (np.mean(psnrs), np.mean(ssims)))
def train(): # Parses indices of specific observations from comma-separated list. if opt.specific_observation_idcs is not None: specific_observation_idcs = util.parse_comma_separated_integers( opt.specific_observation_idcs) else: specific_observation_idcs = None img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths) batch_size_per_sidelength = util.parse_comma_separated_integers( opt.batch_size_per_img_sidelength) max_steps_per_sidelength = util.parse_comma_separated_integers( opt.max_steps_per_img_sidelength) train_dataset = dataio.SceneClassDataset( root_dir=opt.data_root, max_num_instances=opt.max_num_instances_train, max_observations_per_instance=opt.max_num_observations_train, img_sidelength=img_sidelengths[0], specific_observation_idcs=specific_observation_idcs, samples_per_instance=1) assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \ "Different number of image sidelengths passed than batch sizes." assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \ "Different number of image sidelengths passed than max steps." if not opt.no_validation: assert (opt.val_root is not None), "No validation directory passed." val_dataset = dataio.SceneClassDataset( root_dir=opt.val_root, max_num_instances=opt.max_num_instances_val, max_observations_per_instance=opt.max_num_observations_val, img_sidelength=img_sidelengths[0], samples_per_instance=1) collate_fn = val_dataset.collate_fn val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False, drop_last=True, collate_fn=val_dataset.collate_fn) model = SRNsModel(num_instances=train_dataset.num_instances, latent_dim=opt.embedding_size, has_params=opt.has_params, fit_single_srn=opt.fit_single_srn, use_unet_renderer=opt.use_unet_renderer, tracing_steps=opt.tracing_steps, freeze_networks=opt.freeze_networks) model.train() model.cuda() if opt.checkpoint_path is not None: print("Loading model from %s" % opt.checkpoint_path) util.custom_load(model, path=opt.checkpoint_path, discriminator=None, optimizer=None, overwrite_embeddings=opt.overwrite_embeddings) ckpt_dir = os.path.join(opt.logging_root, 'checkpoints') events_dir = os.path.join(opt.logging_root, 'events') util.cond_mkdir(opt.logging_root) util.cond_mkdir(ckpt_dir) util.cond_mkdir(events_dir) # Save command-line parameters log directory. with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) # Save text summary of model into log directory. with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file: out_file.write(str(model)) optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) writer = SummaryWriter(events_dir) iter = opt.start_step epoch = iter // len(train_dataset) step = 0 print('Beginning training...') # This loop implements training with an increasing image sidelength. cum_max_steps = 0 # Tracks max_steps cumulatively over all image sidelengths. for img_sidelength, max_steps, batch_size in zip( img_sidelengths, max_steps_per_sidelength, batch_size_per_sidelength): print("\n" + "#" * 10) print("Training with sidelength %d for %d steps with batch size %d" % (img_sidelength, max_steps, batch_size)) print("#" * 10 + "\n") train_dataset.set_img_sidelength(img_sidelength) # Need to instantiate DataLoader every time to set new batch size. train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn, pin_memory=opt.preload) cum_max_steps += max_steps # Loops over epochs. while True: for model_input, ground_truth in train_dataloader: model_outputs = model(model_input) optimizer.zero_grad() dist_loss = model.get_image_loss(model_outputs, ground_truth) reg_loss = model.get_regularization_loss( model_outputs, ground_truth) latent_loss = model.get_latent_loss() weighted_dist_loss = opt.l1_weight * dist_loss weighted_reg_loss = opt.reg_weight * reg_loss weighted_latent_loss = opt.kl_weight * latent_loss total_loss = (weighted_dist_loss + weighted_reg_loss + weighted_latent_loss) total_loss.backward() optimizer.step() print( "Iter %07d Epoch %03d L_img %0.4f L_latent %0.4f L_depth %0.4f" % (iter, epoch, weighted_dist_loss, weighted_latent_loss, weighted_reg_loss)) model.write_updates(writer, model_outputs, ground_truth, iter) writer.add_scalar("scaled_distortion_loss", weighted_dist_loss, iter) writer.add_scalar("scaled_regularization_loss", weighted_reg_loss, iter) writer.add_scalar("scaled_latent_loss", weighted_latent_loss, iter) writer.add_scalar("total_loss", total_loss, iter) if iter % opt.steps_til_val == 0 and not opt.no_validation: print("Running validation set...") model.eval() with torch.no_grad(): psnrs = [] ssims = [] dist_losses = [] for model_input, ground_truth in val_dataloader: model_outputs = model(model_input) dist_loss = model.get_image_loss( model_outputs, ground_truth).cpu().numpy() psnr, ssim = model.get_psnr( model_outputs, ground_truth) psnrs.append(psnr) ssims.append(ssim) dist_losses.append(dist_loss) model.write_updates(writer, model_outputs, ground_truth, iter, prefix='val_') writer.add_scalar("val_dist_loss", np.mean(dist_losses), iter) writer.add_scalar("val_psnr", np.mean(psnrs), iter) writer.add_scalar("val_ssim", np.mean(ssims), iter) model.train() iter += 1 step += 1 if iter == cum_max_steps: break if iter % opt.steps_til_ckpt == 0: util.custom_save(model, os.path.join( ckpt_dir, 'epoch_%04d_iter_%06d.pth' % (epoch, iter)), discriminator=None, optimizer=optimizer) if iter == cum_max_steps: break epoch += 1 util.custom_save(model, os.path.join(ckpt_dir, 'epoch_%04d_iter_%06d.pth' % (epoch, iter)), discriminator=None, optimizer=optimizer)