示例#1
0
def main():
    args = get_arguments()
    logdir = os.path.join(args.logdir, 'train', str(datetime.now()))
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNet(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        use_biases=wavenet_params['use_biases'])

    samples = tf.placeholder(tf.int32)

    next_sample = net.predict_proba(samples)

    saver = tf.train.Saver()
    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = net.decode(samples)

    quantization_steps = wavenet_params['quantization_steps']
    waveform = np.random.randint(quantization_steps, size=(1, )).tolist()
    for step in range(args.samples):
        if len(waveform) > args.window:
            window = waveform[-args.window:]
        else:
            window = waveform
        prediction = sess.run(next_sample, feed_dict={samples: window})
        sample = np.random.choice(np.arange(quantization_steps), p=prediction)
        waveform.append(sample)
        print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples,
                                                sample))
        if (args.wav_out_path and args.save_every
                and (step + 1) % args.save_every == 0):

            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.train.SummaryWriter(
        os.path.join(logdir, 'generation', datestring))
    tf.audio_summary('generated', decode, wavenet_params['sample_rate'])
    summaries = tf.merge_all_summaries()

    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    if args.wav_out_path:
        out = sess.run(decode, feed_dict={samples: waveform})
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
示例#2
0
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(args.data)
        self.modelPath = Path('checkpoints') / args.expName

        self.logger = create_output_dir(args, self.modelPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.start_epoch = 0

        #torch.manual_seed(args.seed)
        #torch.cuda.manual_seed(args.seed)

        #get the pretrained model checkpoints
        checkpoint = args.checkpoint.parent.glob(args.checkpoint.name +
                                                 '_*.pth')
        checkpoint = [c for c in checkpoint
                      if extract_id(c) in args.decoder][0]

        model_args = torch.load(args.checkpoint.parent / 'args.pth')[0]

        self.encoder = Encoder(model_args)
        self.decoder = WaveNet(model_args)

        self.encoder = Encoder(model_args)
        self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state'])

        #encoder freeze
        for param in self.encoder.parameters():
            param.requires_grad = False
            #self.logger.debug(f'encoder at start: {param}')

        self.decoder = WaveNet(model_args)
        self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])

        #decoder freeze
        for param in self.decoder.layers[:-args.decoder_update].parameters():
            param.requires_grad = False
            #self.logger.debug(f'decoder at start: {param}')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()
        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.step()
示例#3
0
 def __init__(self, domains, domain_cnn):
     """param domains: int specifying number of domains
        param domain_cnn: a Domain Confusion CNN"""
     super(MusicAutoEncoder, self).__init__()
     self.domains = domains
     self.encoder = WaveNet(**encoder_config).cuda()
     self.decoders = [WaveNet(**decoder_config).cuda() for k in range(domains)]
     self.domain_cnn = domain_cnn
示例#4
0
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
    print("Saving model and optimizer state at iteration {} to {}".format(
          iteration, filepath))
    model_for_saving = WaveNet(**wavenet_config).cuda()
    model_for_saving.load_state_dict(model.state_dict())
    torch.save({'model': model_for_saving,
                'iteration': iteration,
                'optimizer': optimizer.state_dict(),
                'learning_rate': learning_rate}, filepath)
示例#5
0
 def setUp(self):
     self.net = WaveNet(batch_size=1,
                        dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256,
                                   1, 2, 4, 8, 16, 32, 64, 128, 256],
                        filter_width=2,
                        residual_channels=16,
                        dilation_channels=16,
                        quantization_channels=256,
                        skip_channels=32)
示例#6
0
    def gpu_decode(feat_list, gpu):
        # set default gpu and do not track gradient
        torch.cuda.set_device(gpu)
        torch.set_grad_enabled(False)

        # define model and load parameters
        if config.use_upsampling_layer:
            upsampling_factor = config.upsampling_factor
        else:
            upsampling_factor = 0
        model = WaveNet(n_quantize=config.n_quantize,
                        n_aux=config.n_aux,
                        n_resch=config.n_resch,
                        n_skipch=config.n_skipch,
                        dilation_depth=config.dilation_depth,
                        dilation_repeat=config.dilation_repeat,
                        kernel_size=config.kernel_size,
                        upsampling_factor=upsampling_factor)
        model.load_state_dict(
            torch.load(args.checkpoint,
                       map_location=lambda storage, loc: storage)["model"])
        model.eval()
        model.cuda()

        # define generator
        generator = decode_generator(
            feat_list,
            batch_size=args.batch_size,
            feature_type=config.feature_type,
            wav_transform=wav_transform,
            feat_transform=feat_transform,
            upsampling_factor=config.upsampling_factor,
            use_upsampling_layer=config.use_upsampling_layer,
            use_speaker_code=config.use_speaker_code)

        # decode
        if args.batch_size > 1:
            for feat_ids, (batch_x, batch_h, n_samples_list) in generator:
                logging.info("decoding start")
                samples_list = model.batch_fast_generate(
                    batch_x, batch_h, n_samples_list, args.intervals)
                for feat_id, samples in zip(feat_ids, samples_list):
                    wav = decode_mu_law(samples, config.n_quantize)
                    sf.write(args.outdir + "/" + feat_id + ".wav", wav,
                             args.fs, "PCM_16")
                    logging.info("wrote %s.wav in %s." %
                                 (feat_id, args.outdir))
        else:
            for feat_id, (x, h, n_samples) in generator:
                logging.info("decoding %s (length = %d)" %
                             (feat_id, n_samples))
                samples = model.fast_generate(x, h, n_samples, args.intervals)
                wav = decode_mu_law(samples, config.n_quantize)
                sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs,
                         "PCM_16")
                logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
