예제 #1
0
    def init_dataset(self, hps):
        # Load list of files and starts/durations
        files = librosa.util.find_files(f'{hps.audio_files_dir}', ['mp3', 'opus', 'm4a', 'aac', 'wav'])
        print_all(f"Found {len(files)} files. Getting durations")
        cache = dist.get_rank() % 8 == 0 if dist.is_available() else True
        durations = np.array([get_duration_sec(file, cache=cache) * self.sr for file in files])  # Could be approximate
        self.filter(files, durations)

        if self.labels:
            self.labeller = Labeller(hps.max_bow_genre_size, hps.n_tokens, self.sample_length, v3=hps.labels_v3)
예제 #2
0
 def filter(self, files, durations):
     # Remove files too short or too long
     keep = []
     for i in range(len(files)):
         if durations[i] / self.sr < self.min_duration:
             continue
         if durations[i] / self.sr >= self.max_duration:
             continue
         keep.append(i)
     print_all(f'self.sr={self.sr}, min: {self.min_duration}, max: {self.max_duration}')
     print_all(f"Keeping {len(keep)} of {len(files)} files")
     self.files = [files[i] for i in keep]
     self.durations = [int(durations[i]) for i in keep]
     self.cumsum = np.cumsum(self.durations)
예제 #3
0
def make_vqvae(hps, device='cuda'):
    from jukebox.vqvae.vqvae import VQVAE
    block_kwargs = dict(
        width=hps.width,
        depth=hps.depth,
        m_conv=hps.m_conv,
        dilation_growth_rate=hps.dilation_growth_rate,
        dilation_cycle=hps.dilation_cycle,
        reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation)

    if not hps.sample_length:
        assert hps.sample_length_in_seconds != 0
        downsamples = calculate_strides(hps.strides_t, hps.downs_t)
        top_raw_to_tokens = np.prod(downsamples)
        hps.sample_length = (hps.sample_length_in_seconds * hps.sr //
                             top_raw_to_tokens) * top_raw_to_tokens
        print(
            f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}"
        )

    vqvae = VQVAE(input_shape=(hps.sample_length, 1),
                  levels=hps.levels,
                  downs_t=hps.downs_t,
                  strides_t=hps.strides_t,
                  emb_width=hps.emb_width,
                  l_bins=hps.l_bins,
                  mu=hps.l_mu,
                  commit=hps.commit,
                  spectral=hps.spectral,
                  multispectral=hps.multispectral,
                  multipliers=hps.hvqvae_multipliers,
                  use_bottleneck=hps.use_bottleneck,
                  **block_kwargs)

    vqvae = vqvae.to(device)
    restore(hps, vqvae, hps.restore_vqvae)
    if hps.train and not hps.prior:
        print_all(f"Loading vqvae in train mode")
        if hps.restore_vqvae != '':
            print_all("Reseting bottleneck emas")
            for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks):
                num_samples = hps.sample_length
                downsamples = calculate_strides(hps.strides_t, hps.downs_t)
                raw_to_tokens = np.prod(downsamples[:level + 1])
                num_tokens = (num_samples //
                              raw_to_tokens) * dist.get_world_size()
                bottleneck.restore_k(num_tokens=num_tokens,
                                     threshold=hps.revival_threshold)
    else:
        print_all(f"Loading vqvae in eval mode")
        vqvae.eval()
        freeze_model(vqvae)
    return vqvae
예제 #4
0
    def forward(self, x, update_k=True):
        N, width, T = x.shape

        # Preprocess
        x, prenorm = self.preprocess(x)

        # Init k if not inited
        if update_k and not self.init:
            self.init_k(x)
        # self.k = self.get_cur_embeddings()

        # Quantise and dequantise through bottleneck
        x_l, fit = self.quantise(x)
        x_d = self.dequantise(x_l)

        # Update embeddings
        if update_k:
            update_metrics = self.update_k(x, x_l)
        else:
            update_metrics = {}

        # Loss
        commit_loss = t.norm(x_d.detach() - x)**2 / np.prod(x.shape)
        # q_latent_loss = t.reduce_mean((x_d - tf.stop_gradient(inputs)) ** 2)

        print_all("DOR PRINTS SHAPES")
        print_all(x_d.shape)  # same shape as
        print_all(x_l.shape)
        print_all(x.shape)  # this one

        # Passthrough
        x_d = x + (x_d - x).detach()

        # Postprocess
        x_l, x_d = self.postprocess(x_l, x_d, (N, T))
        return x_l, x_d, commit_loss, dict(fit=fit,
                                           pn=prenorm,
                                           **update_metrics)
