Esempio n. 1
0
    def __init__(self,
                 dataset: ScanObjectNNDataset,
                 use_potential=True,
                 balance_labels=False):
        Sampler.__init__(self, dataset)

        # Does the sampler use potential for regular sampling
        self.use_potential = use_potential

        # Should be balance the classes when sampling
        self.balance_labels = balance_labels

        # Dataset used by the sampler (no copy is made in memory)
        self.dataset = dataset

        # Create potentials
        if self.use_potential:
            self.potentials = np.random.rand(len(
                dataset.input_labels)) * 0.1 + 0.1
        else:
            self.potentials = None

        # Initialize value for batch limit (max number of points per batch).
        self.batch_limit = 100000

        return
Esempio n. 2
0
 def __init__(self, num_samples):
     """
     初始化权重均相同,跟据该权重初次抽样进入网络
     :param num_samples:
     """
     Sampler.__init__(self, data_source=None)
     self.weights = torch.zeros(num_samples).fill_(1 / num_samples)
     self.num_samples = num_samples
     self.sampled_index = None
Esempio n. 3
0
 def __init__(
         self,
         sequence_strs: List[str],
         toks_per_batch: int,
         crop_sizes: Tuple[int, int] = (512, 1024),
 ):
     Sampler.__init__(self, data_source=None)
     self._sequence_strs = sequence_strs
     self._toks_per_batch = toks_per_batch
     self._crop_sizes = crop_sizes
     self._init_batches = get_batch_indices(
         sequence_strs=sequence_strs,
         toks_per_batch=toks_per_batch,
         crop_sizes=crop_sizes,
     )
    def __init__(self, dataset: ModelNet40Dataset, use_potential=True, balance_labels=False):
        Sampler.__init__(self, dataset)
        self.use_potential = use_potential
        self.balance_labels = balance_labels  # 采样时是否平衡各类别

        self.dataset = dataset  # 需要采样处理的dataset,内存中没有复制

        if self.use_potential:
            # 大小是数据集点云的数目
            self.potentials = np.random.rand(len(dataset.labels)) * 0.1 + 0.1
        else:
            self.potentials = None
        # 每个batch所能包含的最多的点数
        self.batch_limit = 10000

        return
Esempio n. 5
0
    def __init__(
        self,
        sequence_strs: List[str],
        toks_per_batch: int,
        crop_sizes: Tuple[int, int] = (512, 1024),
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        seed: int = 0,
    ):
        Sampler.__init__(self, data_source=None)

        # Replicate Torch Distributed Sampler logic
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError(
                    "Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError(
                    "Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError("Invalid rank {}, rank should be in the interval"
                             " [0, {}]".format(rank, num_replicas - 1))

        self._num_replicas = num_replicas
        self._rank = rank
        self._epoch = 0
        self._seed = seed

        self._sequence_strs = sequence_strs
        self._toks_per_batch = toks_per_batch
        self._crop_sizes = crop_sizes
        self._init_batches = get_batch_indices(
            sequence_strs=sequence_strs,
            toks_per_batch=toks_per_batch,
            crop_sizes=crop_sizes,
            seed=self._seed + self._epoch,
        )
        self._num_samples = math.ceil(
            len(self._init_batches) / self._num_replicas)
        self._total_size = self._num_samples * self._num_replicas
Esempio n. 6
0
 def __init__(self, dataset, *args, **kwargs):
     Sampler.__init__(self, dataset)
     self.init(*args, **kwargs)