示例#1
0
    def _data_dep_init():
        inputs_val = reader.get_init_batch(pwn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=pwn.wave_length)
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        init_ff_dict = pwn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Running data dependent initialization '
                            'for weight normalization')
            init_out = session.run(init_ff_dict,
                                   feed_dict={_inputs_dict['mel']: mel_data})
            new_x = init_out['x']
            mean = init_out['mean_tot']
            scale = init_out['scale_tot']
            _init_logging(new_x, 'new_x')
            _init_logging(mean, 'mean')
            _init_logging(scale, 'scale')
            tf.logging.info('Done data dependent initialization '
                            'for weight normalization')

        return callback
示例#2
0
    def _data_dep_init():
        # slim.learning.train runs init_fn earlier than start_queue_runner
        # so the the function got dead locker if use the `input_dict` in L76 as input
        inputs_val = reader.get_init_batch(
            wn.train_path, batch_size=args.total_batch_size, seq_len=wn.wave_length)
        wave_data = inputs_val['wav']
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'wav': tf.placeholder(dtype=tf.float32, shape=wave_data.shape),
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)}

        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        init_ff_dict = wn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Calculate initial statistics.')
            init_out = session.run(
                init_ff_dict, feed_dict={_inputs_dict['wav']: wave_data,
                                         _inputs_dict['mel']: mel_data})
            init_out_params = init_out['out_params']
            if wn.loss_type == 'mol':
                _, mean, log_scale = np.split(init_out_params, 3, axis=2)
                scale = np.exp(np.maximum(log_scale, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(scale, 'scale')
            elif wn.loss_type == 'gauss':
                mean, log_std = np.split(init_out_params, 2, axis=2)
                std = np.exp(np.maximum(log_std, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(std, 'std')
            tf.logging.info('Done Calculate initial statistics.')
        return callback
示例#3
0
def test_np_reader():
    input_dict = reader.get_init_batch(tfr_path, 4, first_n=10)
    waves = input_dict['wav']
    out_dir = 'test_reader'
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir, exist_ok=True)
    for i, wave in enumerate(waves):
        save_name = '{}/test_reader-{}.wav'.format(out_dir, i)
        librosa.output.write_wav(save_name, wave, sr=16000)
def spec_feat_mean_std(train_path, feat_fn=lambda x: tf.pow(tf.abs(x), 2.0)):
    local_graph = tf.Graph()
    with local_graph.as_default():
        input_vals = reader.get_init_batch(train_path,
                                           batch_size=4096,
                                           seq_len=7680,
                                           first_n=10000)['wav']
        ph = tf.placeholder(dtype=np.float32, shape=[4096, 7680])
        feat = feat_fn(_tf_stft(ph))

    tf.logging.info('Calculating mean and std for stft feat.')
    config = tf.ConfigProto(device_count={'GPU': 0})
    sess = tf.Session(config=config, graph=local_graph)
    feat_val = sess.run(feat, feed_dict={ph: input_vals})
    mean_val = np.mean(feat_val, axis=(0, 1))
    std_val = np.std(feat_val, axis=(0, 1))
    tf.logging.info('Done calculating mean and std for stft feat.')

    return mean_val, std_val