Beispiel #1
0
def main():
    args = get_arguments()

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        print(str(e))
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    # Even if we restored the model, we will treat it as new training
    # if the trained model is written into an arbitrary location.
    is_overwritten_training = logdir != restore_from

    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # Create coordinator.
    coord = tf.train.Coordinator()

    # Load raw waveform from VCTK corpus.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > EPSILON else None

        gc_enabled = args.gc_channels is not None
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            gc_enabled=gc_enabled,
            receptive_field=WaveNetModel.calculate_receptive_field(
                wavenet_params['filter_width'], wavenet_params['dilations'],
                wavenet_params['scalar_input'],
                wavenet_params['initial_filter_width']),
            sample_size=args.sample_size,
            silence_threshold=silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)
        if gc_enabled:
            gc_id_batch = reader.dequeue_gc(args.batch_size)
        else:
            gc_id_batch = None
Beispiel #2
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = GPU
    args = get_arguments()
    # override the hparams
    if args.hparams is not None:
        hparams.parse(args.hparams)
    if not hparams.gc_enable:
        hparams.global_cardinality = None
        hparams.global_channel = None
    print(hparams_debug_string())

    try:
        directories = validate_directories(args)
    except ValueError as e:
        print("Some arguments are wrong:")
        return

    logdir = directories['logdir']
    restore_from = directories['restore_from']

    is_overwritten_training = logdir != restore_from

    coord = tf.train.Coordinator()

    args.train_txt = os.path.join(hparams.NPY_DATAROOT, args.train_txt)
    with tf.name_scope('create_input'):
        reader = DataFeeder(
            metadata_filename=args.train_txt,
            coord=coord,
            receptive_field=WaveNetModel.calculate_receptive_field(
                hparams.filter_width, hparams.dilations, hparams.scalar_input,
                hparams.initial_filter_width),
            gc_enable=hparams.gc_enable,
            sample_size=args.sample_size,
            npy_dataroot=hparams.NPY_DATAROOT,
            num_mels=hparams.lc_initial_channels,
            speaker_id=args.speaker_id)
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=hparams.dilations,
        filter_width=hparams.filter_width,
        residual_channels=hparams.residual_channels,
        dilation_channels=hparams.dilation_channels,
        skip_channels=hparams.skip_channels,
        quantization_channels=hparams.quantization_channels,
        use_biases=hparams.use_biases,
        scalar_input=hparams.scalar_input,
        initial_filter_width=hparams.initial_filter_width,
        histograms=args.histograms,
        local_condition_channel=hparams.lc_channels,
        lc_initial_channels=hparams.lc_initial_channels,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_factor=hparams.upsample_factor,
        global_cardinality=hparams.global_cardinality,
        global_channel=hparams.global_channel,
        is_training=True)

    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None

    trainable = tf.trainable_variables()

    # get global step
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)

    # decay learning rate
    # Calculate the learning rate schedule.
    decay_steps = hparams.NUM_STEPS_RATIO_PER_DECAY * args.num_steps
    # Decay the learning rate exponentially based on the number of steps.
    lr = tf.train.exponential_decay(args.learning_rate,
                                    global_step,
                                    decay_steps,
                                    hparams.LEARNING_RATE_DECAY_FACTOR,
                                    staircase=True)

    optimizer = optimizer_factory[args.optimizer](learning_rate=lr,
                                                  momentum=args.momentum)

    mul_batch_size = args.batch_size * args.num_gpus
    if hparams.gc_enable:
        audio_batch, lc_batch, gc_batch = reader.dequeue(mul_batch_size)
    else:
        audio_batch, lc_batch = reader.dequeue(mul_batch_size)
        gc_batch = None

    split_audio_batch = tf.split(value=audio_batch,
                                 num_or_size_splits=args.num_gpus,
                                 axis=0)
    split_lc_batch = tf.split(value=lc_batch,
                              num_or_size_splits=args.num_gpus,
                              axis=0)
    if hparams.gc_enable:
        split_gc_batch = tf.split(value=gc_batch,
                                  num_or_size_splits=args.num_gpus,
                                  axis=0)
    else:
        split_gc_batch = [None for _ in range(args.num_gpus)]

    # support multi gpu train
    tower_grads = []
    tower_losses = []
    with tf.variable_scope(tf.get_variable_scope()):
        for i in range(args.num_gpus):
            with tf.device('/gpu:{}'.format(i)):
                with tf.name_scope('losstower_{}'.format(i)) as scope:
                    loss = net.loss(input_batch=split_audio_batch[i],
                                    local_condition=split_lc_batch[i],
                                    global_condition=split_gc_batch[i],
                                    l2_regularization_strength=args.
                                    l2_regularization_strength,
                                    name=scope)
                    tf.get_variable_scope().reuse_variables()
                    tower_losses.append(loss)
                    grad_vars = optimizer.compute_gradients(loss,
                                                            var_list=trainable)
                    tower_grads.append(grad_vars)
    if args.num_gpus == 1:
        optim = optimizer.minimize(loss,
                                   var_list=trainable,
                                   global_step=global_step)
    else:
        loss = tf.reduce_mean(tower_losses)
        avg_grad = average_gradients(tower_grads)
        optim = optimizer.apply_gradients(avg_grad, global_step=global_step)

    # Track the moving averages of all trainable variables.
    variable_averages = tf.train.ExponentialMovingAverage(
        hparams.MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    train_op = tf.group(optim, variables_averages_op)

    # init the sess
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                            allow_soft_placement=True,
                                            gpu_options=tf.GPUOptions(
                                                allow_growth=True)))
    init = tf.global_variables_initializer()
    sess.run(init)

    saver = tf.train.Saver(var_list=tf.trainable_variables(),
                           max_to_keep=args.max_checkpoints)

    try:
        saved_global_step, sess = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            saved_global_step = 0
    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    step = None
    last_saved_step = saved_global_step

    try:
        print_loss = 0.
        start_time = time.time()
        for step in range(saved_global_step, args.num_steps):
            loss_value, _ = sess.run([loss, train_op])
            print_loss += loss_value

            if step % PRINT_LOSS_EVERY == 0:
                duration = time.time() - start_time
                now = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
                print('{:s}, step {:d} - loss = {:.6f}, ({:.3f} sec/step)'.
                      format(now, step, print_loss / PRINT_LOSS_EVERY,
                             duration / PRINT_LOSS_EVERY))
                start_time = time.time()
                print_loss = 0.

            if step % args.checkpoint_every == 0:
                #encoded_shape = sess.run([tf.shape(net.encoded)])
                #print(encoded_shape)
                target, predicted = sess.run([net.target, net.predicted])
                mat_path = logdir + "/tar_out_%d.mat" % step
                sio.savemat(
                    mat_path, {
                        'target_%d' % (step): target,
                        'output_%d' % (step): predicted,
                    })
                save(saver, sess, logdir, step)
                last_saved_step = step

    except KeyboardInterrupt:
        print()
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)
Beispiel #3
0
# In[2]:

