def generate_fn(args): device = torch.device("cuda" if hparams.use_cuda else "cpu") upsample_factor = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) model = create_model(hparams) checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) if torch.cuda.device_count() > 1: model.module.load_state_dict(checkpoint['model']) else: model.load_state_dict(checkpoint['model']) model.to(device) model.eval() if hparams.feature_type == "mcc": scaler = StandardScaler() scaler.mean_ = np.load(os.path.join(args.data_dir, 'mean.npy')) scaler.scale_ = np.load(os.path.join(args.data_dir, 'scale.npy')) feat_transform = transforms.Compose([lambda x: scaler.transform(x)]) else: feat_transform = None with torch.no_grad(): samples, local_condition, uv = prepare_data(args.lc_file, upsample_factor, model.receptive_field, read_fn=lambda x: np.load(x), feat_transform=feat_transform) start = time.time() for i in tqdm(range(local_condition.size(-1) - model.receptive_field)): sample = torch.FloatTensor(np.array(samples[-model.receptive_field:]).reshape(1, -1, 1)) h = local_condition[:, :, i+1 : i+1 + model.receptive_field] sample, h = sample.to(device), h.to(device) output = model(sample, h) if hparams.feature_type == "mcc": if uv[i+model.receptive_field] == 0: output = output[0, :, -1] outprob = F.softmax(output, dim=0).cpu().numpy() sample = np.random.choice( np.arange(hparams.quantization_channels), p=outprob) else: output = output[0, :, -1] * 2 outprob = F.softmax(output, dim=0).cpu().numpy() sample = outprob.argmax(0) else: # I tested sampling, but it will produce more noise, # so I use argmax in this time. output = output[0, :, -1] outprob = F.softmax(output, dim=0).cpu().numpy() sample = outprob.argmax(0) sample = mu_law_decode(sample, hparams.quantization_channels) samples.append(sample) write_wav(np.asarray(samples), hparams.sample_rate, os.path.join(os.path.dirname(args.checkpoint), "generated-{}.wav".format(os.path.basename(args.checkpoint))))
def train_fn(args): device = torch.device("cuda" if hparams.use_cuda else "cpu") upsample_factor = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) model = create_model(hparams) model.to(device) optimizer = optim.Adam(model.parameters(), lr=hparams.learning_rate) for state in optimizer.state.values(): for key, value in state.items(): if torch.is_tensor(value): state[key] = value.to(device) if args.resume is not None: log("Resume checkpoint from: {}:".format(args.resume)) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) if torch.cuda.device_count() > 1: model.module.load_state_dict(checkpoint['model']) else: model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint["optimizer"]) global_step = checkpoint['steps'] else: global_step = 0 log("receptive field: {0} ({1:.2f}ms)".format( model.receptive_field, model.receptive_field / hparams.sample_rate * 1000)) if hparams.feature_type == "mcc": scaler = StandardScaler() scaler.mean_ = np.load(os.path.join(args.data_dir, 'mean.npy')) scaler.scale_ = np.load(os.path.join(args.data_dir, 'scale.npy')) feat_transform = transforms.Compose([lambda x: scaler.transform(x)]) else: feat_transform = None dataset = CustomDataset( meta_file=os.path.join(args.data_dir, 'train.txt'), receptive_field=model.receptive_field, sample_size=hparams.sample_size, upsample_factor=upsample_factor, quantization_channels=hparams.quantization_channels, use_local_condition=hparams.use_local_condition, noise_injecting=hparams.noise_injecting, feat_transform=feat_transform) dataloader = DataLoader(dataset, batch_size=hparams.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) criterion = nn.CrossEntropyLoss() ema = ExponentialMovingAverage(args.ema_decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) writer = SummaryWriter(args.checkpoint_dir) while global_step < hparams.training_steps: for i, data in enumerate(dataloader, 0): audio, target, local_condition = data target = target.squeeze(-1) local_condition = local_condition.transpose(1, 2) audio, target, h = audio.to(device), target.to( device), local_condition.to(device) optimizer.zero_grad() output = model(audio[:, :-1, :], h[:, :, 1:]) loss = criterion(output, target) log('step [%3d]: loss: %.3f' % (global_step, loss.item())) writer.add_scalar('loss', loss.item(), global_step) loss.backward() optimizer.step() # update moving average if ema is not None: apply_moving_average(model, ema) global_step += 1 if global_step % hparams.checkpoint_interval == 0: save_checkpoint(device, hparams, model, optimizer, global_step, args.checkpoint_dir, ema) out = output[1, :, :] samples = out.argmax(0) waveform = mu_law_decode( np.asarray(samples[model.receptive_field:]), hparams.quantization_channels) write_wav( waveform, hparams.sample_rate, os.path.join(args.checkpoint_dir, "train_eval_{}.wav".format(global_step)))