示例#7
0
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)

        assert args.checkpoint, 'you MUST pass a checkpoint for the encoder'

        if args.continue_training:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
        else:
            self.start_epoch = 0

        states = torch.load(args.checkpoint)
        self.encoder.load_state_dict(states['encoder_state'])
        if args.continue_training:
            self.decoder.load_state_dict(states['decoder_state'])
        self.logger.info('Loaded checkpoint parameters')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(self.decoder.parameters(),
                                          lr=args.lr)

        if args.continue_training:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()
示例#8
0
 def setUp(self):
     quantization_steps = 256
     self.net = WaveNet(batch_size=1,
                        channels=quantization_steps,
                        dilations=[
                            1, 2, 4, 8, 16, 32, 64, 128, 256, 1, 2, 4, 8,
                            16, 32, 64, 128, 256
                        ],
                        filter_width=2,
                        residual_channels=16,
                        dilation_channels=16,
                        use_biases=True)
示例#9
0
def main():
    print('initial training...')
    print(
        f'work_dir:{cfg.workdir}, pretrained:{cfg.load_from}, batch_size:{cfg.batch_size} lr:{cfg.lr}, epochs:{cfg.epochs}'
    )
    args = parse_args()
    writer = SummaryWriter(log_dir=cfg.workdir + '/runs')

    # distributed training setting
    assert cfg.distributed
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group('nccl', init_method='env://')

    # build dataloader
    vctk_train = VCTK(cfg, 'train')
    train_sample = torch.utils.data.distributed.DistributedSampler(
        vctk_train,
        shuffle=True,
    )
    # train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)
    train_loader = DataLoader(vctk_train,
                              batch_size=cfg.batch_size,
                              sampler=train_sample,
                              num_workers=8,
                              pin_memory=True)

    vctk_val = VCTK(cfg, 'val')
    val_sample = torch.utils.data.distributed.DistributedSampler(
        vctk_val,
        shuffle=False,
    )
    # val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)
    val_loader = DataLoader(vctk_val,
                            batch_size=cfg.batch_size,
                            sampler=val_sample,
                            num_workers=8,
                            pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20).cuda()
    model = DDP(model, device_ids=[args.local_rank], broadcast_buffers=False)
    # model = nn.DataParallel(model)

    # build loss
    loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)

    # train
    train(args, train_loader, scheduler, model, loss_fn, val_loader, writer)
示例#10
0
    def gpu_decode(feat_list, gpu):
        with torch.cuda.device(gpu):
            # define model and load parameters
            model = WaveNet(n_quantize=config.n_quantize,
                            n_aux=config.n_aux,
                            n_resch=config.n_resch,
                            n_skipch=config.n_skipch,
                            dilation_depth=config.dilation_depth,
                            dilation_repeat=config.dilation_repeat,
                            kernel_size=config.kernel_size,
                            upsampling_factor=config.upsampling_factor)
            model.load_state_dict(
                torch.load(args.checkpoint,
                           map_location=lambda storage, loc: storage.cuda(gpu))
                ["model"])
            model.eval()
            model.cuda()
            torch.backends.cudnn.benchmark = True

            # define generator
            generator = decode_generator(
                feat_list,
                batch_size=args.batch_size,
                wav_transform=wav_transform,
                feat_transform=feat_transform,
                use_speaker_code=config.use_speaker_code,
                upsampling_factor=config.upsampling_factor)

            # decode
            if args.batch_size > 1:
                for feat_ids, (batch_x, batch_h, n_samples_list) in generator:
                    logging.info("decoding start")
                    samples_list = model.batch_fast_generate(
                        batch_x, batch_h, n_samples_list, args.intervals)
                    for feat_id, samples in zip(feat_ids, samples_list):
                        wav = decode_mu_law(samples, config.n_quantize)
                        sf.write(args.outdir + "/" + feat_id + ".wav", wav,
                                 args.fs, "PCM_16")
                        logging.info("wrote %s.wav in %s." %
                                     (feat_id, args.outdir))
            else:
                for feat_id, (x, h, n_samples) in generator:
                    logging.info("decoding %s (length = %d)" %
                                 (feat_id, n_samples))
                    samples = model.fast_generate(x, h, n_samples,
                                                  args.intervals)
                    wav = decode_mu_law(samples, config.n_quantize)
                    sf.write(args.outdir + "/" + feat_id + ".wav", wav,
                             args.fs, "PCM_16")
                    logging.info("wrote %s.wav in %s." %
                                 (feat_id, args.outdir))