train_sample = "train_samples/saber.wav"
layers = 10
blocks = 2
classes = 128
hidden_channels = 32
kernel_size = 8

use_cuda = torch.cuda.is_available()

# In[3]:

model = WaveNetModel(num_layers=layers,
                     num_blocks=blocks,
                     num_classes=classes,
                     hidden_channels=hidden_channels,
                     kernel_size=kernel_size)

if use_cuda:
    model.cuda()
    print("use cuda")

#print("model: ", model)
print("scope: ", model.scope)
print(model.parameter_count(), " parameters")

data = WaveNetData(train_sample,
                   input_length=model.scope,
                   target_length=model.last_block_scope,
                   num_classes=model.num_classes,
Beispiel #4
0
from model import WaveNetModel, Optimizer, WaveNetData

from IPython.display import Audio
from IPython.core.debugger import Tracer
from matplotlib import pyplot as plt
from matplotlib import pylab as pl
from IPython import display
import torch
import numpy as np
#%matplotlib inline
# get_ipython().magic('matplotlib notebook')

# In[10]:

model = WaveNetModel(num_blocks=2,
                     num_layers=12,
                     hidden_channels=128,
                     num_classes=256)
print('model: ', model)
print('scope: ', model.scope)
model = model.cuda()

# In[11]:

from scipy.io import wavfile

data = WaveNetData('../data/bach.wav',
                   input_length=model.scope,
                   target_length=model.last_block_scope,
                   num_classes=model.num_classes)
start_tensor = data.get_minibatch([30000])[0].squeeze()
plt.ion()
def main():
    args = get_arguments()
    if args.hparams is not None:
        hparams.parse(args.hparams)
    if not hparams.gc_enable:
        hparams.global_cardinality = None
        hparams.global_channel = None
    print(hparams_debug_string())

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                            gpu_options=tf.GPUOptions(
                                                allow_growth=True)))

    net = WaveNetModel(
        batch_size=1,
        dilations=hparams.dilations,
        filter_width=hparams.filter_width,
        residual_channels=hparams.residual_channels,
        dilation_channels=hparams.dilation_channels,
        skip_channels=hparams.skip_channels,
        quantization_channels=hparams.quantization_channels,
        use_biases=hparams.use_biases,
        scalar_input=hparams.scalar_input,
        initial_filter_width=hparams.initial_filter_width,
        local_condition_channel=hparams.num_mels,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_factor=hparams.upsample_factor,
        global_cardinality=hparams.global_cardinality,
        global_channel=hparams.global_channel)
    samples = tf.placeholder(tf.int32)
    local_ph = tf.placeholder(tf.float32, shape=(1, hparams.num_mels))

    sess.run(tf.global_variables_initializer())
    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    tmp_global_condition = None
    upsample_factor = audio.get_hop_size()

    generate_list = []
    with open(args.eval_txt, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            if line is not None:
                line = line.strip().split('|')
                npy_path = os.path.join(hparams.NPY_DATAROOT, line[1])
                tmp_local_condition = np.load(npy_path).astype(np.float32)
                if len(line) == 5:
                    tmp_global_condition = int(line[4])
                if hparams.global_channel is None:
                    tmp_global_condition = None
                generate_list.append(
                    (tmp_local_condition, tmp_global_condition, line[1]))

    for local_condition, global_condition, npy_path in generate_list:
        wav_id = npy_path.split('-mel')[0]
        wav_out_path = "wav/{}_gen.wav".format(wav_id)

        if not hparams.upsample_conditional_features:
            local_condition = np.repeat(local_condition,
                                        upsample_factor,
                                        axis=0)
        else:
            local_condition = np.expand_dims(local_condition, 0)
            local_condition = net.create_upsample(local_condition)
            local_condition = tf.squeeze(local_condition,
                                         [0]).eval(session=sess)
        next_sample = net.predict_proba_incremental(samples, local_ph,
                                                    global_condition)
        sess.run(net.init_ops)

        quantization_channels = hparams.quantization_channels

        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

        sample_len = local_condition.shape[0]
        for step in tqdm(range(0, sample_len)):

            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]

            # Run the WaveNet to predict the next sample.
            prediction = sess.run(outputs,
                                  feed_dict={
                                      samples: window,
                                      local_ph:
                                      local_condition[step:step + 1, :]
                                  })[0]

            # Scale prediction distribution using temperature.
            np.seterr(divide='ignore')
            scaled_prediction = np.log(prediction) / args.temperature
            scaled_prediction = (scaled_prediction -
                                 np.logaddexp.reduce(scaled_prediction))
            scaled_prediction = np.exp(scaled_prediction)
            np.seterr(divide='warn')
            # print(quantization_channels, scaled_prediction)
            sample = np.random.choice(np.arange(quantization_channels),
                                      p=scaled_prediction)
            waveform.append(sample)

            # If we have partial writing, save the result so far.
            if (wav_out_path and args.save_every
                    and (step + 1) % args.save_every == 0):
                out = P.inv_mulaw_quantize(np.array(waveform),
                                           quantization_channels)
                write_wav(out, hparams.sample_rate, wav_out_path)

                # Introduce a newline to clear the carriage return from the progress.
        print()
        # Save the result as a wav file.
        if wav_out_path:
            out = P.inv_mulaw_quantize(
                np.array(waveform).astype(np.int16), quantization_channels)
            # out = P.inv_mulaw_quantize(np.asarray(waveform), quantization_channels)
            write_wav(out, hparams.sample_rate, wav_out_path)
    print('Finished generating.')
