def build(self, inputs): """Build the graph for this configuration. Args: inputs: A dict of inputs. For training, should contain 'wav'. Returns: A dict of outputs that includes the 'predictions', 'init_ops', the 'push_ops', and the 'quantized_input'. """ num_stages = 10 num_layers = 30 filter_length = 3 width = 512 skip_width = 256 num_z = 16 # Encode the source with 8-bit Mu-Law. x = inputs['wav'] batch_size = self.batch_size x_quantized = utils.mu_law(x) x_scaled = tf.cast(x_quantized, tf.float32) / 128.0 x_scaled = tf.expand_dims(x_scaled, 2) encoding = tf.placeholder( name='encoding', shape=[batch_size, num_z], dtype=tf.float32) en = tf.expand_dims(encoding, 1) init_ops, push_ops = [], [] ### # The WaveNet Decoder. ### l = x_scaled l, inits, pushs = utils.causal_linear( x=l, n_inputs=1, n_outputs=width, name='startconv', rate=1, batch_size=batch_size, filter_length=filter_length) for init in inits: init_ops.append(init) for push in pushs: push_ops.append(push) # Set up skip connections. s = utils.linear(l, width, skip_width, name='skip_start') # Residual blocks with skip connections. for i in range(num_layers): dilation = 2**(i % num_stages) # dilated masked cnn d, inits, pushs = utils.causal_linear( x=l, n_inputs=width, n_outputs=width * 2, name='dilatedconv_%d' % (i + 1), rate=dilation, batch_size=batch_size, filter_length=filter_length) for init in inits: init_ops.append(init) for push in pushs: push_ops.append(push) # local conditioning d += utils.linear(en, num_z, width * 2, name='cond_map_%d' % (i + 1)) # gated cnn assert d.get_shape().as_list()[2] % 2 == 0 m = d.get_shape().as_list()[2] // 2 d = tf.sigmoid(d[:, :, :m]) * tf.tanh(d[:, :, m:]) # residuals l += utils.linear(d, width, width, name='res_%d' % (i + 1)) # skips s += utils.linear(d, width, skip_width, name='skip_%d' % (i + 1)) s = tf.nn.relu(s) s = (utils.linear(s, skip_width, skip_width, name='out1') + utils.linear( en, num_z, skip_width, name='cond_map_out1')) s = tf.nn.relu(s) ### # Compute the logits and get the loss. ### logits = utils.linear(s, skip_width, 256, name='logits') logits = tf.reshape(logits, [-1, 256]) probs = tf.nn.softmax(logits, name='softmax') return { 'init_ops': init_ops, 'push_ops': push_ops, 'predictions': probs, 'encoding': encoding, 'quantized_input': x_quantized, }
def build(self, inputs): """Build the graph for this configuration. Args: inputs: A dict of inputs. For training, should contain 'wav'. Returns: A dict of outputs that includes the 'predictions', 'init_ops', the 'push_ops', and the 'quantized_input'. """ num_stages = 10 num_layers = 30 filter_length = 3 width = 512 skip_width = 256 num_z = 16 # Encode the source with 8-bit Mu-Law. x = inputs['wav'] batch_size = self.batch_size x_quantized = utils.mu_law(x) x_scaled = tf.cast(x_quantized, tf.float32) / 128.0 x_scaled = tf.expand_dims(x_scaled, 2) encoding = tf.placeholder(name='encoding', shape=[batch_size, num_z], dtype=tf.float32) en = tf.expand_dims(encoding, 1) init_ops, push_ops = [], [] ### # The WaveNet Decoder. ### l = x_scaled l, inits, pushs = utils.causal_linear(x=l, n_inputs=1, n_outputs=width, name='startconv', rate=1, batch_size=batch_size, filter_length=filter_length) for init in inits: init_ops.append(init) for push in pushs: push_ops.append(push) # Set up skip connections. s = utils.linear(l, width, skip_width, name='skip_start') # Residual blocks with skip connections. for i in range(num_layers): dilation = 2**(i % num_stages) # dilated masked cnn d, inits, pushs = utils.causal_linear(x=l, n_inputs=width, n_outputs=width * 2, name='dilatedconv_%d' % (i + 1), rate=dilation, batch_size=batch_size, filter_length=filter_length) for init in inits: init_ops.append(init) for push in pushs: push_ops.append(push) # local conditioning d += utils.linear(en, num_z, width * 2, name='cond_map_%d' % (i + 1)) # gated cnn assert d.get_shape().as_list()[2] % 2 == 0 m = d.get_shape().as_list()[2] // 2 d = tf.sigmoid(d[:, :, :m]) * tf.tanh(d[:, :, m:]) # residuals l += utils.linear(d, width, width, name='res_%d' % (i + 1)) # skips s += utils.linear(d, width, skip_width, name='skip_%d' % (i + 1)) s = tf.nn.relu(s) s = (utils.linear(s, skip_width, skip_width, name='out1') + utils.linear(en, num_z, skip_width, name='cond_map_out1')) s = tf.nn.relu(s) ### # Compute the logits and get the loss. ### logits = utils.linear(s, skip_width, 256, name='logits') logits = tf.reshape(logits, [-1, 256]) probs = tf.nn.softmax(logits, name='softmax') return { 'init_ops': init_ops, 'push_ops': push_ops, 'predictions': probs, 'encoding': encoding, 'quantized_input': x_quantized, }