示例#11
0
def test_assert_different_length_batch_generation():
    # prepare batch
    batch = 4
    length = 32
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, length)
    length_list = sorted(
        list(np.random.randint(length // 2, length - 1, batch)))

    with torch.no_grad():
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        for x_, h_, length in zip(x, h, length_list):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            gen1_list += [gen1]

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen2_list = net.batch_fast_generate(batch_x, batch_h, length_list, 1,
                                            "argmax")

        # assertion
        for gen1, gen2 in zip(gen1_list, gen2_list):
            np.testing.assert_array_equal(gen1, gen2)
示例#12
0
def main():
    args = parse_args()
    cfg.resume = args.resume
    cfg.exp_name = args.exp
    cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size = args.batch_size
    cfg.lr = args.lr
    cfg.load_from = args.load_from
    cfg.save_excel = args.save_excel

    weights_dir = os.path.join(cfg.workdir, 'weights')
    check_and_mkdir(weights_dir)

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir + '/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train,
                              batch_size=cfg.batch_size,
                              num_workers=4,
                              shuffle=True,
                              pin_memory=True)
    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val,
                            batch_size=cfg.batch_size,
                            num_workers=4,
                            shuffle=False,
                            pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=40, dilations=[1, 2, 4, 8, 16])
    model = nn.DataParallel(model)
    model.cuda()
    model.train()

    # build loss
    loss_fn = nn.CTCLoss(blank=27)

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'),
                              strict=True)
        print("loading", cfg.workdir + '/weights/best.pth')
        cfg.load_from = cfg.workdir + '/weights/best.pth'

    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
示例#13
0
def custom_model_fn(features, labels, mode, params):
    """Model function for custom WaveNetEsimator"""
    model = WaveNet(**params)
    if mode == tf.estimator.ModeKeys.PREDICT:
        logits = model((features['mel'], labels), training=False)
        predictions = {
            'logits': logits
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={'upsampled': export_output.PredictOutput(predictions)}
            )

    logits = model((features, labels), training=True)
    logits = tf.transpose(logits, [0, 2, 1])
    labels = tf.one_hot(tf.cast(labels, dtype=tf.int32), 256)

    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
    metrics = {'loss': loss}
    tf.summary.scalar('loss', loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=metrics
        )
    assert mode == tf.estimator.ModeKeys.TRAIN
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
示例#14
0
 def __init__(self,
              hparams,
              loss_fn=F.cross_entropy,
              log_grads: bool = False,
              use_sentence_split: bool = True):
     super().__init__()
     """Configuration flags"""
     self.use_sentence_split = use_sentence_split
     self.log_grads = log_grads
     """Dataset"""
     self.batch_size = hparams.batch_size
     self.output_length = hparams.out_len
     self.win_len = hparams.win_len
     self._setup_dataloaders()
     """Training"""
     self.loss_fn = loss_fn
     self.lr = hparams.lr
     """Embedding"""
     self.embedding_dim = hparams.emb_dim
     self.embedding = nn.Embedding(self.num_classes, self.embedding_dim)
     """Metrics"""
     self.metrics = MetricsCalculator(
         ["accuracy", "precision", "recall", "f1"])
     """Model"""
     self.model = WaveNet(num_blocks=hparams.num_blocks,
                          num_layers=hparams.num_layers,
                          num_classes=self.num_classes,
                          output_len=self.output_length,
                          ch_start=self.embedding_dim,
                          ch_residual=hparams.ch_residual,
                          ch_dilation=hparams.ch_dilation,
                          ch_skip=hparams.ch_skip,
                          ch_end=hparams.ch_end,
                          kernel_size=hparams.kernel_size,
                          bias=True)
