コード例 #1
0
    def encode_signal(self, inputs, add_noise=False):
        ###
        # Encode the source with 8-bit Mu-Law or just use 16-bit signal.
        ###
        quant_chann = self.quant_chann
        use_mu_law = self.use_mu_law

        x = inputs['wav']
        if use_mu_law:
            x_quantized = utils.mu_law(x)
            x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2.)
            real_targets = x_scaled
            cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(quant_chann / 2., tf.int32)
        else:
            x_quantized = utils.cast_quantize(x, quant_chann)
            x_scaled = x
            real_targets = x
            cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(quant_chann / 2., tf.int32)

        if add_noise:
            # only used when the wavenet is trained as a teacher.
            x_scaled += tf.random_normal(shape=x_scaled.get_shape(), mean=0.0, stddev=0.1)

        return {'wav_scaled': x_scaled,
                'real_targets': real_targets,
                'cate_targets': cate_targets}
コード例 #2
0
    def encode_signal(self, inputs):
        ###
        # Encode the source with 8-bit Mu-Law or just use 16-bit signal.
        ###
        quant_chann = self.quant_chann
        use_mu_law = self.use_mu_law

        x = inputs['wav']
        if use_mu_law:
            x_quantized = utils.mu_law(x)
            x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2.)
            real_targets = x_scaled
            cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(
                quant_chann / 2., tf.int32)
        else:
            x_quantized = utils.cast_quantize(x, quant_chann)
            x_scaled = x
            real_targets = x
            cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(
                quant_chann / 2., tf.int32)

        return {
            'wav_scaled': x_scaled,
            'real_targets': real_targets,
            'cate_targets': cate_targets
        }
コード例 #3
0
    def sample(self, inputs):
        """Build the graph for this configuration.

        Args:
          inputs: A dict of inputs. For training, should contain 'wav'.

        Returns:
          A dict of outputs that includes the 'predictions',
          'init_ops', the 'push_ops', and the 'quantized_input'.
        """
        batch_size = self.batch_size
        num_stages = self.hparams.num_stages
        num_layers = self.hparams.num_layers
        filter_length = self.hparams.filter_length
        width = self.hparams.width
        skip_width = self.hparams.skip_width
        use_mu_law = self.use_mu_law
        use_weight_norm = self.use_weight_norm
        quant_chann = self.quant_chann
        out_width = self.out_width
        deconv_width = self.hparams.deconv_width
        loss_type = self.loss_type
        gate_width = 2 * width if self.double_gate_width else width
        use_dropout = self.use_dropout

        # mel information is trans_conv_stack output, different from wavenet.feed_forward
        mel_en = inputs['encoding']  # [batch_size, deconv_width]
        mel_en = tf.expand_dims(mel_en, 1)  # [batch_size, 1, deconv_width]

        x = inputs['wav']  # [batch_size, 1]
        if use_mu_law:
            # Encode the source with 8-bit Mu-Law.
            x_quantized = utils.mu_law(x)
            x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2)
        else:
            x_scaled = x
        x_scaled = tf.expand_dims(x_scaled, 2)  # [batch_size, 1, 1]

        init_ops, push_ops = [], []

        ###
        # The WaveNet Decoder.
        ###
        l = x_scaled
        l, inits, pushs = masked.causal_linear(x=l,
                                               n_inputs=1,
                                               n_outputs=width,
                                               name='conv_start',
                                               rate=1,
                                               batch_size=batch_size,
                                               filter_length=filter_length,
                                               use_weight_norm=use_weight_norm)
        if use_dropout:
            l = tf.layers.dropout(l,
                                  rate=0.2,
                                  training=False,
                                  name='conv_dropout')

        for init in inits:
            init_ops.append(init)
        for push in pushs:
            push_ops.append(push)

        # Set up skip connections.
        s = masked.linear(l,
                          width,
                          skip_width,
                          name='skip_start',
                          use_weight_norm=use_weight_norm)

        # Residual blocks with skip connections.
        for i in range(num_layers):
            dilation = 2**(i % num_stages)

            # dilated masked cnn
            d, inits, pushs = masked.causal_linear(
                x=l,
                n_inputs=width,
                n_outputs=gate_width,
                name='dilated_conv_%d' % (i + 1),
                rate=dilation,
                batch_size=batch_size,
                filter_length=filter_length,
                use_weight_norm=use_weight_norm)

            for init in inits:
                init_ops.append(init)
            for push in pushs:
                push_ops.append(push)

            # local conditioning
            d += masked.linear(mel_en,
                               deconv_width,
                               gate_width,
                               name='mel_cond_%d' % (i + 1),
                               use_weight_norm=use_weight_norm)

            # gated cnn
            assert d.get_shape().as_list()[2] % 2 == 0
            m = d.get_shape().as_list()[2] // 2
            d = tf.sigmoid(d[:, :, :m]) * tf.tanh(d[:, :, m:])

            # residuals
            l += masked.linear(d,
                               gate_width // 2,
                               width,
                               name='res_%d' % (i + 1),
                               use_weight_norm=use_weight_norm)

            # skips
            s += masked.linear(d,
                               gate_width // 2,
                               skip_width,
                               name='skip_%d' % (i + 1),
                               use_weight_norm=use_weight_norm)

            # dropout
            if use_dropout:
                l = tf.layers.dropout(l,
                                      rate=0.2,
                                      training=False,
                                      name='res_dropout_%d' % (i + 1))

        s = tf.nn.relu(s)
        s = (masked.linear(s,
                           skip_width,
                           skip_width,
                           name='out1',
                           use_weight_norm=use_weight_norm) +
             masked.linear(mel_en,
                           deconv_width,
                           skip_width,
                           name='mel_cond_out1',
                           use_weight_norm=use_weight_norm))
        s = tf.nn.relu(s)
        out = masked.linear(
            s,
            skip_width,
            out_width,
            name='out2',
            use_weight_norm=use_weight_norm)  # [batch_size, 1, out_width]

        if loss_type == 'ce':
            sample = loss_func.ce_sample(out, quant_chann)
        elif loss_type == 'mol':
            sample = loss_func.mol_sample(out, quant_chann)
        elif loss_type == 'gauss':
            sample = loss_func.gauss_sample(out, quant_chann)
        else:
            raise ValueError('[{}] loss is not supported.'.format(loss_type))

        return {'init_ops': init_ops, 'push_ops': push_ops, 'sample': sample}
