Exemple #1
0
    def res_layer(inputs):
        """
        From original doc string:

        The layer contains a gated filter that connects to dense output
        and to a skip connection:

               |-> [gate]   -|        |-> 1x1 conv -> skip output
               |             |-> (*) -|
        input -|-> [filter] -|        |-> 1x1 conv -|
               |                                    |-> (+) -> dense output
               |------------------------------------|

        Where `[gate]` and `[filter]` are causal convolutions with a
        non-linear activation at the output
        """
        gated = Sequential(
            Conv1D(dilation_channels, (filter_width, ), dilation=(dilation, )),
            sigmoid)(inputs)
        filtered = Sequential(
            Conv1D(dilation_channels, (filter_width, ), dilation=(dilation, )),
            np.tanh)(inputs)
        p = gated * filtered
        out = Conv1D(residual_channels, (1, ), padding='SAME')(p)
        # Add the transformed output of the resblock to the sliced input:
        sliced_inputs = lax.dynamic_slice(
            inputs, [0, inputs.shape[1] - out.shape[1], 0],
            [inputs.shape[0], out.shape[1], inputs.shape[2]])
        new_out = sum(out, sliced_inputs)
        skip = Conv1D(residual_channels, (1, ),
                      padding='SAME')(skip_slice(p, output_width))
        return new_out, skip
Exemple #2
0
 def wavenet(inputs):
     hidden = Conv1D(residual_channels, (initial_filter_width, ))(inputs)
     out = np.zeros((hidden.shape[0], out_width, residual_channels),
                    'float32')
     for dilation in dilations:
         res = ResLayer(dilation_channels, residual_channels, filter_width,
                        dilation, out_width)(hidden)
         hidden, out_partial = res
         out += out_partial
     return Sequential(relu, Conv1D(skip_channels, (1, )), relu,
                       Conv1D(3 * nr_mix, (1, )))(out)
Exemple #3
0
def test_Conv1DTranspose_runs(channels, filter_shape, padding, strides,
                              input_shape):
    conv = Conv1D(channels, filter_shape, strides=strides, padding=padding)
    inputs = random_inputs(input_shape)
    params = conv.init_parameters(PRNGKey(0), inputs)
    conv.apply(params, inputs)