示例#15
0
def save_checkpoint(model, optimizer, scheduler, learning_rate, iteration,
                    output_directory, ema, wavenet_config):
    checkpoint_path = "{}/wavenet_{}".format(output_directory, iteration)
    print("Saving model and optimizer state at iteration {} to {}".format(
        iteration, checkpoint_path))
    model_for_saving = WaveNet(**wavenet_config).cuda()
    model_for_saving.load_state_dict(model.state_dict())
    torch.save(
        {
            'model': model_for_saving,
            'iteration': iteration,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'learning_rate': learning_rate
        }, checkpoint_path)
    ema_path = "{}/wavenet_ema_{}".format(output_directory, iteration)
    print("Saving ema model at iteration {} to {}".format(iteration, ema_path))

    state_dict = model_for_saving.state_dict()
    for name, _ in model.named_parameters():
        if name in ema.shadow:
            state_dict[name] = ema.shadow[name]
    model_for_saving.load_state_dict(state_dict)
    torch.save(
        {
            'model': model_for_saving,
            'iteration': iteration,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'learning_rate': learning_rate
        }, ema_path)
示例#16
0
    def __init__(self, args):
        #TODO
        self.args = args
        #self.data = [Dataset(args, domain_path) for domain_path in args.data]
        self.expPath = args.checkpoint / 'MusicStar' / args.exp_name
        self.logger = train_logger(self.args, self.expPath)

        self.encoder = MusicStarEncoder(args)
        self.decoder = WaveNet(args)
示例#17
0
def test():
    np.random.seed(42)
    audio, speaker_ids = make_sine_waves(None)
    dilations = [2**i for i in range(7)] * 2
    receptive_field = WaveNet.calculate_receptive_field(2, dilations)
    audio = np.pad(audio, (receptive_field - 1, 0),
                   'constant').astype(np.float32)

    encoded = mu_law_encode(audio, 2**8)
    encoded = encoded[np.newaxis, :]
    encoded_one_hot = one_hot(encoded, 2**8)

    signal_length = int(tf.shape(encoded_one_hot)[1] - 1)
    input_one_hot = tf.slice(encoded_one_hot, [0, 0, 0],
                             [-1, signal_length, -1])
    target_one_hot = tf.slice(encoded_one_hot, [0, receptive_field, 0],
                              [-1, -1, -1])
    print('input shape: ', tf.shape(input_one_hot))
    print('output shape: ', tf.shape(target_one_hot))

    net = WaveNet(1, dilations, 2, signal_length, 32, 32, 32, 2**8, True, 0.01)
    net.build(input_shape=(None, signal_length, 2**8))
    optimizer = Adam(lr=1e-3)

    for epoch in range(301):
        with tf.GradientTape() as tape:
            # [b, 1254, 256] => [b, 999, 256]
            logits = net(input_one_hot, training=True)
            # [b, 999, 256] => [b * 999, 256]
            logits = tf.reshape(logits, [-1, 2**8])
            target_one_hot = tf.reshape(target_one_hot, [-1, 2**8])
            # comput loss
            loss = tf.losses.categorical_crossentropy(target_one_hot,
                                                      logits,
                                                      from_logits=True)
            loss = tf.reduce_mean(loss)

        grads = tape.gradient(loss, net.trainable_variables)
        optimizer.apply_gradients(zip(grads, net.trainable_variables))
        if epoch % 100 == 0:
            print(epoch, 'loss: ', float(loss))
