def train(): # Parses indices of specific observations from comma-separated list. if opt.specific_observation_idcs is not None: specific_observation_idcs = util.parse_comma_separated_integers( opt.specific_observation_idcs) else: specific_observation_idcs = None img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths) batch_size_per_sidelength = util.parse_comma_separated_integers( opt.batch_size_per_img_sidelength) max_steps_per_sidelength = util.parse_comma_separated_integers( opt.max_steps_per_img_sidelength) train_dataset = dataio.SceneClassDataset( root_dir=opt.data_root, max_num_instances=opt.max_num_instances_train, max_observations_per_instance=opt.max_num_observations_train, img_sidelength=img_sidelengths[0], specific_observation_idcs=specific_observation_idcs, samples_per_instance=1) assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \ "Different number of image sidelengths passed than batch sizes." assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \ "Different number of image sidelengths passed than max steps." if not opt.no_validation: assert (opt.val_root is not None), "No validation directory passed." val_dataset = dataio.SceneClassDataset( root_dir=opt.val_root, max_num_instances=opt.max_num_instances_val, max_observations_per_instance=opt.max_num_observations_val, img_sidelength=img_sidelengths[0], samples_per_instance=1) collate_fn = val_dataset.collate_fn val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False, drop_last=True, collate_fn=val_dataset.collate_fn) model = SRNsModel(num_instances=train_dataset.num_instances, latent_dim=opt.embedding_size, has_params=opt.has_params, fit_single_srn=opt.fit_single_srn, use_unet_renderer=opt.use_unet_renderer, tracing_steps=opt.tracing_steps, freeze_networks=opt.freeze_networks) model.train() model.cuda() if opt.checkpoint_path is not None: print("Loading model from %s" % opt.checkpoint_path) util.custom_load(model, path=opt.checkpoint_path, discriminator=None, optimizer=None, overwrite_embeddings=opt.overwrite_embeddings) ckpt_dir = os.path.join(opt.logging_root, 'checkpoints') events_dir = os.path.join(opt.logging_root, 'events') util.cond_mkdir(opt.logging_root) util.cond_mkdir(ckpt_dir) util.cond_mkdir(events_dir) # Save command-line parameters log directory. with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) # Save text summary of model into log directory. with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file: out_file.write(str(model)) optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) writer = SummaryWriter(events_dir) iter = opt.start_step epoch = iter // len(train_dataset) step = 0 print('Beginning training...') # This loop implements training with an increasing image sidelength. cum_max_steps = 0 # Tracks max_steps cumulatively over all image sidelengths. for img_sidelength, max_steps, batch_size in zip( img_sidelengths, max_steps_per_sidelength, batch_size_per_sidelength): print("\n" + "#" * 10) print("Training with sidelength %d for %d steps with batch size %d" % (img_sidelength, max_steps, batch_size)) print("#" * 10 + "\n") train_dataset.set_img_sidelength(img_sidelength) # Need to instantiate DataLoader every time to set new batch size. train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn, pin_memory=opt.preload) cum_max_steps += max_steps # Loops over epochs. while True: for model_input, ground_truth in train_dataloader: model_outputs = model(model_input) optimizer.zero_grad() dist_loss = model.get_image_loss(model_outputs, ground_truth) reg_loss = model.get_regularization_loss( model_outputs, ground_truth) latent_loss = model.get_latent_loss() weighted_dist_loss = opt.l1_weight * dist_loss weighted_reg_loss = opt.reg_weight * reg_loss weighted_latent_loss = opt.kl_weight * latent_loss total_loss = (weighted_dist_loss + weighted_reg_loss + weighted_latent_loss) total_loss.backward() optimizer.step() print( "Iter %07d Epoch %03d L_img %0.4f L_latent %0.4f L_depth %0.4f" % (iter, epoch, weighted_dist_loss, weighted_latent_loss, weighted_reg_loss)) model.write_updates(writer, model_outputs, ground_truth, iter) writer.add_scalar("scaled_distortion_loss", weighted_dist_loss, iter) writer.add_scalar("scaled_regularization_loss", weighted_reg_loss, iter) writer.add_scalar("scaled_latent_loss", weighted_latent_loss, iter) writer.add_scalar("total_loss", total_loss, iter) if iter % opt.steps_til_val == 0 and not opt.no_validation: print("Running validation set...") model.eval() with torch.no_grad(): psnrs = [] ssims = [] dist_losses = [] for model_input, ground_truth in val_dataloader: model_outputs = model(model_input) dist_loss = model.get_image_loss( model_outputs, ground_truth).cpu().numpy() psnr, ssim = model.get_psnr( model_outputs, ground_truth) psnrs.append(psnr) ssims.append(ssim) dist_losses.append(dist_loss) model.write_updates(writer, model_outputs, ground_truth, iter, prefix='val_') writer.add_scalar("val_dist_loss", np.mean(dist_losses), iter) writer.add_scalar("val_psnr", np.mean(psnrs), iter) writer.add_scalar("val_ssim", np.mean(ssims), iter) model.train() iter += 1 step += 1 if iter == cum_max_steps: break if iter % opt.steps_til_ckpt == 0: util.custom_save(model, os.path.join( ckpt_dir, 'epoch_%04d_iter_%06d.pth' % (epoch, iter)), discriminator=None, optimizer=optimizer) if iter == cum_max_steps: break epoch += 1 util.custom_save(model, os.path.join(ckpt_dir, 'epoch_%04d_iter_%06d.pth' % (epoch, iter)), discriminator=None, optimizer=optimizer)
def test(): if opt.specific_observation_idcs is not None: specific_observation_idcs = list( map(int, opt.specific_observation_idcs.split(','))) else: specific_observation_idcs = None dataset = dataio.SceneClassDataset( root_dir=opt.data_root, max_num_instances=opt.max_num_instances, specific_observation_idcs=specific_observation_idcs, max_observations_per_instance=-1, samples_per_instance=1, img_sidelength=opt.img_sidelength) dataset = DataLoader(dataset, collate_fn=dataset.collate_fn, batch_size=1, shuffle=False, drop_last=False) model = SRNsModel(num_instances=opt.num_instances, latent_dim=opt.embedding_size, has_params=opt.has_params, fit_single_srn=opt.fit_single_srn, use_unet_renderer=opt.use_unet_renderer, tracing_steps=opt.tracing_steps) assert (opt.checkpoint_path is not None), "Have to pass checkpoint!" print("Loading model from %s" % opt.checkpoint_path) util.custom_load(model, path=opt.checkpoint_path, discriminator=None, overwrite_embeddings=False) model.eval() model.cuda() # directory structure: month_day/ renderings_dir = os.path.join(opt.logging_root, 'renderings') gt_comparison_dir = os.path.join(opt.logging_root, 'gt_comparisons') util.cond_mkdir(opt.logging_root) util.cond_mkdir(gt_comparison_dir) util.cond_mkdir(renderings_dir) # Save command-line parameters to log directory. with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) print('Beginning evaluation...') with torch.no_grad(): instance_idx = 0 idx = 0 psnrs, ssims = list(), list() for model_input, ground_truth in dataset: model_outputs = model(model_input) psnr, ssim = model.get_psnr(model_outputs, ground_truth) psnrs.extend(psnr) ssims.extend(ssim) instance_idcs = model_input['instance_idx'] print("Object instance %d. Running mean PSNR %0.6f SSIM %0.6f" % (instance_idcs[-1], np.mean(psnrs), np.mean(ssims))) if instance_idx < opt.save_out_first_n: output_imgs = model.get_output_img(model_outputs).cpu().numpy() comparisons = model.get_comparisons(model_input, model_outputs, ground_truth) for i in range(len(output_imgs)): prev_instance_idx = instance_idx instance_idx = instance_idcs[i] if prev_instance_idx != instance_idx: idx = 0 img_only_path = os.path.join(renderings_dir, "%06d" % instance_idx) comp_path = os.path.join(gt_comparison_dir, "%06d" % instance_idx) util.cond_mkdir(img_only_path) util.cond_mkdir(comp_path) pred = util.convert_image(output_imgs[i].squeeze()) comp = util.convert_image(comparisons[i].squeeze()) util.write_img( pred, os.path.join(img_only_path, "%06d.png" % idx)) util.write_img(comp, os.path.join(comp_path, "%06d.png" % idx)) idx += 1 with open(os.path.join(opt.logging_root, "results.txt"), "w") as out_file: out_file.write("%0.6f, %0.6f" % (np.mean(psnrs), np.mean(ssims))) print("Final mean PSNR %0.6f SSIM %0.6f" % (np.mean(psnrs), np.mean(ssims)))
def test(): # Create the training dataset loader dataset = TestDataset(pose_dir=os.path.join(opt.data_root, 'pose')) util.custom_load(model, opt.checkpoint) model.eval() dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) dir_name = os.path.join( datetime.datetime.now().strftime('%m_%d'), datetime.datetime.now().strftime('%H-%M-%S_') + '_'.join(opt.checkpoint.strip('/').split('/')[-2:]) + '_' + opt.data_root.strip('/').split('/')[-1]) traj_dir = os.path.join(opt.logging_root, 'test_traj', dir_name) depth_dir = os.path.join(traj_dir, 'depth') data_util.cond_mkdir(traj_dir) data_util.cond_mkdir(depth_dir) forward_time = 0. print('starting testing...') with torch.no_grad(): iter = 0 depth_imgs = [] for trgt_pose in dataloader: trgt_pose = trgt_pose.squeeze().to(device) start = time.time() # compute projection mapping proj_mapping = projection.compute_proj_idcs( trgt_pose.squeeze(), grid_origin) if proj_mapping is None: # invalid sample print('(invalid sample)') continue proj_ind_3d, proj_ind_2d = proj_mapping # Run through model output, depth_maps, = model(None, [proj_ind_3d], [proj_ind_2d], None, None, None) end = time.time() forward_time += end - start output[0] = output[0][:, :, 5:-5, 5:-5] print("Iter %d" % iter) output_img = np.array(output[0].squeeze().cpu().detach().numpy()) output_img = output_img.transpose(1, 2, 0) output_img += 0.5 output_img *= 2**16 - 1 output_img = output_img.round().clip(0, 2**16 - 1) depth_img = depth_maps[0].squeeze(0).cpu().detach().numpy() depth_img = depth_img.transpose(1, 2, 0) depth_imgs.append(depth_img) cv2.imwrite(os.path.join(traj_dir, "img_%05d.png" % iter), output_img.astype(np.uint16)[:, :, ::-1]) iter += 1 depth_imgs = np.stack(depth_imgs, axis=0) depth_imgs = (depth_imgs - np.amin(depth_imgs)) / ( np.amax(depth_imgs) - np.amin(depth_imgs)) depth_imgs *= 2**16 - 1 depth_imgs = depth_imgs.round() for i in range(len(depth_imgs)): cv2.imwrite(os.path.join(depth_dir, "img_%05d.png" % i), depth_imgs[i].astype(np.uint16)) print("Average forward pass time over %d examples is %f" % (iter, forward_time / iter))
def train(): # Parses indices of specific observations from comma-separated list. if opt.specific_observation_idcs is not None: specific_observation_idcs = util.parse_comma_separated_integers( opt.specific_observation_idcs) else: specific_observation_idcs = None img_sidelengths = util.parse_comma_separated_integers(opt.img_sidelengths) batch_size_per_sidelength = util.parse_comma_separated_integers( opt.batch_size_per_img_sidelength) max_steps_per_sidelength = util.parse_comma_separated_integers( opt.max_steps_per_img_sidelength) train_dataset = dataio.PBWDataset(train=True) assert (len(img_sidelengths) == len(batch_size_per_sidelength)), \ "Different number of image sidelengths passed than batch sizes." assert (len(img_sidelengths) == len(max_steps_per_sidelength)), \ "Different number of image sidelengths passed than max steps." if not opt.no_validation: assert (opt.val_root is not None), "No validation directory passed." val_dataset = dataio.PBWDataset(train=False) val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, drop_last=True, collate_fn=val_dataset.collate_fn) model = SRNsModel3(latent_dim=opt.embedding_size, has_params=opt.has_params, fit_single_srn=True, tracing_steps=opt.tracing_steps, freeze_networks=opt.freeze_networks) model.train() model.cuda() if opt.checkpoint_path is not None: print("Loading model from %s" % opt.checkpoint_path) util.custom_load(model, path=opt.checkpoint_path, discriminator=None, optimizer=None, overwrite_embeddings=opt.overwrite_embeddings) ckpt_dir = os.path.join(opt.logging_root, 'checkpoints') events_dir = os.path.join(opt.logging_root, 'events') util.cond_mkdir(opt.logging_root) util.cond_mkdir(ckpt_dir) util.cond_mkdir(events_dir) # Save command-line parameters log directory. with open(os.path.join(opt.logging_root, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) # Save text summary of model into log directory. with open(os.path.join(opt.logging_root, "model.txt"), "w") as out_file: out_file.write(str(model)) optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) writer = SummaryWriter(events_dir) iter = opt.start_step epoch = iter // len(train_dataset) step = 0 print('Beginning training...') # This loop implements training with an increasing image sidelength. cum_max_steps = 0 # Tracks max_steps cumulatively over all image sidelengths. for img_sidelength, max_steps, batch_size in zip( img_sidelengths, max_steps_per_sidelength, batch_size_per_sidelength): print("\n" + "#" * 10) print("Training with sidelength %d for %d steps with batch size %d" % (img_sidelength, max_steps, batch_size)) print("#" * 10 + "\n") # Need to instantiate DataLoader every time to set new batch size. train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=train_dataset.collate_fn, ) cum_max_steps += max_steps # Loops over epochs. while True: for batch in train_dataloader: rgb, ext_mat, info, rgb_mat = batch ground_truth = {"rgb": rgb} model_input = (ext_mat, rgb_mat, info ) # color, pix coord, location, box model_outputs = model(model_input) optimizer.zero_grad() total_loss = model.get_image_loss(model_outputs, ground_truth) total_loss.backward() optimizer.step() if iter % 100 == 0: print("Iter %07d Epoch %03d L_img %0.4f" % (iter, epoch, total_loss)) if iter % opt.steps_til_val == 0 and not opt.no_validation: print("Running validation set...") acc = test(model, val_dataloader, str(iter)) print("Accuracy:", acc) iter += 1 step += 1 if iter == cum_max_steps: break if iter == cum_max_steps: break epoch += 1 util.custom_save(model, os.path.join(ckpt_dir, 'epoch_%04d_iter_%06d.pth' % (epoch, iter)), discriminator=None, optimizer=optimizer)
def train(): discriminator.train() model.train() if opt.checkpoint: util.custom_load(model, opt.checkpoint, discriminator) # Create the training dataset loader train_dataset = NovelViewTriplets(root_dir=opt.data_root, img_size=input_image_dims, sampling_pattern=opt.sampling_pattern) dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=8) # directory name contains some info about hyperparameters. dir_name = os.path.join( datetime.datetime.now().strftime('%m_%d'), datetime.datetime.now().strftime('%H-%M-%S_') + (opt.sampling_pattern + '_') + ('%0.2f_l1_weight_' % opt.l1_weight) + ('%d_trgt_' % opt.num_trgt) + '_' + opt.data_root.strip('/').split('/')[-1] + opt.experiment_name) log_dir = os.path.join(opt.logging_root, 'logs', dir_name) run_dir = os.path.join(opt.logging_root, 'runs', dir_name) data_util.cond_mkdir(log_dir) data_util.cond_mkdir(run_dir) # Save all command line arguments into a txt file in the logging directory for later referene. with open(os.path.join(log_dir, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) writer = SummaryWriter(run_dir) iter = opt.start_epoch * len(train_dataset) print('Begin training...') for epoch in range(opt.start_epoch, opt.max_epoch): for trgt_views, nearest_view in dataloader: backproj_mapping = projection.comp_lifting_idcs( camera_to_world=nearest_view['pose'].squeeze().to(device), grid2world=grid_origin) proj_mappings = list() for i in range(len(trgt_views)): proj_mappings.append( projection.compute_proj_idcs( trgt_views[i]['pose'].squeeze().to(device), grid2world=grid_origin)) if backproj_mapping is None: print("Lifting invalid") continue else: lift_volume_idcs, lift_img_coords = backproj_mapping if None in proj_mappings: print('Projection invalid') continue proj_frustrum_idcs, proj_grid_coords = list(zip(*proj_mappings)) outputs, depth_maps = model(nearest_view['gt_rgb'].to(device), proj_frustrum_idcs, proj_grid_coords, lift_volume_idcs, lift_img_coords, writer=writer) # Convert the depth maps to metric for i in range(len(depth_maps)): depth_maps[i] = ( (depth_maps[i] + 0.5) * int(np.ceil(np.sqrt(3) * grid_dims[-1])) * voxel_size + near_plane) # We don't enforce a loss on the outermost 5 pixels to alleviate boundary errors for i in range(len(trgt_views)): outputs[i] = outputs[i][:, :, 5:-5, 5:-5] trgt_views[i]['gt_rgb'] = trgt_views[i]['gt_rgb'][:, :, 5:-5, 5:-5] l1_losses = list() for idx in range(len(trgt_views)): l1_losses.append( criterionL1( outputs[idx].contiguous().view(-1).float(), trgt_views[idx]['gt_rgb'].to(device).view(-1).float())) losses_d = [] losses_g = [] optimizerD.zero_grad() optimizerG.zero_grad() for idx in range(len(trgt_views)): ####### ## Train Discriminator ####### out_perm = outputs[idx] # batch, ndf, height, width # Fake forward step pred_fake = discriminator.forward(out_perm.detach( )) # Detach to make sure no gradients go into generator loss_d_fake = criterionGAN(pred_fake, False) # Real forward step real_input = trgt_views[idx]['gt_rgb'].float().to(device) pred_real = discriminator.forward(real_input) loss_d_real = criterionGAN(pred_real, True) # Combined Loss losses_d.append((loss_d_fake + loss_d_real) * 0.5) ####### ## Train generator ####### # Try to fake discriminator pred_fake = discriminator.forward(out_perm) loss_g_gan = criterionGAN(pred_fake, True) loss_g_l1 = l1_losses[idx] * opt.l1_weight losses_g.append(loss_g_gan + loss_g_l1) loss_d = torch.stack(losses_d, dim=0).mean() loss_g = torch.stack(losses_g, dim=0).mean() loss_d.backward() optimizerD.step() loss_g.backward() optimizerG.step() print( "Iter %07d Epoch %03d loss_gen %0.4f loss_discrim %0.4f" % (iter, epoch, loss_g, loss_d)) if not iter % 100: # Write tensorboard logs. writer.add_image( "Depth", torchvision.utils.make_grid( [ depth_map.squeeze(dim=0).repeat(3, 1, 1) for depth_map in depth_maps ], scale_each=True, normalize=True).cpu().detach().numpy(), iter) writer.add_image( "Nearest_neighbors_rgb", torchvision.utils.make_grid( nearest_view['gt_rgb'], scale_each=True, normalize=True).detach().numpy(), iter) output_vs_gt = torch.cat( (torch.cat(outputs, dim=0), torch.cat([i['gt_rgb'].to(device) for i in trgt_views], dim=0)), dim=0) writer.add_image( "Output_vs_gt", torchvision.utils.make_grid( output_vs_gt, scale_each=True, normalize=True).cpu().detach().numpy(), iter) writer.add_scalar("out_min", outputs[0].min(), iter) writer.add_scalar("out_max", outputs[0].max(), iter) writer.add_scalar("trgt_min", trgt_views[0]['gt_rgb'].min(), iter) writer.add_scalar("trgt_max", trgt_views[0]['gt_rgb'].max(), iter) writer.add_scalar("discrim_loss", loss_d, iter) writer.add_scalar("gen_loss_total", loss_g, iter) writer.add_scalar("gen_loss_l1", loss_g_l1, iter) writer.add_scalar("gen_loss_g", loss_g_gan, iter) iter += 1 if iter % 10000 == 0: util.custom_save( model, os.path.join(log_dir, 'model-epoch_%d_iter_%s.pth' % (epoch, iter)), discriminator) util.custom_save( model, os.path.join(log_dir, 'model-epoch_%d_iter_%s.pth' % (epoch, iter)), discriminator)
use_gcn=False) # interpolater interpolater = network.Interpolater() # L1 loss criterionL1 = nn.L1Loss(reduction='mean').to(device) # Optimizer optimizerG = torch.optim.Adam(list(texture_mapper.parameters()) + list(render_net.parameters()), lr=opt.lr) # load checkpoint if opt.checkpoint: util.custom_load([texture_mapper, render_net], ['texture_mapper', 'render_net'], opt.checkpoint) # move to device texture_mapper.to(device) render_net.to(device) interpolater.to(device) # get module texture_mapper_module = texture_mapper render_net_module = render_net # use multi-GPU if opt.gpu_id != '': texture_mapper = nn.DataParallel(texture_mapper) render_net = nn.DataParallel(render_net) interpolater = nn.DataParallel(interpolater)
input_dims=input_shape, hidden_dim=args.hidden_dim, num_slots=args.num_slots, encoder=args.encoder, cnn_size=args.cnn_size, trans_model=args.trans_model, decoder=args.decoder, identity_action=args.identity_action, residual=args.residual, canonical=args.canonical_rep) model.to(device) print('Number of parameters in model', util.count_params(model)) if args.checkpoint_path is not None: print("Loading model from %s" % args.checkpoint_path) util.custom_load(model, path=args.checkpoint_path) else: print("Initialising random weights") model.apply(util.weights_init) optimizer = torch.optim.Adam( model.parameters(), lr=args.learning_rate) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.4) now = datetime.datetime.now() timestamp = now.isoformat() if args.name == 'none': exp_name = timestamp
num_vertex = mesh.num_vertex # interpolater interpolater = network.Interpolater() # texture mapper texture_mapper = network.TextureMapper(texture_size=opt.texture_size, texture_num_ch=opt.texture_num_ch, mipmap_level=opt.mipmap_level, texture_init=None, fix_texture=True, apply_sh=opt.apply_sh) # load checkpoint checkpoint_dict = util.custom_load([texture_mapper], ['texture_mapper'], checkpoint_fp, strict=True) # trained lighting model new_state_dict = checkpoint_dict['lighting_model'] lighting_model_train = network.LightingSH(l_dir, lmax=int(params['sh_lmax']), num_lighting=2, num_channel=num_channel, fix_params=True) lighting_model_train.coeff.data = new_state_dict['coeff'] lighting_model_train.l_samples.data = new_state_dict['l_samples'] # lighting model lp lighting_model_lp = network.LightingLP(l_dir, num_channel=num_channel,
def test(): test_dataset = dataio.TwoViewsDataset( data_dir=args.test_dir, num_pairs_per_scene=args.test_pairs_per_scene, num_scenes=args.num_test_scenes, sidelength=args.sidelength) test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) print(f'Size of test dataset {len(test_dataset)}') obs = test_loader.__iter__().next() data_util.show_batch_pairs(obs) input_shape = obs['image1'].size()[1:] # Load training params with open(args.train_log_dir + '/params.txt', 'r') as f: train_params = yaml.safe_load(f) model = nod.NodModel(embedding_dim=train_params['embedding_dim'], input_dims=input_shape, hidden_dim=train_params['hidden_dim'], num_slots=train_params['num_slots'], encoder=train_params['encoder'], decoder=train_params['decoder']) print("Loading model from %s" % args.checkpoint_path) util.custom_load(model, path=args.checkpoint_path) print("Evaluation to be saved to %s" % args.results_dir) model.to(device) model.eval() gt_comparison_dir = os.path.join(args.results_dir, 'gt_comparisons') sv_comps_dir = os.path.join(args.results_dir, 'components_same_view') dv_comps_dir = os.path.join(args.results_dir, 'components_diff_view') util.cond_mkdir(args.results_dir) util.cond_mkdir(gt_comparison_dir) util.cond_mkdir(sv_comps_dir) util.cond_mkdir(dv_comps_dir) # Save command-line parameters to log directory. with open(os.path.join(args.results_dir, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(args).items()])) l2_loss = nn.MSELoss(reduction="mean") print('Beginning evaluation...') with torch.no_grad(): same_view_losses = [] diff_view_losses = [] total_losses = [] for batch_idx, data_batch in enumerate(test_loader): img1, img2 = data_batch['image1'].to( device), data_batch['image2'].to(device) batch_size = img1.shape[0] imgs = torch.cat((img1, img2), dim=0) w, h = imgs.size(-2), imgs.size(-1) images_gt = torch.cat((img1.unsqueeze(1), img2.unsqueeze(1)), dim=1) action1, action2 = data_batch['transf21'].to( device), data_batch['transf12'].to(device) actions = torch.cat((action1, action2), dim=0) out = model(imgs, actions) masks, masked_comps, recs = model.compose_image(out) rec_views = recs[:batch_size * 2] novel_views = recs[batch_size * 2:] same_view_loss = l2_loss(rec_views, imgs) novel_view_loss = l2_loss(novel_views, imgs) total_loss = same_view_loss + novel_view_loss same_view_losses.append(same_view_loss.item()) diff_view_losses.append(novel_view_loss.item()) total_losses.append(total_loss.item()) print( f"Number input images {batch_idx * args.batch_size} | Running l2 loss: {np.mean(total_losses)}" ) break if batch_idx * args.batch_size < args.save_out_first_n: rec_views = rec_views.reshape(2, args.batch_size, 3, w, h).transpose(0, 1) novel_views = novel_views.reshape(2, args.batch_size, 3, w, h).transpose(0, 1) same_view_masked_comps = masked_comps[:args.batch_size * 2].reshape( 2, args.batch_size, model.num_slots, 3, w, h).transpose(0, 1) diff_view_masked_comps = masked_comps[args.batch_size * 2:].reshape( 2, args.batch_size, model.num_slots, 3, w, h).transpose(0, 1) same_view_masks = masks[args.batch_size * 2:].reshape( 2, args.batch_size, model.num_slots, w, h).transpose(0, 1) diff_view_masks = masks[args.batch_size * 2:].reshape( 2, args.batch_size, model.num_slots, w, h).transpose(0, 1) # Expand to have 3 channels so can concat with rgb images same_view_masks = same_view_masks.unsqueeze(3).repeat( 1, 1, 1, 3, 1, 1) diff_view_masks = diff_view_masks.unsqueeze(3).repeat( 1, 1, 1, 3, 1, 1) # Shift to be in range [-1, 1] like rgb same_view_masks = same_view_masks * 2 - 1 diff_view_masks = diff_view_masks * 2 - 1 for i in range(args.batch_size): gt = images_gt[i] same_view_rec = rec_views[i] diff_view_rec = novel_views[i] # Save ground truth reconstruction comparison gt_vs_rec_vs_nv = torch.cat( (gt, same_view_rec, diff_view_rec), dim=0) gt_comparison_imgs = torchvision.utils.make_grid( gt_vs_rec_vs_nv, nrow=2, scale_each=False, normalize=True, range=(-1, 1)).cpu().detach().numpy() plt.imsave( os.path.join( gt_comparison_dir, f'{i + batch_idx * args.batch_size:04d}.png'), np.transpose(gt_comparison_imgs, (1, 2, 0))) # Save components sv_images = torch.cat( (images_gt[i].unsqueeze(1), same_view_rec.unsqueeze(1), same_view_masked_comps[i], same_view_masks[i]), dim=1) dv_images = torch.cat( (images_gt[i].unsqueeze(1), diff_view_rec.unsqueeze(1), diff_view_masked_comps[i], diff_view_masks[i]), dim=1) comps_same_view_images = torchvision.utils.make_grid( sv_images.reshape(-1, 3, h, w), nrow=2 * model.num_slots + 2, scale_each=False, normalize=True, range=(-1, 1)).cpu().detach().numpy() comps_diff_view_images = torchvision.utils.make_grid( dv_images.reshape(-1, 3, h, w), nrow=2 * model.num_slots + 2, scale_each=False, normalize=True, range=(-1, 1)).cpu().detach().numpy() plt.imsave( os.path.join( sv_comps_dir, f'{i + batch_idx * args.batch_size:04d}.png'), np.transpose(comps_same_view_images, (1, 2, 0))) plt.imsave( os.path.join( dv_comps_dir, f'{i + batch_idx * args.batch_size:04d}.png'), np.transpose(comps_diff_view_images, (1, 2, 0))) save_circles(model, args.results_dir, args.circle_source_img_path.split()) with open(os.path.join(args.results_dir, "results.txt"), "w") as out_file: out_file.write("Evaluation Metric: score \n\n") out_file.write( f"Same view rec l2 loss: {np.mean(same_view_losses):10f} \n") out_file.write( f"Diff view rec l2 loss: {np.mean(diff_view_losses):10f} \n") out_file.write(f"Rec l2 loss: {np.mean(total_losses):10f} \n") print("\nFinal score: ")