Ejemplo n.º 1
0
 def residual_block(x):
     original_x = x
     # TODO: initalization, regularization?
     # Note: The AtrousConvolution1D with the 'causal' flag is implemented in github.com/basveeling/keras#@wavenet.
     tanh_out = layers.AtrousConvolution1D(nb_filters,
                                           2,
                                           atrous_rate=2**i,
                                           border_mode='valid',
                                           causal=True,
                                           bias=use_bias,
                                           name='dilated_conv_%d_tanh_s%d' %
                                           (2**i, s),
                                           activation='tanh')(x)
     sigm_out = layers.AtrousConvolution1D(nb_filters,
                                           2,
                                           atrous_rate=2**i,
                                           border_mode='valid',
                                           causal=True,
                                           bias=use_bias,
                                           name='dilated_conv_%d_sigm_s%d' %
                                           (2**i, s),
                                           activation='sigmoid')(x)
     x = layers.Merge(mode='mul', name='gated_activation_%d_s%d' %
                      (i, s))([tanh_out, sigm_out])
     x = layers.Convolution1D(nb_filters,
                              1,
                              border_mode='same',
                              bias=use_bias)(x)
     skip_out = x
     x = layers.Merge(mode='sum')([original_x, x])
     return x, skip_out
Ejemplo n.º 2
0
def build_model(fragment_length, nb_filters, nb_output_bins, dilation_depth, nb_stacks, use_skip_connections,
                learn_all_outputs, _log, desired_sample_rate, use_bias):
    def residual_block(x):
        original_x = x
        # TODO: initalization, regularization?
        # Note: The AtrousConvolution1D with the 'causal' flag is implemented in github.com/basveeling/keras#@wavenet.
        tanh_out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=2 ** i, border_mode='valid', causal=True,
                                              bias=use_bias,
                                              name='dilated_conv_%d_tanh_s%d' % (2 ** i, s), activation='tanh')(x)
        sigm_out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=2 ** i, border_mode='valid', causal=True,
                                              bias=use_bias,
                                              name='dilated_conv_%d_sigm_s%d' % (2 ** i, s), activation='sigmoid')(x)
        x = layers.Merge(mode='mul', name='gated_activation_%d_s%d' % (i, s))([tanh_out, sigm_out])
        x = layers.Convolution1D(nb_filters, 1, border_mode='same', bias=use_bias)(x)
        skip_out = x
        x = layers.Merge(mode='sum')([original_x, x])
        return x, skip_out

    input = Input(shape=(fragment_length, nb_output_bins), name='input_part')
    out = input
    skip_connections = []
    out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=1, border_mode='valid', causal=True,
                                     name='initial_causal_conv')(out)
    for s in xrange(nb_stacks):
        for i in xrange(0, dilation_depth + 1):
            out, skip_out = residual_block(out)
            skip_connections.append(skip_out)

    if use_skip_connections:
        out = layers.Merge(mode='sum')(skip_connections)
    out = layers.Activation('relu')(out)
    out = layers.Convolution1D(nb_output_bins, 1, border_mode='same')(out)
    out = layers.Activation('relu')(out)
    out = layers.Convolution1D(nb_output_bins, 1, border_mode='same')(out)

    if not learn_all_outputs:
        raise DeprecationWarning('Learning on just all outputs is wasteful, now learning only inside receptive field.')
        out = layers.Lambda(lambda x: x[:, -1, :], output_shape=(out._keras_shape[-1],))(
            out)  # Based on gif in deepmind blog: take last output?

    out = layers.Activation('softmax', name="output_softmax")(out)
    model = Model(input, out)

    receptive_field, receptive_field_ms = compute_receptive_field()

    _log.info('Receptive Field: %d (%dms)' % (receptive_field, int(receptive_field_ms)))
    return model