Beispiel #6
0
def main():
    start = time.time()
    os.environ["CUDA_VISIBLE_DEVICES"] = GPU
    args = get_arguments()
    os.makedirs(args.wav_out_path, exist_ok=True)
    if args.hparams is not None:
        hparams.parse(args.hparams)
    if not hparams.gc_enable:
        hparams.global_cardinality = None
        hparams.global_channel = None
    print(hparams_debug_string())

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                            gpu_options=tf.GPUOptions(
                                                allow_growth=True)))

    net = WaveNetModel(
        batch_size=1,
        dilations=hparams.dilations,
        filter_width=hparams.filter_width,
        residual_channels=hparams.residual_channels,
        dilation_channels=hparams.dilation_channels,
        skip_channels=hparams.skip_channels,
        quantization_channels=hparams.quantization_channels,
        use_biases=hparams.use_biases,
        scalar_input=hparams.scalar_input,
        initial_filter_width=hparams.initial_filter_width,
        local_condition_channel=hparams.lc_channels,
        lc_initial_channels=hparams.lc_initial_channels,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_factor=hparams.upsample_factor,
        global_cardinality=hparams.global_cardinality,
        global_channel=hparams.global_channel,
        is_training=False)
    samples = tf.placeholder(tf.int32)
    local_ph = tf.placeholder(tf.float32,
                              shape=(1, net.local_condition_channel))

    sess.run(tf.global_variables_initializer())
    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    print('Restore is done succesfully!')

    tmp_global_condition = None
    upsample_factor = audio.get_hop_size()

    generate_list = []
    with open(os.path.join(WAV_OUT_PATH + "_dev", args.eval_txt),
              'r',
              encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            if line is not None:
                line = line.strip().split('|')
                npy_path = os.path.join(hparams.NPY_DATAROOT + "_dev", line[1])
                local_condition = np.load(npy_path)
                if hparams.triphone:
                    pre_phone = one_hot(local_condition[0],
                                        num_labels=net.lc_initial_channels)
                    cur_phone = one_hot(local_condition[1],
                                        num_labels=net.lc_initial_channels)
                    nxt_phone = one_hot(local_condition[2],
                                        num_labels=net.lc_initial_channels)
                    tmp_local_condition = np.concatenate(
                        (pre_phone, cur_phone, nxt_phone), axis=1)
                else:
                    tmp_local_condition = one_hot(
                        local_condition, num_labels=net.lc_initial_channels)

                h = tmp_local_condition
                tmp_local_condition = lc_averaging(tmp_local_condition)

                tmp_local_condition = tmp_local_condition.astype(np.float32)
                if len(line) == 5:
                    tmp_global_condition = int(line[4])
                if hparams.global_channel is None:
                    tmp_global_condition = None
                generate_list.append(
                    (tmp_local_condition, tmp_global_condition, line[1], h))

    for local_condition, global_condition, npy_path, h in generate_list:
        h_hat = local_condition
        wav_id = npy_path.split('-phone')[0]
        wav_out_path = os.path.join(args.wav_out_path,
                                    "{}_gen.wav".format(wav_id))

        if not hparams.upsample_conditional_features and hparams.lc_conv_layers < 1:
            local_condition = np.repeat(local_condition,
                                        upsample_factor,
                                        axis=0)
        elif hparams.upsample_conditional_features:
            local_condition = np.expand_dims(local_condition, 0)
            local_condition = net.create_upsample(local_condition)
            local_condition = tf.squeeze(local_condition,
                                         [0]).eval(session=sess)
        else:
            local_condition = np.expand_dims(local_condition, 0)
            local_condition, h_list = net._create_lc_conv_layer(
                local_condition)
            # h3 = local_condition.eval(session=sess)
            if hparams.lc_average:  # upsampling by repeat
                if hparams.lc_overlap:
                    lc_len = tf.shape(local_condition).eval(session=sess)
                    local_condition = lc_upsampling_by_repeat(
                        local_condition, hparams.average_window_shift)
                    mod = math.ceil(
                        lc_len[1] * hparams.average_window_shift /
                        256) * 256 - lc_len[1] * hparams.average_window_shift
                    edge = tf.slice(
                        local_condition,
                        [0, lc_len[1] * hparams.average_window_shift - 1, 0],
                        [-1, 1, -1])
                    edge = tf.tile(edge, [1, mod, 1])
                    local_condition = tf.concat([local_condition, edge],
                                                axis=1)
                else:
                    local_condition = lc_upsampling_by_repeat(
                        local_condition, hparams.average_window_len)
            local_condition = tf.squeeze(local_condition).eval(session=sess)

        h1 = h_list[0].eval(session=sess)
        h2 = h_list[1].eval(session=sess)
        h3 = h_list[2].eval(session=sess)
        mat_out_path = os.path.join(args.wav_out_path, "{}.mat".format(wav_id))
        sio.savemat(
            mat_out_path, {
                'C': local_condition,
                'h': h,
                'h_hat': h_hat,
                'h1': h1,
                'h2': h2,
                'h3': h3
            })

        next_sample = net.predict_proba_incremental(samples, local_ph,
                                                    global_condition)
        sess.run(net.init_ops)

        quantization_channels = hparams.quantization_channels

        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))

        sample_len = local_condition.shape[0]
        for step in tqdm(range(0, sample_len)):

            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]

            # Run the WaveNet to predict the next sample.
            prediction = sess.run(outputs,
                                  feed_dict={
                                      samples: window,
                                      local_ph:
                                      local_condition[step:step + 1, :]
                                  })[0]

            # Scale prediction distribution using temperature.
            np.seterr(divide='ignore')
            scaled_prediction = np.log(prediction) / args.temperature
            scaled_prediction = (scaled_prediction -
                                 np.logaddexp.reduce(scaled_prediction))
            scaled_prediction = np.exp(scaled_prediction)
            np.seterr(divide='warn')
            # print(quantization_channels, scaled_prediction)
            sample = np.random.choice(np.arange(quantization_channels),
                                      p=scaled_prediction)
            waveform.append(sample)

            # If we have partial writing, save the result so far.
            if (wav_out_path and args.save_every
                    and (step + 1) % args.save_every == 0):
                out = P.inv_mulaw_quantize(np.array(waveform),
                                           quantization_channels)
                write_wav(out, hparams.sample_rate, wav_out_path)

                # Introduce a newline to clear the carriage return from the progress.
        print()
        # Save the result as a wav file.
        if wav_out_path:
            out = P.inv_mulaw_quantize(
                np.array(waveform).astype(np.int16), quantization_channels)
            # out = P.inv_mulaw_quantize(np.asarray(waveform), quantization_channels)
            write_wav(out, hparams.sample_rate, wav_out_path)
    end = time.time()
    print('Finished generating.')
    print('It took %.2f seconds' % (end - start))