コード例 #4
0
ファイル: wavenet.py プロジェクト: MLBK/nsynth_wavenet
    def feed_forward(self, inputs):
        """Build the graph for this configuration.

        Args:
          inputs: A dict of inputs. For training, should contain 'wav'.

        Returns:
          A dict of outputs that includes the 'predictions', 'loss', the 'encoding',
          the 'quantized_input', and whatever metrics we want to track for eval.
        """
        num_stages = self.hparams.num_stages
        num_layers = self.hparams.num_layers
        filter_length = self.hparams.filter_length
        width = self.hparams.width
        skip_width = self.hparams.skip_width
        use_mu_law = self.use_mu_law
        quant_chann = self.quant_chann
        out_width = self.out_width

        ###
        # The Transpose Convolution Stack for mel feature.
        ###
        # wavenet inputs <- trans_conv (l2, s2) <- trans_conv (l1, s1) <- mel_ceps
        # win_len: l1 * s2 + (l2 - s2); win_shift: s1 * s2
        # (l1, s1) = (40, 10), (l2, s2) = (80, 20) is a proper configuration.
        # it is almost consistent with mel analysis frame shift (200) and frame length (800).
        mel = inputs['mel']
        ds_dict = self.deconv_stack({'mel': mel})
        mel_en = ds_dict['encoding']

        ###
        # Encode the source with 8-bit Mu-Law or just use 16-bit signal.
        ###
        x = inputs['wav']
        if use_mu_law:
            x_quantized = utils.mu_law(x)
            x_scaled = tf.cast(x_quantized, tf.float32) / (quant_chann / 2.)
            real_targets = x_scaled
            cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(
                quant_chann / 2., tf.int32)
        else:
            x_quantized = utils.cast_quantize(x, quant_chann)
            x_scaled = x
            real_targets = x
            cate_targets = tf.cast(x_quantized, tf.int32) + tf.cast(
                quant_chann / 2., tf.int32)
        x_scaled = tf.expand_dims(x_scaled, 2)

        ###
        # The WaveNet Decoder.
        ###
        l = masked.shift_right(x_scaled)
        l = masked.conv1d(l,
                          num_filters=width,
                          filter_length=filter_length,
                          name='startconv')

        # Set up skip connections.
        s = masked.conv1d(l,
                          num_filters=skip_width,
                          filter_length=1,
                          name='skip_start')

        # Residual blocks with skip connections.
        for i in range(num_layers):
            dilation = 2**(i % num_stages)
            d = masked.conv1d(l,
                              num_filters=2 * width,
                              filter_length=filter_length,
                              dilation=dilation,
                              name='dilated_conv_%d' % (i + 1))
            c = masked.conv1d(mel_en,
                              num_filters=2 * width,
                              filter_length=1,
                              name='mel_cond_%d' % (i + 1))
            d = _condition(d, c)

            assert d.get_shape().as_list()[2] % 2 == 0
            m = d.get_shape().as_list()[2] // 2
            d_sigmoid = tf.sigmoid(d[:, :, :m])
            d_tanh = tf.tanh(d[:, :, m:])
            d = d_sigmoid * d_tanh

            l += masked.conv1d(d,
                               num_filters=width,
                               filter_length=1,
                               name='res_%d' % (i + 1))
            s += masked.conv1d(d,
                               num_filters=skip_width,
                               filter_length=1,
                               name='skip_%d' % (i + 1))

        s = tf.nn.relu(s)
        s = masked.conv1d(s,
                          num_filters=skip_width,
                          filter_length=1,
                          name='out1')
        c = masked.conv1d(mel_en,
                          num_filters=skip_width,
                          filter_length=1,
                          name='mel_cond_out1')
        s = _condition(s, c)
        s = tf.nn.relu(s)
        # when using mol loss, the model always predicts log_scale, the initializer makes
        # the log_scale in a reasonable small range to speed up convergence.
        final_kernel_init = (tf.truncated_normal_initializer(0.0, 0.01)
                             if self.loss_type == 'mol' else
                             tf.uniform_unit_scaling_initializer(1.0))
        out = masked.conv1d(s,
                            num_filters=out_width,
                            filter_length=1,
                            name='out2',
                            kernel_initializer=final_kernel_init)

        return {
            'real_targets': real_targets,
            'cate_targets': cate_targets,
            'encoding': mel_en,
            'out_params': out
        }