예제 #5
0
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
    model.train()
    orig_model.train()
    if hps.prior:
        _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
    else:
        _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")

    print_all(data_processor.train_loader)
    print_all(len(data_processor.train_loader))

    for i, x in logger.get_range(data_processor.train_loader):
        if isinstance(x, (tuple, list)):
            x, y = x
        else:
            y = None

        x = x.to('cuda', non_blocking=True)
        if y is not None:
            y = y.to('cuda', non_blocking=True)

        x_in = x = audio_preprocess(x, hps)
        log_input_output = (logger.iters % hps.save_iters == 0)

        if hps.prior:
            forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
        else:
            forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)

        # Forward
        x_out, loss, _metrics = model(x, **forw_kwargs)

        # Backward
        loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
                                                                         scalar=scalar, fp16=hps.fp16, logger=logger)
        # Skip step if overflow
        grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
        if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
            zero_grad(orig_model)
            continue

        # Step opt. Divide by scale to include clipping and fp16 scaling
        logger.step()
        opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
        zero_grad(orig_model)
        lr = hps.lr if shd is None else shd.get_lr()[0]
        if shd is not None: shd.step()
        if ema is not None: ema.step()
        next_lr = hps.lr if shd is None else shd.get_lr()[0]
        finished_training = (next_lr == 0.0)

        # Logging
        for key, val in _metrics.items():
            _metrics[key] = val.item()
        _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
        _metrics["gn"] = grad_norm
        _metrics["lr"] = lr
        _metrics["lg_loss_scale"] = np.log2(scale)

        # Average and log
        for key, val in _metrics.items():
            _metrics[key] = metrics.update(key, val, x.shape[0])
            if logger.iters % hps.log_steps == 0:
                logger.add_scalar(key, _metrics[key])

        # Save checkpoint
        with t.no_grad():
            if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
                if ema is not None: ema.swap()
                orig_model.eval()
                name = 'latest' if hps.prior else f'step_{logger.iters}'
                if dist.get_rank() % 8 == 0:
                    save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
                orig_model.train()
                if ema is not None: ema.swap()

        # Sample
        with t.no_grad():
            if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
                if hps.prior:
                    sample_prior(orig_model, ema, logger, x_in, y, hps)

        # Input/Output
        with t.no_grad():
            if log_input_output:
                log_inputs(orig_model, logger, x_in, y, x_out, hps)

        print("Hey there")
        logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
        print("by there")
        if finished_training:
            dist.barrier()
            exit()
    logger.close_range()
    return {key: metrics.avg(key) for key in _metrics.keys()}