train_sample = "train_samples/saber.wav"
parameters = "model_parameters/saber_10-2-128-32-8"
layers = 10
blocks = 2
classes = 128
hidden_channels = 32
kernel_size = 8

use_cuda = torch.cuda.is_available()

# In[3]:

model = WaveNetModel(num_layers=layers,
                     num_blocks=blocks,
                     num_classes=classes,
                     hidden_channels=hidden_channels,
                     kernel_size=kernel_size)

if use_cuda:
    model.cuda()
    print("use cuda")

model.load_state_dict(torch.load(parameters))
print("parameter count: ", model.parameter_count())

data = WaveNetData(train_sample,
                   input_length=model.scope,
                   target_length=model.last_block_scope,
                   num_classes=model.num_classes,
                   cuda=use_cuda)
Beispiel #8
0
def generate_subband(subband_id,
                     checkpoint,
                     local_condition,
                     global_condition,
                     subband_queue,
                     temperature=TEMPERATURE):
    print('generating subband d%d' % subband_id)
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                            gpu_options=tf.GPUOptions(
                                                allow_growth=True)))
    restore_net_start = time.time()
    net = WaveNetModel(
        batch_size=1,
        dilations=hparams.dilations,
        filter_width=hparams.filter_width,
        residual_channels=hparams.residual_channels,
        dilation_channels=hparams.dilation_channels,
        skip_channels=hparams.skip_channels,
        quantization_channels=hparams.quantization_channels,
        use_biases=hparams.use_biases,
        scalar_input=hparams.scalar_input,
        initial_filter_width=hparams.initial_filter_width,
        local_condition_channel=hparams.lc_channels,
        lc_initial_channels=hparams.lc_initial_channels,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_factor=hparams.upsample_factor,
        global_cardinality=hparams.global_cardinality,
        global_channel=hparams.global_channel,
        is_training=False)
    sess.run(tf.global_variables_initializer())
    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)
    saver.restore(sess, checkpoint)
    restore_net_end = time.time()
    print('Restore is done succesfully for d%d!' % subband_id)
    print('it took %.2f (sec.)' % (restore_net_end - restore_net_start))

    samples = tf.placeholder(tf.int32)  # 8 subbands
    local_ph = tf.placeholder(tf.float32,
                              shape=(1, net.local_condition_channel))
    subband = 'd%d' % (subband_id + 1)
    next_sample = net.predict_proba_incremental(subband, samples, local_ph,
                                                global_condition)
    sess.run(net.init_ops)
    # Silence with a single random sample at the end.
    waveform = [net.quantization_channels / 2] * (net.receptive_field - 1)
    waveform.append(np.random.randint(net.quantization_channels))

    sample_len = local_condition.shape[0]
    for step in range(0, sample_len):
        outputs = [next_sample]
        outputs.extend(net.push_ops)
        window = waveform[-1]

        # Run the WaveNet to predict the next sample.
        prediction = sess.run(outputs,
                              feed_dict={
                                  samples: window,
                                  local_ph: local_condition[step:step + 1, :]
                              })[0]
        # Scale prediction distribution using temperature.
        # prediction = np.random.randint(0, 255, 256)
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / temperature
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')
        # print(quantization_channels, scaled_prediction)
        sample = np.random.choice(np.arange(net.quantization_channels),
                                  p=scaled_prediction)

        waveform.append(sample)
    subband_queue.put([subband_id, waveform])
    subband_queue.task_done()
