Exemplo n.º 1
0
        def __init__(self,embedding_matrix,num_classes,attention_size,attention_heads):
        
            super(hisan.hisan_model,self).__init__()
            self.attention_size = attention_size
            self.attention_heads = attention_heads
            
            self.embedding = layers.Embedding(embedding_matrix.shape[0],
                             embedding_matrix.shape[1],
                             embeddings_initializer=tf.keras.initializers.Constant(
                             embedding_matrix.astype(np.float32)))
            self.word_drop = layers.Dropout(0.1)
            self.word_Q = layers.Dense(self.attention_size)
            self.word_K = layers.Dense(self.attention_size)
            self.word_V = layers.Dense(self.attention_size)
            self.word_target = tf.Variable(tf.random.uniform(shape=[1,self.attention_heads,1,
                               int(self.attention_size/self.attention_heads)]))
            self.word_self_att = layers.Attention(use_scale=True)
            self.word_targ_att = layers.Attention(use_scale=True)
            
            self.line_drop = layers.Dropout(0.1)
            self.line_Q = layers.Dense(self.attention_size)
            self.line_K = layers.Dense(self.attention_size)
            self.line_V = layers.Dense(self.attention_size)
            self.line_target = tf.Variable(tf.random.uniform(shape=[1,self.attention_heads,1,
                               int(self.attention_size/self.attention_heads)]))
            self.line_self_att = layers.Attention(use_scale=True)
            self.line_targ_att = layers.Attention(use_scale=True)

            self.doc_drop = layers.Dropout(0.1)
            self.classify = layers.Dense(num_classes)
def build_dense_model(num_features_1, num_features_2):
    """ Simple two layer MLP """
    inputs_1 = layers.Input(shape=(num_features_1, ))
    output_1 = layers.GaussianDropout(0.1)(inputs_1)

    inputs_2 = layers.Input(shape=(None, num_features_2))
    output_2 = layers.Conv1D(filters=5,
                             kernel_size=4,
                             padding='same',
                             activation='relu')(inputs_2)
    output_2 = layers.Attention()([output_2, output_2])
    output_2 = layers.GlobalMaxPooling1D()(output_2)
    output_2 = layers.Dropout(0.1)(output_2)

    output = layers.concatenate([output_1, output_2])
    output = layers.Dense(64, activation='relu')(output)
    output = layers.BatchNormalization()(output)
    output = layers.Dropout(0.1)(output)
    output = layers.Dense(64, activation='relu')(output)
    output = layers.BatchNormalization()(output)
    output_actual = layers.Dense(NUM_BINS, activation='softmax')(output)
    output_median = layers.Dense(NUM_BINS, activation='softmax')(output)
    model = Model(inputs=[inputs_1, inputs_2],
                  outputs=[output_actual, output_median])

    model.compile(optimizer=Adam(), loss='categorical_crossentropy')

    return model
Exemplo n.º 3
0
    def build(self):
        categorical_input_size = {
            'country': self.total_countries,
        }

        text_inputs = [
            self.new_query_input(),
            self.new_title_input(),
            self.new_ingredients_input(),
            self.new_description_input(),
        ]
        categorical_inputs = [
            self.new_country_input(),
        ]
        inputs = text_inputs + categorical_inputs

        embedding = layers.Embedding(self.total_words, self.embedding_dim)
        text_features = [embedding(text_input) for text_input in text_inputs]
        text_features = [layers.GlobalMaxPooling1D()(feature) for feature in text_features]

        categorical_features = []
        for name, categorical_input in zip(categorical_input_size, categorical_inputs):
            embedding = layers.Embedding(categorical_input_size[name], self.embedding_dim)
            feature = embedding(categorical_input)
            feature = tf.reshape(feature, shape=(-1, self.embedding_dim,))
            categorical_features.append(feature)

        features = text_features + categorical_features

        features = layers.concatenate(features)
        x = layers.Attention()([features, features])

        output = layers.Dense(1, activation='sigmoid', name='label')(x)
        return tf.keras.Model(inputs=inputs, outputs=output, name=self.name)
        def MultiHeadAttention(query=None,
                               value=None,
                               key=None,
                               units=None,
                               use_scale=False,
                               num_heads=4,
                               dropout=0.0):
            if key is None: key = value
            if units is None: units = value.get_shape().as_list()[-1]
            #Projection
            q_ = layers.Dense(units, activation='relu')(query)
            k_ = layers.Dense(units, activation='relu')(key)
            v_ = layers.Dense(units, activation='relu')(value)

            #Split and Concat
            q = tf.concat(tf.split(q_, num_heads, axis=-1), axis=0)
            k = tf.concat(tf.split(k_, num_heads, axis=-1), axis=0)
            v = tf.concat(tf.split(v_, num_heads, axis=-1), axis=0)

            #scaled dot production
            scaled_dot_prod = layers.Attention(use_scale=use_scale,
                                               dropout=dropout)([q, v, k])
            scaled_dot_prod = tf.concat(tf.split(scaled_dot_prod,
                                                 num_heads,
                                                 axis=0),
                                        axis=-1)

            #Residual and Normalization
            scaled_dot_prod += query

            return layers.LayerNormalization()(scaled_dot_prod)
Exemplo n.º 5
0
def up_unit(conv4_1, conv3_1, stage):
    up3_3 = layers.UpSampling3D(name='up' + stage)(conv4_1)
    #up3_3 = layers.Conv3DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    att3 = layers.Attention()([conv3_1, up3_3])
    conv3_3 = layers.concatenate([up3_3, att3],
                                 name='merge' + stage,
                                 axis=bn_axis)
    return conv3_3
Exemplo n.º 6
0
def decoder_att_321(conv1_1,
                    conv2_1,
                    conv3_1,
                    conv4_1,
                    decoder_num=0,
                    layer_act='relu',
                    bn_axis=-1):
    up3_3 = layers.UpSampling3D(name='up33_' + str(decoder_num))(conv4_1)
    #up3_3 = layers.Conv3DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    att3_3 = layers.Attention()([up3_3, conv3_1])
    conv3_3 = layers.concatenate([up3_3, att3_3],
                                 name='merge33_' + decoder_num,
                                 axis=bn_axis)
    conv3_3 = standard_unit(conv3_3,
                            stage='33_' + str(decoder_num),
                            nb_filter=nb_filter[2],
                            layer_act=layer_act)

    up2_4 = layers.UpSampling3D(name='up24_' + str(decoder_num))(conv3_3)
    #up2_4 = layers.Conv3DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
    att2_4 = layers.Attention()([up2_4, conv2_1])
    conv2_4 = layers.concatenate([up2_4, att2_4],
                                 name='merge24_' + str(decoder_num),
                                 axis=bn_axis)
    conv2_4 = standard_unit(conv2_4,
                            stage='24_' + str(decoder_num),
                            nb_filter=nb_filter[1],
                            layer_act=layer_act)

    up1_5 = layers.UpSampling3D(name='up15_' + str(decoder_num))(conv2_4)
    #up1_5 = layers.Conv3DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
    att1_5 = layers.Attention()([up1_5, conv1_1])
    conv1_5 = layers.concatenate([up1_5, att1_5],
                                 name='merge15_' + str(decoder_num),
                                 axis=bn_axis)
    conv1_5 = standard_unit(conv1_5,
                            stage='15_' + str(decoder_num),
                            nb_filter=nb_filter[0],
                            layer_act=layer_act)

    unet_output = layers.Conv3D(1, (1, 1, 1),
                                activation=activation,
                                name='output_' + str(decoder_num),
                                kernel_initializer='he_normal',
                                padding='same')(conv1_5)
    return unet_output
