Esempio n. 1
0
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
"""
Tests that the NV-WaveNet class is producing audio
"""
import torch
from scipy.io.wavfile import write
import nv_wavenet
from wavenet import WaveNet
import utils
import json

if __name__ == '__main__':
    config = json.loads(open('config.json').read())
    wavenet_config = config["wavenet_config"]
    model = WaveNet(**wavenet_config).cuda()
    weights = model.export_weights()
    wavenet = nv_wavenet.NVWaveNet(**weights)
    num_samples = 10*1000
    batch_size = config['train_config']['batch_size']
    cond_input = torch.zeros([2 * wavenet_config['n_residual_channels'], batch_size, wavenet_config['n_layers'], num_samples]).cuda()

    samples = wavenet.infer(cond_input, nv_wavenet.Impl.PERSISTENT)[0]
   
    audio = utils.mu_law_decode_numpy(samples.cpu().numpy(), 256)
    audio = utils.MAX_WAV_VALUE * audio
    wavdata = audio.astype('int16')
    write('audio.wav',16000, wavdata)
Esempio n. 2
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          iters_per_checkpoint, iters_per_eval, batch_size, seed, checkpoint_path, log_dir, ema_decay=0.9999):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    if train_data_config["no_chunks"]:
        criterion = MaskedCrossEntropyLoss()
    else:
        criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cuda()
    ema = ExponentialMovingAverage(ema_decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=200000, gamma=0.5)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, scheduler, iteration, ema = load_checkpoint(checkpoint_path, model,
                                                                      optimizer, scheduler, ema)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2SampOnehot(audio_config=audio_config, verbose=True, **train_data_config)
    validset = Mel2SampOnehot(audio_config=audio_config, verbose=False, **valid_data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    valid_sampler = DistributedSampler(validset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    print(train_data_config)
    if train_data_config["no_chunks"]:
        collate_fn = utils.collate_fn
    else:
        collate_fn = torch.utils.data.dataloader.default_collate
    train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
                              collate_fn=collate_fn,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=True,
                              drop_last=True)
    valid_loader = DataLoader(validset, num_workers=1, shuffle=False,
                              sampler=valid_sampler, batch_size=1, pin_memory=True)
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    writer = SummaryWriter(log_dir)
    print("Checkpoints writing to: {}".format(log_dir))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            if low_memory:
                torch.cuda.empty_cache()
            scheduler.step()
            model.zero_grad()

            if train_data_config["no_chunks"]:
                x, y, seq_lens = batch
                seq_lens = to_gpu(seq_lens)
            else:
                x, y = batch
            x = to_gpu(x).float()
            y = to_gpu(y)
            x = (x, y)  # auto-regressive takes outputs as inputs
            y_pred = model(x)
            if train_data_config["no_chunks"]:
                loss = criterion(y_pred, y, seq_lens)
            else:
                loss = criterion(y_pred, y)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus)[0]
            else:
                reduced_loss = loss.data[0]
            loss.backward()
            optimizer.step()

            for name, param in model.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

            print("{}:\t{:.9f}".format(iteration, reduced_loss))
            if rank == 0:
                writer.add_scalar('loss', reduced_loss, iteration)
            if (iteration % iters_per_checkpoint == 0 and iteration):
                if rank == 0:
                    checkpoint_path = "{}/wavenet_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, scheduler, learning_rate, iteration,
                                    checkpoint_path, ema, wavenet_config)
            if (iteration % iters_per_eval == 0 and iteration > 0 and not config["no_validation"]):
                if low_memory:
                    torch.cuda.empty_cache()
                if rank == 0:
                    model_eval = nv_wavenet.NVWaveNet(**(model.export_weights()))
                    for j, valid_batch in enumerate(valid_loader):
                        mel, audio = valid_batch
                        mel = to_gpu(mel).float()
                        cond_input = model.get_cond_input(mel)
                        predicted_audio = model_eval.infer(cond_input, nv_wavenet.Impl.AUTO)
                        predicted_audio = utils.mu_law_decode_numpy(predicted_audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid/predicted_audio_{}".format(j),
                                         predicted_audio,
                                         iteration,
                                         22050)
                        audio = utils.mu_law_decode_numpy(audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid_true/audio_{}".format(j),
                                         audio,
                                         iteration,
                                         22050)
                        if low_memory:
                            torch.cuda.empty_cache()
            iteration += 1