def get_generator(model, cfg, device, **kwargs):
    ''' Returns the generator object.

    Args:
        model (nn.Module): Occupancy Network model
        cfg (dict): imported yaml config
        device (device): pytorch device
    '''

    if cfg['data']['input_type'] == 'pointcloud_crop':
        # calculate the volume boundary
        query_vol_metric = cfg['data']['padding'] + 1
        unit_size = cfg['data']['unit_size']
        recep_field = 2**(
            cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2)
        if 'unet' in cfg['model']['encoder_kwargs']:
            depth = cfg['model']['encoder_kwargs']['unet_kwargs']['depth']
        elif 'unet3d' in cfg['model']['encoder_kwargs']:
            depth = cfg['model']['encoder_kwargs']['unet3d_kwargs'][
                'num_levels']

        vol_info = decide_total_volume_range(query_vol_metric, recep_field,
                                             unit_size, depth)

        grid_reso = cfg['data']['query_vol_size'] + recep_field - 1
        grid_reso = update_reso(grid_reso, depth)
        query_vol_size = cfg['data']['query_vol_size'] * unit_size
        input_vol_size = grid_reso * unit_size
        # only for the sliding window case
        vol_bound = None
        if cfg['generation']['sliding_window']:
            vol_bound = {
                'query_crop_size': query_vol_size,
                'input_crop_size': input_vol_size,
                'fea_type': cfg['model']['encoder_kwargs']['plane_type'],
                'reso': grid_reso
            }

    else:
        vol_bound = None
        vol_info = None

    generator = generation.Generator3D(
        model,
        device=device,
        threshold=cfg['test']['threshold'],
        resolution0=cfg['generation']['resolution_0'],
        upsampling_steps=cfg['generation']['upsampling_steps'],
        sample=cfg['generation']['use_sampling'],
        refinement_step=cfg['generation']['refinement_step'],
        simplify_nfaces=cfg['generation']['simplify_nfaces'],
        input_type=cfg['data']['input_type'],
        padding=cfg['data']['padding'],
        vol_info=vol_info,
        vol_bound=vol_bound,
    )
    return generator
    def get_vol_info(self, model_path):
        ''' Get crop information

        Args:
            model_path (str): path to the current data
        '''
        query_vol_size = self.cfg['data']['query_vol_size']
        unit_size = self.cfg['data']['unit_size']
        field_name = self.cfg['data']['pointcloud_file']
        plane_type = self.cfg['model']['encoder_kwargs']['plane_type']
        recep_field = 2**(
            self.cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels']
            + 2)

        if self.cfg['data']['multi_files'] is None:
            file_path = os.path.join(model_path, field_name)
        else:
            num = np.random.randint(self.cfg['data']['multi_files'])
            file_path = os.path.join(model_path, field_name,
                                     '%s_%02d.npz' % (field_name, num))

        points_dict = np.load(file_path)
        p = points_dict['points']
        if self.split == 'train':
            # randomly sample a point as the center of input/query volume
            p_c = [
                np.random.uniform(p[:, i].min(), p[:, i].max())
                for i in range(3)
            ]
            # p_c = [np.random.uniform(-0.55, 0.55) for i in range(3)]
            p_c = np.array(p_c).astype(np.float32)

            reso = query_vol_size + recep_field - 1
            # make sure the defined reso can be properly processed by UNet
            reso = update_reso(reso, self.depth)
            input_vol_metric = reso * unit_size
            query_vol_metric = query_vol_size * unit_size

            # bound for the volumes
            lb_input_vol, ub_input_vol = p_c - input_vol_metric / 2, p_c + input_vol_metric / 2
            lb_query_vol, ub_query_vol = p_c - query_vol_metric / 2, p_c + query_vol_metric / 2

            input_vol = [lb_input_vol, ub_input_vol]
            query_vol = [lb_query_vol, ub_query_vol]
        else:
            reso = self.total_reso
            input_vol = self.total_input_vol
            query_vol = self.total_query_vol

        vol_info = {
            'plane_type': plane_type,
            'reso': reso,
            'input_vol': input_vol,
            'query_vol': query_vol
        }
        return vol_info
def get_model(cfg, device=None, dataset=None, **kwargs):
    ''' Return the Occupancy Network model.

    Args:
        cfg (dict): imported yaml config 
        device (device): pytorch device
        dataset (dataset): dataset
    '''
    decoder = cfg['model']['decoder']
    encoder = cfg['model']['encoder']
    dim = cfg['data']['dim']
    c_dim = cfg['model']['c_dim']
    decoder_kwargs = cfg['model']['decoder_kwargs']
    encoder_kwargs = cfg['model']['encoder_kwargs']
    padding = cfg['data']['padding']

    # for pointcloud_crop
    try:
        encoder_kwargs['unit_size'] = cfg['data']['unit_size']
        decoder_kwargs['unit_size'] = cfg['data']['unit_size']
    except:
        pass
    # local positional encoding
    if 'local_coord' in cfg['model'].keys():
        encoder_kwargs['local_coord'] = cfg['model']['local_coord']
        decoder_kwargs['local_coord'] = cfg['model']['local_coord']
    if 'pos_encoding' in cfg['model']:
        encoder_kwargs['pos_encoding'] = cfg['model']['pos_encoding']
        decoder_kwargs['pos_encoding'] = cfg['model']['pos_encoding']

    # update the feature volume/plane resolution
    if cfg['data']['input_type'] == 'pointcloud_crop':
        fea_type = cfg['model']['encoder_kwargs']['plane_type']
        if (dataset.split == 'train') or (cfg['generation']['sliding_window']):
            recep_field = 2**(
                cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] +
                2)
            reso = cfg['data']['query_vol_size'] + recep_field - 1
            if 'grid' in fea_type:
                encoder_kwargs['grid_resolution'] = update_reso(
                    reso, dataset.depth)
            if bool(set(fea_type) & set(['xz', 'xy', 'yz'])):
                encoder_kwargs['plane_resolution'] = update_reso(
                    reso, dataset.depth)
        # if dataset.split == 'val': #TODO run validation in room level during training
        else:
            if 'grid' in fea_type:
                encoder_kwargs['grid_resolution'] = dataset.total_reso
            if bool(set(fea_type) & set(['xz', 'xy', 'yz'])):
                encoder_kwargs['plane_resolution'] = dataset.total_reso

    decoder = models.decoder_dict[decoder](dim=dim,
                                           c_dim=c_dim,
                                           padding=padding,
                                           **decoder_kwargs)

    if encoder == 'idx':
        encoder = nn.Embedding(len(dataset), c_dim)
    elif encoder is not None:
        encoder = encoder_dict[encoder](dim=dim,
                                        c_dim=c_dim,
                                        padding=padding,
                                        **encoder_kwargs)
    else:
        encoder = None

    model = models.ConvolutionalOccupancyNetwork(decoder,
                                                 encoder,
                                                 device=device)

    return model