def get_generator(model, cfg, device, **kwargs):
    ''' Returns the generator object.

    Args:
        model (nn.Module): Occupancy Network model
        cfg (dict): imported yaml config
        device (device): pytorch device
    '''

    if cfg['data']['input_type'] == 'pointcloud_crop':
        # calculate the volume boundary
        query_vol_metric = cfg['data']['padding'] + 1
        unit_size = cfg['data']['unit_size']
        recep_field = 2**(
            cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2)
        if 'unet' in cfg['model']['encoder_kwargs']:
            depth = cfg['model']['encoder_kwargs']['unet_kwargs']['depth']
        elif 'unet3d' in cfg['model']['encoder_kwargs']:
            depth = cfg['model']['encoder_kwargs']['unet3d_kwargs'][
                'num_levels']

        vol_info = decide_total_volume_range(query_vol_metric, recep_field,
                                             unit_size, depth)

        grid_reso = cfg['data']['query_vol_size'] + recep_field - 1
        grid_reso = update_reso(grid_reso, depth)
        query_vol_size = cfg['data']['query_vol_size'] * unit_size
        input_vol_size = grid_reso * unit_size
        # only for the sliding window case
        vol_bound = None
        if cfg['generation']['sliding_window']:
            vol_bound = {
                'query_crop_size': query_vol_size,
                'input_crop_size': input_vol_size,
                'fea_type': cfg['model']['encoder_kwargs']['plane_type'],
                'reso': grid_reso
            }

    else:
        vol_bound = None
        vol_info = None

    generator = generation.Generator3D(
        model,
        device=device,
        threshold=cfg['test']['threshold'],
        resolution0=cfg['generation']['resolution_0'],
        upsampling_steps=cfg['generation']['upsampling_steps'],
        sample=cfg['generation']['use_sampling'],
        refinement_step=cfg['generation']['refinement_step'],
        simplify_nfaces=cfg['generation']['simplify_nfaces'],
        input_type=cfg['data']['input_type'],
        padding=cfg['data']['padding'],
        vol_info=vol_info,
        vol_bound=vol_bound,
    )
    return generator
    def __init__(self,
                 dataset_folder,
                 fields,
                 split=None,
                 categories=None,
                 no_except=True,
                 transform=None,
                 cfg=None):
        ''' Initialization of the the 3D shape dataset.

        Args:
            dataset_folder (str): dataset folder
            fields (dict): dictionary of fields
            split (str): which split is used
            categories (list): list of categories to use
            no_except (bool): no exception
            transform (callable): transformation applied to data points
            cfg (yaml): config file
        '''
        # Attributes
        self.dataset_folder = dataset_folder
        self.fields = fields
        self.no_except = no_except
        self.transform = transform
        self.cfg = cfg

        # If categories is None, use all subfolders
        if categories is None:
            categories = os.listdir(dataset_folder)
            categories = [
                c for c in categories
                if os.path.isdir(os.path.join(dataset_folder, c))
            ]

        # Read metadata file
        metadata_file = os.path.join(dataset_folder, 'metadata.yaml')

        if os.path.exists(metadata_file):
            with open(metadata_file, 'r') as f:
                self.metadata = yaml.load(f)
        else:
            self.metadata = {c: {'id': c, 'name': 'n/a'} for c in categories}

        # Set index
        for c_idx, c in enumerate(categories):
            self.metadata[c]['idx'] = c_idx

        # Get all models
        self.models = []
        for c_idx, c in enumerate(categories):
            subpath = os.path.join(dataset_folder, c)
            if not os.path.isdir(subpath):
                logger.warning('Category %s does not exist in dataset.' % c)

            if split is None:
                self.models += [{
                    'category': c,
                    'model': m
                } for m in [
                    d for d in os.listdir(subpath)
                    if (os.path.isdir(os.path.join(subpath, d)) and d != '')
                ]]

            else:
                split_file = os.path.join(subpath, split + '.lst')
                with open(split_file, 'r') as f:
                    models_c = f.read().split('\n')

                if '' in models_c:
                    models_c.remove('')

                self.models += [{'category': c, 'model': m} for m in models_c]

        # precompute
        if self.cfg['data']['input_type'] == 'pointcloud_crop':
            self.split = split
            # proper resolution for feature plane/volume of the ENTIRE scene
            query_vol_metric = self.cfg['data']['padding'] + 1
            unit_size = self.cfg['data']['unit_size']
            recep_field = 2**(
                cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] +
                2)
            if 'unet' in cfg['model']['encoder_kwargs']:
                depth = cfg['model']['encoder_kwargs']['unet_kwargs']['depth']
            elif 'unet3d' in cfg['model']['encoder_kwargs']:
                depth = cfg['model']['encoder_kwargs']['unet3d_kwargs'][
                    'num_levels']

            self.depth = depth
            #! for sliding-window case, pass all points!
            if self.cfg['generation']['sliding_window']:
                self.total_input_vol, self.total_query_vol, self.total_reso = \
                    decide_total_volume_range(100000, recep_field, unit_size, depth) # contain the whole scene
            else:
                self.total_input_vol, self.total_query_vol, self.total_reso = \
                    decide_total_volume_range(query_vol_metric, recep_field, unit_size, depth)