def create_samplers(self, hps):
     if not dist.is_available():
         self.train_sampler = BatchSampler(RandomSampler(
             self.train_dataset),
                                           batch_size=hps.bs,
                                           drop_last=True)
         self.test_sampler = BatchSampler(RandomSampler(self.test_dataset),
                                          batch_size=hps.bs,
                                          drop_last=True)
     else:
         self.train_sampler = DistributedSampler(self.train_dataset)
         self.test_sampler = DistributedSampler(self.test_dataset)
예제 #2
0
def setup_dist_from_mpi(
    master_addr="127.0.0.1", backend="nccl", port=29500, n_attempts=5, verbose=False
):
    if dist.is_available():
        return _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose)
    else:
        use_cuda = torch.cuda.is_available()
        print(f'Using cuda {use_cuda}')

        mpi_rank = 0
        local_rank = 0

        device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu")
        torch.cuda.set_device(local_rank)

        return mpi_rank, local_rank, device
    def init_dataset(self, hps):
        # Load list of files and starts/durations
        files = librosa.util.find_files(f'{hps.audio_files_dir}',
                                        ['mp3', 'opus', 'm4a', 'aac', 'wav'])
        print_all(f"Found {len(files)} files. Getting durations")
        cache = dist.get_rank() % 8 == 0 if dist.is_available() else True
        durations = np.array([
            get_duration_sec(file, cache=cache) * self.sr for file in files
        ])  # Could be approximate
        self.filter(files, durations)

        if self.labels:
            self.labeller = Labeller(hps.max_bow_genre_size,
                                     hps.n_tokens,
                                     self.sample_length,
                                     v3=hps.labels_v3)
예제 #4
0
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1),
                                 hps.n_fft,
                                 hop_length=hps.hop_length,
                                 win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples**2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2=total_sq / n_seen - mean**2,
                     l1=l1 / n_seen,
                     spec=spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth
예제 #5
0
def print_all(msg):
    if (not dist.is_available()):
        print(msg)
    elif dist.get_rank()%8==0:
        print(f'{dist.get_rank()//8}: {msg}')
예제 #6
0
def print_once(msg):
    if (not dist.is_available()) or dist.get_rank()==0:
        print(msg)