Ejemplo n.º 1
0
def setup_dist_from_mpi(master_addr="127.0.0.1",
                        backend="nccl",
                        port=29500,
                        n_attempts=5,
                        verbose=False):
    if dist.is_available():
        if not dist.is_initisialised():
            return _setup_dist_from_mpi(master_addr, backend, port, n_attempts,
                                        verbose)
        else:
            from mpi4py import MPI
            mpi_rank = MPI.COMM_WORLD.Get_rank()
            use_cuda = torch.cuda.is_available()
            local_rank = mpi_rank % 8
            device = torch.device(
                "cuda", local_rank) if use_cuda else torch.device("cpu")
            return mpi_rank, local_rank, device
    else:
        use_cuda = torch.cuda.is_available()
        print(f'Using cuda {use_cuda}')

        mpi_rank = 0
        local_rank = 0

        device = torch.device("cuda",
                              local_rank) if use_cuda else torch.device("cpu")
        torch.cuda.set_device(local_rank)

        return mpi_rank, local_rank, device
Ejemplo n.º 2
0
 def create_samplers(self, hps):
     if not dist.is_available():
         self.train_sampler = BatchSampler(RandomSampler(self.train_dataset), batch_size=hps.bs, drop_last=True)
         self.test_sampler = BatchSampler(RandomSampler(self.test_dataset), batch_size=hps.bs, drop_last=True)
     else:
         self.train_sampler = DistributedSampler(self.train_dataset)
         self.test_sampler = DistributedSampler(self.test_dataset)
Ejemplo n.º 3
0
    def init_dataset(self, hps):
        # Load list of files and starts/durations
        files = librosa.util.find_files(f'{hps.audio_files_dir}', ['mp3', 'opus', 'm4a', 'aac', 'wav'])
        print_all(f"Found {len(files)} files. Getting durations")
        cache = dist.get_rank() % 8 == 0 if dist.is_available() else True
        durations = np.array([get_duration_sec(file, cache=cache) * self.sr for file in files])  # Could be approximate
        self.filter(files, durations)

        if self.labels:
            self.labeller = Labeller(hps.max_bow_genre_size, hps.n_tokens, self.sample_length, v3=hps.labels_v3)
Ejemplo n.º 4
0
def setup_dist_from_mpi(master_addr="127.0.0.1",
                        backend="nccl",
                        port=29500,
                        n_attempts=5,
                        verbose=False):
    print('Welcome to jukebox-opt, a system ram optimized version of jukebox.')
    print(
        'For most notebooks/enviroments this acts as a mostly drop in replacement.'
    )
    print()
    print(
        'I say mostly because most notebooks/enviroments tend to load the tokens (zs)'
    )
    print(
        'onto the gpu when i want them on the cpu to allow for longer songs, my code'
    )
    print(
        'will put it back once loaded from a checkpoint, if the song is to longer than normal'
    )
    print(
        'when loading it will use too much gpu memory and probably error out later in'
    )
    print(
        "generation. To fix this, go through your notebook and change all 'cuda' to 'cpu'"
    )
    print('and all .cuda() to .cpu() for the tokens (zs).')
    print()
    print(
        "Example; change \"zs = t.load(blablabla, location='cuda')\" to \"zs = t.load(blablabla, location='cpu')\""
    )
    print(
        'Example; change "zs = [ z.cuda() for z in zs ]" to "zs = [ z.cpu() for z in zs ]"'
    )
    print('Example; change "zs[blabla].cuda()" to "zs[blabla].cpu()"')
    print()
    print()

    if dist.is_available():
        return _setup_dist_from_mpi(master_addr, backend, port, n_attempts,
                                    verbose)
    else:
        use_cuda = torch.cuda.is_available()
        print(f'Using cuda {use_cuda}')

        mpi_rank = 0
        local_rank = 0

        device = torch.device("cuda",
                              local_rank) if use_cuda else torch.device("cpu")
        torch.cuda.set_device(local_rank)

        return mpi_rank, local_rank, device
Ejemplo n.º 5
0
def setup_dist_from_mpi(
    master_addr="127.0.0.1", backend="nccl", port=29500, n_attempts=5, verbose=False
):
    if dist.is_available():
        return _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose)
    else:
        use_cuda = torch.cuda.is_available()
        print(f'Using cuda {use_cuda}')

        mpi_rank = 0
        local_rank = 0

        device = torch.device("cuda", local_rank) if use_cuda else torch.device("cpu")
        torch.cuda.set_device(local_rank)

        return mpi_rank, local_rank, device
Ejemplo n.º 6
0
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1),
                                 hps.n_fft,
                                 hop_length=hps.hop_length,
                                 win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples**2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2=total_sq / n_seen - mean**2,
                     l1=l1 / n_seen,
                     spec=spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth
Ejemplo n.º 7
0
 def set_epoch(self, epoch):
     if dist.is_available():
         self.train_sampler.set_epoch(epoch)
         self.test_sampler.set_epoch(epoch)
Ejemplo n.º 8
0
def print_once(msg):
    if (not dist.is_available()) or dist.get_rank() == 0:
        print(msg)
Ejemplo n.º 9
0
def print_all(msg):
    if (not dist.is_available()):
        print(msg)
    elif dist.get_rank() % 8 == 0:
        print(f'{dist.get_rank()//8}: {msg}')