Example #1
0
  def build(self, inputs):
    """Build the graph for this configuration.

    Args:
      inputs: A dict of inputs. For training, should contain 'wav'.

    Returns:
      A dict of outputs that includes the 'predictions',
      'init_ops', the 'push_ops', and the 'quantized_input'.
    """
    num_stages = 10
    num_layers = 30
    filter_length = 3
    width = 512
    skip_width = 256
    num_z = 16

    # Encode the source with 8-bit Mu-Law.
    x = inputs['wav']
    batch_size = self.batch_size
    x_quantized = utils.mu_law(x)
    x_scaled = tf.cast(x_quantized, tf.float32) / 128.0
    x_scaled = tf.expand_dims(x_scaled, 2)

    encoding = tf.placeholder(
        name='encoding', shape=[batch_size, num_z], dtype=tf.float32)
    en = tf.expand_dims(encoding, 1)

    init_ops, push_ops = [], []

    ###
    # The WaveNet Decoder.
    ###
    l = x_scaled
    l, inits, pushs = utils.causal_linear(
        x=l,
        n_inputs=1,
        n_outputs=width,
        name='startconv',
        rate=1,
        batch_size=batch_size,
        filter_length=filter_length)

    for init in inits:
      init_ops.append(init)
    for push in pushs:
      push_ops.append(push)

    # Set up skip connections.
    s = utils.linear(l, width, skip_width, name='skip_start')

    # Residual blocks with skip connections.
    for i in range(num_layers):
      dilation = 2**(i % num_stages)

      # dilated masked cnn
      d, inits, pushs = utils.causal_linear(
          x=l,
          n_inputs=width,
          n_outputs=width * 2,
          name='dilatedconv_%d' % (i + 1),
          rate=dilation,
          batch_size=batch_size,
          filter_length=filter_length)

      for init in inits:
        init_ops.append(init)
      for push in pushs:
        push_ops.append(push)

      # local conditioning
      d += utils.linear(en, num_z, width * 2, name='cond_map_%d' % (i + 1))

      # gated cnn
      assert d.get_shape().as_list()[2] % 2 == 0
      m = d.get_shape().as_list()[2] // 2
      d = tf.sigmoid(d[:, :, :m]) * tf.tanh(d[:, :, m:])

      # residuals
      l += utils.linear(d, width, width, name='res_%d' % (i + 1))

      # skips
      s += utils.linear(d, width, skip_width, name='skip_%d' % (i + 1))

    s = tf.nn.relu(s)
    s = (utils.linear(s, skip_width, skip_width, name='out1') + utils.linear(
        en, num_z, skip_width, name='cond_map_out1'))
    s = tf.nn.relu(s)

    ###
    # Compute the logits and get the loss.
    ###
    logits = utils.linear(s, skip_width, 256, name='logits')
    logits = tf.reshape(logits, [-1, 256])
    probs = tf.nn.softmax(logits, name='softmax')

    return {
        'init_ops': init_ops,
        'push_ops': push_ops,
        'predictions': probs,
        'encoding': encoding,
        'quantized_input': x_quantized,
    }
Example #2
0
    def build(self, inputs):
        """Build the graph for this configuration.

    Args:
      inputs: A dict of inputs. For training, should contain 'wav'.

    Returns:
      A dict of outputs that includes the 'predictions',
      'init_ops', the 'push_ops', and the 'quantized_input'.
    """
        num_stages = 10
        num_layers = 30
        filter_length = 3
        width = 512
        skip_width = 256
        num_z = 16

        # Encode the source with 8-bit Mu-Law.
        x = inputs['wav']
        batch_size = self.batch_size
        x_quantized = utils.mu_law(x)
        x_scaled = tf.cast(x_quantized, tf.float32) / 128.0
        x_scaled = tf.expand_dims(x_scaled, 2)

        encoding = tf.placeholder(name='encoding',
                                  shape=[batch_size, num_z],
                                  dtype=tf.float32)
        en = tf.expand_dims(encoding, 1)

        init_ops, push_ops = [], []

        ###
        # The WaveNet Decoder.
        ###
        l = x_scaled
        l, inits, pushs = utils.causal_linear(x=l,
                                              n_inputs=1,
                                              n_outputs=width,
                                              name='startconv',
                                              rate=1,
                                              batch_size=batch_size,
                                              filter_length=filter_length)

        for init in inits:
            init_ops.append(init)
        for push in pushs:
            push_ops.append(push)

        # Set up skip connections.
        s = utils.linear(l, width, skip_width, name='skip_start')

        # Residual blocks with skip connections.
        for i in range(num_layers):
            dilation = 2**(i % num_stages)

            # dilated masked cnn
            d, inits, pushs = utils.causal_linear(x=l,
                                                  n_inputs=width,
                                                  n_outputs=width * 2,
                                                  name='dilatedconv_%d' %
                                                  (i + 1),
                                                  rate=dilation,
                                                  batch_size=batch_size,
                                                  filter_length=filter_length)

            for init in inits:
                init_ops.append(init)
            for push in pushs:
                push_ops.append(push)

            # local conditioning
            d += utils.linear(en,
                              num_z,
                              width * 2,
                              name='cond_map_%d' % (i + 1))

            # gated cnn
            assert d.get_shape().as_list()[2] % 2 == 0
            m = d.get_shape().as_list()[2] // 2
            d = tf.sigmoid(d[:, :, :m]) * tf.tanh(d[:, :, m:])

            # residuals
            l += utils.linear(d, width, width, name='res_%d' % (i + 1))

            # skips
            s += utils.linear(d, width, skip_width, name='skip_%d' % (i + 1))

        s = tf.nn.relu(s)
        s = (utils.linear(s, skip_width, skip_width, name='out1') +
             utils.linear(en, num_z, skip_width, name='cond_map_out1'))
        s = tf.nn.relu(s)

        ###
        # Compute the logits and get the loss.
        ###
        logits = utils.linear(s, skip_width, 256, name='logits')
        logits = tf.reshape(logits, [-1, 256])
        probs = tf.nn.softmax(logits, name='softmax')

        return {
            'init_ops': init_ops,
            'push_ops': push_ops,
            'predictions': probs,
            'encoding': encoding,
            'quantized_input': x_quantized,
        }