-
Notifications
You must be signed in to change notification settings - Fork 0
/
inputs.py
171 lines (143 loc) · 6.93 KB
/
inputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Input pipeline for DCASE 2018 Task 2 Baseline models."""
import functools
import os
import numpy as np
from scipy.io import wavfile
import tensorflow as tf
import librosa
from tensorflow.contrib.framework.python.ops import audio_ops as tf_audio
# All input clips use a 44.1 kHz sample rate. Define SAMPLE_RATE based on sample_rate input
# default sample_rate
SAMPLE_RATE = 44100
def clip_to_waveform(clip, clip_dir=None):
"""Decodes a WAV clip into a waveform tensor."""
'''
data , sampling_rate = librosa.load('data/sound.wav', sr=SAMPLE_RATE)
# for use in tensorflow
data_tensor = tf.convert_to_tensor( data )
'''
# Decode the WAV-format clip into a waveform tensor where
# the values lie in [-1, +1].
clip_path = tf.string_join([clip_dir, clip], separator=os.sep)
clip_data = tf.read_file(clip_path)
waveform, sr = tf_audio.decode_wav(clip_data)
# Assert that the clip has the expected sample rate.
check_sr = tf.assert_equal(sr, sr)
# and check that it is mono.
check_channels = tf.assert_equal(tf.shape(waveform)[1], 1)
with tf.control_dependencies([tf.group(check_sr, check_channels)]):
return tf.squeeze(waveform)
def clip_to_log_mel_examples(clip, clip_dir=None, hparams=None):
"""Decodes a WAV clip into a batch of log mel spectrum examples."""
# Decode WAV clip into waveform tensor.
waveform = clip_to_waveform(clip, clip_dir=clip_dir)
# Convert waveform into spectrogram using a Short-Time Fourier Transform.
# Note that tf.contrib.signal.stft() uses a periodic Hann window by default.
window_length_samples = int(round(SAMPLE_RATE * hparams.stft_window_seconds))
hop_length_samples = int(round(SAMPLE_RATE * hparams.stft_hop_seconds))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
magnitude_spectrogram = tf.abs(tf.contrib.signal.stft(
signals=waveform,
frame_length=window_length_samples,
frame_step=hop_length_samples,
fft_length=fft_length))
# Convert spectrogram into log mel spectrogram.
num_spectrogram_bins = fft_length // 2 + 1
linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
num_mel_bins=hparams.mel_bands,
num_spectrogram_bins=num_spectrogram_bins,
sample_rate=SAMPLE_RATE,
lower_edge_hertz=hparams.mel_min_hz,
upper_edge_hertz=hparams.mel_max_hz)
mel_spectrogram = tf.matmul(magnitude_spectrogram, linear_to_mel_weight_matrix)
log_mel_spectrogram = tf.log(mel_spectrogram + hparams.mel_log_offset)
# Frame log mel spectrogram into examples.
spectrogram_sr = 1 / hparams.stft_hop_seconds
example_window_length_samples = int(round(spectrogram_sr * hparams.example_window_seconds))
example_hop_length_samples = int(round(spectrogram_sr * hparams.example_hop_seconds))
features = tf.contrib.signal.frame(
signal=log_mel_spectrogram,
frame_length=example_window_length_samples,
frame_step=example_hop_length_samples,
axis=0)
return features
def record_to_labeled_log_mel_examples(csv_record, clip_dir=None, hparams=None,
label_class_index_table=None, num_classes=None):
"""Creates a batch of log mel spectrum examples from a training record.
Args:
csv_record: a line from the train.csv file downloaded from Kaggle.
clip_dir: path to a directory containing clips referenced by csv_record.
hparams: tf.contrib.training.HParams object containing model hyperparameters.
label_class_index_table: a lookup table that represents the class map.
num_classes: number of classes in the class map.
Returns:
features: Tensor containing a batch of log mel spectrum examples.
labels: Tensor containing corresponding labels in 1-hot format.
"""
[clip, label, _] = tf.decode_csv(csv_record, record_defaults=[[''],[''],[0]])
features = clip_to_log_mel_examples(clip, clip_dir=clip_dir, hparams=hparams)
class_index = label_class_index_table.lookup(label)
label_onehot = tf.one_hot(class_index, num_classes)
num_examples = tf.shape(features)[0]
labels = tf.tile([label_onehot], [num_examples, 1])
return features, labels
def get_class_map(class_map_path):
"""Constructs a class label lookup table from a class map."""
label_class_index_table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileInitializer(
filename=class_map_path,
key_dtype=tf.string, key_index=1,
value_dtype=tf.int32, value_index=0,
delimiter=','),
default_value=-1)
num_classes = len(open(class_map_path).readlines())
return label_class_index_table, num_classes
def train_input(train_csv_path=None, train_clip_dir=None, class_map_path=None, hparams=None, sample_rate=None):
#Define SAMPLE_RATE
SAMPLE_RATE = sample_rate
"""Creates training input pipeline.
Args:
train_csv_path: path to the train.csv file provided by Kaggle.
train_clip_dir: path to the unzipped audio_train/ directory from the
audio_train.zip file provided by Kaggle.
class_map_path: path to the class map prepared from the training data.
hparams: tf.contrib.training.HParams object containing model hyperparameters
Returns:
features: Tensor containing a batch of log mel spectrum examples.
labels: Tensor containing corresponding labels in 1-hot format.
num_classes: number of classes.
iter_init: an initializer op for the iterator that provides features and
labels, to be run before the input pipeline is read.
"""
label_class_index_table, num_classes = get_class_map(class_map_path)
dataset = tf.data.TextLineDataset(train_csv_path)
# Skip the header.
dataset = dataset.skip(1)
# Shuffle the list of clips. 10K is big enough to cover all clips.
dataset = dataset.shuffle(buffer_size=10000)
# Map each clip to a batch of framed log mel spectrum examples.
dataset = dataset.map(
map_func=functools.partial(
record_to_labeled_log_mel_examples,
clip_dir=train_clip_dir,
hparams=hparams,
label_class_index_table=label_class_index_table,
num_classes=num_classes),
# 4 is empirically chosen to use 4 logical CPU cores. Adjust as
# needed if more or less resources are available.
num_parallel_calls=4)
# Unbatch so that we have a dataset of individual examples that we can then
# shuffle for training. 20K should be enough to allow shuffling across a
# few hundred clips which are already in random order.
dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.shuffle(buffer_size=20000)
# Run until we have completed 100 epochs of the training set.
dataset = dataset.repeat(100)
# Batch examples.
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size=hparams.batch_size))
# Let the input pipeline run a few batches ahead so that the model is
# never starved of data.
dataset = dataset.prefetch(10)
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
return features, labels, num_classes, iterator.initializer