def __init__(self, args): self.args = args self.args.n_datasets = len(args.data) self.modelPath = Path('checkpoints') / args.expName self.logger = create_output_dir(args, self.modelPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.start_epoch = 0 #torch.manual_seed(args.seed) #torch.cuda.manual_seed(args.seed) #get the pretrained model checkpoints checkpoint = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth') checkpoint = [c for c in checkpoint if extract_id(c) in args.decoder][0] model_args = torch.load(args.checkpoint.parent / 'args.pth')[0] self.encoder = Encoder(model_args) self.decoder = WaveNet(model_args) self.encoder = Encoder(model_args) self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state']) #encoder freeze for param in self.encoder.parameters(): param.requires_grad = False #self.logger.debug(f'encoder at start: {param}') self.decoder = WaveNet(model_args) self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) #decoder freeze for param in self.decoder.layers[:-args.decoder_update].parameters(): param.requires_grad = False #self.logger.debug(f'decoder at start: {param}') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.step()
def __init__(self, domains, domain_cnn): """param domains: int specifying number of domains param domain_cnn: a Domain Confusion CNN""" super(MusicAutoEncoder, self).__init__() self.domains = domains self.encoder = WaveNet(**encoder_config).cuda() self.decoders = [WaveNet(**decoder_config).cuda() for k in range(domains)] self.domain_cnn = domain_cnn
def main(): args = get_arguments() logdir = os.path.join(args.logdir, 'train', str(datetime.now())) with open(args.wavenet_params, 'r') as config_file: wavenet_params = json.load(config_file) sess = tf.Session() net = WaveNet( batch_size=1, dilations=wavenet_params['dilations'], filter_width=wavenet_params['filter_width'], residual_channels=wavenet_params['residual_channels'], dilation_channels=wavenet_params['dilation_channels'], quantization_channels=wavenet_params['quantization_channels'], use_biases=wavenet_params['use_biases']) samples = tf.placeholder(tf.int32) next_sample = net.predict_proba(samples) saver = tf.train.Saver() print('Restoring model from {}'.format(args.checkpoint)) saver.restore(sess, args.checkpoint) decode = net.decode(samples) quantization_steps = wavenet_params['quantization_steps'] waveform = np.random.randint(quantization_steps, size=(1, )).tolist() for step in range(args.samples): if len(waveform) > args.window: window = waveform[-args.window:] else: window = waveform prediction = sess.run(next_sample, feed_dict={samples: window}) sample = np.random.choice(np.arange(quantization_steps), p=prediction) waveform.append(sample) print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples, sample)) if (args.wav_out_path and args.save_every and (step + 1) % args.save_every == 0): out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) datestring = str(datetime.now()).replace(' ', 'T') writer = tf.train.SummaryWriter( os.path.join(logdir, 'generation', datestring)) tf.audio_summary('generated', decode, wavenet_params['sample_rate']) summaries = tf.merge_all_summaries() summary_out = sess.run(summaries, feed_dict={samples: np.reshape(waveform, [-1, 1])}) writer.add_summary(summary_out) if args.wav_out_path: out = sess.run(decode, feed_dict={samples: waveform}) write_wav(out, wavenet_params['sample_rate'], args.wav_out_path) print('Finished generating. The result can be viewed in TensorBoard.')
def __init__(self, hparams, loss_fn=F.cross_entropy, log_grads: bool = False, use_sentence_split: bool = True): super().__init__() """Configuration flags""" self.use_sentence_split = use_sentence_split self.log_grads = log_grads """Dataset""" self.batch_size = hparams.batch_size self.output_length = hparams.out_len self.win_len = hparams.win_len self._setup_dataloaders() """Training""" self.loss_fn = loss_fn self.lr = hparams.lr """Embedding""" self.embedding_dim = hparams.emb_dim self.embedding = nn.Embedding(self.num_classes, self.embedding_dim) """Metrics""" self.metrics = MetricsCalculator( ["accuracy", "precision", "recall", "f1"]) """Model""" self.model = WaveNet(num_blocks=hparams.num_blocks, num_layers=hparams.num_layers, num_classes=self.num_classes, output_len=self.output_length, ch_start=self.embedding_dim, ch_residual=hparams.ch_residual, ch_dilation=hparams.ch_dilation, ch_skip=hparams.ch_skip, ch_end=hparams.ch_end, kernel_size=hparams.kernel_size, bias=True)
def test_assert_different_length_batch_generation(): # prepare batch batch = 4 length = 32 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, length) length_list = sorted( list(np.random.randint(length // 2, length - 1, batch))) with torch.no_grad(): net = WaveNet(256, 28, 4, 4, 10, 3, 2) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] for x_, h_, length in zip(x, h, length_list): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") gen1_list += [gen1] # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen2_list = net.batch_fast_generate(batch_x, batch_h, length_list, 1, "argmax") # assertion for gen1, gen2 in zip(gen1_list, gen2_list): np.testing.assert_array_equal(gen1, gen2)
def custom_model_fn(features, labels, mode, params): """Model function for custom WaveNetEsimator""" model = WaveNet(**params) if mode == tf.estimator.ModeKeys.PREDICT: logits = model((features['mel'], labels), training=False) predictions = { 'logits': logits } return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={'upsampled': export_output.PredictOutput(predictions)} ) logits = model((features, labels), training=True) logits = tf.transpose(logits, [0, 2, 1]) labels = tf.one_hot(tf.cast(labels, dtype=tf.int32), 256) loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) metrics = {'loss': loss} tf.summary.scalar('loss', loss) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops=metrics ) assert mode == tf.estimator.ModeKeys.TRAIN optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def save_checkpoint(model, optimizer, scheduler, learning_rate, iteration, output_directory, ema, wavenet_config): checkpoint_path = "{}/wavenet_{}".format(output_directory, iteration) print("Saving model and optimizer state at iteration {} to {}".format( iteration, checkpoint_path)) model_for_saving = WaveNet(**wavenet_config).cuda() model_for_saving.load_state_dict(model.state_dict()) torch.save( { 'model': model_for_saving, 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'learning_rate': learning_rate }, checkpoint_path) ema_path = "{}/wavenet_ema_{}".format(output_directory, iteration) print("Saving ema model at iteration {} to {}".format(iteration, ema_path)) state_dict = model_for_saving.state_dict() for name, _ in model.named_parameters(): if name in ema.shadow: state_dict[name] = ema.shadow[name] model_for_saving.load_state_dict(state_dict) torch.save( { 'model': model_for_saving, 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'learning_rate': learning_rate }, ema_path)
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): print("Saving model and optimizer state at iteration {} to {}".format( iteration, filepath)) model_for_saving = WaveNet(**wavenet_config).cuda() model_for_saving.load_state_dict(model.state_dict()) torch.save({'model': model_for_saving, 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'learning_rate': learning_rate}, filepath)
def setUp(self): self.net = WaveNet(batch_size=1, dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256, 1, 2, 4, 8, 16, 32, 64, 128, 256], filter_width=2, residual_channels=16, dilation_channels=16, quantization_channels=256, skip_channels=32)
def __init__(self, args): #TODO self.args = args #self.data = [Dataset(args, domain_path) for domain_path in args.data] self.expPath = args.checkpoint / 'MusicStar' / args.exp_name self.logger = train_logger(self.args, self.expPath) self.encoder = MusicStarEncoder(args) self.decoder = WaveNet(args)
def gpu_decode(feat_list, gpu): # set default gpu and do not track gradient torch.cuda.set_device(gpu) torch.set_grad_enabled(False) # define model and load parameters if config.use_upsampling_layer: upsampling_factor = config.upsampling_factor else: upsampling_factor = 0 model = WaveNet(n_quantize=config.n_quantize, n_aux=config.n_aux, n_resch=config.n_resch, n_skipch=config.n_skipch, dilation_depth=config.dilation_depth, dilation_repeat=config.dilation_repeat, kernel_size=config.kernel_size, upsampling_factor=upsampling_factor) model.load_state_dict( torch.load(args.checkpoint, map_location=lambda storage, loc: storage)["model"]) model.eval() model.cuda() # define generator generator = decode_generator( feat_list, batch_size=args.batch_size, feature_type=config.feature_type, wav_transform=wav_transform, feat_transform=feat_transform, upsampling_factor=config.upsampling_factor, use_upsampling_layer=config.use_upsampling_layer, use_speaker_code=config.use_speaker_code) # decode if args.batch_size > 1: for feat_ids, (batch_x, batch_h, n_samples_list) in generator: logging.info("decoding start") samples_list = model.batch_fast_generate( batch_x, batch_h, n_samples_list, args.intervals) for feat_id, samples in zip(feat_ids, samples_list): wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir)) else: for feat_id, (x, h, n_samples) in generator: logging.info("decoding %s (length = %d)" % (feat_id, n_samples)) samples = model.fast_generate(x, h, n_samples, args.intervals) wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' cfg.workdir = cfg.work_root + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from cfg.save_excel = args.save_excel weights_dir = os.path.join(cfg.workdir, 'weights') check_and_mkdir(weights_dir) print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir + '/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True) vctk_val = VCTK(cfg, 'val') val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=40, dilations=[1, 2, 4, 8, 16]) model = nn.DataParallel(model) model.cuda() model.train() # build loss loss_fn = nn.CTCLoss(blank=27) if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True) print("loading", cfg.workdir + '/weights/best.pth') cfg.load_from = cfg.workdir + '/weights/best.pth' scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) train(train_loader, scheduler, model, loss_fn, val_loader, writer)
def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) assert args.checkpoint, 'you MUST pass a checkpoint for the encoder' if args.continue_training: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 else: self.start_epoch = 0 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) if args.continue_training: self.decoder.load_state_dict(states['decoder_state']) self.logger.info('Loaded checkpoint parameters') self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(self.decoder.parameters(), lr=args.lr) if args.continue_training: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step()
def test_forward(): # get batch generator = sine_generator(100) batch = next(generator) batch_input = batch.view(1, -1) batch_aux = torch.rand(1, 28, batch_input.size(1)).float() # define model without upsampling with kernel size = 2 net = WaveNet(256, 28, 32, 128, 10, 1, 2) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256 # define model without upsampling with kernel size = 3 net = WaveNet(256, 28, 32, 128, 10, 1, 2) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256 batch_input = batch.view(1, -1) batch_aux = torch.rand(1, 28, batch_input.size(1) // 10).float() # define model with upsampling and kernel size = 2 net = WaveNet(256, 28, 32, 128, 10, 1, 2, 10) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256 # define model with upsampling and kernel size = 3 net = WaveNet(256, 28, 32, 128, 10, 1, 3, 10) net.apply(initialize) net.eval() y = net(batch_input, batch_aux)[0] assert y.size(0) == batch_input.size(1) assert y.size(1) == 256
def setUp(self): quantization_steps = 256 self.net = WaveNet(batch_size=1, channels=quantization_steps, dilations=[ 1, 2, 4, 8, 16, 32, 64, 128, 256, 1, 2, 4, 8, 16, 32, 64, 128, 256 ], filter_width=2, residual_channels=16, dilation_channels=16, use_biases=True)
def main(): print('initial training...') print( f'work_dir:{cfg.workdir}, pretrained:{cfg.load_from}, batch_size:{cfg.batch_size} lr:{cfg.lr}, epochs:{cfg.epochs}' ) args = parse_args() writer = SummaryWriter(log_dir=cfg.workdir + '/runs') # distributed training setting assert cfg.distributed torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group('nccl', init_method='env://') # build dataloader vctk_train = VCTK(cfg, 'train') train_sample = torch.utils.data.distributed.DistributedSampler( vctk_train, shuffle=True, ) # train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True) train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, sampler=train_sample, num_workers=8, pin_memory=True) vctk_val = VCTK(cfg, 'val') val_sample = torch.utils.data.distributed.DistributedSampler( vctk_val, shuffle=False, ) # val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True) val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, sampler=val_sample, num_workers=8, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=20).cuda() model = DDP(model, device_ids=[args.local_rank], broadcast_buffers=False) # model = nn.DataParallel(model) # build loss loss_fn = nn.CTCLoss() # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) # train train(args, train_loader, scheduler, model, loss_fn, val_loader, writer)
def gpu_decode(feat_list, gpu): with torch.cuda.device(gpu): # define model and load parameters model = WaveNet(n_quantize=config.n_quantize, n_aux=config.n_aux, n_resch=config.n_resch, n_skipch=config.n_skipch, dilation_depth=config.dilation_depth, dilation_repeat=config.dilation_repeat, kernel_size=config.kernel_size, upsampling_factor=config.upsampling_factor) model.load_state_dict( torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda(gpu)) ["model"]) model.eval() model.cuda() torch.backends.cudnn.benchmark = True # define generator generator = decode_generator( feat_list, batch_size=args.batch_size, wav_transform=wav_transform, feat_transform=feat_transform, use_speaker_code=config.use_speaker_code, upsampling_factor=config.upsampling_factor) # decode if args.batch_size > 1: for feat_ids, (batch_x, batch_h, n_samples_list) in generator: logging.info("decoding start") samples_list = model.batch_fast_generate( batch_x, batch_h, n_samples_list, args.intervals) for feat_id, samples in zip(feat_ids, samples_list): wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir)) else: for feat_id, (x, h, n_samples) in generator: logging.info("decoding %s (length = %d)" % (feat_id, n_samples)) samples = model.fast_generate(x, h, n_samples, args.intervals) wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
def __init__(self,input_C=96,input_L=1366,L_trans_channels=256): super(CascadeModel,self).__init__() self.input_C = input_C self.input_L = input_L self.first_block = nn.Sequential(nn.Conv1d(input_L,L_trans_channels,1), nn.Conv1d(L_trans_channels,L_trans_channels,3), nn.BatchNorm1d(L_trans_channels), nn.ReLU(), nn.Conv1d(L_trans_channels,L_trans_channels,3), nn.BatchNorm1d(L_trans_channels), nn.ELU(), nn.MaxPool1d(2), nn.Conv1d(L_trans_channels,L_trans_channels,3), nn.BatchNorm1d(L_trans_channels), nn.ReLU(), nn.Conv1d(L_trans_channels,L_trans_channels,3), nn.BatchNorm1d(L_trans_channels), nn.ELU(), nn.MaxPool1d(2), nn.Conv1d(L_trans_channels,L_trans_channels,3), nn.BatchNorm1d(L_trans_channels), nn.ReLU(), nn.Conv1d(L_trans_channels,L_trans_channels,3,padding=1), nn.BatchNorm1d(L_trans_channels), nn.ELU(), nn.MaxPool1d(2), nn.Conv1d(L_trans_channels,L_trans_channels,1), nn.BatchNorm1d(L_trans_channels), nn.ELU(), nn.Conv1d(L_trans_channels,input_L,1) ) self.wavenet = WaveNet(in_depth = 9, dilation_channels=32, res_channels=32, skip_channels=256, end_channels = 128, dilation_depth = 6, n_blocks = 5) self.post = nn.Sequential(nn.Dropout(p=0.2), nn.Linear(128,256), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(256,50), nn.Sigmoid(), )
def test_generate(): batch = 2 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, 32) length = h.shape[-1] - 1 with torch.no_grad(): net = WaveNet(256, 28, 16, 32, 10, 3, 2) net.apply(initialize) net.eval() for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() net.generate(batch_x, batch_h, length, 1, "sampling") net.fast_generate(batch_x, batch_h, length, 1, "sampling") batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "sampling")
def test(): np.random.seed(42) audio, speaker_ids = make_sine_waves(None) dilations = [2**i for i in range(7)] * 2 receptive_field = WaveNet.calculate_receptive_field(2, dilations) audio = np.pad(audio, (receptive_field - 1, 0), 'constant').astype(np.float32) encoded = mu_law_encode(audio, 2**8) encoded = encoded[np.newaxis, :] encoded_one_hot = one_hot(encoded, 2**8) signal_length = int(tf.shape(encoded_one_hot)[1] - 1) input_one_hot = tf.slice(encoded_one_hot, [0, 0, 0], [-1, signal_length, -1]) target_one_hot = tf.slice(encoded_one_hot, [0, receptive_field, 0], [-1, -1, -1]) print('input shape: ', tf.shape(input_one_hot)) print('output shape: ', tf.shape(target_one_hot)) net = WaveNet(1, dilations, 2, signal_length, 32, 32, 32, 2**8, True, 0.01) net.build(input_shape=(None, signal_length, 2**8)) optimizer = Adam(lr=1e-3) for epoch in range(301): with tf.GradientTape() as tape: # [b, 1254, 256] => [b, 999, 256] logits = net(input_one_hot, training=True) # [b, 999, 256] => [b * 999, 256] logits = tf.reshape(logits, [-1, 2**8]) target_one_hot = tf.reshape(target_one_hot, [-1, 2**8]) # comput loss loss = tf.losses.categorical_crossentropy(target_one_hot, logits, from_logits=True) loss = tf.reduce_mean(loss) grads = tape.gradient(loss, net.trainable_variables) optimizer.apply_gradients(zip(grads, net.trainable_variables)) if epoch % 100 == 0: print(epoch, 'loss: ', float(loss))
def __init__(self, args): self.args = args self.args.n_datasets = len(self.args.data) self.expPath = Path('checkpoints') / args.expName torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) self.logger = create_output_dir(args, self.expPath) self.data = [DatasetSet(d, args.seq_len, args) for d in args.data] assert not args.distributed or len(self.data) == int( os.environ['WORLD_SIZE'] ), "Number of datasets must match number of nodes" self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.loss_d_right = LossMeter('d') self.loss_total = LossMeter('total') self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.args.n_datasets) ] self.eval_d_right = LossMeter('eval d') self.eval_total = LossMeter('eval total') self.encoder = Encoder(args) self.decoder = WaveNet(args) self.discriminator = ZDiscriminator(args) if args.checkpoint: checkpoint_args_path = os.path.dirname( args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 states = torch.load(args.checkpoint) self.encoder.load_state_dict(states['encoder_state']) self.decoder.load_state_dict(states['decoder_state']) self.discriminator.load_state_dict(states['discriminator_state']) self.logger.info('Loaded checkpoint parameters') else: self.start_epoch = 0 if args.distributed: self.encoder.cuda() self.encoder = torch.nn.parallel.DistributedDataParallel( self.encoder) self.discriminator.cuda() self.discriminator = torch.nn.parallel.DistributedDataParallel( self.discriminator) self.logger.info('Created DistributedDataParallel') else: self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.discriminator = torch.nn.DataParallel( self.discriminator).cuda() self.decoder = torch.nn.DataParallel(self.decoder).cuda() self.model_optimizer = optim.Adam(chain(self.encoder.parameters(), self.decoder.parameters()), lr=args.lr) self.d_optimizer = optim.Adam(self.discriminator.parameters(), lr=args.lr) if args.checkpoint and args.load_optimizer: self.model_optimizer.load_state_dict( states['model_optimizer_state']) self.d_optimizer.load_state_dict(states['d_optimizer_state']) self.lr_manager = torch.optim.lr_scheduler.ExponentialLR( self.model_optimizer, args.lr_decay) self.lr_manager.last_epoch = self.start_epoch self.lr_manager.step()
def main(): args = get_arguments() try: directories = validate_directories(args) except ValueError as e: print("Some arguments are wrong:") print(str(e)) return logdir = directories['logdir'] logdir_root = directories['logdir_root'] restore_from = directories['restore_from'] # Even if we restored the model, we will treat it as new training # if the trained model is written into an arbitrary location. is_overwritten_training = logdir != restore_from with open(args.wavenet_params, 'r') as f: wavenet_params = json.load(f) # Create coordinator. coord = tf.train.Coordinator() # Load raw waveform from VCTK corpus. with tf.name_scope('create_inputs'): reader = AudioReader(args.data_dir, coord, sample_rate=wavenet_params['sample_rate'], sample_size=args.sample_size) audio_batch = reader.dequeue(args.batch_size) # Create network. net = WaveNet( batch_size=args.batch_size, dilations=wavenet_params["dilations"], filter_width=wavenet_params["filter_width"], residual_channels=wavenet_params["residual_channels"], dilation_channels=wavenet_params["dilation_channels"], skip_channels=wavenet_params["skip_channels"], quantization_channels=wavenet_params["quantization_channels"], use_biases=wavenet_params["use_biases"]) loss = net.loss(audio_batch) optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate) trainable = tf.trainable_variables() optim = optimizer.minimize(loss, var_list=trainable) # Set up logging for TensorBoard. writer = tf.train.SummaryWriter(logdir) writer.add_graph(tf.get_default_graph()) run_metadata = tf.RunMetadata() summaries = tf.merge_all_summaries() # Set up session sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) init = tf.initialize_all_variables() sess.run(init) # Saver for storing checkpoints of the model. saver = tf.train.Saver() try: saved_global_step = load(saver, sess, restore_from) if is_overwritten_training or saved_global_step is None: # The first training step will be saved_global_step + 1, # therefore we put -1 here for new or overwritten trainings. saved_global_step = -1 except: print("Something went wrong while restoring checkpoint. " "We will terminate training to avoid accidentally overwriting " "the previous model.") raise threads = tf.train.start_queue_runners(sess=sess, coord=coord) reader.start_threads(sess) try: last_saved_step = saved_global_step for step in range(saved_global_step + 1, args.num_steps): start_time = time.time() if args.store_metadata and step % 50 == 0: # Slow run that stores extra information for debugging. print('Storing metadata') run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) summary, loss_value, _ = sess.run([summaries, loss, optim], options=run_options, run_metadata=run_metadata) writer.add_summary(summary, step) writer.add_run_metadata(run_metadata, 'step_{:04d}'.format(step)) tl = timeline.Timeline(run_metadata.step_stats) timeline_path = os.path.join(logdir, 'timeline.trace') with open(timeline_path, 'w') as f: f.write(tl.generate_chrome_trace_format(show_memory=True)) else: summary, loss_value, _ = sess.run([summaries, loss, optim]) writer.add_summary(summary, step) duration = time.time() - start_time print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format( step, loss_value, duration)) if step % 50 == 0: save(saver, sess, logdir, step) last_saved_step = step except KeyboardInterrupt: # Introduce a line break after ^C is displayed so save message # is on its own line. print() finally: if step > last_saved_step: save(saver, sess, logdir, step) coord.request_stop() coord.join(threads)
def main(args): print('Starting') matplotlib.use('agg') os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth') checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders] assert len(checkpoints) >= 1, "No checkpoints found." model_args = torch.load(args.model.parent / 'args.pth')[0] encoder = wavenet_models.Encoder(model_args) encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state']) encoder.eval() encoder = encoder.cuda() decoders = [] decoder_ids = [] for checkpoint in checkpoints: decoder = WaveNet(model_args) decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) decoder.eval() decoder = decoder.cuda() if args.py: decoder = WavenetGenerator(decoder, args.batch_size, wav_freq=args.rate) else: decoder = NVWavenetGenerator(decoder, args.rate * (args.split_size // 20), args.batch_size, 3) decoders += [decoder] decoder_ids += [extract_id(checkpoint)] xs = [] assert args.output_next_to_orig ^ (args.output is not None) if len(args.files) == 1 and args.files[0].is_dir(): top = args.files[0] file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5')) else: file_paths = args.files if not args.skip_filter: file_paths = [f for f in file_paths if not '_' in str(f.name)] for file_path in file_paths: if file_path.suffix == '.wav': data, rate = librosa.load(file_path, sr=16000) assert rate == 16000 data = utils.mu_law(data) elif file_path.suffix == '.h5': data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2**15)) if data.shape[-1] % args.rate != 0: data = data[:-(data.shape[-1] % args.rate)] assert data.shape[-1] % args.rate == 0 else: raise Exception(f'Unsupported filetype {file_path}') if args.sample_len: data = data[:args.sample_len] else: args.sample_len = len(data) xs.append(torch.tensor(data).unsqueeze(0).float().cuda()) xs = torch.stack(xs).contiguous() print(f'xs size: {xs.size()}') def save(x, decoder_ix, filepath): wav = utils.inv_mu_law(x.cpu().numpy()) print(f'X size: {x.shape}') print(f'X min: {x.min()}, max: {x.max()}') if args.output_next_to_orig: save_audio(wav.squeeze(), filepath.parent / f'{filepath.stem}_{decoder_ix}.wav', rate=args.rate) else: save_audio(wav.squeeze(), args.output / str(extract_id(args.model)) / str(args.update) / filepath.with_suffix('.wav').name, rate=args.rate) yy = {} with torch.no_grad(): zz = [] for xs_batch in torch.split(xs, args.batch_size): zz += [encoder(xs_batch)] zz = torch.cat(zz, dim=0) with utils.timeit("Generation timer"): for i, decoder_id in enumerate(decoder_ids): yy[decoder_id] = [] decoder = decoders[i] for zz_batch in torch.split(zz, args.batch_size): print(zz_batch.shape) splits = torch.split(zz_batch, args.split_size, -1) audio_data = [] decoder.reset() for cond in tqdm.tqdm(splits): audio_data += [decoder.generate(cond).cpu()] audio_data = torch.cat(audio_data, -1) yy[decoder_id] += [audio_data] yy[decoder_id] = torch.cat(yy[decoder_id], dim=0) del decoder for decoder_ix, decoder_result in yy.items(): for sample_result, filepath in zip(decoder_result, file_paths): save(sample_result, decoder_ix, filepath)
with open(TEST_SCRIPT_FILE, 'r') as f: test_list = f.readlines() num_test = len(f.readlines()) # Enqueue jobs for i in range(num_train): tasks.put(Task(train_list[i], DATA_SAVE_DIR, 'train', i)) for i in range(num_test): tasks.put(Task(test_list[i], DATA_SAVE_DIR, 'test', i)) # Add a poison pill for each consumer for i in range(num_consumers): tasks.put(None) wvn = WaveNet(input_dim=256+406+2, dilations=[1,2,4,8,16,32,64,128,256,512], filter_width=2) wvn.build() wvn.compile() wvn.plot() wvn.add_callbacks(os.path.join(CKPT_PATH,'weights.epoch001.{epoch:02d}.hdf5'), None) # Start 1st epoch training num_jobs = num_train train_files = [] train_times = [] while num_jobs: f = results_tr.get() train_files.append(f) # Model training start = T() wvn.fit_on_file(f)
def train(model_directory, epochs, learning_rate, epochs_per_checkpoint, batch_size, seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) criterion = CrossEntropyLoss() model = WaveNet(**wavenet_config).cuda() # model.upsample = torch.nn.Sequential() #replace the upsample step with no operation as we manually control samples # model.upsample.weight = None # model.upsample.bias = None optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Load checkpoint if one exists iteration = 0 checkpoint_path = find_checkpoint(model_directory) if checkpoint_path is not None: model, optimizer, iteration = load_checkpoint(checkpoint_path, model, optimizer) iteration += 1 # next iteration is iteration + 1 trainset = SimpleWaveLoader() train_loader = DataLoader(trainset, num_workers=1, shuffle=False, sampler=None, batch_size=batch_size, pin_memory=False, drop_last=True) model.train() epoch_offset = max(0, int(iteration / len(train_loader))) # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, epochs): print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): model.zero_grad() x, y = batch x = to_gpu(x).float() y = to_gpu(y) x = (x, y) # auto-regressive takes outputs as inputs y_pred = model(x) loss = criterion(y_pred, y) reduced_loss = loss.data.item() loss.backward() optimizer.step() #print out the loss, and save to a file print("{}:\t{:.9f}".format(iteration, reduced_loss)) with open(os.path.join(model_directory, 'loss_history.txt'), 'a') as f: f.write('%s\n' % str(reduced_loss)) iteration += 1 torch.cuda.empty_cache() if (epoch != 0 and epoch % epochs_per_checkpoint == 0): checkpoint_path = os.path.join(model_directory, 'checkpoint_%d' % iteration) save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
def test_assert_fast_generation(): # get batch batch = 2 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, 32) length = h.shape[-1] - 1 with torch.no_grad(): # -------------------------------------------------------- # define model without upsampling and with kernel size = 2 # -------------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 2) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2) # -------------------------------------------------------- # define model without upsampling and with kernel size = 3 # -------------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 3) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2) # get batch batch = 2 upsampling_factor = 10 x = np.random.randint(0, 256, size=(batch, 1)) h = np.random.randn(batch, 28, 3) length = h.shape[-1] * upsampling_factor - 1 # ----------------------------------------------------- # define model with upsampling and with kernel size = 2 # ----------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2) # ----------------------------------------------------- # define model with upsampling and with kernel size = 3 # ----------------------------------------------------- net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor) net.apply(initialize) net.eval() # sample-by-sample generation gen1_list = [] gen2_list = [] for x_, h_ in zip(x, h): batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long() batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float() gen1 = net.generate(batch_x, batch_h, length, 1, "argmax") gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax") np.testing.assert_array_equal(gen1, gen2) gen1_list += [gen1] gen2_list += [gen2] gen1 = np.stack(gen1_list) gen2 = np.stack(gen2_list) np.testing.assert_array_equal(gen1, gen2) # batch generation batch_x = torch.from_numpy(x).long() batch_h = torch.from_numpy(h).float() gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1, "argmax") gen3 = np.stack(gen3_list) np.testing.assert_array_equal(gen3, gen2)
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, iters_per_checkpoint, batch_size, seed, checkpoint_path): torch.manual_seed(seed) torch.cuda.manual_seed(seed) #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: init_distributed(rank, num_gpus, group_name, **dist_config) #=====END: ADDED FOR DISTRIBUTED====== criterion = CrossEntropyLoss() model = WaveNet(**wavenet_config).cuda() #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: model = apply_gradient_allreduce(model) #=====END: ADDED FOR DISTRIBUTED====== optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Load checkpoint if one exists iteration = 0 if checkpoint_path != "": model, optimizer, iteration = load_checkpoint(checkpoint_path, model, optimizer) iteration += 1 # next iteration is iteration + 1 #trainset = Mel2SampOnehot(**data_config) trainset = DeepMels(**data_config) # =====START: ADDED FOR DISTRIBUTED====== train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None # =====END: ADDED FOR DISTRIBUTED====== train_loader = DataLoader(trainset, num_workers=1, shuffle=False, sampler=train_sampler, batch_size=batch_size, pin_memory=False, drop_last=True) # Get shared output_directory ready if rank == 0: if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) print("output directory", output_directory) model.train() epoch_offset = max(0, int(iteration / len(train_loader))) # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, epochs): total_loss = 0 print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): model.zero_grad() x, y = batch x = to_gpu(x).float() y = to_gpu(y) x = (x, y) # auto-regressive takes outputs as inputs y_pred = model(x) loss = criterion(y_pred, y) if num_gpus > 1: reduced_loss = reduce_tensor(loss.data, num_gpus)[0] else: reduced_loss = loss.data[0] loss.backward() optimizer.step() total_loss += reduced_loss if (iteration % iters_per_checkpoint == 0): if rank == 0: checkpoint_path = "{}/wavenet_{}".format( output_directory, iteration) save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) iteration += 1 print("epoch:{}, total epoch loss:{}".format(epoch, total_loss))
pass filename = args.params_dir + "/{}".format(args.params_filename) if os.path.isfile(filename): f = open(filename) try: dict = json.load(f) params = Params(dict) except: raise Exception("could not load {}".format(filename)) params.gpu_enabled = True if args.gpu_enabled == 1 else False if args.use_faster_wavenet: wavenet = FasterWaveNet(params) else: wavenet = WaveNet(params) else: params = Params() params.audio_channels = 256 params.causal_conv_no_bias = True params.causal_conv_kernel_width = 2 params.causal_conv_channels = [128] params.residual_conv_dilation_no_bias = True params.residual_conv_projection_no_bias = True params.residual_conv_kernel_width = 2 params.residual_conv_channels = [32, 32, 32, 32, 32, 32, 32, 32, 32] params.residual_num_blocks = 5 params.softmax_conv_no_bias = True
def main(): args = get_arguments() try: directories = validate_directories(args) except ValueError as e: print("Some arguments are wrong:") print(str(e)) return logdir = directories['logdir'] logdir_root = directories['logdir_root'] restore_from = directories['restore_from'] # Even if we restored the model, we will treat it as new training # if the trained model is written into arbitrary location. is_new_training = logdir != restore_from with open(args.wavenet_params, 'r') as f: wavenet_params = json.load(f) # create coordinator coord = tf.train.Coordinator() # Load raw waveform from VCTK corpus. with tf.name_scope('create_inputs'): custom_runner = CustomRunner(args, wavenet_params, coord) audio_batch, _ = custom_runner.get_inputs() # Create network. net = WaveNet(args.batch_size, wavenet_params["quantization_steps"], wavenet_params["dilations"], wavenet_params["filter_width"], wavenet_params["residual_channels"], wavenet_params["dilation_channels"]) loss = net.loss(audio_batch) optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate) trainable = tf.trainable_variables() optim = optimizer.minimize(loss, var_list=trainable) # Set up logging for TensorBoard. writer = tf.train.SummaryWriter(logdir) writer.add_graph(tf.get_default_graph()) run_metadata = tf.RunMetadata() summaries = tf.merge_all_summaries() # Set up session sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) init = tf.initialize_all_variables() sess.run(init) # Saver for storing checkpoints of the model. saver = tf.train.Saver() try: saved_global_step = load(saver, sess, restore_from) if is_new_training or saved_global_step is None: # For "new" training with using pre-trained model, # We should ignore saved_global_step # The training step is start from saved_global_step + 1 # Therefore put -1 here if the new training starts. saved_global_step = -1 except: print("Something is wrong while restoring checkpoint. " "We will terminate training to avoid accidentally overwriting " "the previous model.") raise threads = tf.train.start_queue_runners(sess=sess, coord=coord) custom_runner.start_threads(sess) try: for step in range(saved_global_step + 1, args.num_steps): start_time = time.time() if args.store_metadata and step % 50 == 0: # Slow run that stores extra information for debugging. print('Storing metadata') run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) summary, loss_value, _ = sess.run([summaries, loss, optim], options=run_options, run_metadata=run_metadata) writer.add_summary(summary, step) writer.add_run_metadata(run_metadata, 'step_{:04d}'.format(step)) tl = timeline.Timeline(run_metadata.step_stats) timeline_path = os.path.join(logdir, 'timeline.trace') with open(timeline_path, 'w') as f: f.write(tl.generate_chrome_trace_format(show_memory=True)) else: summary, loss_value, _ = sess.run([summaries, loss, optim]) writer.add_summary(summary, step) duration = time.time() - start_time print('step %d - loss = %.3f, (%.3f sec/step)' % (step, loss_value, duration)) if step % 50 == 0: save(saver, sess, logdir, step) finally: coord.request_stop() coord.join(threads)
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", required=True, type=str, help="directory or list of wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of aux feat files") parser.add_argument("--stats", required=True, type=str, help="hdf5 file including statistics") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") # network structure setting parser.add_argument("--n_quantize", default=256, type=int, help="number of quantization") parser.add_argument("--n_aux", default=28, type=int, help="number of dimension of aux feats") parser.add_argument("--n_resch", default=512, type=int, help="number of channels of residual output") parser.add_argument("--n_skipch", default=256, type=int, help="number of channels of skip output") parser.add_argument("--dilation_depth", default=10, type=int, help="depth of dilation") parser.add_argument("--dilation_repeat", default=1, type=int, help="number of repeating of dilation") parser.add_argument("--kernel_size", default=2, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--upsampling_factor", default=0, type=int, help="upsampling factor of aux features" "(if set 0, do not apply)") parser.add_argument("--use_speaker_code", default=False, type=strtobool, help="flag to use speaker code") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay coefficient") parser.add_argument( "--batch_size", default=20000, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--iters", default=200000, type=int, help="number of iterations") # other setting parser.add_argument("--checkpoints", default=10000, type=int, help="how frequent saving model") parser.add_argument("--intervals", default=100, type=int, help="log interval") parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig( level=logging.INFO, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig( level=logging.WARN, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # save args as conf torch.save(args, args.expdir + "/model.conf") # # define network model = WaveNet(n_quantize=args.n_quantize, n_aux=args.n_aux, n_resch=args.n_resch, n_skipch=args.n_skipch, dilation_depth=args.dilation_depth, dilation_repeat=args.dilation_repeat, kernel_size=args.kernel_size, upsampling_factor=args.upsampling_factor) logging.info(model) model.apply(initialize) model.train() # define loss and optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss() # define transforms scaler = StandardScaler() scaler.mean_ = read_hdf5(args.stats, "/mean") scaler.scale_ = read_hdf5(args.stats, "/scale") wav_transform = transforms.Compose( [lambda x: encode_mu_law(x, args.n_quantize)]) feat_transform = transforms.Compose([lambda x: scaler.transform(x)]) # define generator if os.path.isdir(args.waveforms): filenames = sorted( find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] feat_list = [ args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames ] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) feat_list = read_txt(args.feats) else: logging.error("--waveforms should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(wav_list)) generator = train_generator(wav_list, feat_list, receptive_field=model.receptive_field, batch_size=args.batch_size, wav_transform=wav_transform, feat_transform=feat_transform, shuffle=True, upsampling_factor=args.upsampling_factor, use_speaker_code=args.use_speaker_code) while not generator.queue.full(): time.sleep(0.1) # resume if args.resume is not None: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) iterations = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % iterations) else: iterations = 0 # send to gpu if torch.cuda.is_available(): model.cuda() criterion.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) # train loss = 0 total = 0 for i in six.moves.range(iterations, args.iters): start = time.time() (batch_x, batch_h), batch_t = generator.next() batch_output = model(batch_x, batch_h)[0] batch_loss = criterion(batch_output[model.receptive_field:], batch_t[model.receptive_field:]) optimizer.zero_grad() batch_loss.backward() optimizer.step() loss += batch_loss.data[0] total += time.time() - start logging.debug("batch loss = %.3f (%.3f sec / batch)" % (batch_loss.data[0], time.time() - start)) # report progress if (i + 1) % args.intervals == 0: logging.info( "(iter:%d) average loss = %.6f (%.3f sec / batch)" % (i + 1, loss / args.intervals, total / args.intervals)) logging.info( "estimated required time = " "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}". format( relativedelta(seconds=int((args.iters - (i + 1)) * (total / args.intervals))))) loss = 0 total = 0 # save intermidiate model if (i + 1) % args.checkpoints == 0: save_checkpoint(args.expdir, model, optimizer, i + 1) # save final model model.cpu() torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")