Exemplo n.º 7
0
def get_multi_head_attention_model(input_shape,
                                   output_size,
                                   dropout=0.2,
                                   idx=0,
                                   key_size=128,
                                   n_multi=2,
                                   n_add=2):
    """
    Create the scaled multi head attention model as described in the paper.
    :param input_shape: The shape of the inputs of this layer.
    :param output_size: The shape of the outputs of this layer.
    :param dropout: The drop_out ratio of the fully connected layer.
    :param idx: The unique index of this layer (only for the name).
    :param key_size: The dimensionality of the heads.
    :param n_multi: The number of dot product attention heads.
    :param n_add: The number of additive attention heads.
    :return: A Keras model of the scaled multi head attention layer.
    """
    assert n_multi + n_add >= 0

    model_input = kl.Input(shape=input_shape)

    outputs_multi = []
    for i in range(n_multi):
        # More layers can be added here.
        keys = kl.TimeDistributed(kl.Dense(key_size))(model_input)
        keys = kl.TimeDistributed(kl.BatchNormalization())(keys)
        queries = kl.TimeDistributed(kl.Dense(key_size))(model_input)
        queries = kl.TimeDistributed(kl.BatchNormalization())(queries)
        output = kl.Attention(use_scale=True)([queries, keys])
        outputs_multi.append(output)

    outputs_add = []
    for i in range(n_add):
        # More layers can be added here.
        keys = kl.TimeDistributed(kl.Dense(key_size))(model_input)
        keys = kl.TimeDistributed(kl.BatchNormalization())(keys)
        queries = kl.TimeDistributed(kl.Dense(key_size))(model_input)
        queries = kl.TimeDistributed(kl.BatchNormalization())(queries)
        output = kl.AdditiveAttention(use_scale=True)([queries, keys])
        outputs_add.append(output)

    outputs = outputs_multi + outputs_add
    if len(outputs) > 1:
        output = kl.Concatenate()(outputs)
    else:
        output = outputs[0]

    # Should I add one more dense layer here?
    output = kl.TimeDistributed(kl.Dense(output_size))(output)
    output = kl.TimeDistributed(kl.Dropout(dropout))(output)
    output = kl.TimeDistributed(kl.BatchNormalization())(output)

    return km.Model(inputs=model_input,
                    outputs=output,
                    name='multi_head_attention_{}'.format(idx))
Exemplo n.º 8
0
 def __init__(self, vocab_size, embed_size, class_num):
     super(TextBilstmAttention, self).__init__()
     self.embedding_layer = layers.Embedding(vocab_size, embed_size)
     self.bilstm_layer = layers.Bidirectional(
         layers.LSTM(units=64, return_sequences=True))
     self.attention_layer = layers.Attention()
     self.pool_layer = layers.GlobalAveragePooling1D()
     self.concat_layer = layers.Concatenate()
     self.dense_layer = layers.Dense(units=64, activation='relu')
     self.output_layer = layers.Dense(units=class_num, activation='softmax')
Exemplo n.º 9
0
 def __init__(self, vocab_size, embed_size, class_num):
     super(TextCnnAttention, self).__init__()
     self.embedding_layer = layers.Embedding(vocab_size, embed_size)
     self.conv_layer = layers.Conv1D(filters=128,
                                     kernel_size=5,
                                     padding='same')
     self.attention_layer = layers.Attention()
     self.pool_layer = layers.GlobalAveragePooling1D()
     self.concat_layer = layers.Concatenate()
     self.dense_layer = layers.Dense(units=64, activation='relu')
     self.output_layer = layers.Dense(units=class_num, activation='softmax')
Exemplo n.º 10
0
 def __init__(self):
     super(Decoder, self).__init__()
     self.embedding = kl.Embedding(pre_data_utils.tgt_vocab_size,
                                   pre_data_utils.embedding_dim,
                                   mask_zero=False)  # Embedding Layer
     self.decoder_lstm = kl.LSTM(pre_data_utils.hidden_units,
                                 activation='tanh',
                                 return_sequences=True,
                                 return_state=True)  # Decode LSTM Layer
     self.attention = kl.Attention()  # Attention Layer
     self.concatenate = kl.Concatenate(
         axis=-1, name='concat_layer')  # Concatenate Layer
def build_model_bpps_Attn(embed_size,
                          seq_len=107,
                          pred_len=68,
                          dropout=0.5,
                          sp_dropout=0.2,
                          embed_dim=200,
                          hidden_dim=256,
                          n_layers=3):
    input_s = L.Input(shape=(seq_len, 3))
    input_bpps = L.Input(shape=(seq_len, 2))

    inputs = L.Concatenate(axis=2)([input_s, input_bpps])
    embed = L.Embedding(input_dim=embed_size, output_dim=embed_dim)(inputs)

    reshaped = tf.reshape(embed,
                          shape=(-1, embed.shape[1],
                                 embed.shape[2] * embed.shape[3]))
    hidden = L.SpatialDropout1D(sp_dropout)(reshaped)
    conv1 = L.Conv1D(128, 64, 1, padding="same", activation=None)(hidden)
    h1 = L.LayerNormalization()(conv1)
    h1 = L.LeakyReLU()(h1)
    conv2 = L.Conv1D(128, 32, 1, padding="same", activation=None)(hidden)
    h2 = L.LayerNormalization()(conv2)
    h2 = L.LeakyReLU()(h2)
    conv3 = L.Conv1D(128, 16, 1, padding="same", activation=None)(hidden)
    h3 = L.LayerNormalization()(conv3)
    h3 = L.LeakyReLU()(h3)
    conv4 = L.Conv1D(128, 8, 1, padding="same", activation=None)(hidden)
    h4 = L.LayerNormalization()(conv4)
    h4 = L.LeakyReLU()(h4)

    hs = L.Concatenate()([h1, h2, h3, h4])

    keys = L.Dropout(0.2)(hs)

    for x in range(n_layers):
        hidden = gru_layer(hidden_dim, dropout)(hidden)

    hidden = L.Attention(dropout=0.2)([hidden, keys])

    # Since we are only making predictions on the first part of each sequence,
    # we have to truncate it
    truncated = hidden[:, :pred_len]
    out = L.Dense(5, activation='linear')(truncated)

    model = tf.keras.Model(inputs=[input_s, input_bpps], outputs=out)
    model.compile(tf.optimizers.Adam(), loss=MCRMSE)

    return model
Exemplo n.º 12
0
def make_block(features, MaybeNoiseOrOutput, layers, units, gain, l1, l2):
    Attention_layer = L.Attention()([features, features])
    block_output = L.Concatenate(2)([Attention_layer, MaybeNoiseOrOutput])
    block_output = L.Activation('tanh')(block_output)
    for layer_number in range(0, round(layers.item()) - 1):
        block_output = L.Dense(
            units,
            kernel_initializer=K.initializers.Orthogonal(gain),
            kernel_regularizer=K.regularizers.L1L2(l1, l2),
            bias_regularizer=K.regularizers.L1L2(l1, l2),
            activation='tanh')(block_output)
        block_output = L.Dropout(0.5)(block_output)
    block_output = L.Dense(MaybeNoiseOrOutput.shape[-1], 'tanh')(block_output)
    block_output = L.Add()([block_output, MaybeNoiseOrOutput])
    return block_output
Exemplo n.º 13
0
    def __init__(self, d_model, num_heads, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model, name='wQ')
        self.wk = tf.keras.layers.Dense(d_model, name='wK')
        self.wv = tf.keras.layers.Dense(d_model, name='wV')

        self.attention = layers.Attention(use_scale=True, name='dotprod_attn')
        self.dense = tf.keras.layers.Dense(d_model, name='attn_dense')
        self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
Exemplo n.º 14
0
    def build_model_arc(self) -> None:
        if tuple(tf.__version__.split('.')) < tuple('2.1.0'.split('.')):
            logger.warning("Attention layer not serializable because it takes init args "
                           "but doesn't implement get_config. "
                           "Please try Attention layer with tf versions >= 2.1.0. "
                           "Issue: https://github.com/tensorflow/tensorflow/issues/32662")
        output_dim = self.label_processor.vocab_size
        config = self.hyper_parameters

        embed_model = self.embedding.embed_model
        # Query embeddings of shape [batch_size, Tq, dimension].
        query_embeddings = embed_model.output
        # Value embeddings of shape [batch_size, Tv, dimension].
        value_embeddings = embed_model.output

        # CNN layer.
        cnn_layer_1 = L.Conv1D(**config['conv_layer1'])
        # Query encoding of shape [batch_size, Tq, filters].
        query_seq_encoding = cnn_layer_1(query_embeddings)
        # Value encoding of shape [batch_size, Tv, filters].
        value_seq_encoding = cnn_layer_1(value_embeddings)

        cnn_layer_2 = L.Conv1D(**config['conv_layer2'])
        query_seq_encoding = cnn_layer_2(query_seq_encoding)
        value_seq_encoding = cnn_layer_2(value_seq_encoding)

        cnn_layer_3 = L.Conv1D(**config['conv_layer3'])
        query_seq_encoding = cnn_layer_3(query_seq_encoding)
        value_seq_encoding = cnn_layer_3(value_seq_encoding)

        # Query-value attention of shape [batch_size, Tq, filters].
        query_value_attention_seq = L.Attention()(
            [query_seq_encoding, value_seq_encoding])

        # Reduce over the sequence axis to produce encodings of shape
        # [batch_size, filters].
        query_encoding = L.GlobalMaxPool1D()(query_seq_encoding)
        query_value_attention = L.GlobalMaxPool1D()(query_value_attention_seq)

        # Concatenate query and document encodings to produce a DNN input layer.
        input_layer = L.Concatenate(axis=-1)([query_encoding, query_value_attention])

        output = L.Dense(output_dim, **config['layer_output'])(input_layer)
        output = self._activation_layer()(output)

        self.tf_model = keras.Model(embed_model.input, output)
Exemplo n.º 15
0
    def build_model(
        self, optimizer="adam", dropout_rate=0.0, neurons=100, middle_layers=1
    ):

        encoder_in = L.Input(
            shape=(self.context_timesteps, self.context_features), name="Encoder_input"
        )
        middle_encoder = encoder_in
        for i in range(middle_layers):
            lstm = L.LSTM(neurons, return_sequences=True)
            middle_encoder = lstm(middle_encoder)
            middle_encoder = L.Dropout(dropout_rate)(middle_encoder)
        encoder = L.LSTM(neurons, return_sequences=True, return_state=True)
        encoder_out, state_h, state_c = encoder(middle_encoder)
        encoder_out = L.Dropout(dropout_rate)(encoder_out)
        encoder_states = [state_h, state_c]

        decoder_in = L.Input(
            shape=(self.input_timesteps, self.input_features), name="Decoder_input"
        )
        middle_decoder = decoder_in
        for i in range(middle_layers):
            lstm = L.LSTM(self.input_features, return_sequences=True)
            middle_decoder = lstm(middle_decoder)
            middle_decoder = L.Dropout(dropout_rate)(middle_decoder)
        decoder_lstm = L.LSTM(neurons, return_sequences=True, return_state=True)
        decoder_out, _, _ = decoder_lstm(middle_decoder, initial_state=encoder_states)
        decoder_out = L.GlobalAveragePooling1D()(decoder_out)
        decoder_out = L.Dropout(dropout_rate)(decoder_out)

        attn_out = L.Attention(name="attention_layer")([encoder_out, decoder_out])
        attn_out = L.GlobalAveragePooling1D()(attn_out)
        decoder_concat = L.Concatenate(name="concat_layer")([decoder_out, attn_out])

        # decoder_concat = L.Flatten()(decoder_concat)
        decoder_dense = L.Dense(self.output_timesteps, activation="linear")

        decoder_out = decoder_dense(decoder_concat)

        model = M.Model([encoder_in, decoder_in], decoder_out)

        mape = "mean_absolute_percentage_error"
        model.compile(loss=mape, optimizer=optimizer, metrics=[mape])

        return model
Exemplo n.º 16
0
    def __init__(self, model_para):
        super(My_Model, self).__init__()

        self.cnn_ch = model_para['CNN_channel']
        self.cnn_size = model_para['filter_size']
        self.cnn_pool_size = model_para['pooling_size']
        self.cnn_f_size = model_para['filter_size']
        self.rnn_ch = model_para['RNN_channel']

        self.s_l = model_para['sample_len']
        self.win_size = model_para['win_size']
        self.s_ch = model_para['signal_ch']

        self.fc_ch = model_para['fc_channel']

        self.conv1 = layers.Conv1D(self.cnn_ch, self.cnn_f_size, strides=1)
        self.conv2 = layers.Conv1D(self.cnn_ch, self.cnn_f_size, strides=1)
        self.pool1 = layers.MaxPool1D(pool_size=self.cnn_pool_size)
        self.pool2 = layers.MaxPool1D(pool_size=self.cnn_pool_size)

        self.rnn_1_f = layers.GRU(self.rnn_ch, return_sequences=True)
        self.rnn_1_b = layers.GRU(self.rnn_ch,
                                  return_sequences=True,
                                  go_backwards=True)
        self.bi_1 = layers.Bidirectional(self.rnn_1_f,
                                         backward_layer=self.rnn_1_b)

        self.rnn_2_f = layers.GRU(self.rnn_ch, return_sequences=True)
        self.rnn_2_b = layers.GRU(self.rnn_ch,
                                  return_sequences=True,
                                  go_backwards=True)
        self.bi_2 = layers.Bidirectional(self.rnn_2_f,
                                         backward_layer=self.rnn_2_b)

        self.d1 = layers.Dense(self.fc_ch)

        self.dropout1 = layers.Dropout(0.5)
        self.d2 = layers.Dense(self.fc_ch)
        self.dropout2 = layers.Dropout(0.5)
        self.d3 = layers.Dense(1)
        self.at = layers.Attention(use_scale=False)
Exemplo n.º 17
0
def create_cross_attention_module(latent_dim, data_dim, projection_dim,
                                  ffn_units, dropout_rate):

    inputs = {
        # Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
        "latent_array": layers.Input(shape=(latent_dim, projection_dim)),
        # Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
        "data_array": layers.Input(shape=(data_dim, projection_dim)),
    }

    # Apply layer norm to the inputs
    latent_array = layers.LayerNormalization(epsilon=1e-6)(
        inputs["latent_array"])
    data_array = layers.LayerNormalization(epsilon=1e-6)(inputs["data_array"])

    # Create query tensor: [1, latent_dim, projection_dim].
    query = layers.Dense(units=projection_dim)(latent_array)
    # Create key tensor: [batch_size, data_dim, projection_dim].
    key = layers.Dense(units=projection_dim)(data_array)
    # Create value tensor: [batch_size, data_dim, projection_dim].
    value = layers.Dense(units=projection_dim)(data_array)

    # Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].
    attention_output = layers.Attention(
        use_scale=True, dropout=0.1)([query, key, value],
                                     return_attention_scores=False)
    # Skip connection 1.
    attention_output = layers.Add()([attention_output, latent_array])

    # Apply layer norm.
    attention_output = layers.LayerNormalization(
        epsilon=1e-6)(attention_output)
    # Apply Feedforward network.
    ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
    outputs = ffn(attention_output)
    # Skip connection 2.
    outputs = layers.Add()([outputs, attention_output])

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model
def Build_Attention_layer(Parametre_layer, encoder, decoder):

    if Parametre_layer["type_attention"] == "Luong":
        # the luong's attention
        attention = L.dot([decoder[0], encoder], axes=[2, 2])
        attention = L.Activation('softmax')(attention)
        context = L.dot([attention, encoder], axes=[2, 1])
        decoder_combined_context = K.concatenate([context, decoder[0]])
    elif Parametre_layer["type_attention"] == "Luong_keras":
        # the luong's attention
        context_vector = L.Attention(
            use_scale=Parametre_layer["use_scale"],
            causal=Parametre_layer["use_self_attention"],
            dropout=Parametre_layer["dropout"])([decoder[0], encoder])
        decoder_combined_context = K.concatenate([context_vector, decoder[0]])
    elif Parametre_layer["type_attention"] == "Bah_keras":
        #we are going to use the AditiveAttention = bahd of keras
        context_vector = L.AdditiveAttention(
            use_scale=Parametre_layer["use_scale"],
            causal=Parametre_layer["use_self_attention"],
            dropout=Parametre_layer["dropout"])([decoder[0], encoder])
        decoder_combined_context = K.concatenate([context_vector, decoder[0]])

    return decoder_combined_context
Exemplo n.º 19
0
rnn_units = 64

input_layer = keras.Input(shape=(104, 1))
# x = layers.Embedding(input_dim=859, output_dim=10,mask_zero='True')(input_layer)
x = layers.LSTM(rnn_units,
                return_sequences=True,
                recurrent_initializer='orthogonal',
                activation='tanh')(input_layer)
x = layers.LSTM(rnn_units,
                return_sequences=True,
                recurrent_initializer='orthogonal',
                activation='tanh',
                dropout=0.5)(x)

x = layers.Attention()([x, x])

x = layers.Dense(64, activation='relu')(x)
output = layers.Dense(17, activation='sigmoid')(x)

model = keras.Model(input_layer, output)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),
              loss=keras.losses.BinaryCrossentropy(),
              metrics=['binary_accuracy'])

print(model.summary())

# history = model.fit(train_x, train_y, epochs=20, batch_size=20, validation_data=(test_x, test_y), shuffle=False)
# model.save("model\\att_model_313.h5")

# plot train and validation loss
Exemplo n.º 20
0
def build_model(vectors, settings, compile=True):
    max_length = settings["maxlen"]
    nr_hidden = settings["n_hidden"]
    nr_class = settings["n_classes"]

    input1 = layers.Input(shape=(max_length, ), dtype="int32", name="words1")
    input2 = layers.Input(shape=(max_length, ), dtype="int32", name="words2")

    # embeddings (projected)
    embed = create_embedding(vectors, settings["emb_dim"],
                             settings["vocab_size"], max_length, nr_hidden,
                             settings["emb_trainable"])

    a = embed(input1)
    b = embed(input2)

    # step 1: attend
    # self-attend
    if settings["self_attention"]:
        S_a = create_feedforward(nr_hidden, dropout=settings["dropout"])
        S_b = create_feedforward(nr_hidden, dropout=settings["dropout"])
        a_p = layers.Attention()([S_a(a), S_a(a)])
        b_p = layers.Attention()([S_b(b), S_b(b)])
        # self_att_a = layers.dot([S_a(a), S_a(a)], axes=-1)
        # self_att_b = layers.dot([S_b(b), S_b(b)], axes=-1)
        # self_norm_a = layers.Lambda(normalizer(1))(self_att_a)
        # self_norm_b = layers.Lambda(normalizer(1))(self_att_b)
        # a_p = layers.dot([self_norm_a, a], axes=1)
        # b_p = layers.dot([self_norm_b, b], axes=1)
    else:
        a_p = a
        b_p = b

    # attend
    F = create_feedforward(nr_hidden, dropout=settings["dropout"])
    att_weights = layers.dot([F(a_p), F(b_p)], axes=-1)

    G = create_feedforward(nr_hidden)

    if settings["entail_dir"] == "both":
        norm_weights_a = layers.Lambda(normalizer(1))(att_weights)
        norm_weights_b = layers.Lambda(normalizer(2))(att_weights)
        alpha = layers.dot([norm_weights_a, a_p], axes=1)
        beta = layers.dot([norm_weights_b, b_p], axes=1)

        # step 2: compare
        comp1 = layers.concatenate([a_p, beta])
        comp2 = layers.concatenate([b_p, alpha])
        v1 = layers.TimeDistributed(G)(comp1)
        v2 = layers.TimeDistributed(G)(comp2)

        # step 3: aggregate
        v1_sum = layers.Lambda(sum_word)(v1)
        v2_sum = layers.Lambda(sum_word)(v2)
        concat = layers.concatenate([v1_sum, v2_sum])

    elif settings["entail_dir"] == "left":
        norm_weights_a = layers.Lambda(normalizer(1))(att_weights)
        alpha = layers.dot([norm_weights_a, a], axes=1)
        comp2 = layers.concatenate([b, alpha])
        v2 = layers.TimeDistributed(G)(comp2)
        v2_sum = layers.Lambda(sum_word)(v2)
        concat = v2_sum

    else:
        norm_weights_b = layers.Lambda(normalizer(2))(att_weights)
        beta = layers.dot([norm_weights_b, b], axes=1)
        comp1 = layers.concatenate([a, beta])
        v1 = layers.TimeDistributed(G)(comp1)
        v1_sum = layers.Lambda(sum_word)(v1)
        concat = v1_sum

    H = create_feedforward(nr_hidden, dropout=settings["dropout"])
    out = H(concat)
    if settings['distilled']:
        out = layers.Dense(nr_class)(out)
        loss = KDLoss(settings["batch_size"])
    else:
        out = layers.Dense(nr_class)(out)
        out = layers.Activation('sigmoid', dtype='float32')(out)
        loss = settings["loss"]

    model = Model([input1, input2], out)

    if compile:
        model.compile(
            optimizer=settings["optimizer"],
            loss=loss,
            metrics=settings["metrics"](),  # Call get_metrics
            experimental_run_tf_function=False,
        )

    return model
Exemplo n.º 21
0
def build_CTS(pos_vocab_size, maxnum, maxlen, readability_feature_count,
              linguistic_feature_count, configs, output_dim):
    pos_embedding_dim = configs.EMBEDDING_DIM
    dropout_prob = configs.DROPOUT
    cnn_filters = configs.CNN_FILTERS
    cnn_kernel_size = configs.CNN_KERNEL_SIZE
    lstm_units = configs.LSTM_UNITS

    pos_word_input = layers.Input(shape=(maxnum * maxlen, ),
                                  dtype='int32',
                                  name='pos_word_input')
    pos_x = layers.Embedding(output_dim=pos_embedding_dim,
                             input_dim=pos_vocab_size,
                             input_length=maxnum * maxlen,
                             weights=None,
                             mask_zero=True,
                             name='pos_x')(pos_word_input)
    pos_x_maskedout = ZeroMaskedEntries(name='pos_x_maskedout')(pos_x)
    pos_drop_x = layers.Dropout(dropout_prob,
                                name='pos_drop_x')(pos_x_maskedout)
    pos_resh_W = layers.Reshape((maxnum, maxlen, pos_embedding_dim),
                                name='pos_resh_W')(pos_drop_x)
    pos_zcnn = layers.TimeDistributed(layers.Conv1D(cnn_filters,
                                                    cnn_kernel_size,
                                                    padding='valid'),
                                      name='pos_zcnn')(pos_resh_W)
    pos_avg_zcnn = layers.TimeDistributed(Attention(),
                                          name='pos_avg_zcnn')(pos_zcnn)

    linguistic_input = layers.Input((linguistic_feature_count, ),
                                    name='linguistic_input')
    readability_input = layers.Input((readability_feature_count, ),
                                     name='readability_input')

    pos_hz_lstm_list = [
        layers.LSTM(lstm_units, return_sequences=True)(pos_avg_zcnn)
        for _ in range(output_dim)
    ]
    pos_avg_hz_lstm_list = [
        Attention()(pos_hz_lstm) for pos_hz_lstm in pos_hz_lstm_list
    ]
    pos_avg_hz_lstm_feat_list = [
        layers.Concatenate()([pos_rep, linguistic_input, readability_input])
        for pos_rep in pos_avg_hz_lstm_list
    ]
    pos_avg_hz_lstm = tf.concat([
        layers.Reshape((1, lstm_units + linguistic_feature_count +
                        readability_feature_count))(pos_rep)
        for pos_rep in pos_avg_hz_lstm_feat_list
    ],
                                axis=-2)

    final_preds = []
    for index, rep in enumerate(range(output_dim)):
        mask = np.array([True for _ in range(output_dim)])
        mask[index] = False
        non_target_rep = tf.boolean_mask(pos_avg_hz_lstm, mask, axis=-2)
        target_rep = pos_avg_hz_lstm[:, index:index + 1]
        att_attention = layers.Attention()([target_rep, non_target_rep])
        attention_concat = tf.concat([target_rep, att_attention], axis=-1)
        attention_concat = layers.Flatten()(attention_concat)
        final_pred = layers.Dense(units=1,
                                  activation='sigmoid')(attention_concat)
        final_preds.append(final_pred)

    y = layers.Concatenate()([pred for pred in final_preds])

    model = keras.Model(
        inputs=[pos_word_input, linguistic_input, readability_input],
        outputs=y)

    model.summary()

    model.compile(loss=masked_loss_function, optimizer='rmsprop')

    return model
Exemplo n.º 22
0
        query_with_time_axis = tf.expand_dims(query, 1)

        # 这里用的是concat公式,但是中并不是contact,而是广播相加,应该可以尝试换成dot和general
        # max_length就是time steps
        # score shape == (batch_size, max_length, 1)
        # (batch_size, max_length, units) -> (batch_size, max_length, 1)
        score = self.V(
            tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))

        # (batch_size, max_length, 1)
        attention_weights = tf.nn.softmax(score, axis=1)

        # (batch_size, max_length, 1) * (batch_size, max_length, hidden_size) -> (batch_size, max_length, hidden_size)
        # 每个权重对乘以一列units
        context_vector = attention_weights * values
        # (batch_size, max_length, hidden_size)->(batch_size, hidden_size)
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights


if __name__ == '__main__':
    x = layers.Input(shape=(20, 4))
    y, h, c = layers.LSTM(10, return_state=True, return_sequences=True)(x)
    att1 = BahdanauAttention(10)(h, y)
    att2 = layers.Attention()([h, y])
    m1 = Model(x, att1)
    m2 = Model(x, att2)
    m1.summary()
    m2.summary()
Exemplo n.º 23
0
    def __init__(self,
                 nqc,
                 value_encodings,
                 relation_encodings,
                 num_gpus=1,
                 encoder=None):
        """Builds a simple, fully-connected model to predict the outcome set given a query string.

    Args:
      nqc: NeuralQueryContext
      value_encodings: (bert features for values, length of value span)
      relation_encodings: (bert features for relations, length of relation span)
      num_gpus: number of gpus for distributed computation
      encoder: encoder (layers.RNN) for parameter sharing between train and dev

    Needs:
      self.input_ph: input to encoder (either one-hot or BERT layers)
      self.mask_ph: mask for the input
      self.correct_set_ph.name: target labels (if loss or accuracy is computed)
      self.prior_start: sparse matrix for string similarity features
      self.is_training: whether the model should is training (for dropout)

    Exposes:
      self.loss: objective for loss
      self.accuracy: mean accuracy metric (P_{predicted}(gold))
      self.accuracy_per_ex: detailed per example accuracy
      self.log_nql_pred_set: predicted entity set (in nql)
      self.log_decoded_relations: predicted relations (as indices)
      self.log_start_values: predicted start values (in nql)
      self.log_start_cmps: components of predicted start values (in nql)
    """
        # Encodings should have the same dimensions
        assert value_encodings[0].shape[-1] == relation_encodings[0].shape[-1]
        self.context = nqc
        self.input_ph = tf.placeholder(tf.float32,
                                       shape=(None, FLAGS.max_query_length,
                                              value_encodings[0].shape[-1]),
                                       name="oh_seq_ph")
        self.mask_ph = tf.placeholder(tf.float32,
                                      shape=(None, FLAGS.max_query_length),
                                      name="oh_mask_ph")
        self.debug = None
        layer_size = FLAGS.layer_size
        num_layers = FLAGS.num_layers
        max_properties = FLAGS.max_properties
        logits_strategy = FLAGS.logits
        dropout_rate = FLAGS.dropout

        inferred_batch_size = tf.shape(self.input_ph)[0]
        self.is_training = tf.placeholder(tf.bool, shape=[])
        value_tensor = util.reshape_to_tensor(value_encodings[0],
                                              value_encodings[1])
        relation_tensor = util.reshape_to_tensor(relation_encodings[0],
                                                 relation_encodings[1])
        # The last state of LSTM encoder is the representation of the input string
        with tf.variable_scope("model"):
            # Build all the model parts:

            #   encoder: LSTM encoder
            #   prior: string features
            #   {value, relation}_similarity: learned embedding similarty
            #   decoder: LSTM decoder
            #   value_model: map from encoder to key for attention
            #   attention: Luong (dot product) attention

            # Builds encoder - note that this is in keras
            self.encoder = self._build_encoder(encoder, layer_size, num_layers)

            # Build module to turn prior (string features) into logits
            self.prior_start = tf.sparse.placeholder(
                tf.float32,
                name="prior_start_ph",
                shape=[inferred_batch_size, value_tensor.shape[1]])

            with tf.variable_scope("prior"):
                prior = Prior()

            # Build similarity module - biaffine qAr
            with tf.variable_scope("value_similarity"):
                value_similarity = Similarity(layer_size, value_tensor,
                                              num_gpus)
            # Build relation decoder
            with tf.variable_scope("relation_decoder"):
                rel_dec_rnn_layers = [
                    contrib_rnn.LSTMBlockCell(layer_size,
                                              name=("attr_lstm_%d" % i))
                    for (i, layer_size) in enumerate([layer_size] * num_layers)
                ]
                relation_decoder_cell = tf.nn.rnn_cell.MultiRNNCell(
                    rel_dec_rnn_layers)
                tf.logging.info(
                    "relation decoder lstm has state of size: {}".format(
                        relation_decoder_cell.state_size))

            # Build similarity module - biaffine qAr
            with tf.variable_scope("relation_similarity"):
                relation_similarity = Similarity(layer_size, relation_tensor,
                                                 1)
            with tf.variable_scope("attention"):
                attention = layers.Attention()
            value_model = tf.get_variable(
                "value_transform",
                shape=[layer_size, relation_decoder_cell.output_size],
                trainable=True)

        # Initialization for logging, variables shouldn't be used elsewhere
        log_decoded_starts = []
        log_start_logits = []
        log_decoded_relations = []

        # Initialization to prepare before first iteration of loop
        prior_logits_0 = prior.compute_logits(
            tf.sparse.to_dense(self.prior_start))
        cumulative_entities = nqc.all("id_t")
        relation_decoder_out = tf.zeros([inferred_batch_size, layer_size])
        encoder_output = self.encoder(self.input_ph, mask=self.mask_ph)
        query_encoder_out = encoder_output[0]
        relation_decoder_state = encoder_output[1:]

        # Initialization for property loss, equal to log vars but separating
        value_dist = []
        relation_dist = []

        for i in range(max_properties):
            prior_logits = tf.layers.dropout(prior_logits_0,
                                             rate=dropout_rate,
                                             training=self.is_training)
            # Use the last state to determine key; more stable than last output
            query_key = tf.nn.relu(
                tf.matmul(
                    tf.expand_dims(relation_decoder_state[-1][-1], axis=1),
                    value_model))

            query_emb = tf.squeeze(attention(
                [query_key, query_encoder_out],
                mask=[None, tf.cast(self.mask_ph, tf.bool)]),
                                   axis=1)

            similarity_logits = value_similarity.compute_logits(query_emb)
            if logits_strategy == "prior":
                total_logits = prior_logits
            elif logits_strategy == "sim":
                total_logits = similarity_logits
            elif logits_strategy == "mixed":
                total_logits = prior_logits + similarity_logits
            total_dist = contrib_layers.softmax(total_logits)
            values_pred = nqc.as_nql(total_dist, "val_g")
            with tf.variable_scope("start_follow_{}".format(i)):
                start_pred = nqc.all("v_t").follow(
                    values_pred)  # find starting nodes

            # Given the previous set of attributes, where are we going?
            (relation_decoder_out,
             relation_decoder_state) = relation_decoder_cell(
                 relation_decoder_out, relation_decoder_state)
            pred_relation = tf.nn.softmax(
                relation_similarity.compute_logits(relation_decoder_out))
            if FLAGS.enforce_type:
                if i == 0:
                    is_adjust = nqc.as_tf(nqc.one(IS_A, "rel_g"))
                else:
                    is_adjust = 1 - nqc.as_tf(nqc.one(IS_A, "rel_g"))
                pred_relation = pred_relation * is_adjust
            nql_pred_relation = nqc.as_nql(pred_relation, "rel_g")
            # Conjunctive (& start.follow() & start.follow()...).
            with tf.variable_scope("relation_follow_{}".format(i)):
                current_entities = start_pred.follow(nql_pred_relation)
            cumulative_entities = cumulative_entities & current_entities

            # For property loss and regularization
            value_dist.append(total_dist)
            relation_dist.append(pred_relation)

            # Store predictions for logging
            log_decoded_starts.append(start_pred)
            log_decoded_relations.append(pred_relation)
            log_start_logits.append([prior_logits, similarity_logits])

        (loss, pred_set_tf,
         pred_set_tf_norm) = self._compute_loss(cumulative_entities)
        property_loss = self._compute_property_loss(value_dist, relation_dist)
        (accuracy_per_ex,
         accuracy) = self._compute_accuracy(cumulative_entities, pred_set_tf)
        value_loss = self._compute_distribution_regularizer(value_dist)
        relation_loss = self._compute_distribution_regularizer(relation_dist)
        self.regularization = FLAGS.time_reg * (value_loss + relation_loss)
        self.loss = loss - self.regularization
        self.property_loss = property_loss
        self.accuracy_per_ex = accuracy_per_ex
        self.accuracy = accuracy

        # Debugging/logging information
        log_decoded_relations = tf.transpose(tf.stack(log_decoded_relations),
                                             [1, 0, 2])
        tf.logging.info("decoded relations has shape: {}".format(
            log_decoded_relations.shape))
        self.log_start_values = log_decoded_starts
        self.log_start_cmps = [[
            nqc.as_nql(logits, "val_g") for logits in comp
        ] for comp in log_start_logits]
        self.log_decoded_relations = tf.nn.top_k(log_decoded_relations, k=5)
        self.log_nql_pred_set = nqc.as_nql(pred_set_tf_norm, "id_t")
Exemplo n.º 24
0
def get_global_attention_layer(input_number,
                               one_input_count=1,
                               idx=0,
                               head_units=512,
                               end_head_units=1024,
                               hidden_units=1024,
                               dropout=0.2,
                               key_size=128,
                               n_multi_heads=2,
                               n_add_heads=2):
    """
    Get the global attention modules with residual fully connected modules for the descriptor network as described
    in the paper. These layers gather information over all lines and output them in an embedding.
    Multiple layers can be connected at the end.
    :param input_number: The maximum number of lines.
    :param one_input_count: The input dimension of the global attention layer. If this is one, the network will
                            learn the query, if not, the network will transform the one_input into a query with a
                            learned layer.
    :param idx: The unique index of this layer (only for the name).
    :param head_units: The dimensionality of the output line embeddings.
    :param end_head_units: The dimensionality of the output of this global attention layer.
    :param hidden_units: The dimensionality of the hidden layer in the residual fully connected modules.
    :param dropout: The drop_out ratio of the fully connected layers.
    :param key_size: The dimensionality of the attention heads.
    :param n_multi_heads: The number of dot product attention heads.
    :param n_add_heads: The number of additive attention heads.
    :return: A Keras model of the global attention module with residual fully connected modules.
    """
    assert n_multi_heads + n_add_heads >= 0

    output_size = end_head_units

    model_input = kl.Input(shape=(input_number, head_units))
    one_input = kl.Input(shape=(1, one_input_count))

    use_bias = not one_input_count == 1

    outputs_multi = []
    for i in range(n_multi_heads):
        # More layers can be added here.
        # For the keys, bias should actually be added. This should be changed in the future.
        # keys = kl.TimeDistributed(kl.Dense(key_size, use_bias=use_bias))(model_input)
        keys = kl.TimeDistributed(kl.Dense(key_size))(model_input)
        keys = kl.TimeDistributed(kl.BatchNormalization())(keys)
        queries = kl.TimeDistributed(kl.Dense(key_size,
                                              use_bias=use_bias))(one_input)
        queries = kl.TimeDistributed(kl.BatchNormalization())(queries)
        output = kl.Attention(use_scale=True)([queries, keys])
        outputs_multi.append(output)

    outputs_add = []
    for i in range(n_add_heads):
        # More layers can be added here.
        # For the keys, bias should actually be added. This should be changed in the future.
        # keys = kl.TimeDistributed(kl.Dense(key_size, use_bias=use_bias))(model_input)
        keys = kl.TimeDistributed(kl.Dense(key_size))(model_input)
        keys = kl.TimeDistributed(kl.BatchNormalization())(keys)
        queries = kl.TimeDistributed(kl.Dense(key_size,
                                              use_bias=use_bias))(one_input)
        queries = kl.TimeDistributed(kl.BatchNormalization())(queries)
        output = kl.AdditiveAttention(use_scale=True)([queries, keys])
        outputs_add.append(output)

    outputs = outputs_multi + outputs_add
    if len(outputs) > 1:
        output = kl.Concatenate()(outputs)
    else:
        output = outputs[0]

    output = kl.Dense(output_size)(output)
    output = kl.Dropout(dropout)(output)
    output = kl.LeakyReLU()(output)
    output = kl.BatchNormalization()(output)

    # Add a residual fully connected layer at the end.
    output_2 = kl.Dense(hidden_units)(output)
    output_2 = kl.BatchNormalization()(output_2)
    output_2 = kl.LeakyReLU()(output_2)
    output_2 = kl.Dense(output_size)(output_2)
    output_2 = kl.BatchNormalization()(output_2)

    output = kl.Add()([output, output_2])
    output = kl.LeakyReLU()(output)

    model = km.Model(inputs=[model_input, one_input],
                     outputs=output,
                     name='global_attention_{}'.format(idx))
    return model
    def __init__(self,
                 z1_dim=32,
                 z2_dim=32,
                 z1_rhus=[256, 256],
                 z2_rhus=[256, 256],
                 tr_shape=(20, 80),
                 mu_nl=None,
                 logvar_nl=None,
                 name="encoder",
                 **kwargs):

        super(Encoder, self).__init__(name=name, **kwargs)

        # latent dims
        self.z1_dim = z1_dim
        self.z2_dim = z2_dim

        # RNN specs for z2_pre_encoder
        self.z2_rhus = z2_rhus

        ## Bidirectional LSTMs with attention layer
        self.lstm_layer1_z2 = layers.Bidirectional(layers.LSTM(
            self.z2_rhus[0], return_sequences=True, time_major=False),
                                                   merge_mode='concat')
        self.lstm_layer2_z2 = layers.Bidirectional(layers.LSTM(
            self.z2_rhus[1], return_sequences=True, time_major=False),
                                                   merge_mode='concat')
        self.attn_fc_v = layers.Dense(512,
                                      use_bias=True,
                                      kernel_initializer='glorot_uniform',
                                      bias_initializer='zeros')
        self.attn_fc_q = layers.Dense(512,
                                      use_bias=True,
                                      kernel_initializer='glorot_uniform',
                                      bias_initializer='zeros')
        self.attn = layers.Attention(use_scale=True)
        self.query = tf.Variable(tf.random.normal([256, 1, 512], stddev=1.0),
                                 trainable=True)

        # RNN specs for z1_pre_encoder
        self.z1_rhus = z1_rhus

        ## Bidirectional LSTMs
        self.lstm_layer1_z1 = layers.Bidirectional(layers.LSTM(
            self.z1_rhus[0],
            return_sequences=True,
            return_state=True,
            time_major=False),
                                                   merge_mode='concat')
        self.lstm_layer2_z1 = layers.Bidirectional(layers.LSTM(
            self.z1_rhus[1], return_state=True, time_major=False),
                                                   merge_mode='concat')

        # fully connected layers for computation of mu and sigma
        self.z1mu_fclayer = layers.Dense(z1_dim,
                                         activation=mu_nl,
                                         use_bias=True,
                                         kernel_initializer='glorot_uniform',
                                         bias_initializer='zeros')

        self.z1logvar_fclayer = layers.Dense(
            z1_dim,
            activation=logvar_nl,
            use_bias=True,
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros')

        self.z2mu_fclayer = layers.Dense(z2_dim,
                                         activation=mu_nl,
                                         use_bias=True,
                                         kernel_initializer='glorot_uniform',
                                         bias_initializer='zeros')

        self.z2logvar_fclayer = layers.Dense(
            z2_dim,
            activation=logvar_nl,
            use_bias=True,
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros')
Exemplo n.º 26
0
def self_attention_model(tensor_in):
    atten = layers.Reshape((-1, tensor_in.shape[3]))(tensor_in)
    atten = layers.Attention()([atten,atten])
    atten = layers.Reshape(tensor_in.shape[1:])(atten)
    return atten
Exemplo n.º 27
0
    def __init__(self, x_depth, enc_rnn_dim, enc_dropout, dec_rnn_dim, dec_dropout, cont_dim, cat_dim, mu_force, t_gumbel, 
                       style_embed_dim, beta_anneal_steps, kl_reg, rnn_type, attention):
        super(MVAE, self).__init__()
        
        self.summaries = []
        self.features = ['pitch', 'dt', 'duration']

        self.x_depth = x_depth
        self.x_dim = sum(x_depth)
        self.content_embed_dim = 32
        self.t_gumbel = float(t_gumbel)
        self.cont_dim = int(cont_dim)
        self.cat_dim = int(cat_dim)
        self.mu_force = float(mu_force)
        self.beta_anneal_steps = int(beta_anneal_steps)
        self.kl_reg = float(kl_reg)
        self.rnn_type = rnn_type
        self.attention = attention

        self.anneal_kl_loss = True

        enc_rnn_dim = int(enc_rnn_dim)
        enc_dropout = float(enc_dropout)
        dec_rnn_dim = int(dec_rnn_dim)
        dec_dropout = float(dec_dropout)
        style_embed_dim = int(style_embed_dim)

        self.enc_rnn_dim = enc_rnn_dim
        
        self.ohc = tfp.distributions.OneHotCategorical
        self.relaxed_ohc = tfp.distributions.RelaxedOneHotCategorical

        # music content variables
        self.pitch_embedding = tfkl.Embedding(x_depth[0], self.content_embed_dim, name='pitch_embedding')
        self.lstm_encoder = tfkl.Bidirectional(self.create_rnn_layer(enc_rnn_dim, dropout=enc_dropout, return_state=True), name='lstm_encoder')
        self.content_head = tfkl.Dense(512, activation='relu', name='content_head')
        self.z_params = tfkl.Dense(self.cont_dim + self.cont_dim, name='z_params')
        
        # only one of those layers is used, based on rnn_type
        self.lstm_decoder_init_state = tfkl.Dense(dec_rnn_dim * 2, name='lstm_decoder_init', activation='tanh')
        self.gru_decoder_init_state = tfkl.Dense(dec_rnn_dim, name='gru_decoder_init', activation='tanh')

        self.lstm_decoder = self.create_rnn_layer(dec_rnn_dim, dropout=dec_dropout, return_state=True, return_sequences=True, name='lstm_decoder')
        self.logit_layer = tfkl.Dense(self.x_dim, name='content_logits')
        
        self.anneal_step = tf.Variable(0.0, trainable=False)
        
        # music style variables
        self.cat_head = tfkl.Dense(512, activation='relu', name='style_head')
        self.z_cat_logit = tfkl.Dense(self.cat_dim, name='z_cat_logit')
        self.style_embedding = tf.Variable(tf.random.uniform([self.cat_dim, style_embed_dim], -1.0, 1.0), trainable=True, name="style_embedding")

        # Luong attention
        self.query_layer = tfkl.Dense(self.attention)
        self.key_layer = tfkl.Dense(self.attention)
        self.attention_layer = tfkl.Attention(use_scale=True, causal=True, dropout=0.2, name='attention_layer')

        
        # train metric trackers
        self.recon_loss_tracker = tfk.metrics.Mean(name="recon_loss_tracker")
        self.kl_loss_tracker = tfk.metrics.Mean(name="kl_loss_tracker")
        self.style_loss_tracker = tfk.metrics.Mean(name="style_loss_tracker")
        self.loss_tracker = tfk.metrics.Mean(name="loss_tracker")
        
        self.p_acc_tracker = tfk.metrics.Mean(name="p_accuracy_tracker")
        self.dt_acc_tracker = tfk.metrics.Mean(name="dt_accuracy_tracker")
        self.d_acc_tracker = tfk.metrics.Mean(name="d_accuracy_tracker")
        self.style_acc_tracker = tfk.metrics.Mean(name="style_accuracy_tracker")
        
        # test metric trackers
        self.val_recon_loss_tracker = tfk.metrics.Mean(name="val_recon_loss_tracker")
        self.val_kl_loss_tracker = tfk.metrics.Mean(name="val_kl_loss_tracker")
        self.val_style_loss_tracker = tfk.metrics.Mean(name="val_style_loss_tracker")
        self.val_loss_tracker = tfk.metrics.Mean(name="val_loss_tracker")
        
        self.val_p_acc_tracker = tfk.metrics.Mean(name="val_p_accuracy_tracker")
        self.val_dt_acc_tracker = tfk.metrics.Mean(name="val_dt_accuracy_tracker")
        self.val_d_acc_tracker = tfk.metrics.Mean(name="val_d_accuracy_tracker")
        self.val_style_acc_tracker = tfk.metrics.Mean(name="val_style_accuracy_tracker")
        
        # generic accuracy used for computing raw accuracies
        self.accuracy_tracker = tfk.metrics.Accuracy(name='accuracy_tracker')