Exemple #1
1
def train_distributed(replica_id, replica_count, port, args, params):
  os.environ['MASTER_ADDR'] = 'localhost'
  os.environ['MASTER_PORT'] = str(port)
  torch.distributed.init_process_group('nccl', rank=replica_id, world_size=replica_count)

  device = torch.device('cuda', replica_id)
  torch.cuda.set_device(device)
  model = WaveGrad(params).to(device)
  model = DistributedDataParallel(model, device_ids=[replica_id])
  _train_impl(replica_id, model, dataset_from_path(args.data_dirs, params, is_distributed=True), args, params)
Exemple #2
0
def train(dataset, args, params):
    model = WaveGrad(params).cuda()
    opt = torch.optim.Adam(model.parameters(), lr=params.learning_rate)

    learner = WaveGradLearner(args.model_dir,
                              model,
                              dataset,
                              opt,
                              params,
                              fp16=args.fp16)
    learner.restore_from_checkpoint()
    learner.train(max_steps=args.max_steps)
Exemple #3
0
    def __init__(self, model_dir, params, device=torch.device('cuda')):
        # Lazy load model.
        if not model_dir in models:
            if os.path.exists(f'{model_dir}/weights.pt'):
                checkpoint = torch.load(f'{model_dir}/weights.pt')
            else:
                weights_path = maybe_download_weights_from_s3(model_dir)
                checkpoint = torch.load(weights_path)
                # checkpoint = torch.load(model_dir)
            model = WaveGrad(AttrDict(base_params)).to(device)
            model.load_state_dict(checkpoint['model'])
            model.eval()
            models[model_dir] = model

        model = models[model_dir]
        model.params.override(params)

        beta = np.array(model.params.noise_schedule)
        alpha = 1 - beta
        alpha_cum = np.cumprod(alpha)

        self.alpha = alpha
        self.alpha_cum = alpha_cum
        self.beta = beta
        self.model = model
        self.device = device
def predict(spectrogram,
            model_dir=None,
            params=None,
            device=torch.device('cuda')):
    # Lazy load model.
    if not model_dir in models:
        if os.path.exists(f'{model_dir}/weights.pt'):
            checkpoint = torch.load(f'{model_dir}/weights.pt')
        else:
            checkpoint = torch.load(model_dir,
                                    map_location=torch.device('cpu'))
        model = WaveGrad(AttrDict(base_params)).to(device)
        model.load_state_dict(checkpoint['model'])
        model.eval()
        models[model_dir] = model

    model = models[model_dir]
    model.params.override(params)
    with torch.no_grad():
        beta = np.array(model.params.noise_schedule)
        alpha = 1 - beta
        alpha_cum = np.cumprod(alpha)

        # Expand rank 2 tensors by adding a batch dimension.
        if len(spectrogram.shape) == 2:
            spectrogram = spectrogram.unsqueeze(0)
        spectrogram = spectrogram.to(device)

        audio = torch.randn(spectrogram.shape[0],
                            model.params.hop_samples * spectrogram.shape[-1],
                            device=device)
        noise_scale = torch.from_numpy(
            alpha_cum**0.5).float().unsqueeze(1).to(device)

        for n in range(len(alpha) - 1, -1, -1):
            c1 = 1 / alpha[n]**0.5
            c2 = (1 - alpha[n]) / (1 - alpha_cum[n])**0.5
            audio = c1 * (audio - c2 *
                          model(audio, spectrogram, noise_scale[n]).squeeze(1))
            if n > 0:
                noise = torch.randn_like(audio)
                sigma = ((1.0 - alpha_cum[n - 1]) / (1.0 - alpha_cum[n]) *
                         beta[n])**0.5
                audio += sigma * noise
            audio = torch.clamp(audio, -1.0, 1.0)
    return audio, model.params.sample_rate
Exemple #5
0
def load_model(model_dir, params=None, device=torch.device('cuda')):
    # Lazy load model.
    if not model_dir in models:
        if os.path.exists(f'{model_dir}/weights.pt'):
            checkpoint = torch.load(f'{model_dir}/weights.pt')
        else:
            checkpoint = torch.load(model_dir)
        model = WaveGrad(AttrDict(base_params)).to(device)
        state_dict = checkpoint['model']
        new_state_dict = {}
        for key in state_dict:
            new_state_dict[key.replace("module.", "")] = state_dict[key]
        model.load_state_dict(new_state_dict)
        model.eval()
        models[model_dir] = model

    model = models[model_dir]
    model.params.override(params)
    return model
Exemple #6
0
def train(args, params):
  dataset = dataset_from_path(args.data_dirs, params)
  model = WaveGrad(params).cuda()
  _train_impl(0, model, dataset, args, params)
Exemple #7
0
def predict(spectrogram,
            model_dir=None,
            params=None,
            device=torch.device('cuda'),
            audio=None,
            severity=None):
    # Lazy load model.
    if not model_dir in models:
        if os.path.exists(f'{model_dir}/weights.pt'):
            checkpoint = torch.load(f'{model_dir}/weights.pt')
        else:
            checkpoint = torch.load(model_dir)
        model = WaveGrad(AttrDict(base_params)).to(device)
        model.load_state_dict(checkpoint['model'])
        model.eval()
        models[model_dir] = model

    model = models[model_dir]
    model.params.override(params)
    with torch.no_grad():
        beta = np.array(model.params.noise_schedule)
        alpha = 1 - beta
        alpha_cum = np.cumprod(alpha)

        if severity is None:
            severity = len(model.params.noise_schedule)

        alpha = alpha[-severity:]
        alpha_cum = alpha_cum[-severity:]

        # Expand rank 2 tensors by adding a batch dimension.
        if len(spectrogram.shape) == 2:
            spectrogram = spectrogram.unsqueeze(0)
        spectrogram = spectrogram.to(device)

        length = model.params.hop_samples * spectrogram.shape[-1]

        if audio is None:
            audio = torch.randn(spectrogram.shape[0], length, device=device)
        else:
            # TODO FROME HERE: padding or truncation
            if audio.shape[-1] > length:
                audio = audio[..., :length]
            else:
                audio = audio.to(device)
                padding = (torch.zeros(
                    [audio.shape[0], length - audio.shape[1]]).to(device))
                audio = torch.cat([audio, padding], -1)

        noise_scale = torch.from_numpy(
            alpha_cum**0.5).float().unsqueeze(1).to(device)

        ti = time()
        for n in range(severity - 1, -1, -1):
            print(f"{n}/{len(alpha)}", end="\r")
            c1 = 1 / alpha[n]**0.5
            c2 = (1 - alpha[n]) / (1 - alpha_cum[n])**0.5
            prediction = model(audio, spectrogram, noise_scale[n]).squeeze(1)
            audio = c1 * (audio - c2 * prediction)
            if n > 0:
                noise = torch.randn_like(audio)
                sigma = ((1.0 - alpha_cum[n - 1]) / (1.0 - alpha_cum[n]) *
                         beta[n])**0.5
                audio += sigma * noise
            audio = torch.clamp(audio, -1.0,
                                1.0)  # TODO: J: I disagree with this step
    print(f"\nFinished {spectrogram.shape} in {time() - ti:.2f} secs.")
    return audio, model.params.sample_rate