Пример #1
0
def scatter_nd_backward(inputs, shape):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    _ = inputs[1]
    idx = inputs[2]
    dx0 = F.gather_nd(dy, idx)
    return dx0, None
Пример #2
0
    def forward_impl(self, inputs, outputs):
        x0 = inputs[0].data
        if self.axis == 2:
            x0 = F.transpose(x0, (0, 2, 1))

        b, n, *_ = x0.shape
        self.mask = self._mask_gen(b, n)

        mask = nn.NdArray.from_numpy_array(self.mask)
        ans = F.gather_nd(x0, mask)

        if self.axis == 2:
            ans = F.transpose(ans, (0, 2, 1))

        y = outputs[0].data
        y.copy_from(ans)
Пример #3
0
    def call(self, x, y):
        hp = self.hp
        results = []
        with nn.parameter_scope('layer_0'):
            x = F.pad(x, (0, 0, 7, 7), 'reflect')
            x = wn_conv(x, hp.ndf, (15,))
            x = F.leaky_relu(x, 0.2, inplace=True)
            results.append(x)

        nf = hp.ndf
        stride = hp.downsamp_factor

        for i in range(1, hp.n_layers_D + 1):
            nf_prev = nf
            nf = min(nf * stride, 1024)
            with nn.parameter_scope(f'layer_{i}'):
                x = wn_conv(
                    x, nf, (stride * 10 + 1,),
                    stride=(stride,),
                    pad=(stride * 5,),
                    group=nf_prev // 4,
                )
                x = F.leaky_relu(x, 0.2, inplace=True)
                results.append(x)

        with nn.parameter_scope(f'layer_{hp.n_layers_D + 1}'):
            nf = min(nf * 2, 1024)
            x = wn_conv(x, nf, kernel=(5,), pad=(2,))
            x = F.leaky_relu(x, 0.2, inplace=True)
            results.append(x)

        with nn.parameter_scope(f'layer_{hp.n_layers_D + 2}'):
            x = wn_conv(x, hp.n_speakers, kernel=(3,), pad=(1,))
            if y is not None:
                idx = F.stack(
                    F.arange(0, hp.batch_size),
                    y.reshape((hp.batch_size,))
                )
                x = F.gather_nd(x, idx)
            results.append(x)

        return results
Пример #4
0
    def backward_impl(self, inputs, outputs, propagate_down, accum):
        # Grads of inputs and outputs
        dx0 = inputs[0].grad
        dy = outputs[0].grad
        grad = dy

        if self.axis == 2:
            grad = F.transpose(grad, (0, 2, 1))

        mask = nn.NdArray.from_numpy_array(self.mask)
        grad = F.gather_nd(grad, mask)

        if self.axis == 2:
            grad = F.transpose(grad, (0, 2, 1))

        # backward w.r.t. x0
        if propagate_down[0]:
            if accum[0]:
                dx0 += grad
            else:
                dx0.copy_from(grad)
