def encode_signal(self, inputs, add_noise=False): ### # Encode the source with 8-bit Mu-Law or just use 16-bit signal. ### quant_chann = self.quant_chann use_mu_law = self.use_mu_law x = inputs['wav'] if use_mu_law: x_quantized = utils.mu_law(x) x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2.) real_targets = x_scaled cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(quant_chann / 2., tf.int32) else: x_quantized = utils.cast_quantize(x, quant_chann) x_scaled = x real_targets = x cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(quant_chann / 2., tf.int32) if add_noise: # only used when the wavenet is trained as a teacher. x_scaled += tf.random_normal(shape=x_scaled.get_shape(), mean=0.0, stddev=0.1) return {'wav_scaled': x_scaled, 'real_targets': real_targets, 'cate_targets': cate_targets}
def encode_signal(self, inputs): ### # Encode the source with 8-bit Mu-Law or just use 16-bit signal. ### quant_chann = self.quant_chann use_mu_law = self.use_mu_law x = inputs['wav'] if use_mu_law: x_quantized = utils.mu_law(x) x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2.) real_targets = x_scaled cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast( quant_chann / 2., tf.int32) else: x_quantized = utils.cast_quantize(x, quant_chann) x_scaled = x real_targets = x cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast( quant_chann / 2., tf.int32) return { 'wav_scaled': x_scaled, 'real_targets': real_targets, 'cate_targets': cate_targets }
def sample(self, inputs): """Build the graph for this configuration. Args: inputs: A dict of inputs. For training, should contain 'wav'. Returns: A dict of outputs that includes the 'predictions', 'init_ops', the 'push_ops', and the 'quantized_input'. """ batch_size = self.batch_size num_stages = self.hparams.num_stages num_layers = self.hparams.num_layers filter_length = self.hparams.filter_length width = self.hparams.width skip_width = self.hparams.skip_width use_mu_law = self.use_mu_law use_weight_norm = self.use_weight_norm quant_chann = self.quant_chann out_width = self.out_width deconv_width = self.hparams.deconv_width loss_type = self.loss_type gate_width = 2 * width if self.double_gate_width else width use_dropout = self.use_dropout # mel information is trans_conv_stack output, different from wavenet.feed_forward mel_en = inputs['encoding'] # [batch_size, deconv_width] mel_en = tf.expand_dims(mel_en, 1) # [batch_size, 1, deconv_width] x = inputs['wav'] # [batch_size, 1] if use_mu_law: # Encode the source with 8-bit Mu-Law. x_quantized = utils.mu_law(x) x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2) else: x_scaled = x x_scaled = tf.expand_dims(x_scaled, 2) # [batch_size, 1, 1] init_ops, push_ops = [], [] ### # The WaveNet Decoder. ### l = x_scaled l, inits, pushs = masked.causal_linear(x=l, n_inputs=1, n_outputs=width, name='conv_start', rate=1, batch_size=batch_size, filter_length=filter_length, use_weight_norm=use_weight_norm) if use_dropout: l = tf.layers.dropout(l, rate=0.2, training=False, name='conv_dropout') for init in inits: init_ops.append(init) for push in pushs: push_ops.append(push) # Set up skip connections. s = masked.linear(l, width, skip_width, name='skip_start', use_weight_norm=use_weight_norm) # Residual blocks with skip connections. for i in range(num_layers): dilation = 2**(i % num_stages) # dilated masked cnn d, inits, pushs = masked.causal_linear( x=l, n_inputs=width, n_outputs=gate_width, name='dilated_conv_%d' % (i + 1), rate=dilation, batch_size=batch_size, filter_length=filter_length, use_weight_norm=use_weight_norm) for init in inits: init_ops.append(init) for push in pushs: push_ops.append(push) # local conditioning d += masked.linear(mel_en, deconv_width, gate_width, name='mel_cond_%d' % (i + 1), use_weight_norm=use_weight_norm) # gated cnn assert d.get_shape().as_list()[2] % 2 == 0 m = d.get_shape().as_list()[2] // 2 d = tf.sigmoid(d[:, :, :m]) * tf.tanh(d[:, :, m:]) # residuals l += masked.linear(d, gate_width // 2, width, name='res_%d' % (i + 1), use_weight_norm=use_weight_norm) # skips s += masked.linear(d, gate_width // 2, skip_width, name='skip_%d' % (i + 1), use_weight_norm=use_weight_norm) # dropout if use_dropout: l = tf.layers.dropout(l, rate=0.2, training=False, name='res_dropout_%d' % (i + 1)) s = tf.nn.relu(s) s = (masked.linear(s, skip_width, skip_width, name='out1', use_weight_norm=use_weight_norm) + masked.linear(mel_en, deconv_width, skip_width, name='mel_cond_out1', use_weight_norm=use_weight_norm)) s = tf.nn.relu(s) out = masked.linear( s, skip_width, out_width, name='out2', use_weight_norm=use_weight_norm) # [batch_size, 1, out_width] if loss_type == 'ce': sample = loss_func.ce_sample(out, quant_chann) elif loss_type == 'mol': sample = loss_func.mol_sample(out, quant_chann) elif loss_type == 'gauss': sample = loss_func.gauss_sample(out, quant_chann) else: raise ValueError('[{}] loss is not supported.'.format(loss_type)) return {'init_ops': init_ops, 'push_ops': push_ops, 'sample': sample}
def feed_forward(self, inputs): """Build the graph for this configuration. Args: inputs: A dict of inputs. For training, should contain 'wav'. Returns: A dict of outputs that includes the 'predictions', 'loss', the 'encoding', the 'quantized_input', and whatever metrics we want to track for eval. """ num_stages = self.hparams.num_stages num_layers = self.hparams.num_layers filter_length = self.hparams.filter_length width = self.hparams.width skip_width = self.hparams.skip_width use_mu_law = self.use_mu_law quant_chann = self.quant_chann out_width = self.out_width ### # The Transpose Convolution Stack for mel feature. ### # wavenet inputs <- trans_conv (l2, s2) <- trans_conv (l1, s1) <- mel_ceps # win_len: l1 * s2 + (l2 - s2); win_shift: s1 * s2 # (l1, s1) = (40, 10), (l2, s2) = (80, 20) is a proper configuration. # it is almost consistent with mel analysis frame shift (200) and frame length (800). mel = inputs['mel'] ds_dict = self.deconv_stack({'mel': mel}) mel_en = ds_dict['encoding'] ### # Encode the source with 8-bit Mu-Law or just use 16-bit signal. ### x = inputs['wav'] if use_mu_law: x_quantized = utils.mu_law(x) x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2.) real_targets = x_scaled cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast( quant_chann / 2., tf.int32) else: x_quantized = utils.cast_quantize(x, quant_chann) x_scaled = x real_targets = x cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast( quant_chann / 2., tf.int32) x_scaled = tf.expand_dims(x_scaled, 2) ### # The WaveNet Decoder. ### l = masked.shift_right(x_scaled) l = masked.conv1d(l, num_filters=width, filter_length=filter_length, name='startconv') # Set up skip connections. s = masked.conv1d(l, num_filters=skip_width, filter_length=1, name='skip_start') # Residual blocks with skip connections. for i in range(num_layers): dilation = 2**(i % num_stages) d = masked.conv1d(l, num_filters=2 * width, filter_length=filter_length, dilation=dilation, name='dilated_conv_%d' % (i + 1)) c = masked.conv1d(mel_en, num_filters=2 * width, filter_length=1, name='mel_cond_%d' % (i + 1)) d = _condition(d, c) assert d.get_shape().as_list()[2] % 2 == 0 m = d.get_shape().as_list()[2] // 2 d_sigmoid = tf.sigmoid(d[:, :, :m]) d_tanh = tf.tanh(d[:, :, m:]) d = d_sigmoid * d_tanh l += masked.conv1d(d, num_filters=width, filter_length=1, name='res_%d' % (i + 1)) s += masked.conv1d(d, num_filters=skip_width, filter_length=1, name='skip_%d' % (i + 1)) s = tf.nn.relu(s) s = masked.conv1d(s, num_filters=skip_width, filter_length=1, name='out1') c = masked.conv1d(mel_en, num_filters=skip_width, filter_length=1, name='mel_cond_out1') s = _condition(s, c) s = tf.nn.relu(s) # when using mol loss, the model always predicts log_scale, the initializer makes # the log_scale in a reasonable small range to speed up convergence. final_kernel_init = (tf.truncated_normal_initializer(0.0, 0.01) if self.loss_type == 'mol' else tf.uniform_unit_scaling_initializer(1.0)) out = masked.conv1d(s, num_filters=out_width, filter_length=1, name='out2', kernel_initializer=final_kernel_init) return { 'real_targets': real_targets, 'cate_targets': cate_targets, 'encoding': mel_en, 'out_params': out }