Пример #1
0
    def __init__(self, num_iters, compile_flag, interactive, dev_str, f):

        # ivy
        f = choose_random_framework() if f is None else f
        ivy.set_framework(f)
        ivy.seed(0)

        # device
        if dev_str is None:
            dev_str = 'gpu:0' if ivy.gpu_is_available() else 'cpu'
        self._dev_str = dev_str

        # Load input images and poses
        this_dir = os.path.dirname(os.path.realpath(__file__))
        data = np.load(os.path.join(this_dir, 'nerf_data/tiny_nerf_data.npz'))
        images = ivy.array(data['images'], 'float32', dev_str)
        inv_ext_mats = ivy.array(data['poses'], 'float32', dev_str)

        # intrinsics
        focal_lengths = ivy.array(np.tile(np.reshape(data['focal'], (1, 1)), [100, 2]), 'float32', dev_str)
        self._img_dims = images.shape[1:3]
        pp_offsets = ivy.tile(ivy.array([[dim/2 - 0.5 for dim in self._img_dims]], dev_str=dev_str), [100, 1])

        # train data
        self._images = images[:100, ..., :3]
        self._intrinsics = ivy_vision.focal_lengths_and_pp_offsets_to_intrinsics_object(
            focal_lengths, pp_offsets, self._img_dims)
        self._cam_geoms = ivy_vision.inv_ext_mat_and_intrinsics_to_cam_geometry_object(
            inv_ext_mats[:100, 0:3], self._intrinsics)

        # test data
        self._test_img = images[101]
        self._test_cam_geom = ivy_vision.inv_ext_mat_and_intrinsics_to_cam_geometry_object(
            inv_ext_mats[101, 0:3], self._intrinsics.slice(0))

        # train config
        self._embed_length = 6
        self._lr = 5e-4
        self._num_samples = 64
        self._num_iters = num_iters

        # log config
        self._interactive = interactive
        self._log_freq = 1
        self._vis_freq = 25 if self._interactive else -1
        self._vis_log_dir = 'nerf_renderings'
        if os.path.exists(self._vis_log_dir):
            shutil.rmtree(self._vis_log_dir)
        os.makedirs(self._vis_log_dir)

        # model
        self._model = Model(4, 256, self._embed_length, dev_str)

        # compile
        if compile_flag:
            rays_o, rays_d = self._get_rays(self._cam_geoms.slice(0))
            target = self._images[0]
            self._loss_fn = ivy.compile_fn(self._loss_fn, False,
                                           example_inputs=[self._model, rays_o, rays_d, target, self._model.v])
Пример #2
0
    def __init__(self, dataset_spec, batch_size, starting_idx, num_sequences, window_size=1,
                 num_workers=1, cache_size=0, unused_key_chains=None, custom_init_fn=None,
                 container_load_mode='dynamic', custom_container_load_fn=None, preshuffle_data=True,
                 shuffle_buffer_size=0, with_prefetching=True, queue_timeout=None, post_proc_fn=None,
                 prefetch_to_devs='gpu:0', single_pass=False, array_strs=None, float_strs=None, uint8_strs=None,
                 custom_img_strs=None, custom_img_fns=None, custom_strs=None, custom_fns=None, array_mode='pickled',
                 load_gray_as_rgb=True, containers_to_skip=None, **kwargs):

        kw = locals_to_kwargs(locals())

        unused_key_chains = ivy.default(unused_key_chains, [])
        array_strs = ivy.default(array_strs, [])
        float_strs = ivy.default(float_strs, [])
        uint8_strs = ivy.default(uint8_strs, [])
        custom_img_strs = ivy.default(custom_img_strs, [[]])
        custom_img_fns = ivy.default(custom_img_fns, [])
        custom_strs = ivy.default(custom_strs, [[]])
        custom_fns = ivy.default(custom_fns, [])
        containers_to_skip = ivy.default(containers_to_skip, [])
        prefetch_to_devs = prefetch_to_devs if ivy.gpu_is_available() or isinstance(prefetch_to_devs, list) else False
        assert container_load_mode in ['preload', 'dynamic', 'custom']
        if container_load_mode == 'custom':
            assert ivy.exists(custom_container_load_fn)
        else:
            assert ivy.exists(dataset_spec.cont_fname_template)

        super(SeqDataLoaderSpec, self).__init__(dataset_spec,
                                                batch_size=batch_size,
                                                window_size=window_size,
                                                starting_idx=starting_idx,
                                                num_sequences=num_sequences,
                                                num_workers=num_workers,
                                                cache_size=cache_size,
                                                unused_key_chains=unused_key_chains,
                                                custom_init_fn=custom_init_fn,
                                                container_load_mode=container_load_mode,
                                                custom_container_load_fn=custom_container_load_fn,
                                                preshuffle_data=preshuffle_data,
                                                shuffle_buffer_size=shuffle_buffer_size,
                                                with_prefetching=with_prefetching,
                                                post_proc_fn=post_proc_fn,
                                                prefetch_to_devs=prefetch_to_devs,
                                                single_pass=single_pass,
                                                array_strs=array_strs,
                                                float_strs=float_strs,
                                                uint8_strs=uint8_strs,
                                                custom_img_strs=custom_img_strs,
                                                custom_img_fns=custom_img_fns,
                                                custom_strs=custom_strs,
                                                custom_fns=custom_fns,
                                                array_mode=array_mode,
                                                load_gray_as_rgb=load_gray_as_rgb,
                                                containers_to_skip=containers_to_skip,
                                                **kwargs)
        self.queue_timeout = ivy.default(queue_timeout, ivy.queue_timeout())  # conflicts with ivy.Container argument

        self._kwargs = kw