Пример #5
0
def test():
    """
    Test(Zooming SloMo) - inference on set of input data or Vid4 data
    """
    # set context and load the model
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    nn.load_parameters(args.model)
    input_dir = args.input_dir
    n_ot = 7

    # list all input sequence folders containing input frames
    inp_dir_list = sorted(glob.glob(input_dir + '/*'))
    inp_dir_name_list = []
    avg_psnr_l = []
    avg_psnr_y_l = []
    avg_ssim_y_l = []
    sub_folder_name_l = []
    save_folder = 'results'
    # for each sub-folder
    for inp_dir in inp_dir_list:
        gt_tested_list = []
        inp_dir_name = inp_dir.split('/')[-1]
        sub_folder_name_l.append(inp_dir_name)

        inp_dir_name_list.append(inp_dir_name)
        save_inp_folder = osp.join(save_folder, inp_dir_name)
        img_low_res_list = sorted(glob.glob(inp_dir + '/*'))

        util.mkdirs(save_inp_folder)
        imgs = util.read_seq_imgs_(inp_dir)

        img_gt_l = []
        if args.metrics:
            replace_str = 'LR'
            for img_gt_path in sorted(glob.glob(osp.join(inp_dir.replace(replace_str, 'HR'), '*'))):
                img_gt_l.append(util.read_image(img_gt_path))

        avg_psnr, avg_psnr_sum, cal_n = 0, 0, 0
        avg_psnr_y, avg_psnr_sum_y = 0, 0
        avg_ssim_y, avg_ssim_sum_y = 0, 0

        skip = args.metrics

        select_idx_list = util.test_index_generation(
            skip, n_ot, len(img_low_res_list))

        # process each image
        for select_idxs in select_idx_list:
            # get input images
            select_idx = [select_idxs[0]]
            gt_idx = select_idxs[1]
            imgs_in = F.gather_nd(
                imgs, indices=nn.Variable.from_numpy_array(select_idx))
            imgs_in = F.reshape(x=imgs_in, shape=(1,) + imgs_in.shape)
            output = zooming_slo_mo_network(imgs_in, args.only_slomo)
            outputs = output[0]
            outputs.forward(clear_buffer=True)

            for idx, name_idx in enumerate(gt_idx):
                if name_idx in gt_tested_list:
                    continue
                gt_tested_list.append(name_idx)
                output_f = outputs.d[idx, :, :, :]
                output = util.tensor2img(output_f)
                cv2.imwrite(osp.join(save_inp_folder,
                                     '{:08d}.png'.format(name_idx + 1)), output)
                print("Saving :", osp.join(save_inp_folder,
                                           '{:08d}.png'.format(name_idx + 1)))

                if args.metrics:
                    # calculate PSNR
                    output = output / 255.
                    ground_truth = np.copy(img_gt_l[name_idx])
                    cropped_output = output
                    cropped_gt = ground_truth

                    crt_psnr = util.calculate_psnr(
                        cropped_output * 255, cropped_gt * 255)
                    cropped_gt_y = util.bgr2ycbcr(cropped_gt, only_y=True)
                    cropped_output_y = util.bgr2ycbcr(
                        cropped_output, only_y=True)
                    crt_psnr_y = util.calculate_psnr(
                        cropped_output_y * 255, cropped_gt_y * 255)
                    crt_ssim_y = util.calculate_ssim(
                        cropped_output_y * 255, cropped_gt_y * 255)

                    avg_psnr_sum += crt_psnr
                    avg_psnr_sum_y += crt_psnr_y
                    avg_ssim_sum_y += crt_ssim_y
                    cal_n += 1

        if args.metrics:
            avg_psnr = avg_psnr_sum / cal_n
            avg_psnr_y = avg_psnr_sum_y / cal_n
            avg_ssim_y = avg_ssim_sum_y / cal_n

            avg_psnr_l.append(avg_psnr)
            avg_psnr_y_l.append(avg_psnr_y)
            avg_ssim_y_l.append(avg_ssim_y)

    if args.metrics:
        print('################ Tidy Outputs ################')
        for name, ssim, psnr_y in zip(sub_folder_name_l, avg_ssim_y_l, avg_psnr_y_l):
            print(
                'Folder {} - Average SSIM: {:.6f}  PSNR-Y: {:.6f} dB. '.format(name, ssim, psnr_y))
        print('################ Final Results ################')
        print('Total Average SSIM: {:.6f}  PSNR-Y: {:.6f} dB for {} clips. '.format(
            sum(avg_ssim_y_l) / len(avg_ssim_y_l), sum(avg_psnr_y_l) /
            len(avg_psnr_y_l),
            len(inp_dir_list)))
