Esempio n. 1
0
def validate(model, criterion, valset, iteration, batch_size, n_gpus,
             collate_fn, logger, distributed_run, rank):
    """Handles all the validation scoring and printing"""
    model.eval()
    with torch.no_grad():
        val_sampler = DistributedSampler(valset) if distributed_run else None
        val_loader = DataLoader(valset,
                                sampler=val_sampler,
                                num_workers=1,
                                shuffle=False,
                                batch_size=batch_size,
                                pin_memory=False,
                                collate_fn=collate_fn)

        val_loss = 0.0
        for i, batch in enumerate(val_loader):
            x, y = parse_batch(batch)
            mel_outputs, mel_outputs_postnet, gate_outputs, alignments, length = model(
                x)
            y_pred = parse_output(
                [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
                length)
            loss = criterion(y_pred, y)
            if distributed_run:
                reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_val_loss = loss.item()
            val_loss += reduced_val_loss
        val_loss = val_loss / (i + 1)

    model.train()
    if rank == 0:
        print("Validation loss {}: {:9f}  ".format(iteration,
                                                   reduced_val_loss))
        logger.log_validation(val_loss, model, y, y_pred, iteration)
Esempio n. 2
0
                        pin_memory=False,
                        drop_last=False,
                        collate_fn=datacollate)
speaker_ids = TextMelLoader(hparams.training_files, hparams).speaker_ids
speaker_id = torch.LongTensor([speaker_ids[speaker]]).cuda()

pytorch_total_params = sum(p.numel() for p in model.parameters())
print("total_num_params:  {}".format(pytorch_total_params))
waveglow_total_params = sum(p.numel() for p in waveglow.parameters())
print("waveglow_num_params:  {}".format(waveglow_total_params))
for i, batch in enumerate(dataloader):
    reference_speaker = test_set.audiopaths_and_text[i][2]
    # x: (text_padded, input_lengths, mel_padded, max_len,
    #                  output_lengths, speaker_ids, f0_padded, input_ids, attention_mask),
    # y: (mel_padded, gate_padded)
    x, y = parse_batch(batch)
    x = [x[i] for i in range(len(x))]
    x[5] = speaker_id

    # inputs = text, style_input, speaker_ids, f0s
    with torch.no_grad():
        # mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model.inference_override(x, torch.LongTensor([1]))
        mel_outputs, mel_outputs_postnet, gate_outputs, alignments = model.inference_override(
            x, x[9])

    with torch.no_grad():
        audio = denoiser(waveglow.infer(mel_outputs_postnet, sigma=0.8),
                         0.01)[:, 0]
        audio = audio.squeeze(1).cpu().numpy()
        top_db = 15
        for j in range(len(audio)):
