예제 #1
0
def load_checkpoint(path):
    print("Loading checkpoint...")
    restore = path
    if restore[:5] == 'gs://':
        gs_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
        gdrive_path = os.path.join("/content/gdrive/My Drive/samples/", gs_path[5:])
        print(f'local path: {local_path}')
        print(f'gdrive path: {gdrive_path}')
        if dist.get_rank() % 8 == 0:
            if os.path.exists(gdrive_path):
                print("Using priors on Google Drive")
                restore = gdrive_path
            elif os.path.exists( os.path.dirname(gdrive_path) ):
                print("Downloading priors to Google Drive")
                download(gs_path, gdrive_path)
                restore = gdrive_path
            else:
                print("Downloading from gce")
                if not os.path.exists(os.path.dirname(local_path)):
                    os.makedirs(os.path.dirname(local_path))
                if not os.path.exists(local_path):
                    download(gs_path, local_path)
                restore = local_path
    dist.barrier()
    checkpoint = t.load(restore, map_location=t.device('cpu'))
    print("RS // Restored from {}".format(restore))
    return checkpoint
예제 #2
0
def save_outputs(model, device, hps):
    # Check logits
    if hps.labels_v3:
        n_ctx = 6144
        n_tokens = 384
        prime_bins = 79
    else:
        n_ctx = 8192
        n_tokens = 512
        prime_bins = 80

    rng = t.random.manual_seed(0)
    x = 2 * t.rand((1, n_ctx * 8 * 4 * 4, 1), generator=rng,
                   dtype=t.float).cuda() - 1.0  # -1 to 1
    lyric_tokens = t.randint(0,
                             prime_bins, (1, n_tokens),
                             generator=rng,
                             dtype=t.long).view(-1).numpy()
    artist_id = 10
    genre_ids = [1]
    total_length = 2 * 2646000
    offset = 2646000

    vqvae, priors = make_model(model, device, hps)

    # encode
    vq_prior = priors[-1]
    zs = vq_prior.encode(x, start_level=0)
    x_ds = [
        vq_prior.decode(zs[level:], start_level=level)
        for level in range(0, len(zs))
    ]

    # priors
    data = dict(zs=zs, x_ds=x_ds)
    for level in range(len(priors)):
        print(f"Doing level {level}")
        if hps.labels_v3 and level != hps.levels - 1:
            print(f"Skipping level {level}")
            continue
        prior = priors[level]
        prior.cuda()
        x_in = x[:, :n_ctx * 8 * (4**level)]
        y_in = t.from_numpy(
            prior.labeller.get_y_from_ids(artist_id, genre_ids, lyric_tokens,
                                          total_length,
                                          offset)).view(1, -1).cuda().long()
        x_out, _, metrics = prior(x_in,
                                  y_in,
                                  fp16=hps.fp16,
                                  get_preds=True,
                                  decode=True)
        preds = metrics['preds']
        data[level] = dict(x=x_in, y=y_in, x_out=x_out, preds=preds)
        prior.cpu()
    t.save(data, 'data.pth.tar')
    dist.barrier()
    print("Saved data")
    exit()
예제 #3
0
def test_dataset_loader():
    from tqdm import tqdm
    from torch.utils.data import DataLoader
    from torch.utils.data.distributed import DistributedSampler
    from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
    from jukebox.hparams import setup_hparams
    from jukebox.data.files_dataset import FilesAudioDataset
    hps = setup_hparams("teeny", {})
    hps.sr = 22050  # 44100
    hps.hop_length = 512
    hps.labels = False
    hps.channels = 2
    hps.aug_shift = False
    hps.bs = 2
    hps.nworkers = 2  # Getting 20 it/s with 2 workers, 10 it/s with 1 worker
    print(hps)
    dataset = hps.dataset
    root = hps.root
    from tensorboardX import SummaryWriter
    sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr]
    writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs')
    dataset = FilesAudioDataset(hps)
    print("Length of dataset", len(dataset))

    # Torch Loader
    collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0)
    sampler = DistributedSampler(dataset)
    train_loader = DataLoader(dataset,
                              batch_size=hps.bs,
                              num_workers=hps.nworkers,
                              pin_memory=False,
                              sampler=sampler,
                              drop_last=True,
                              collate_fn=collate_fn)

    dist.barrier()
    sampler.set_epoch(0)
    for i, x in enumerate(tqdm(train_loader)):
        x = x.to('cuda', non_blocking=True)
        for j, aud in enumerate(x):
            writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote in")
        x = audio_preprocess(x, hps)
        x = audio_postprocess(x, hps)
        for j, aud in enumerate(x):
            writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote out")
        dist.barrier()
        break