예제 #6
0
def make_prior(hps, vqvae, device='cuda'):
    from jukebox.prior.prior import SimplePrior

    prior_kwargs = dict(input_shape=(hps.n_ctx, ),
                        bins=hps.l_bins,
                        width=hps.prior_width,
                        depth=hps.prior_depth,
                        heads=hps.heads,
                        attn_order=hps.attn_order,
                        blocks=hps.blocks,
                        spread=hps.spread,
                        attn_dropout=hps.attn_dropout,
                        resid_dropout=hps.resid_dropout,
                        emb_dropout=hps.emb_dropout,
                        zero_out=hps.zero_out,
                        res_scale=hps.res_scale,
                        pos_init=hps.pos_init,
                        init_scale=hps.init_scale,
                        m_attn=hps.m_attn,
                        m_mlp=hps.m_mlp,
                        checkpoint_res=hps.c_res if hps.train else 0,
                        checkpoint_attn=hps.c_attn if hps.train else 0,
                        checkpoint_mlp=hps.c_mlp if hps.train else 0)

    x_cond_kwargs = dict(
        out_width=hps.prior_width,
        init_scale=hps.init_scale,
        width=hps.cond_width,
        depth=hps.cond_depth,
        m_conv=hps.cond_m_conv,
        dilation_growth_rate=hps.cond_dilation_growth_rate,
        dilation_cycle=hps.cond_dilation_cycle,
        zero_out=hps.cond_zero_out,
        res_scale=hps.cond_res_scale,
        checkpoint_res=hps.cond_c_res)  # have to keep this else names wrong
    y_cond_kwargs = dict(out_width=hps.prior_width,
                         init_scale=hps.init_scale,
                         y_bins=hps.y_bins,
                         t_bins=hps.t_bins,
                         t_ranges=hps.t_ranges,
                         max_bow_genre_size=hps.max_bow_genre_size)

    if hps.use_tokens and not hps.single_enc_dec:
        prime_kwargs = dict(
            use_tokens=hps.use_tokens,
            prime_loss_fraction=hps.prime_loss_fraction,
            n_tokens=hps.n_tokens,
            bins=hps.n_vocab,
            width=hps.prime_width,
            depth=hps.prime_depth,
            heads=hps.prime_heads,
            attn_order=hps.prime_attn_order,
            blocks=hps.prime_blocks,
            spread=hps.prime_spread,
            attn_dropout=hps.prime_attn_dropout,
            resid_dropout=hps.prime_resid_dropout,
            emb_dropout=hps.prime_emb_dropout,
            zero_out=hps.prime_zero_out,
            res_scale=hps.prime_res_scale,
            pos_init=hps.prime_pos_init,
            init_scale=hps.prime_init_scale,
            m_attn=hps.prime_m_attn,
            m_mlp=hps.prime_m_mlp,
            checkpoint_res=hps.prime_c_res if hps.train else 0,
            checkpoint_attn=hps.prime_c_attn if hps.train else 0,
            checkpoint_mlp=hps.prime_c_mlp if hps.train else 0)
    else:
        prime_kwargs = dict(use_tokens=hps.use_tokens,
                            prime_loss_fraction=hps.prime_loss_fraction,
                            n_tokens=hps.n_tokens,
                            bins=hps.n_vocab)

    # z_shapes for other levels given this level gets n_ctx codes
    rescale = lambda z_shape: (z_shape[0] * hps.n_ctx // vqvae.z_shapes[
        hps.level][0], )
    z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes]

    prior = SimplePrior(z_shapes=z_shapes,
                        l_bins=hps.l_bins,
                        encoder=vqvae.encode,
                        decoder=vqvae.decode,
                        level=hps.level,
                        downs_t=hps.downs_t,
                        strides_t=hps.strides_t,
                        labels=hps.labels,
                        prior_kwargs=prior_kwargs,
                        x_cond_kwargs=x_cond_kwargs,
                        y_cond_kwargs=y_cond_kwargs,
                        prime_kwargs=prime_kwargs,
                        copy_input=hps.copy_input,
                        labels_v3=hps.labels_v3,
                        merged_decoder=hps.merged_decoder,
                        single_enc_dec=hps.single_enc_dec)

    prior.alignment_head = hps.get('alignment_head', None)
    prior.alignment_layer = hps.get('alignment_layer', None)

    if hps.fp16_params:
        print_all("Converting to fp16 params")
        from jukebox.transformer.ops import _convert_conv_weights_to_fp16
        prior.apply(_convert_conv_weights_to_fp16)
    prior = prior.to(device)
    restore(hps, prior, hps.restore_prior)
    if hps.train:
        print_all(f"Loading prior in train mode")
        pass
    else:
        print_all(f"Loading prior in eval mode")
        prior.eval()
        freeze_model(prior)
    return prior
예제 #7
0
 def print_stats(self, hps):
     print_all(
         f"Train {len(self.train_dataset)} samples. Test {len(self.test_dataset)} samples"
     )
     print_all(f'Train sampler: {self.train_sampler}')
     print_all(f'Train loader: {len(self.train_loader)}')