def __init__(self, dataset, stratify=None, undersample=None, distributed=False, num_replicas=None, rank=None): self.stratify = stratify self.undersample = undersample self.distributed = distributed if self.distributed: DistributedSampler.__init__(self, dataset, num_replicas=num_replicas, rank=rank) else: #TODO need to create ither variables defined in distriburedsampler when stratify off self.num_replicas = 1 self.rank = 0 self.epoch = 0 if self.stratify is not None: self.pos_stratify = np.where(self.stratify == 1)[0] self.neg_stratify = np.where(self.stratify == 0)[0] self.Npos = int(sum(self.stratify)) self.Nneg = int(self.stratify.size - sum(self.stratify)) self.pos_num_samples = int( math.ceil(self.Npos * 1.0 / self.num_replicas)) if self.undersample is not None: self.neg_num_samples = int(self.undersample * self.pos_num_samples) else: self.neg_num_samples = int( math.ceil(self.Nneg * 1.0 / self.num_replicas)) self.num_samples = self.pos_num_samples + self.neg_num_samples self.pos_total_size = self.pos_num_samples * self.num_replicas self.neg_total_size = self.neg_num_samples * self.num_replicas if self.undersample is not None: g = torch.Generator() g.manual_seed(0) neg_indices = torch.randperm(self.Nneg, generator=g) self.neg_num_samples = int(self.undersample * self.pos_num_samples) self.neg_indices_init = neg_indices[:self.neg_total_size] self.neg_used_indices = self.neg_stratify[ self.neg_indices_init] else: self.neg_used_indices = self.neg_stratify #global indices used (fixed over epochs), for convenience in reading out self.pos_used_indices = self.pos_stratify
def __init__(self, dataset, num_replicas, rank, epoch_size=None): DistributedSampler.__init__(self, dataset, num_replicas, rank) epoch_size = len(dataset) if epoch_size is None else epoch_size self.num_samples = int(math.ceil(epoch_size * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas