def residual_block(x): original_x = x # TODO: initalization, regularization? # Note: The AtrousConvolution1D with the 'causal' flag is implemented in github.com/basveeling/keras#@wavenet. tanh_out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=2**i, border_mode='valid', causal=True, bias=use_bias, name='dilated_conv_%d_tanh_s%d' % (2**i, s), activation='tanh')(x) sigm_out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=2**i, border_mode='valid', causal=True, bias=use_bias, name='dilated_conv_%d_sigm_s%d' % (2**i, s), activation='sigmoid')(x) x = layers.Merge(mode='mul', name='gated_activation_%d_s%d' % (i, s))([tanh_out, sigm_out]) x = layers.Convolution1D(nb_filters, 1, border_mode='same', bias=use_bias)(x) skip_out = x x = layers.Merge(mode='sum')([original_x, x]) return x, skip_out
def build_model(fragment_length, nb_filters, nb_output_bins, dilation_depth, nb_stacks, use_skip_connections, learn_all_outputs, _log, desired_sample_rate, use_bias): def residual_block(x): original_x = x # TODO: initalization, regularization? # Note: The AtrousConvolution1D with the 'causal' flag is implemented in github.com/basveeling/keras#@wavenet. tanh_out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=2 ** i, border_mode='valid', causal=True, bias=use_bias, name='dilated_conv_%d_tanh_s%d' % (2 ** i, s), activation='tanh')(x) sigm_out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=2 ** i, border_mode='valid', causal=True, bias=use_bias, name='dilated_conv_%d_sigm_s%d' % (2 ** i, s), activation='sigmoid')(x) x = layers.Merge(mode='mul', name='gated_activation_%d_s%d' % (i, s))([tanh_out, sigm_out]) x = layers.Convolution1D(nb_filters, 1, border_mode='same', bias=use_bias)(x) skip_out = x x = layers.Merge(mode='sum')([original_x, x]) return x, skip_out input = Input(shape=(fragment_length, nb_output_bins), name='input_part') out = input skip_connections = [] out = layers.AtrousConvolution1D(nb_filters, 2, atrous_rate=1, border_mode='valid', causal=True, name='initial_causal_conv')(out) for s in xrange(nb_stacks): for i in xrange(0, dilation_depth + 1): out, skip_out = residual_block(out) skip_connections.append(skip_out) if use_skip_connections: out = layers.Merge(mode='sum')(skip_connections) out = layers.Activation('relu')(out) out = layers.Convolution1D(nb_output_bins, 1, border_mode='same')(out) out = layers.Activation('relu')(out) out = layers.Convolution1D(nb_output_bins, 1, border_mode='same')(out) if not learn_all_outputs: raise DeprecationWarning('Learning on just all outputs is wasteful, now learning only inside receptive field.') out = layers.Lambda(lambda x: x[:, -1, :], output_shape=(out._keras_shape[-1],))( out) # Based on gif in deepmind blog: take last output? out = layers.Activation('softmax', name="output_softmax")(out) model = Model(input, out) receptive_field, receptive_field_ms = compute_receptive_field() _log.info('Receptive Field: %d (%dms)' % (receptive_field, int(receptive_field_ms))) return model