Пример #6
0
def train_nerf(config, comm, model, dataset='blender'):

    use_transient = False
    use_embedding = False

    if model == 'wild':
        use_transient = True
        use_embedding = True
    elif model == 'uncertainty':
        use_transient = True
    elif model == 'appearance':
        use_embedding = True

    save_results_dir = config.log.save_results_dir
    os.makedirs(save_results_dir, exist_ok=True)

    train_loss_dict = {
        'train_coarse_loss': 0.0,
        'train_fine_loss': 0.0,
        'train_total_loss': 0.0,
    }

    test_metric_dict = {'test_loss': 0.0, 'test_psnr': 0.0}

    monitor_manager = MonitorManager(train_loss_dict, test_metric_dict,
                                     save_results_dir)

    if dataset != 'phototourism':
        images, poses, _, hwf, i_test, i_train, near_plane, far_plane = get_data(
            config)
        height, width, focal_length = hwf
    else:
        di = get_photo_tourism_dataiterator(config, 'train', comm)
        val_di = get_photo_tourism_dataiterator(config, 'val', comm)

    if model != 'vanilla':
        if dataset != 'phototourism':
            config.train.n_vocab = max(np.max(i_train), np.max(i_test)) + 1
        print(
            f'Setting Vocabulary size of embedding as {config.train.n_vocab}')

    if dataset != 'phototourism':
        if model in ['vanilla']:
            if comm is not None:
                # uncomment the following line to test on fewer images
                i_test = i_test[3 * comm.rank:3 * (comm.rank + 1)]
                pass
            else:
                # uncomment the following line to test on fewer images
                i_test = i_test[:3]
                pass
        else:
            # i_test = i_train[0:5]
            i_test = [i * (comm.rank + 1) for i in range(5)]
    else:
        i_test = [1]

    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    lr = config.solver.lr
    num_decay_steps = config.solver.lr_decay_step * 1000
    lr_decay_factor = config.solver.lr_decay_factor
    solver = S.Adam(alpha=lr)

    load_solver_state = False
    if config.checkpoint.param_path is not None:
        nn.load_parameters(config.checkpoint.param_path)
        load_solver_state = True

    if comm is not None:
        num_decay_steps /= comm.n_procs
        comm_size = comm.n_procs
    else:
        comm_size = 1
    pbar = trange(config.train.num_iterations // comm_size,
                  disable=(comm is not None and comm.rank > 0))

    for i in pbar:

        if dataset != 'phototourism':

            idx = np.random.choice(i_train)
            image = nn.Variable.from_numpy_array(images[idx][None, :, :, :3])
            pose = nn.Variable.from_numpy_array(poses[idx])

            ray_directions, ray_origins = get_ray_bundle(
                height, width, focal_length, pose)

            grid = get_direction_grid(width,
                                      height,
                                      focal_length,
                                      return_ij_2d_grid=True)
            grid = F.reshape(grid, (-1, 2))

            select_inds = np.random.choice(grid.shape[0],
                                           size=[config.train.num_rand_points],
                                           replace=False)
            select_inds = F.gather_nd(grid, select_inds[None, :])
            select_inds = F.transpose(select_inds, (1, 0))

            embed_inp = nn.Variable.from_numpy_array(
                np.full((config.train.chunksize_fine, ), idx, dtype=int))

            ray_origins = F.gather_nd(ray_origins, select_inds)
            ray_directions = F.gather_nd(ray_directions, select_inds)

            image = F.gather_nd(image[0], select_inds)

        else:
            rays, embed_inp, image = di.next()
            ray_origins = nn.Variable.from_numpy_array(rays[:, :3])
            ray_directions = nn.Variable.from_numpy_array(rays[:, 3:6])
            near_plane = nn.Variable.from_numpy_array(rays[:, 6])
            far_plane = nn.Variable.from_numpy_array(rays[:, 7])

            embed_inp = nn.Variable.from_numpy_array(embed_inp)
            image = nn.Variable.from_numpy_array(image)

            hwf = None

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                   config.train.n_app)

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

        if use_transient:
            rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf,
                image=image)
            course_loss = 0.5 * F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = 0.5 * F.mean(
                F.squared_error(rgb_map_fine, image) /
                F.reshape(F.pow_scalar(beta, 2), beta.shape + (1, )))
            beta_reg_loss = 3 + F.mean(F.log(beta))
            sigma_trans_reg_loss = 0.01 * F.mean(transient_sigma)
            loss = course_loss + fine_loss + beta_reg_loss + sigma_trans_reg_loss
        else:
            rgb_map_course, _, _, _, rgb_map_fine, _, _, _ = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf)
            course_loss = F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = F.mean(F.squared_error(rgb_map_fine, image))
            loss = course_loss + fine_loss

        pbar.set_description(
            f'Total: {np.around(loss.d, 4)}, Course: {np.around(course_loss.d, 4)}, Fine: {np.around(fine_loss.d, 4)}'
        )

        solver.set_parameters(nn.get_parameters(),
                              reset=False,
                              retain_state=True)
        if load_solver_state:
            solver.load_states(config['checkpoint']['solver_path'])
            load_solver_state = False

        solver.zero_grad()

        loss.backward(clear_buffer=True)

        # Exponential LR decay
        if dataset != 'phototourism':
            lr_factor = (lr_decay_factor**((i) / num_decay_steps))
            solver.set_learning_rate(lr * lr_factor)
        else:
            if i % num_decay_steps == 0 and i != 0:
                solver.set_learning_rate(lr * lr_decay_factor)

        if comm is not None:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=True)
        solver.update()

        if ((i % config.train.save_interval == 0
             or i == config.train.num_iterations - 1)
                and i != 0) and (comm is not None and comm.rank == 0):
            nn.save_parameters(os.path.join(save_results_dir, f'iter_{i}.h5'))
            solver.save_states(
                os.path.join(save_results_dir, f'solver_iter_{i}.h5'))

        if (i % config.train.test_interval == 0
                or i == config.train.num_iterations - 1) and i != 0:
            avg_psnr, avg_mse = 0.0, 0.0
            for i_t in trange(len(i_test)):

                if dataset != 'phototourism':
                    idx_test = i_test[i_t]
                    image = nn.NdArray.from_numpy_array(
                        images[idx_test][None, :, :, :3])
                    pose = nn.NdArray.from_numpy_array(poses[idx_test])

                    ray_directions, ray_origins = get_ray_bundle(
                        height, width, focal_length, pose)

                    ray_directions = F.reshape(ray_directions, (-1, 3))
                    ray_origins = F.reshape(ray_origins, (-1, 3))

                    embed_inp = nn.NdArray.from_numpy_array(
                        np.full((config.train.chunksize_fine, ),
                                idx_test,
                                dtype=int))

                else:
                    rays, embed_inp, image = val_di.next()
                    ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3])
                    ray_directions = nn.NdArray.from_numpy_array(rays[0, :,
                                                                      3:6])
                    near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6])
                    far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7])

                    embed_inp = nn.NdArray.from_numpy_array(
                        embed_inp[0, :config.train.chunksize_fine])
                    image = nn.NdArray.from_numpy_array(image[0].transpose(
                        1, 2, 0))
                    image = F.reshape(image, (1, ) + image.shape)
                    idx_test = 1

                app_emb, trans_emb = None, None
                if use_embedding:
                    with nn.parameter_scope('embedding_a'):
                        app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                           config.train.n_app)

                if use_transient:
                    with nn.parameter_scope('embedding_t'):
                        trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                             config.train.n_trans)

                num_ray_batches = ray_directions.shape[
                    0] // config.train.ray_batch_size + 1

                if use_transient:
                    rgb_map_fine_list, static_rgb_map_fine_list, transient_rgb_map_fine_list = [], [], []
                else:
                    rgb_map_fine_list, depth_map_fine_list = [], []

                for r_idx in trange(num_ray_batches):
                    if r_idx != num_ray_batches - 1:
                        ray_d, ray_o = ray_directions[
                            r_idx * config.train.ray_batch_size:(r_idx + 1) *
                            config.train.ray_batch_size], ray_origins[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]

                        if dataset == 'phototourism':
                            near_plane = near_plane_[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]
                            far_plane = far_plane_[r_idx *
                                                   config.train.ray_batch_size:
                                                   (r_idx + 1) *
                                                   config.train.ray_batch_size]

                    else:
                        if ray_directions.shape[0] - (
                                num_ray_batches -
                                1) * config.train.ray_batch_size == 0:
                            break
                        ray_d, ray_o = ray_directions[
                            r_idx *
                            config.train.ray_batch_size:, :], ray_origins[
                                r_idx * config.train.ray_batch_size:, :]
                        if dataset == 'phototourism':
                            near_plane = near_plane_[r_idx * config.train.
                                                     ray_batch_size:]
                            far_plane = far_plane_[r_idx * config.train.
                                                   ray_batch_size:]

                    if use_transient:
                        rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                            ray_d,
                            ray_o,
                            near_plane,
                            far_plane,
                            app_emb,
                            trans_emb,
                            encode_position_function,
                            encode_direction_function,
                            config,
                            use_transient,
                            hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        static_rgb_map_fine_list.append(static_rgb_map_fine)
                        transient_rgb_map_fine_list.append(
                            transient_rgb_map_fine)

                    else:
                        _, _, _, _, rgb_map_fine, depth_map_fine, _, _ = \
                            forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                                         encode_position_function, encode_direction_function, config, use_transient, hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        depth_map_fine_list.append(depth_map_fine)

                if use_transient:
                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    static_rgb_map_fine = F.concatenate(
                        *static_rgb_map_fine_list, axis=0)
                    transient_rgb_map_fine = F.concatenate(
                        *transient_rgb_map_fine_list, axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    static_rgb_map_fine = F.reshape(static_rgb_map_fine,
                                                    image[0].shape)
                    transient_rgb_map_fine = F.reshape(transient_rgb_map_fine,
                                                       image[0].shape)
                    static_trans_img_to_save = np.concatenate(
                        (static_rgb_map_fine.data,
                         np.ones((image[0].shape[0], 5, 3)),
                         transient_rgb_map_fine.data),
                        axis=1)
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)
                else:

                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    depth_map_fine = F.concatenate(*depth_map_fine_list,
                                                   axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    depth_map_fine = F.reshape(depth_map_fine,
                                               image[0].shape[:-1])
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)

                filename = os.path.join(save_results_dir,
                                        f'{i}_{idx_test}.png')
                try:
                    imsave(filename,
                           np.clip(img_to_save, 0, 1),
                           channel_first=False)
                    print(f'Saved generation at {filename}')
                    if use_transient:
                        filename_static_trans = os.path.join(
                            save_results_dir, f's_t_{i}_{idx_test}.png')
                        imsave(filename_static_trans,
                               np.clip(static_trans_img_to_save, 0, 1),
                               channel_first=False)

                    else:
                        filename_dm = os.path.join(save_results_dir,
                                                   f'dm_{i}_{idx_test}.png')
                        depth_map_fine = (depth_map_fine.data -
                                          depth_map_fine.data.min()) / (
                                              depth_map_fine.data.max() -
                                              depth_map_fine.data.min())
                        imsave(filename_dm,
                               depth_map_fine[:, :, None],
                               channel_first=False)
                        plt.imshow(depth_map_fine.data)
                        plt.savefig(filename_dm)
                        plt.close()
                except:
                    pass

                avg_mse += F.mean(F.squared_error(rgb_map_fine, image[0])).data
                avg_psnr += (-10. * np.log10(
                    F.mean(F.squared_error(rgb_map_fine, image[0])).data))

            test_metric_dict['test_loss'] = avg_mse / len(i_test)
            test_metric_dict['test_psnr'] = avg_psnr / len(i_test)
            monitor_manager.add(i, test_metric_dict)
            print(
                f'Saved generations after {i} training iterations! Average PSNR: {avg_psnr/len(i_test)}. Average MSE: {avg_mse/len(i_test)}'
            )