Exemplo n.º 1
0
def run(hps="teeny", port=29500, **kwargs):
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
Exemplo n.º 2
0
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length,
                 hop_length, hps):
    print_once(f"Sampling level {level}")
    if total_length >= prior.n_ctx:
        for start in get_starts(total_length, prior.n_ctx, hop_length):
            zs = sample_single_window(zs, labels, sampling_kwargs, level,
                                      prior, start, hps)
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior,
                                   total_length, hps)
    return zs
Exemplo n.º 3
0
def sample_single_window(zs, labels, sampling_kwargs, level, prior, start,
                         hps):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:, start:end]

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(
        f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens"
    )

    if new_tokens <= 0:
        # Nothing new to sample
        return zs

    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    # set y offset, sample_length and lyrics tokens
    y = prior.get_y(labels, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']

    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y_list = split_batch(y, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
        z_samples_i = prior.sample(n_samples=z_i.shape[0],
                                   z=z_i,
                                   z_conds=z_conds_i,
                                   y=y_i,
                                   **sampling_kwargs)
        z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:, -new_tokens:]
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs
Exemplo n.º 4
0
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1),
                                 hps.n_fft,
                                 hop_length=hps.hop_length,
                                 win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples**2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2=total_sq / n_seen - mean**2,
                     l1=l1 / n_seen,
                     spec=spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth
Exemplo n.º 5
0
    def __init__(self,
                 z_shapes,
                 l_bins,
                 encoder,
                 decoder,
                 level,
                 downs_t,
                 strides_t,
                 labels,
                 prior_kwargs,
                 x_cond_kwargs,
                 y_cond_kwargs,
                 prime_kwargs,
                 copy_input,
                 labels_v3=False,
                 merged_decoder=False,
                 single_enc_dec=False):
        super().__init__()

        self.use_tokens = prime_kwargs.pop('use_tokens')
        self.n_tokens = prime_kwargs.pop('n_tokens')
        self.prime_loss_fraction = prime_kwargs.pop('prime_loss_fraction')

        self.copy_input = copy_input
        if self.copy_input:
            prime_kwargs['bins'] = l_bins

        self.z_shapes = z_shapes
        self.levels = len(self.z_shapes)

        self.z_shape = self.z_shapes[level]

        self.level = level
        assert level < self.levels, f"Total levels {self.levels}, got level {level}"

        self.l_bins = l_bins

        # Passing functions instead of the vqvae module to avoid getting params
        self.encoder = encoder
        self.decoder = decoder

        # X conditioning
        self.x_cond = (level != (self.levels - 1))
        self.cond_level = level + 1

        # Y conditioning
        self.y_cond = labels

        self.single_enc_dec = single_enc_dec
        # X conditioning
        if self.x_cond:
            self.conditioner_blocks = nn.ModuleList()
            conditioner_block = lambda _level: Conditioner(
                input_shape=z_shapes[_level],
                bins=l_bins,
                down_t=downs_t[_level],
                stride_t=strides_t[_level],
                **x_cond_kwargs)
            if dist.get_rank() == 0: print(f"Conditioning on 1 above level(s)")
            self.conditioner_blocks.append(conditioner_block(self.cond_level))

        # Y conditioning
        if self.y_cond:
            self.n_time = self.z_shape[
                0]  # Assuming STFT=TF order and raw=T1 order, so T is first dim
            self.y_emb = LabelConditioner(n_time=self.n_time,
                                          include_time_signal=not self.x_cond,
                                          **y_cond_kwargs)

        # Lyric conditioning
        if single_enc_dec:
            # Single encoder-decoder transformer
            self.prior_shapes = [(self.n_tokens, ),
                                 prior_kwargs.pop('input_shape')]
            self.prior_bins = [prime_kwargs['bins'], prior_kwargs.pop('bins')]
            self.prior_dims = [np.prod(shape) for shape in self.prior_shapes]
            self.prior_bins_shift = np.cumsum([0, *self.prior_bins])[:-1]
            self.prior_width = prior_kwargs['width']
            print_once(
                f'Creating cond. autoregress with prior bins {self.prior_bins}, '
            )
            print_once(f'dims {self.prior_dims}, ')
            print_once(f'shift {self.prior_bins_shift}')
            print_once(f'input shape {sum(self.prior_dims)}')
            print_once(f'input bins {sum(self.prior_bins)}')
            print_once(f'Self copy is {self.copy_input}')

            self.prime_loss_dims, self.gen_loss_dims = self.prior_dims[
                0], self.prior_dims[1]
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(
                input_shape=(sum(self.prior_dims), ),
                bins=sum(self.prior_bins),
                x_cond=(self.x_cond or self.y_cond),
                y_cond=True,
                prime_len=self.prime_loss_dims,
                **prior_kwargs)

        else:
            # Separate encoder-decoder transformer
            if self.n_tokens != 0 and self.use_tokens:
                from app.jukebox.transformer.ops import Conv1D
                prime_input_shape = (self.n_tokens, )
                self.prime_loss_dims = np.prod(prime_input_shape)
                self.prime_acts_width, self.prime_state_width = prime_kwargs[
                    'width'], prior_kwargs['width']
                self.prime_prior = ConditionalAutoregressive2D(
                    input_shape=prime_input_shape,
                    x_cond=False,
                    y_cond=False,
                    only_encode=True,
                    **prime_kwargs)
                self.prime_state_proj = Conv1D(
                    self.prime_acts_width,
                    self.prime_state_width,
                    init_scale=prime_kwargs['init_scale'])
                self.prime_state_ln = LayerNorm(self.prime_state_width)
                self.prime_bins = prime_kwargs['bins']
                self.prime_x_out = nn.Linear(self.prime_state_width,
                                             self.prime_bins,
                                             bias=False)
                nn.init.normal_(self.prime_x_out.weight,
                                std=0.02 * prior_kwargs['init_scale'])
            else:
                self.prime_loss_dims = 0
            self.gen_loss_dims = np.prod(self.z_shape)
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(
                x_cond=(self.x_cond or self.y_cond),
                y_cond=self.y_cond,
                encoder_dims=self.prime_loss_dims,
                merged_decoder=merged_decoder,
                **prior_kwargs)

        self.n_ctx = self.gen_loss_dims
        self.downsamples = calculate_strides(strides_t, downs_t)
        self.cond_downsample = self.downsamples[
            level + 1] if level != self.levels - 1 else None
        self.raw_to_tokens = np.prod(self.downsamples[:level + 1])
        self.sample_length = self.n_ctx * self.raw_to_tokens
        if labels:
            self.labels_v3 = labels_v3
            self.labeller = Labeller(self.y_emb.max_bow_genre_size,
                                     self.n_tokens,
                                     self.sample_length,
                                     v3=self.labels_v3)
        else:
            self.labeller = EmptyLabeller()

        print(
            f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}"
        )