Esempio n. 3
0
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
          rank, group_name, hparams):
    """Training and validation logging results to tensorboard and stdout

    Params
    ------
    output_directory (string): directory to save checkpoints
    log_directory (string) directory to save tensorboard logs
    checkpoint_path(string): checkpoint path
    n_gpus (int): number of gpus
    rank (int): rank of current gpu
    hparams (object): comma separated list of "name=value" pairs.
    """
    if hparams.distributed_run:
        init_distributed(hparams, n_gpus, rank, group_name)

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    model = initiate_model(hparams)
    learning_rate = hparams.learning_rate
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=hparams.weight_decay)

    if hparams.fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    if hparams.distributed_run:
        model = apply_gradient_allreduce(model)

    criterion = Tacotron2Loss()

    logger = prepare_directories_and_logger(output_directory, log_directory,
                                            rank)

    single_train_loader, single_valset, single_collate_fn, single_train_sampler = prepare_single_dataloaders(
        hparams, output_directory)
    train_loader, valset, collate_fn, train_sampler = prepare_dataloaders(
        hparams, output_directory)
    single_train_loader.dataset.speaker_ids = train_loader.dataset.speaker_ids
    single_valset.speaker_ids = train_loader.dataset.speaker_ids
    # Load checkpoint if one exists
    iteration = 0
    epoch_offset = 0
    if checkpoint_path is not None:
        if warm_start:
            model = warm_start_model(checkpoint_path, model,
                                     hparams.ignore_layers)
        else:
            # model = torch.nn.DataParallel(model)
            model, optimizer, _learning_rate, iteration = load_checkpoint(
                checkpoint_path, model, optimizer)
            if hparams.use_saved_learning_rate:
                learning_rate = _learning_rate
            iteration += 1  # next iteration is iteration + 1
            epoch_offset = max(0, int(iteration / len(single_train_loader)))

    model = torch.nn.DataParallel(model)
    model.train()
    is_overflow = False
    # init training loop with single speaker
    for epoch in range(epoch_offset, 30):
        print("Epoch: {}".format(epoch))
        if single_train_sampler is not None:
            single_train_sampler.set_epoch(epoch)
        for i, batch in enumerate(single_train_loader):
            start = time.perf_counter()
            if iteration > 0 and iteration % hparams.learning_rate_anneal == 0:
                learning_rate = max(hparams.learning_rate_min,
                                    learning_rate * 0.5)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rate

            model.zero_grad()
            x, y = parse_batch(batch)
            mel_outputs, mel_outputs_postnet, gate_outputs, alignments, length = model(
                x)
            y_pred = parse_output(
                [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
                length)

            loss = criterion(y_pred, y)
            if hparams.distributed_run:
                reduced_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_loss = loss.item()

            if hparams.fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if hparams.fp16_run:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), hparams.grad_clip_thresh)
                is_overflow = math.isnan(grad_norm)
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hparams.grad_clip_thresh)

            optimizer.step()

            if not is_overflow and rank == 0:
                duration = time.perf_counter() - start
                print(
                    "Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
                        iteration, reduced_loss, grad_norm, duration))
                logger.log_training(reduced_loss, grad_norm, learning_rate,
                                    duration, iteration)

            if not is_overflow and (iteration % hparams.iters_per_checkpoint
                                    == 0):
                validate(model, criterion, single_valset, iteration,
                         hparams.batch_size, n_gpus, single_collate_fn, logger,
                         hparams.distributed_run, rank)
                if rank == 0:
                    checkpoint_path = os.path.join(
                        output_directory, "checkpoint_{}".format(iteration))
                    save_checkpoint(model.module, optimizer, learning_rate,
                                    iteration, checkpoint_path)

            iteration += 1

    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(30, hparams.epochs):
        print("Epoch: {}".format(epoch))
        if train_sampler is not None:
            train_sampler.set_epoch(epoch)
        for i, batch in enumerate(train_loader):
            start = time.perf_counter()
            if iteration > 0 and iteration % hparams.learning_rate_anneal == 0:
                learning_rate = max(hparams.learning_rate_min,
                                    learning_rate * 0.5)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rate

            model.zero_grad()
            x, y = parse_batch(batch)
            mel_outputs, mel_outputs_postnet, gate_outputs, alignments, length = model(
                x)
            y_pred = parse_output(
                [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
                length)
            loss = criterion(y_pred, y)
            if hparams.distributed_run:
                reduced_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_loss = loss.item()

            if hparams.fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if hparams.fp16_run:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), hparams.grad_clip_thresh)
                is_overflow = math.isnan(grad_norm)
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hparams.grad_clip_thresh)

            optimizer.step()

            if not is_overflow and rank == 0:
                duration = time.perf_counter() - start
                print(
                    "Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
                        iteration, reduced_loss, grad_norm, duration))
                logger.log_training(reduced_loss, grad_norm, learning_rate,
                                    duration, iteration)

            if not is_overflow and (iteration % hparams.iters_per_checkpoint
                                    == 0):
                validate(model, criterion, valset, iteration,
                         hparams.batch_size, n_gpus, collate_fn, logger,
                         hparams.distributed_run, rank)
                if rank == 0:
                    checkpoint_path = os.path.join(
                        output_directory, "checkpoint_{}".format(iteration))
                    save_checkpoint(model.module, optimizer, learning_rate,
                                    iteration, checkpoint_path)

            iteration += 1