Ejemplo n.º 1
0
def run(hps="teeny", port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
Ejemplo n.º 2
0
def run(hps="teeny", port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi

    audio_database = kwargs['audio_database']
    del kwargs['audio_database']

    rank, local_rank, device = setup_dist_from_mpi(port=port)
    print('device:', device)
    print("hps setup")
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = 0
    hps.nworkers = 0
    hps.argv = " ".join(sys.argv)

    hps.bs_sample, hps.nworkers, hps.bs = 1, 1, 1
    hps.bs_sample = hps.nworkers = hps.bs
    print("setting up database")
    # Setup dataset
    data_processor = DataProcessor(hps, audio_database)
    print("midi chunk call")
    for idx in range(8868):
        chunk = data_processor.dataset.get_midi_chunk(idx)
        if chunk.shape != (95, 128):
            print(chunk.shape)
            raise RuntimeError('It failed')