def scatter_nd_backward(inputs, shape): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] _ = inputs[1] idx = inputs[2] dx0 = F.gather_nd(dy, idx) return dx0, None
def forward_impl(self, inputs, outputs): x0 = inputs[0].data if self.axis == 2: x0 = F.transpose(x0, (0, 2, 1)) b, n, *_ = x0.shape self.mask = self._mask_gen(b, n) mask = nn.NdArray.from_numpy_array(self.mask) ans = F.gather_nd(x0, mask) if self.axis == 2: ans = F.transpose(ans, (0, 2, 1)) y = outputs[0].data y.copy_from(ans)
def call(self, x, y): hp = self.hp results = [] with nn.parameter_scope('layer_0'): x = F.pad(x, (0, 0, 7, 7), 'reflect') x = wn_conv(x, hp.ndf, (15,)) x = F.leaky_relu(x, 0.2, inplace=True) results.append(x) nf = hp.ndf stride = hp.downsamp_factor for i in range(1, hp.n_layers_D + 1): nf_prev = nf nf = min(nf * stride, 1024) with nn.parameter_scope(f'layer_{i}'): x = wn_conv( x, nf, (stride * 10 + 1,), stride=(stride,), pad=(stride * 5,), group=nf_prev // 4, ) x = F.leaky_relu(x, 0.2, inplace=True) results.append(x) with nn.parameter_scope(f'layer_{hp.n_layers_D + 1}'): nf = min(nf * 2, 1024) x = wn_conv(x, nf, kernel=(5,), pad=(2,)) x = F.leaky_relu(x, 0.2, inplace=True) results.append(x) with nn.parameter_scope(f'layer_{hp.n_layers_D + 2}'): x = wn_conv(x, hp.n_speakers, kernel=(3,), pad=(1,)) if y is not None: idx = F.stack( F.arange(0, hp.batch_size), y.reshape((hp.batch_size,)) ) x = F.gather_nd(x, idx) results.append(x) return results
def backward_impl(self, inputs, outputs, propagate_down, accum): # Grads of inputs and outputs dx0 = inputs[0].grad dy = outputs[0].grad grad = dy if self.axis == 2: grad = F.transpose(grad, (0, 2, 1)) mask = nn.NdArray.from_numpy_array(self.mask) grad = F.gather_nd(grad, mask) if self.axis == 2: grad = F.transpose(grad, (0, 2, 1)) # backward w.r.t. x0 if propagate_down[0]: if accum[0]: dx0 += grad else: dx0.copy_from(grad)
def test(): """ Test(Zooming SloMo) - inference on set of input data or Vid4 data """ # set context and load the model ctx = get_extension_context(args.context) nn.set_default_context(ctx) nn.load_parameters(args.model) input_dir = args.input_dir n_ot = 7 # list all input sequence folders containing input frames inp_dir_list = sorted(glob.glob(input_dir + '/*')) inp_dir_name_list = [] avg_psnr_l = [] avg_psnr_y_l = [] avg_ssim_y_l = [] sub_folder_name_l = [] save_folder = 'results' # for each sub-folder for inp_dir in inp_dir_list: gt_tested_list = [] inp_dir_name = inp_dir.split('/')[-1] sub_folder_name_l.append(inp_dir_name) inp_dir_name_list.append(inp_dir_name) save_inp_folder = osp.join(save_folder, inp_dir_name) img_low_res_list = sorted(glob.glob(inp_dir + '/*')) util.mkdirs(save_inp_folder) imgs = util.read_seq_imgs_(inp_dir) img_gt_l = [] if args.metrics: replace_str = 'LR' for img_gt_path in sorted(glob.glob(osp.join(inp_dir.replace(replace_str, 'HR'), '*'))): img_gt_l.append(util.read_image(img_gt_path)) avg_psnr, avg_psnr_sum, cal_n = 0, 0, 0 avg_psnr_y, avg_psnr_sum_y = 0, 0 avg_ssim_y, avg_ssim_sum_y = 0, 0 skip = args.metrics select_idx_list = util.test_index_generation( skip, n_ot, len(img_low_res_list)) # process each image for select_idxs in select_idx_list: # get input images select_idx = [select_idxs[0]] gt_idx = select_idxs[1] imgs_in = F.gather_nd( imgs, indices=nn.Variable.from_numpy_array(select_idx)) imgs_in = F.reshape(x=imgs_in, shape=(1,) + imgs_in.shape) output = zooming_slo_mo_network(imgs_in, args.only_slomo) outputs = output[0] outputs.forward(clear_buffer=True) for idx, name_idx in enumerate(gt_idx): if name_idx in gt_tested_list: continue gt_tested_list.append(name_idx) output_f = outputs.d[idx, :, :, :] output = util.tensor2img(output_f) cv2.imwrite(osp.join(save_inp_folder, '{:08d}.png'.format(name_idx + 1)), output) print("Saving :", osp.join(save_inp_folder, '{:08d}.png'.format(name_idx + 1))) if args.metrics: # calculate PSNR output = output / 255. ground_truth = np.copy(img_gt_l[name_idx]) cropped_output = output cropped_gt = ground_truth crt_psnr = util.calculate_psnr( cropped_output * 255, cropped_gt * 255) cropped_gt_y = util.bgr2ycbcr(cropped_gt, only_y=True) cropped_output_y = util.bgr2ycbcr( cropped_output, only_y=True) crt_psnr_y = util.calculate_psnr( cropped_output_y * 255, cropped_gt_y * 255) crt_ssim_y = util.calculate_ssim( cropped_output_y * 255, cropped_gt_y * 255) avg_psnr_sum += crt_psnr avg_psnr_sum_y += crt_psnr_y avg_ssim_sum_y += crt_ssim_y cal_n += 1 if args.metrics: avg_psnr = avg_psnr_sum / cal_n avg_psnr_y = avg_psnr_sum_y / cal_n avg_ssim_y = avg_ssim_sum_y / cal_n avg_psnr_l.append(avg_psnr) avg_psnr_y_l.append(avg_psnr_y) avg_ssim_y_l.append(avg_ssim_y) if args.metrics: print('################ Tidy Outputs ################') for name, ssim, psnr_y in zip(sub_folder_name_l, avg_ssim_y_l, avg_psnr_y_l): print( 'Folder {} - Average SSIM: {:.6f} PSNR-Y: {:.6f} dB. '.format(name, ssim, psnr_y)) print('################ Final Results ################') print('Total Average SSIM: {:.6f} PSNR-Y: {:.6f} dB for {} clips. '.format( sum(avg_ssim_y_l) / len(avg_ssim_y_l), sum(avg_psnr_y_l) / len(avg_psnr_y_l), len(inp_dir_list)))
def train_nerf(config, comm, model, dataset='blender'): use_transient = False use_embedding = False if model == 'wild': use_transient = True use_embedding = True elif model == 'uncertainty': use_transient = True elif model == 'appearance': use_embedding = True save_results_dir = config.log.save_results_dir os.makedirs(save_results_dir, exist_ok=True) train_loss_dict = { 'train_coarse_loss': 0.0, 'train_fine_loss': 0.0, 'train_total_loss': 0.0, } test_metric_dict = {'test_loss': 0.0, 'test_psnr': 0.0} monitor_manager = MonitorManager(train_loss_dict, test_metric_dict, save_results_dir) if dataset != 'phototourism': images, poses, _, hwf, i_test, i_train, near_plane, far_plane = get_data( config) height, width, focal_length = hwf else: di = get_photo_tourism_dataiterator(config, 'train', comm) val_di = get_photo_tourism_dataiterator(config, 'val', comm) if model != 'vanilla': if dataset != 'phototourism': config.train.n_vocab = max(np.max(i_train), np.max(i_test)) + 1 print( f'Setting Vocabulary size of embedding as {config.train.n_vocab}') if dataset != 'phototourism': if model in ['vanilla']: if comm is not None: # uncomment the following line to test on fewer images i_test = i_test[3 * comm.rank:3 * (comm.rank + 1)] pass else: # uncomment the following line to test on fewer images i_test = i_test[:3] pass else: # i_test = i_train[0:5] i_test = [i * (comm.rank + 1) for i in range(5)] else: i_test = [1] encode_position_function = get_encoding_function( config.train.num_encodings_position, True, True) if config.train.use_view_directions: encode_direction_function = get_encoding_function( config.train.num_encodings_direction, True, True) else: encode_direction_function = None lr = config.solver.lr num_decay_steps = config.solver.lr_decay_step * 1000 lr_decay_factor = config.solver.lr_decay_factor solver = S.Adam(alpha=lr) load_solver_state = False if config.checkpoint.param_path is not None: nn.load_parameters(config.checkpoint.param_path) load_solver_state = True if comm is not None: num_decay_steps /= comm.n_procs comm_size = comm.n_procs else: comm_size = 1 pbar = trange(config.train.num_iterations // comm_size, disable=(comm is not None and comm.rank > 0)) for i in pbar: if dataset != 'phototourism': idx = np.random.choice(i_train) image = nn.Variable.from_numpy_array(images[idx][None, :, :, :3]) pose = nn.Variable.from_numpy_array(poses[idx]) ray_directions, ray_origins = get_ray_bundle( height, width, focal_length, pose) grid = get_direction_grid(width, height, focal_length, return_ij_2d_grid=True) grid = F.reshape(grid, (-1, 2)) select_inds = np.random.choice(grid.shape[0], size=[config.train.num_rand_points], replace=False) select_inds = F.gather_nd(grid, select_inds[None, :]) select_inds = F.transpose(select_inds, (1, 0)) embed_inp = nn.Variable.from_numpy_array( np.full((config.train.chunksize_fine, ), idx, dtype=int)) ray_origins = F.gather_nd(ray_origins, select_inds) ray_directions = F.gather_nd(ray_directions, select_inds) image = F.gather_nd(image[0], select_inds) else: rays, embed_inp, image = di.next() ray_origins = nn.Variable.from_numpy_array(rays[:, :3]) ray_directions = nn.Variable.from_numpy_array(rays[:, 3:6]) near_plane = nn.Variable.from_numpy_array(rays[:, 6]) far_plane = nn.Variable.from_numpy_array(rays[:, 7]) embed_inp = nn.Variable.from_numpy_array(embed_inp) image = nn.Variable.from_numpy_array(image) hwf = None app_emb, trans_emb = None, None if use_embedding: with nn.parameter_scope('embedding_a'): app_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_app) if use_transient: with nn.parameter_scope('embedding_t'): trans_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_trans) if use_transient: rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass( ray_directions, ray_origins, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf, image=image) course_loss = 0.5 * F.mean(F.squared_error(rgb_map_course, image)) fine_loss = 0.5 * F.mean( F.squared_error(rgb_map_fine, image) / F.reshape(F.pow_scalar(beta, 2), beta.shape + (1, ))) beta_reg_loss = 3 + F.mean(F.log(beta)) sigma_trans_reg_loss = 0.01 * F.mean(transient_sigma) loss = course_loss + fine_loss + beta_reg_loss + sigma_trans_reg_loss else: rgb_map_course, _, _, _, rgb_map_fine, _, _, _ = forward_pass( ray_directions, ray_origins, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) course_loss = F.mean(F.squared_error(rgb_map_course, image)) fine_loss = F.mean(F.squared_error(rgb_map_fine, image)) loss = course_loss + fine_loss pbar.set_description( f'Total: {np.around(loss.d, 4)}, Course: {np.around(course_loss.d, 4)}, Fine: {np.around(fine_loss.d, 4)}' ) solver.set_parameters(nn.get_parameters(), reset=False, retain_state=True) if load_solver_state: solver.load_states(config['checkpoint']['solver_path']) load_solver_state = False solver.zero_grad() loss.backward(clear_buffer=True) # Exponential LR decay if dataset != 'phototourism': lr_factor = (lr_decay_factor**((i) / num_decay_steps)) solver.set_learning_rate(lr * lr_factor) else: if i % num_decay_steps == 0 and i != 0: solver.set_learning_rate(lr * lr_decay_factor) if comm is not None: params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=True) solver.update() if ((i % config.train.save_interval == 0 or i == config.train.num_iterations - 1) and i != 0) and (comm is not None and comm.rank == 0): nn.save_parameters(os.path.join(save_results_dir, f'iter_{i}.h5')) solver.save_states( os.path.join(save_results_dir, f'solver_iter_{i}.h5')) if (i % config.train.test_interval == 0 or i == config.train.num_iterations - 1) and i != 0: avg_psnr, avg_mse = 0.0, 0.0 for i_t in trange(len(i_test)): if dataset != 'phototourism': idx_test = i_test[i_t] image = nn.NdArray.from_numpy_array( images[idx_test][None, :, :, :3]) pose = nn.NdArray.from_numpy_array(poses[idx_test]) ray_directions, ray_origins = get_ray_bundle( height, width, focal_length, pose) ray_directions = F.reshape(ray_directions, (-1, 3)) ray_origins = F.reshape(ray_origins, (-1, 3)) embed_inp = nn.NdArray.from_numpy_array( np.full((config.train.chunksize_fine, ), idx_test, dtype=int)) else: rays, embed_inp, image = val_di.next() ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3]) ray_directions = nn.NdArray.from_numpy_array(rays[0, :, 3:6]) near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6]) far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7]) embed_inp = nn.NdArray.from_numpy_array( embed_inp[0, :config.train.chunksize_fine]) image = nn.NdArray.from_numpy_array(image[0].transpose( 1, 2, 0)) image = F.reshape(image, (1, ) + image.shape) idx_test = 1 app_emb, trans_emb = None, None if use_embedding: with nn.parameter_scope('embedding_a'): app_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_app) if use_transient: with nn.parameter_scope('embedding_t'): trans_emb = PF.embed(embed_inp, config.train.n_vocab, config.train.n_trans) num_ray_batches = ray_directions.shape[ 0] // config.train.ray_batch_size + 1 if use_transient: rgb_map_fine_list, static_rgb_map_fine_list, transient_rgb_map_fine_list = [], [], [] else: rgb_map_fine_list, depth_map_fine_list = [], [] for r_idx in trange(num_ray_batches): if r_idx != num_ray_batches - 1: ray_d, ray_o = ray_directions[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size], ray_origins[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size] if dataset == 'phototourism': near_plane = near_plane_[ r_idx * config.train.ray_batch_size:(r_idx + 1) * config.train.ray_batch_size] far_plane = far_plane_[r_idx * config.train.ray_batch_size: (r_idx + 1) * config.train.ray_batch_size] else: if ray_directions.shape[0] - ( num_ray_batches - 1) * config.train.ray_batch_size == 0: break ray_d, ray_o = ray_directions[ r_idx * config.train.ray_batch_size:, :], ray_origins[ r_idx * config.train.ray_batch_size:, :] if dataset == 'phototourism': near_plane = near_plane_[r_idx * config.train. ray_batch_size:] far_plane = far_plane_[r_idx * config.train. ray_batch_size:] if use_transient: rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass( ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) rgb_map_fine_list.append(rgb_map_fine) static_rgb_map_fine_list.append(static_rgb_map_fine) transient_rgb_map_fine_list.append( transient_rgb_map_fine) else: _, _, _, _, rgb_map_fine, depth_map_fine, _, _ = \ forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function, encode_direction_function, config, use_transient, hwf=hwf) rgb_map_fine_list.append(rgb_map_fine) depth_map_fine_list.append(depth_map_fine) if use_transient: rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0) static_rgb_map_fine = F.concatenate( *static_rgb_map_fine_list, axis=0) transient_rgb_map_fine = F.concatenate( *transient_rgb_map_fine_list, axis=0) rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape) static_rgb_map_fine = F.reshape(static_rgb_map_fine, image[0].shape) transient_rgb_map_fine = F.reshape(transient_rgb_map_fine, image[0].shape) static_trans_img_to_save = np.concatenate( (static_rgb_map_fine.data, np.ones((image[0].shape[0], 5, 3)), transient_rgb_map_fine.data), axis=1) img_to_save = np.concatenate( (image[0].data, np.ones( (image[0].shape[0], 5, 3)), rgb_map_fine.data), axis=1) else: rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0) depth_map_fine = F.concatenate(*depth_map_fine_list, axis=0) rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape) depth_map_fine = F.reshape(depth_map_fine, image[0].shape[:-1]) img_to_save = np.concatenate( (image[0].data, np.ones( (image[0].shape[0], 5, 3)), rgb_map_fine.data), axis=1) filename = os.path.join(save_results_dir, f'{i}_{idx_test}.png') try: imsave(filename, np.clip(img_to_save, 0, 1), channel_first=False) print(f'Saved generation at {filename}') if use_transient: filename_static_trans = os.path.join( save_results_dir, f's_t_{i}_{idx_test}.png') imsave(filename_static_trans, np.clip(static_trans_img_to_save, 0, 1), channel_first=False) else: filename_dm = os.path.join(save_results_dir, f'dm_{i}_{idx_test}.png') depth_map_fine = (depth_map_fine.data - depth_map_fine.data.min()) / ( depth_map_fine.data.max() - depth_map_fine.data.min()) imsave(filename_dm, depth_map_fine[:, :, None], channel_first=False) plt.imshow(depth_map_fine.data) plt.savefig(filename_dm) plt.close() except: pass avg_mse += F.mean(F.squared_error(rgb_map_fine, image[0])).data avg_psnr += (-10. * np.log10( F.mean(F.squared_error(rgb_map_fine, image[0])).data)) test_metric_dict['test_loss'] = avg_mse / len(i_test) test_metric_dict['test_psnr'] = avg_psnr / len(i_test) monitor_manager.add(i, test_metric_dict) print( f'Saved generations after {i} training iterations! Average PSNR: {avg_psnr/len(i_test)}. Average MSE: {avg_mse/len(i_test)}' )