示例#18
0
class TestNet(tf.test.TestCase):

    def setUp(self):
        self.net = WaveNet(batch_size=1,
                           dilations=[1, 2, 4, 8, 16, 32, 64, 128, 256,
                                      1, 2, 4, 8, 16, 32, 64, 128, 256],
                           filter_width=2,
                           residual_channels=16,
                           dilation_channels=16,
                           quantization_channels=256,
                           skip_channels=32)

    # Train a net on a short clip of 3 sine waves superimposed
    # (an e-flat chord).
    #
    # Presumably it can overfit to such a simple signal. This test serves
    # as a smoke test where we just check that it runs end-to-end during
    # training, and learns this waveform.

    def testEndToEndTraining(self):
        audio = MakeSineWaves()
        np.random.seed(42)

        audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
        loss = self.net.loss(audio_tensor)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.02)
        trainable = tf.trainable_variables()
        optim = optimizer.minimize(loss, var_list=trainable)
        init = tf.initialize_all_variables()

        max_allowed_loss = 0.1
        loss_val = max_allowed_loss
        initial_loss = None
        with self.test_session() as sess:
            sess.run(init)
            initial_loss = sess.run(loss)
            for i in range(50):
                loss_val, _ = sess.run([loss, optim])
                # print "i: %d loss: %f" % (i, loss_val)

        # Sanity check the initial loss was larger.
        self.assertGreater(initial_loss, max_allowed_loss)

        # Loss after training should be small.
        self.assertLess(loss_val, max_allowed_loss)

        # Loss should be at least two orders of magnitude better
        # than before training.
        self.assertLess(loss_val / initial_loss, 0.01)
    def __init__(self,input_C=96,input_L=1366,L_trans_channels=256):
        super(CascadeModel,self).__init__()
        self.input_C = input_C
        self.input_L = input_L
        self.first_block = nn.Sequential(nn.Conv1d(input_L,L_trans_channels,1),
                                         nn.Conv1d(L_trans_channels,L_trans_channels,3),
					 nn.BatchNorm1d(L_trans_channels),
					 nn.ReLU(),
                                         nn.Conv1d(L_trans_channels,L_trans_channels,3),
                                         nn.BatchNorm1d(L_trans_channels),
                                         nn.ELU(),
                                         nn.MaxPool1d(2),
                                         nn.Conv1d(L_trans_channels,L_trans_channels,3),
					 nn.BatchNorm1d(L_trans_channels),
					 nn.ReLU(),
                                         nn.Conv1d(L_trans_channels,L_trans_channels,3),
                                         nn.BatchNorm1d(L_trans_channels),
                                         nn.ELU(),
                                         nn.MaxPool1d(2),
                                         nn.Conv1d(L_trans_channels,L_trans_channels,3),
					 nn.BatchNorm1d(L_trans_channels),
					 nn.ReLU(),
                                         nn.Conv1d(L_trans_channels,L_trans_channels,3,padding=1),
                                         nn.BatchNorm1d(L_trans_channels),
                                         nn.ELU(),
                                         nn.MaxPool1d(2),
                                         nn.Conv1d(L_trans_channels,L_trans_channels,1),
                                         nn.BatchNorm1d(L_trans_channels),
                                         nn.ELU(),
                                         nn.Conv1d(L_trans_channels,input_L,1)                                    
                                         )
        
        self.wavenet = WaveNet(in_depth = 9,
                               dilation_channels=32,
                               res_channels=32,
                               skip_channels=256,
                               end_channels = 128,
                               dilation_depth = 6,
                               n_blocks = 5)

        self.post = nn.Sequential(nn.Dropout(p=0.2),
                                  nn.Linear(128,256),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.3),
                                  nn.Linear(256,50),
                                  nn.Sigmoid(), 
                                  )
示例#20
0
def test_assert_fast_generation():
    # get batch
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 32)
    length = h.shape[-1] - 1

    with torch.no_grad():
        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 2
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # --------------------------------------------------------
        # define model without upsampling and with kernel size = 3
        # --------------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 3)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # get batch
        batch = 2
        upsampling_factor = 10
        x = np.random.randint(0, 256, size=(batch, 1))
        h = np.random.randn(batch, 28, 3)
        length = h.shape[-1] * upsampling_factor - 1

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 2
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)

        # -----------------------------------------------------
        # define model with upsampling and with kernel size = 3
        # -----------------------------------------------------
        net = WaveNet(256, 28, 4, 4, 10, 3, 2, upsampling_factor)
        net.apply(initialize)
        net.eval()

        # sample-by-sample generation
        gen1_list = []
        gen2_list = []
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            gen1 = net.generate(batch_x, batch_h, length, 1, "argmax")
            gen2 = net.fast_generate(batch_x, batch_h, length, 1, "argmax")
            np.testing.assert_array_equal(gen1, gen2)
            gen1_list += [gen1]
            gen2_list += [gen2]
        gen1 = np.stack(gen1_list)
        gen2 = np.stack(gen2_list)
        np.testing.assert_array_equal(gen1, gen2)

        # batch generation
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        gen3_list = net.batch_fast_generate(batch_x, batch_h, [length] * batch,
                                            1, "argmax")
        gen3 = np.stack(gen3_list)
        np.testing.assert_array_equal(gen3, gen2)
示例#21
0
def test_generate():
    batch = 2
    x = np.random.randint(0, 256, size=(batch, 1))
    h = np.random.randn(batch, 28, 10)
    length = h.shape[-1] - 1
    with torch.no_grad():
        net = WaveNet(256, 28, 4, 4, 10, 3, 2)
        net.apply(initialize)
        net.eval()
        for x_, h_ in zip(x, h):
            batch_x = torch.from_numpy(np.expand_dims(x_, 0)).long()
            batch_h = torch.from_numpy(np.expand_dims(h_, 0)).float()
            net.generate(batch_x, batch_h, length, 1, "sampling")
            net.fast_generate(batch_x, batch_h, length, 1, "sampling")
        batch_x = torch.from_numpy(x).long()
        batch_h = torch.from_numpy(h).float()
        net.batch_fast_generate(batch_x, batch_h, [length] * batch, 1,
                                "sampling")
