예제 #1
0
def generate_fn(args):
    device = torch.device("cuda" if hparams.use_cuda else "cpu")
    upsample_factor = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
    model = create_model(hparams)

    checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage)

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(checkpoint['model'])
    else:
        model.load_state_dict(checkpoint['model'])

    model.to(device)
    model.eval()
    
    if hparams.feature_type == "mcc":
        scaler = StandardScaler()
        scaler.mean_ = np.load(os.path.join(args.data_dir, 'mean.npy'))
        scaler.scale_ = np.load(os.path.join(args.data_dir, 'scale.npy'))
        feat_transform = transforms.Compose([lambda x: scaler.transform(x)])
    else:
        feat_transform = None

    with torch.no_grad():
        samples, local_condition, uv = prepare_data(args.lc_file, upsample_factor, 
                                                model.receptive_field, read_fn=lambda x: np.load(x), feat_transform=feat_transform)

        start = time.time()
        for i in tqdm(range(local_condition.size(-1) - model.receptive_field)):
            sample = torch.FloatTensor(np.array(samples[-model.receptive_field:]).reshape(1, -1, 1))
            h = local_condition[:, :, i+1 : i+1 + model.receptive_field]
            sample, h = sample.to(device), h.to(device)
            output = model(sample, h)
           
            if hparams.feature_type == "mcc":
                if uv[i+model.receptive_field] == 0:
                    output = output[0, :, -1]
                    outprob = F.softmax(output, dim=0).cpu().numpy()
                    sample = np.random.choice(
                        np.arange(hparams.quantization_channels),
                        p=outprob)
                else:
                    output = output[0, :, -1] * 2
                    outprob = F.softmax(output, dim=0).cpu().numpy()
                    sample = outprob.argmax(0)
            else:
                # I tested sampling, but it will produce more noise,
                # so I use argmax in this time.
                output = output[0, :, -1]
                outprob = F.softmax(output, dim=0).cpu().numpy()
                sample = outprob.argmax(0)

            sample = mu_law_decode(sample, hparams.quantization_channels)
            samples.append(sample)


        write_wav(np.asarray(samples), hparams.sample_rate, 
                  os.path.join(os.path.dirname(args.checkpoint), "generated-{}.wav".format(os.path.basename(args.checkpoint))))
예제 #2
0
def train_fn(args):
    device = torch.device("cuda" if hparams.use_cuda else "cpu")
    upsample_factor = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)

    model = create_model(hparams)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=hparams.learning_rate)
    for state in optimizer.state.values():
        for key, value in state.items():
            if torch.is_tensor(value):
                state[key] = value.to(device)

    if args.resume is not None:
        log("Resume checkpoint from: {}:".format(args.resume))
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint["optimizer"])
        global_step = checkpoint['steps']
    else:
        global_step = 0

    log("receptive field: {0} ({1:.2f}ms)".format(
        model.receptive_field,
        model.receptive_field / hparams.sample_rate * 1000))

    if hparams.feature_type == "mcc":
        scaler = StandardScaler()
        scaler.mean_ = np.load(os.path.join(args.data_dir, 'mean.npy'))
        scaler.scale_ = np.load(os.path.join(args.data_dir, 'scale.npy'))
        feat_transform = transforms.Compose([lambda x: scaler.transform(x)])
    else:
        feat_transform = None

    dataset = CustomDataset(
        meta_file=os.path.join(args.data_dir, 'train.txt'),
        receptive_field=model.receptive_field,
        sample_size=hparams.sample_size,
        upsample_factor=upsample_factor,
        quantization_channels=hparams.quantization_channels,
        use_local_condition=hparams.use_local_condition,
        noise_injecting=hparams.noise_injecting,
        feat_transform=feat_transform)

    dataloader = DataLoader(dataset,
                            batch_size=hparams.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()

    ema = ExponentialMovingAverage(args.ema_decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    writer = SummaryWriter(args.checkpoint_dir)

    while global_step < hparams.training_steps:
        for i, data in enumerate(dataloader, 0):
            audio, target, local_condition = data
            target = target.squeeze(-1)
            local_condition = local_condition.transpose(1, 2)
            audio, target, h = audio.to(device), target.to(
                device), local_condition.to(device)

            optimizer.zero_grad()
            output = model(audio[:, :-1, :], h[:, :, 1:])
            loss = criterion(output, target)
            log('step [%3d]: loss: %.3f' % (global_step, loss.item()))
            writer.add_scalar('loss', loss.item(), global_step)

            loss.backward()
            optimizer.step()

            # update moving average
            if ema is not None:
                apply_moving_average(model, ema)

            global_step += 1

            if global_step % hparams.checkpoint_interval == 0:
                save_checkpoint(device, hparams, model, optimizer, global_step,
                                args.checkpoint_dir, ema)
                out = output[1, :, :]
                samples = out.argmax(0)
                waveform = mu_law_decode(
                    np.asarray(samples[model.receptive_field:]),
                    hparams.quantization_channels)
                write_wav(
                    waveform, hparams.sample_rate,
                    os.path.join(args.checkpoint_dir,
                                 "train_eval_{}.wav".format(global_step)))