Beispiel #9
0
def main():
    start = time.time()
    os.environ["CUDA_VISIBLE_DEVICES"] = GPU
    args = get_arguments()
    os.makedirs(args.wav_out_path, exist_ok=True)
    if args.hparams is not None:
        hparams.parse(args.hparams)
    if not hparams.gc_enable:
        hparams.global_cardinality = None
        hparams.global_channel = None
    print(hparams_debug_string())

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                            gpu_options=tf.GPUOptions(
                                                allow_growth=True)))
    restore_net_start = time.time()
    net = WaveNetModel(
        batch_size=1,
        dilations=hparams.dilations,
        filter_width=hparams.filter_width,
        residual_channels=hparams.residual_channels,
        dilation_channels=hparams.dilation_channels,
        skip_channels=hparams.skip_channels,
        quantization_channels=hparams.quantization_channels,
        use_biases=hparams.use_biases,
        scalar_input=hparams.scalar_input,
        initial_filter_width=hparams.initial_filter_width,
        local_condition_channel=hparams.lc_channels,
        lc_initial_channels=hparams.lc_initial_channels,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_factor=hparams.upsample_factor,
        global_cardinality=hparams.global_cardinality,
        global_channel=hparams.global_channel,
        is_training=False)
    sess.run(tf.global_variables_initializer())
    variables_to_restore = {
        var.name[:-2]: var
        for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)
    }
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)
    print('Restore is done succesfully!')
    restore_net_end = time.time()
    print('it took %.2f (sec.)' % (restore_net_end - restore_net_start))

    tmp_global_condition = None
    upsample_factor = audio.get_hop_size()

    generate_list = []
    with open(args.eval_txt, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            if line is not None:
                line = line.strip().split('|')
                target = np.load(os.path.join(hparams.NPY_DATAROOT, line[0]))
                npy_path = os.path.normpath(
                    os.path.join(hparams.NPY_DATAROOT, line[1]))
                tmp_local_condition = one_hot(
                    np.load(npy_path), num_labels=hparams.lc_initial_channels)
                h = tmp_local_condition
                tmp_local_condition = lc_averaging(tmp_local_condition)
                tmp_local_condition = tmp_local_condition.astype(np.float32)
                if len(line) == 5:
                    tmp_global_condition = int(line[4])
                if hparams.global_channel is None:
                    tmp_global_condition = None
                generate_list.append(
                    (tmp_local_condition, tmp_global_condition, line[1],
                     target, h))

    for local_condition, global_condition, npy_path, target, h in generate_list:
        wav_id = npy_path.split('/')[-1]
        wav_id = wav_id.split('-phone')[0]
        wav_id = wav_id.replace("-", "_")
        mat_out_path = os.path.normpath(
            os.path.join(args.wav_out_path, "{}.mat".format(wav_id)))

        if not hparams.upsample_conditional_features and hparams.lc_conv_layers < 1:
            local_condition = np.repeat(local_condition,
                                        upsample_factor,
                                        axis=0)
        elif hparams.upsample_conditional_features:
            local_condition = np.expand_dims(local_condition, 0)
            local_condition = net.create_upsample(local_condition)
            local_condition = tf.squeeze(local_condition,
                                         [0]).eval(session=sess)
        else:
            local_condition = np.expand_dims(local_condition, 0)
            local_condition, h_list = net._create_lc_conv_layer(
                local_condition)
            if hparams.lc_average and not hparams.lc_overlap:
                local_condition = lc_upsampling_by_repeat(
                    local_condition, hparams.average_window_len)
            local_condition = tf.squeeze(local_condition).eval(session=sess)
        print('conditional features are made sucessfully!')

        # option1: ----- without threading ----
        # for i in range(8):
        #     generate_subband(i, args.checkpoint, local_condition, global_condition, subband_q, args.temperature)

        # option2: ----- with threading ----
        subband_q = queue.Queue()
        threads = [
            threading.Thread(target=generate_subband,
                             args=(i, args.checkpoint, local_condition,
                                   global_condition, subband_q,
                                   args.temperature)) for i in range(8)
        ]
        for t in threads:
            t.daemon = True  # Thread will close when parent quits.
            t.start()
        subband_q.join()
        out1 = [None] * 8
        predicted1_256 = [None] * 8
        for _ in range(8):
            [subband_id, x] = subband_q.get()
            predicted1_256[subband_id] = x
            out1[subband_id] = P.inv_mulaw_quantize(
                np.array(x).astype(np.int16), net.quantization_channels)
            out1[subband_id] = denormalize(out1[subband_id],
                                           'd%d' % (subband_id + 1))

        if mat_out_path:
            sio.savemat(mat_out_path, {'predicted1': out1, 'target': target})
    end = time.time()
    print('Finished generating. Estimated time: %.3f sec.' % (end - start))