Пример #3
0
 def __init__(self,
              dataset_spec: DatasetSpec,
              dev_strs: Union[str, List[str]] = None,
              **kwargs) -> None:
     """
     base class for storing general parameters which define the way in which the data loader loads the dataset
     """
     kw = locals_to_kwargs(locals())
     super().__init__(dataset_spec=dataset_spec,
                      dev_strs=ivy.default(dev_strs, ['gpu:0'] if ivy.gpu_is_available() else ['cpu']),
                      **kwargs)
     self._kwargs = kw
Пример #4
0
    def __init__(self, dev_str=None, v=None):
        """
        Initialze Ivy layer, which is a stateful object consisting of trainable variables.

        :param dev_str: device on which to create the layer's variables 'cuda:0', 'cuda:1', 'cpu' etc.
        :type dev_str: str, optional
        :param v: Ivy container of trainable variables. Created internally by default.
        :type v: ivy container, optional
        """
        if dev_str is None:
            dev_str = 'gpu:0' if ivy.gpu_is_available() else 'cpu'
        self._dev_str = dev_str
        if v is None:
            self.v = Container(self._find_and_create_variables())
        else:
            self.v = Container(v)
Пример #5
0
 def __init__(self,
              data_loader: None,
              network: Network,
              log_dir: str = 'log',
              overwrite_log_dir: bool = False,
              seed: int = 0,
              ld_chkpt: bool = False,
              save_freq: int = 1000,
              save_at_end: bool = True,
              log_freq: int = 100,
              log_at_end: bool = True,
              vis_freq: int = 500,
              vis_at_end: bool = True,
              log_validation: bool = True,
              log_time: bool = True,
              log_learning_rate: bool = True,
              starting_iteration: int = None,
              total_iterations: int = 1e6,
              initial_learning_rate: float = 1e-4,
              save_spec: bool = True,
              custom_train_step: bool = False,
              auto_detect_weights: bool = True,
              log_gradients: (tuple, str) = 'all',
              log_variables: (tuple, str) = 'all',
              log_optimizer_state: (tuple, str) = 'all',
              profile_start_step: int = 5,
              steps_to_profile: int = 0,
              compile_graph: bool = 'all',
              dev_strs: Union[str, List[str]] = None,
              dev_map_fn: str = '_split_execute_with_grads',
              tune_device_allocation: bool = True,
              tune_splitting: bool = True,
              **kwargs) -> None:
     """
     parameters which define the training procedure
     """
     kw = locals_to_kwargs(locals())
     if log_gradients == 'all' or 'all' in log_gradients:
         log_gradients = ['mean', 'abs_mean', 'var', 'abs_var', 'min', 'abs_min', 'max', 'abs_max', 'vector_norm',
                          'global_vector_norm']
     if log_variables == 'all' or 'all' in log_variables:
         log_variables = ['mean', 'abs_mean', 'var', 'abs_var', 'min', 'abs_min', 'max', 'abs_max', 'vector_norm',
                          'global_vector_norm']
     if log_optimizer_state == 'all' or 'all' in log_optimizer_state:
         log_optimizer_state = ['mean', 'abs_mean', 'var', 'abs_var', 'min', 'abs_min', 'max', 'abs_max',
                                'vector_norm', 'global_vector_norm']
     super().__init__(data_loader=data_loader,
                      network=network,
                      log_dir=log_dir,
                      overwrite_log_dir=overwrite_log_dir,
                      seed=seed,
                      ld_chkpt=ld_chkpt,
                      save_freq=save_freq,
                      save_at_end=save_at_end,
                      log_freq=log_freq,
                      log_at_end=log_at_end,
                      vis_freq=vis_freq,
                      vis_at_end=vis_at_end,
                      log_validation=log_validation,
                      log_time=log_time,
                      log_learning_rate=log_learning_rate,
                      starting_iteration=starting_iteration,
                      total_iterations=total_iterations,
                      initial_learning_rate=initial_learning_rate,
                      save_spec=save_spec,
                      custom_train_step=custom_train_step,
                      auto_detect_weights=auto_detect_weights,
                      log_gradients=log_gradients,
                      log_variables=log_variables,
                      log_optimizer_state=log_optimizer_state,
                      profile_start_step=profile_start_step,
                      steps_to_profile=steps_to_profile,
                      compile_graph=compile_graph,
                      dev_strs=ivy.default(dev_strs, ['gpu:0'] if ivy.gpu_is_available() else ['cpu']),
                      dev_map_fn=dev_map_fn,
                      tune_device_allocation=tune_device_allocation,
                      tune_splitting=tune_splitting,
                      **kwargs)
     self._kwargs = kw