예제 #4
0
def load_checkpoint(path):
    restore = path
    if restore[:5] == 'gs://':
        gs_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
        if dist.get_rank() % 8 == 0:
            print("Downloading from gce")
            if not os.path.exists(os.path.dirname(local_path)):
                os.makedirs(os.path.dirname(local_path))
            if not os.path.exists(local_path):
                download(gs_path, local_path)
        restore = local_path
    dist.barrier()
    checkpoint = t.load(restore, map_location=t.device('cpu'))
    print("Restored from {}".format(restore))
    return checkpoint
예제 #5
0
def load_checkpoint(path):
    restore = path
    if restore.startswith(REMOTE_PREFIX):
        remote_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), remote_path[len(REMOTE_PREFIX):])
        if dist.get_rank() % 8 == 0:
            print("Downloading from azure")
            if not os.path.exists(os.path.dirname(local_path)):
                os.makedirs(os.path.dirname(local_path))
            if not os.path.exists(local_path):
                download(remote_path, local_path)
        restore = local_path
    dist.barrier()
    checkpoint = custom_load(restore, map_location=t.device('cpu'))
    print("Restored from {}".format(restore))
    return checkpoint
예제 #6
0
def run(hps="teeny", port=29500, **kwargs):
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
예제 #7
0
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
    model.train()
    orig_model.train()
    if hps.prior:
        _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
    else:
        _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")

    print_all(data_processor.train_loader)
    print_all(len(data_processor.train_loader))

    for i, x in logger.get_range(data_processor.train_loader):
        if isinstance(x, (tuple, list)):
            x, y = x
        else:
            y = None

        x = x.to('cuda', non_blocking=True)
        if y is not None:
            y = y.to('cuda', non_blocking=True)

        x_in = x = audio_preprocess(x, hps)
        log_input_output = (logger.iters % hps.save_iters == 0)

        if hps.prior:
            forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
        else:
            forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)

        # Forward
        x_out, loss, _metrics = model(x, **forw_kwargs)

        # Backward
        loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
                                                                         scalar=scalar, fp16=hps.fp16, logger=logger)
        # Skip step if overflow
        grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
        if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
            zero_grad(orig_model)
            continue

        # Step opt. Divide by scale to include clipping and fp16 scaling
        logger.step()
        opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
        zero_grad(orig_model)
        lr = hps.lr if shd is None else shd.get_lr()[0]
        if shd is not None: shd.step()
        if ema is not None: ema.step()
        next_lr = hps.lr if shd is None else shd.get_lr()[0]
        finished_training = (next_lr == 0.0)

        # Logging
        for key, val in _metrics.items():
            _metrics[key] = val.item()
        _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
        _metrics["gn"] = grad_norm
        _metrics["lr"] = lr
        _metrics["lg_loss_scale"] = np.log2(scale)

        # Average and log
        for key, val in _metrics.items():
            _metrics[key] = metrics.update(key, val, x.shape[0])
            if logger.iters % hps.log_steps == 0:
                logger.add_scalar(key, _metrics[key])

        # Save checkpoint
        with t.no_grad():
            if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
                if ema is not None: ema.swap()
                orig_model.eval()
                name = 'latest' if hps.prior else f'step_{logger.iters}'
                if dist.get_rank() % 8 == 0:
                    save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
                orig_model.train()
                if ema is not None: ema.swap()

        # Sample
        with t.no_grad():
            if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
                if hps.prior:
                    sample_prior(orig_model, ema, logger, x_in, y, hps)

        # Input/Output
        with t.no_grad():
            if log_input_output:
                log_inputs(orig_model, logger, x_in, y, x_out, hps)

        print("Hey there")
        logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
        print("by there")
        if finished_training:
            dist.barrier()
            exit()
    logger.close_range()
    return {key: metrics.avg(key) for key in _metrics.keys()}