示例#22
0
def test_forward():
    # get batch
    generator = sine_generator(100)
    batch = next(generator)
    batch_input = batch.view(1, -1)
    batch_aux = torch.rand(1, 28, batch_input.size(1)).float()

    # define model without upsampling with kernel size = 2
    net = WaveNet(256, 28, 32, 128, 10, 1, 2)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256

    # define model without upsampling with kernel size = 3
    net = WaveNet(256, 28, 32, 128, 10, 1, 2)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256

    batch_input = batch.view(1, -1)
    batch_aux = torch.rand(1, 28, batch_input.size(1) // 10).float()

    # define model with upsampling and kernel size = 2
    net = WaveNet(256, 28, 32, 128, 10, 1, 2, 10)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256

    # define model with upsampling and kernel size = 3
    net = WaveNet(256, 28, 32, 128, 10, 1, 3, 10)
    net.apply(initialize)
    net.eval()
    y = net(batch_input, batch_aux)[0]
    assert y.size(0) == batch_input.size(1)
    assert y.size(1) == 256
示例#23
0
文件: model.py 项目: rhythm92/wavenet
    pass
filename = args.params_dir + "/{}".format(args.params_filename)
if os.path.isfile(filename):
    f = open(filename)
    try:
        dict = json.load(f)
        params = Params(dict)
    except:
        raise Exception("could not load {}".format(filename))

    params.gpu_enabled = True if args.gpu_enabled == 1 else False

    if args.use_faster_wavenet:
        wavenet = FasterWaveNet(params)
    else:
        wavenet = WaveNet(params)
else:
    params = Params()
    params.audio_channels = 256

    params.causal_conv_no_bias = True
    params.causal_conv_kernel_width = 2
    params.causal_conv_channels = [128]

    params.residual_conv_dilation_no_bias = True
    params.residual_conv_projection_no_bias = True
    params.residual_conv_kernel_width = 2
    params.residual_conv_channels = [32, 32, 32, 32, 32, 32, 32, 32, 32]
    params.residual_num_blocks = 5

    params.softmax_conv_no_bias = True
示例#24
0
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']
        ), "Number of datasets must match number of nodes"

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            states = torch.load(args.checkpoint)

            self.encoder.load_state_dict(states['encoder_state'])
            self.decoder.load_state_dict(states['decoder_state'])
            self.discriminator.load_state_dict(states['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        if args.distributed:
            self.encoder.cuda()
            self.encoder = torch.nn.parallel.DistributedDataParallel(
                self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(
                self.discriminator)
            self.logger.info('Created DistributedDataParallel')
        else:
            self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(
                self.discriminator).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        if args.checkpoint and args.load_optimizer:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])
            self.d_optimizer.load_state_dict(states['d_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()
示例#25
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        reader = AudioReader(args.data_dir,
                             coord,
                             sample_rate=wavenet_params['sample_rate'],
                             sample_size=args.sample_size)
        audio_batch = reader.dequeue(args.batch_size)

    # Create network.
    net = WaveNet(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"])
    loss = net.loss(audio_batch)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    try:
        last_saved_step = saved_global_step
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'.format(
                step, loss_value, duration))

            if step % 50 == 0:
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        # Introduce a line break after ^C is displayed so save message
        # is on its own line.
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
示例#26
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          iters_per_checkpoint, batch_size, seed, checkpoint_path):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cuda()

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    #trainset = Mel2SampOnehot(**data_config)
    trainset = DeepMels(**data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        total_loss = 0
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            model.zero_grad()

            x, y = batch
            x = to_gpu(x).float()
            y = to_gpu(y)
            x = (x, y)  # auto-regressive takes outputs as inputs
            y_pred = model(x)
            loss = criterion(y_pred, y)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus)[0]
            else:
                reduced_loss = loss.data[0]
            loss.backward()
            optimizer.step()

            total_loss += reduced_loss

            if (iteration % iters_per_checkpoint == 0):
                if rank == 0:
                    checkpoint_path = "{}/wavenet_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path)

            iteration += 1
        print("epoch:{}, total epoch loss:{}".format(epoch, total_loss))
示例#27
0
def main(args):
    print('Starting')
    matplotlib.use('agg')
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_args = torch.load(args.model.parent / 'args.pth')[0]
    encoder = wavenet_models.Encoder(model_args)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    decoders = []
    decoder_ids = []
    for checkpoint in checkpoints:
        decoder = WaveNet(model_args)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()
        if args.py:
            decoder = WavenetGenerator(decoder,
                                       args.batch_size,
                                       wav_freq=args.rate)
        else:
            decoder = NVWavenetGenerator(decoder,
                                         args.rate * (args.split_size // 20),
                                         args.batch_size, 3)

        decoders += [decoder]
        decoder_ids += [extract_id(checkpoint)]

    xs = []
    assert args.output_next_to_orig ^ (args.output is not None)

    if len(args.files) == 1 and args.files[0].is_dir():
        top = args.files[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = args.files

    if not args.skip_filter:
        file_paths = [f for f in file_paths if not '_' in str(f.name)]

    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=16000)
            assert rate == 16000
            data = utils.mu_law(data)
        elif file_path.suffix == '.h5':
            data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2**15))
            if data.shape[-1] % args.rate != 0:
                data = data[:-(data.shape[-1] % args.rate)]
            assert data.shape[-1] % args.rate == 0
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if args.sample_len:
            data = data[:args.sample_len]
        else:
            args.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_ix, filepath):
        wav = utils.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        if args.output_next_to_orig:
            save_audio(wav.squeeze(),
                       filepath.parent / f'{filepath.stem}_{decoder_ix}.wav',
                       rate=args.rate)
        else:
            save_audio(wav.squeeze(),
                       args.output / str(extract_id(args.model)) /
                       str(args.update) / filepath.with_suffix('.wav').name,
                       rate=args.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, args.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        with utils.timeit("Generation timer"):
            for i, decoder_id in enumerate(decoder_ids):
                yy[decoder_id] = []
                decoder = decoders[i]
                for zz_batch in torch.split(zz, args.batch_size):
                    print(zz_batch.shape)
                    splits = torch.split(zz_batch, args.split_size, -1)
                    audio_data = []
                    decoder.reset()
                    for cond in tqdm.tqdm(splits):
                        audio_data += [decoder.generate(cond).cpu()]
                    audio_data = torch.cat(audio_data, -1)
                    yy[decoder_id] += [audio_data]
                yy[decoder_id] = torch.cat(yy[decoder_id], dim=0)
                del decoder

    for decoder_ix, decoder_result in yy.items():
        for sample_result, filepath in zip(decoder_result, file_paths):
            save(sample_result, decoder_ix, filepath)
示例#28
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--feats",
                        required=True,
                        type=str,
                        help="directory or list of aux feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="hdf5 file including statistics")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the model")
    # network structure setting
    parser.add_argument("--n_quantize",
                        default=256,
                        type=int,
                        help="number of quantization")
    parser.add_argument("--n_aux",
                        default=28,
                        type=int,
                        help="number of dimension of aux feats")
    parser.add_argument("--n_resch",
                        default=512,
                        type=int,
                        help="number of channels of residual output")
    parser.add_argument("--n_skipch",
                        default=256,
                        type=int,
                        help="number of channels of skip output")
    parser.add_argument("--dilation_depth",
                        default=10,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_repeat",
                        default=1,
                        type=int,
                        help="number of repeating of dilation")
    parser.add_argument("--kernel_size",
                        default=2,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor",
                        default=0,
                        type=int,
                        help="upsampling factor of aux features"
                        "(if set 0, do not apply)")
    parser.add_argument("--use_speaker_code",
                        default=False,
                        type=strtobool,
                        help="flag to use speaker code")
    # network training setting
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="weight decay coefficient")
    parser.add_argument(
        "--batch_size",
        default=20000,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--iters",
                        default=200000,
                        type=int,
                        help="number of iterations")
    # other setting
    parser.add_argument("--checkpoints",
                        default=10000,
                        type=int,
                        help="how frequent saving model")
    parser.add_argument("--intervals",
                        default=100,
                        type=int,
                        help="log interval")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--resume",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # # define network
    model = WaveNet(n_quantize=args.n_quantize,
                    n_aux=args.n_aux,
                    n_resch=args.n_resch,
                    n_skipch=args.n_skipch,
                    dilation_depth=args.dilation_depth,
                    dilation_repeat=args.dilation_repeat,
                    kernel_size=args.kernel_size,
                    upsampling_factor=args.upsampling_factor)
    logging.info(model)
    model.apply(initialize)
    model.train()

    # define loss and optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # define transforms
    scaler = StandardScaler()
    scaler.mean_ = read_hdf5(args.stats, "/mean")
    scaler.scale_ = read_hdf5(args.stats, "/scale")
    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, args.n_quantize)])
    feat_transform = transforms.Compose([lambda x: scaler.transform(x)])

    # define generator
    if os.path.isdir(args.waveforms):
        filenames = sorted(
            find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [
            args.feats + "/" + filename.replace(".wav", ".h5")
            for filename in filenames
        ]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    generator = train_generator(wav_list,
                                feat_list,
                                receptive_field=model.receptive_field,
                                batch_size=args.batch_size,
                                wav_transform=wav_transform,
                                feat_transform=feat_transform,
                                shuffle=True,
                                upsampling_factor=args.upsampling_factor,
                                use_speaker_code=args.use_speaker_code)
    while not generator.queue.full():
        time.sleep(0.1)

    # resume
    if args.resume is not None:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        iterations = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % iterations)
    else:
        iterations = 0

    # send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    # train
    loss = 0
    total = 0
    for i in six.moves.range(iterations, args.iters):
        start = time.time()
        (batch_x, batch_h), batch_t = generator.next()
        batch_output = model(batch_x, batch_h)[0]
        batch_loss = criterion(batch_output[model.receptive_field:],
                               batch_t[model.receptive_field:])
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        loss += batch_loss.data[0]
        total += time.time() - start
        logging.debug("batch loss = %.3f (%.3f sec / batch)" %
                      (batch_loss.data[0], time.time() - start))

        # report progress
        if (i + 1) % args.intervals == 0:
            logging.info(
                "(iter:%d) average loss = %.6f (%.3f sec / batch)" %
                (i + 1, loss / args.intervals, total / args.intervals))
            logging.info(
                "estimated required time = "
                "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}".
                format(
                    relativedelta(seconds=int((args.iters - (i + 1)) *
                                              (total / args.intervals)))))
            loss = 0
            total = 0

        # save intermidiate model
        if (i + 1) % args.checkpoints == 0:
            save_checkpoint(args.expdir, model, optimizer, i + 1)

    # save final model
    model.cpu()
    torch.save({"model": model.state_dict()},
               args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
示例#29
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    logdir_root = directories['logdir_root']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into arbitrary location.
    is_new_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # create coordinator
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        custom_runner = CustomRunner(args, wavenet_params, coord)
        audio_batch, _ = custom_runner.get_inputs()

    # Create network.
    net = WaveNet(args.batch_size, wavenet_params["quantization_steps"],
                  wavenet_params["dilations"], wavenet_params["filter_width"],
                  wavenet_params["residual_channels"],
                  wavenet_params["dilation_channels"])
    loss = net.loss(audio_batch)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)

    # Set up logging for TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    # Set up session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver()

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_new_training or saved_global_step is None:
            # For "new" training with using pre-trained model,
            # We should ignore saved_global_step

            # The training step is start from saved_global_step + 1
            # Therefore put -1 here if the new training starts.
            saved_global_step = -1

    except:
        print("Something is wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    custom_runner.start_threads(sess)

    try:
        for step in range(saved_global_step + 1, args.num_steps):
            start_time = time.time()
            if args.store_metadata and step % 50 == 0:
                # Slow run that stores extra information for debugging.
                print('Storing metadata')
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                summary, loss_value, _ = sess.run([summaries, loss, optim],
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                writer.add_summary(summary, step)
                writer.add_run_metadata(run_metadata,
                                        'step_{:04d}'.format(step))
                tl = timeline.Timeline(run_metadata.step_stats)
                timeline_path = os.path.join(logdir, 'timeline.trace')
                with open(timeline_path, 'w') as f:
                    f.write(tl.generate_chrome_trace_format(show_memory=True))
            else:
                summary, loss_value, _ = sess.run([summaries, loss, optim])
                writer.add_summary(summary, step)

            duration = time.time() - start_time
            print('step %d - loss = %.3f, (%.3f sec/step)' %
                  (step, loss_value, duration))

            if step % 50 == 0:
                save(saver, sess, logdir, step)

    finally:
        coord.request_stop()
        coord.join(threads)
示例#30
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)

        assert args.checkpoint, 'you MUST pass a checkpoint for the encoder'

        if args.continue_training:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
        else:
            self.start_epoch = 0

        states = torch.load(args.checkpoint)
        self.encoder.load_state_dict(states['encoder_state'])
        if args.continue_training:
            self.decoder.load_state_dict(states['decoder_state'])
        self.logger.info('Loaded checkpoint parameters')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(self.decoder.parameters(),
                                          lr=args.lr)

        if args.continue_training:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # optimize G - reconstructs well
        z = self.encoder(x_aug)
        z = z.detach()  # stop gradients
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = recon_loss.mean()
        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_total.reset()

        self.encoder.eval()
        self.decoder.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.expPath / filename

        states = torch.load(self.args.checkpoint)

        torch.save(
            {
                'encoder_state': states['encoder_state'],
                'decoder_state': self.decoder.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')