示例#1
0
def train(teacher_forcing_prob):
    #tf.debugging.set_log_device_placement(True)

    train_data, test_data, constants = load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = train_data

    # Add slices here to train on only a subset of the data
    encoder_input_data = encoder_input_data
    decoder_input_data = decoder_input_data
    decoder_target_data = decoder_target_data

    model = _build_model(constants, teacher_forcing_prob)
    model.summary(line_length=200)

    l = create_losses(constants['MASK_VALUE'])
    o = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.8)
    epochs = int(sys.argv[3])
    model.compile(optimizer=o, loss=l, loss_weights=LOSS_WEIGHTS)

    model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
        filepath=MODEL_PATH + '/bigthrush/{epoch:02d}',
        save_weights_only=False,
        save_freq=150)

    model.fit([encoder_input_data, decoder_input_data],
              decoder_target_data,
              batch_size=64,
              epochs=epochs,
              validation_split=0.05,
              shuffle=True,
              verbose=1,
              callbacks=[model_checkpoint_callback])

    model.save('runmodel')
示例#2
0
def train(teacher_forcing_prob):
    train_data, test_data, constants = feature_extractors.load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = train_data

    model = _build_model(constants, teacher_forcing_prob)
    model.summary(line_length=200)

    l = create_losses(constants['MASK_VALUE'])
    o = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.8)
    epochs = int(sys.argv[3])
    model.compile(optimizer=o, loss=l, loss_weights=LOSS_WEIGHTS)

    model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
        filepath=
        'thrush_attn_bidir_32_istermmse_lr001B09_luongattn_big/{epoch:02d}',
        save_weights_only=False,
        save_freq=150)

    model.fit([encoder_input_data, decoder_input_data],
              decoder_target_data,
              batch_size=128,
              epochs=epochs,
              validation_split=0.05,
              shuffle=True,
              verbose=1)

    model.save('runmodel2')
示例#3
0
def train():
    train_data, test_data, constants = feature_extractors.load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = train_data

    # Add slices here to train on only a subset of the data
    # encoder_input_data = encoder_input_data[:50]
    # decoder_input_data = decoder_input_data[:50]
    # decoder_target_data = decoder_target_data[:50]

    model = _build_model(constants)

    l = create_losses(constants['MASK_VALUE'])
    o = keras.optimizers.Adam(learning_rate=0.005, beta_1=0.8)
    epochs = int(sys.argv[3])
    model.compile(optimizer=o, loss=l)
    model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
        filepath=f'thrush_attn_bidir_{LATENT_DIM}',
        save_weights_only=False,
        save_freq=100)
    model.fit([encoder_input_data, decoder_input_data],
              decoder_target_data,
              batch_size=64,
              epochs=epochs,
              validation_split=0.05,
              shuffle=True,
              callbacks=[model_checkpoint_callback],
              verbose=1)
    model.save(f'thrush_attn_bidir_{LATENT_DIM}_{epochs}')
示例#4
0
def train():
    train_data, test_data, constants = feature_extractors.load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = train_data
    model = _build_model(constants)
    print(model.summary())
    l = keras.losses.MeanSquaredError()
    o = keras.optimizers.Adam(learning_rate=0.0075)
    model.compile(optimizer=o, loss=l)
    model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
            batch_size=64,
            epochs=int(sys.argv[3]),
            validation_split=0.05)
    model.save(f'chorale_model_bidirect_{LATENT_DIM}')
示例#5
0
def train():
    train_data, test_data, constants = feature_extractors.load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = train_data
    model = _build_model(constants)

    l = keras.losses.MeanSquaredError()
    o = keras.optimizers.Adam(learning_rate=0.005)
    model.compile(optimizer=o, loss=l)
    model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
            batch_size=32,
            epochs=100,
            validation_split=0.05)
    model.save('chorale_model_128')
示例#6
0
def predict():
    train_data, test_data, constants = feature_extractors.load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = test_data

    model = keras.models.load_model(f'uni_lstm_{LATENT_DIM}')

    # Extract encoder from graph
    encoder_inputs = model.input[0]
    _, state_h_enc, state_c_enc = model.layers[3].output  # lstm_1
    encoder_states = [state_h_enc, state_c_enc]
    encoder_model = keras.Model(encoder_inputs, encoder_states)

    # Extract decoder from graph
    decoder_inputs = model.input[1]
    decoder_state_input_h = keras.Input(shape=(LATENT_DIM, ), name="input_3")
    decoder_state_input_c = keras.Input(shape=(LATENT_DIM, ), name="input_4")
    decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

    decoder_lstm = model.layers[-2]
    decoder_dense = model.layers[-1]

    decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
        decoder_inputs, initial_state=decoder_states_inputs)
    decoder_states = [state_h_dec, state_c_dec]
    decoder_outputs = decoder_dense(decoder_outputs)
    decoder_model = keras.Model([decoder_inputs] + decoder_states_inputs,
                                [decoder_outputs] + decoder_states)

    def _terminate(toks):
        return np.isclose(np.mean(-1 - toks), 0.0, atol=0.5)

    def decode(input_seq):
        states_value = encoder_model.predict(np.array([input_seq]))
        target_seq = np.ones((1, 1, constants['Y_DIM'])) * -1.

        result = []
        stop = False
        for _ in range(100):
            output_tokens, h, c = decoder_model.predict([target_seq] +
                                                        states_value)
            if _terminate(output_tokens):
                return result
            result.append(output_tokens)

            target_seq = np.ones((1, 1, constants['Y_DIM'])) * output_tokens
            states_value = [h, c]
        print("Decoding did not terminate! Returning large RNA.")
        return result

    def cut_off_ground_truth(ground_truth):
        res = []
        for g in ground_truth:
            if _terminate(g):
                return res
            res.append(g)
        print("Ground truth does not terminate! Returning large RNA.")

    err_rates = []
    len_diffs = []
    for chorale_ind in range(len(encoder_input_data))[:15]:
        print("Eval for chorale " + str(chorale_ind))
        decoded = decode(encoder_input_data[chorale_ind])
        decoded_rna_chords = [
            feature_extractors.RNAChord(encoding=decoded[i][0][0])
            for i in range(len(decoded))
        ]

        ground_truth = cut_off_ground_truth(decoder_target_data[chorale_ind])
        ground_truth_chords = [
            feature_extractors.RNAChord(encoding=ground_truth[i])
            for i in range(len(ground_truth))
        ]

        errs = scoring.levenshtein(
            ground_truth_chords,
            decoded_rna_chords,
            equality_fn=scoring.EQUALITY_FNS['key_enharmonic'])
        print(len(ground_truth_chords) - len(decoded_rna_chords))
        len_diffs.append(
            abs(len(ground_truth_chords) - len(decoded_rna_chords)))
        err_rates.append(float(errs / len(ground_truth_chords)))
    print("Error rate: " + str(np.mean(err_rates)))
    print("Len diff: " + str(np.mean(len_diffs)))
示例#7
0
def predict(epochs, teacher_forcing_prob):
    train_data, test_data, constants = load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = test_data

    # Add slices here to test only a subset of the data
    # encoder_input_data = encoder_input_data[:1]
    # decoder_input_data = decoder_input_data[:1]
    # decoder_target_data = decoder_target_data[:1]

    # model = keras.models.load_model(MODEL_PATH + '/thrush_attn_32_bs256_istermmse_lr001B09_luongattn_big_TF0_3/510',
    #                                 custom_objects={
    #                                     'DecoderLayer': DecoderLayer,
    #                                     'AttentionCell': AttentionCell,
    #                                 },
    #                                 compile=False)
    model = _build_model(constants, teacher_forcing_prob)
    #model.load_weights(MODEL_PATH + '/thrush_attn_32_bs256_istermmse_lr001B09_luongattn_big_TF075_4/660/variables/variables')
    #model.load_weights(MODEL_PATH + '/thrush_attn_32_bs256_istermmse_lr001B09_luongattn_big_TF05_4/720/variables/variables')
    model.load_weights('runmodel/variables/variables')

    #model.load_weights(MODEL_PATH + '/thrush_attn_64_bs256_attnin_lr001B09_luongattn_big_TF09_dropout/700/variables/variables')

    def _get_layers(layer_type):
        return [l for l in model.layers if layer_type in str(type(l))]

    # We have to do a weird thing here to make decoding work: We need to create
    # new Keras model which instead of consuming an entire chorale/RNA sequence,
    # consumes only a chorale sequence, and produces the RNA sequence one step
    # at a time.  To do this, we extract the relevant layers of the training
    # model one by one and then piece them back together.  I got the idea from
    # https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html

    # Extract encoder from graph
    encoder_inputs = model.input[0]
    encoder_outputs, state_h_enc, state_c_enc = _get_layers(
        'LSTM')[-1].output  # lstm_1
    encoder_states = [state_h_enc, state_c_enc]
    encoder_model = keras.Model(encoder_inputs,
                                [encoder_states, encoder_outputs])

    # Extract decoder from graph
    decoder_inputs = keras.Input(shape=(1, constants['Y_DIM']),
                                 name='rna_input_inference')
    dense_out_inputs = keras.Input(shape=(constants['Y_DIM']),
                                   name='dense_out_inputs_inference')
    # Inputs for the last decoder state
    decoder_state_inputs = [
        keras.Input(shape=(LATENT_DIM, ), name='decoder_state_h_1_inference'),
        keras.Input(shape=(LATENT_DIM, ), name='decoder_state_h_2_inference'),
        keras.Input(shape=(LATENT_DIM, ), name='decoder_state_h_3_inference'),
        keras.Input(shape=(LATENT_DIM, ), name='decoder_state_c_1_inference'),
        keras.Input(shape=(LATENT_DIM, ), name='decoder_state_c_2_inference'),
        keras.Input(shape=(LATENT_DIM, ), name='decoder_state_c_3_inference'),
    ]
    decoder_state_input_attn_energies = keras.Input(
        shape=(constants['MAX_CHORALE_LENGTH']),
        name='decoder_state_input_attn_energies')
    encoder_output_input = keras.Input(shape=(constants['MAX_CHORALE_LENGTH'],
                                              LATENT_DIM),
                                       name="encoder_output_input")

    decoder_recurrent = _get_layers('recurrent.RNN')[0]

    decoder_states_inputs = [
        dense_out_inputs, decoder_state_inputs,
        decoder_state_input_attn_energies, encoder_output_input
    ]
    dense_outputs, _, _, decoder_state, attention_energies, _ = decoder_recurrent(
        decoder_inputs, initial_state=decoder_states_inputs)

    output_components = _convert_dense_to_output_components(dense_outputs)

    ins = [decoder_inputs]
    ins.extend(decoder_states_inputs)
    outs = [output_components, dense_outputs]
    outs.extend([decoder_state, attention_energies])
    decoder_model = keras.Model(ins, outs)

    # Returns true when the "is_terminal" output is set to -1, meaning the RNA
    # is finished.
    def _terminate(toks):
        return (toks[-1] <= -0.5)

    # Cut off analysis at minimum value of is_terminal
    def _cut_off(result, attn_energies, output_tokens):
        is_terms = [r[0][0][-1] for r in output_tokens]
        terminal_chord_ind = np.argmin(is_terms)
        for i, term in enumerate(is_terms):
            if term < 0:
                terminal_chord_ind = i
                break
        return result[:terminal_chord_ind +
                      1], attn_energies[:terminal_chord_ind + 1]

    # Decode a single chorale.
    def decode(input_seq, decoder_input=None):
        (prev_decoder_state_h,
         prev_decoder_state_c), encoder_output_values = encoder_model.predict(
             np.array([input_seq]))
        prev_decoder_states = [
            prev_decoder_state_h, prev_decoder_state_h, prev_decoder_state_h,
            prev_decoder_state_c, prev_decoder_state_c, prev_decoder_state_c
        ]

        prev_attention_energies = tf.zeros_like(
            tf.reduce_sum(encoder_output_values, axis=-1))

        target_seq = np.ones((1, 1, constants['Y_DIM'])) * -5.
        prev_dense = np.ones((1, constants['Y_DIM'])) * -5.

        result = []
        output_tokens_list = []
        attn_energies = []
        for k in range(constants['MAX_ANALYSIS_LENGTH'])[:-1]:
            ins = [target_seq]
            ins.extend([
                prev_dense, prev_decoder_states, prev_attention_energies,
                encoder_output_values
            ])
            ins.append(encoder_output_values)
            output_components, dense_outs, prev_decoder_states, prev_attention_energies = decoder_model.predict(
                ins)
            attn_energies.append(prev_attention_energies)
            output_tokens_after_softmax = np.concatenate(
                [output_components[key] for key in COMPONENTS_IN_ORDER],
                axis=-1)

            # translate to RNAChord during decoding
            rna_chord = RNAChord(encoding=output_tokens_after_softmax[0][0])
            output_tokens_from_chord = [[rna_chord.encode()]]
            output_tokens_list.append(output_tokens_after_softmax)
            result.append(output_tokens_from_chord)

            target_seq = np.ones((1, 1, constants['Y_DIM'])) * dense_outs
            prev_dense = np.ones((1, constants['Y_DIM'])) * dense_outs
            prev_dense = np.squeeze(prev_dense, axis=1)

        result, attn_energies = _cut_off(result, attn_energies,
                                         output_tokens_list)
        return result, attn_energies

    def cut_off_ground_truth(ground_truth):
        res = []
        for g in ground_truth:
            if _terminate(g):
                return res
            res.append(g)
        print("Ground truth does not terminate! Returning large RNA.")

    err_rates = []
    len_diffs = []
    attn_energy_matrixes = []
    chorale_inds = list(range(len(encoder_input_data)))
    random.shuffle(chorale_inds)
    for chorale_ind in chorale_inds:
        print("Eval for chorale " + str(chorale_ind))

        decoded, attn_energies = decode(encoder_input_data[chorale_ind],
                                        decoder_input_data[chorale_ind])
        attn_energy_matrixes.append(attn_energies)
        decoded_rna_chords = [
            RNAChord(encoding=decoded[i][0][0]) for i in range(len(decoded))
        ]

        ground_truth = cut_off_ground_truth(decoder_target_data[chorale_ind])
        ground_truth_chords = [
            RNAChord(encoding=ground_truth[i])
            for i in range(len(ground_truth))
        ]

        err_rates = collections.defaultdict(list)
        for fn_name in EQUALITY_FNS.keys():
            errs = levenshtein(ground_truth_chords,
                               decoded_rna_chords,
                               equality_fn=EQUALITY_FNS[fn_name],
                               substitution_cost=1,
                               left_deletion_cost=1,
                               right_deletion_cost=1)
            err_rates[fn_name].append(float(errs / len(decoded_rna_chords)))

        print("Ground Truth: %s, Decoded: %s" %
              (len(ground_truth_chords), len(decoded_rna_chords)))

        # Uncomment these lines to see the ground truth RNA sequence together
        # with the decoded prediction.
        # print("--------------------- GROUND TRUTH  ------------------")
        # for c in ground_truth_chords:
        #     print(c)
        # print("---------------------  PREDICTION  -------------------")
        # for c in decoded_rna_chords:
        #     print(c)
        # print("-------------------- ANALYSIS COMPLETE ---------------")

    for fn_name in EQUALITY_FNS.keys():
        print("Error Name: " + fn_name + " Error Rate: " +
              str(np.mean(err_rates[fn_name])))
    return attn_energy_matrixes
示例#8
0
def predict(epochs, teacher_forcing_prob):
    train_data, test_data, constants = feature_extractors.load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = test_data

    model = _build_model(constants, teacher_forcing_prob=teacher_forcing_prob)
    model.load_weights('runmodel2/variables/variables')

    # model.load_weights(
    #     f'thrush_attn_32_bs256_istermmse_lr001B09_luongattn_big_{LATENT_DIM}_{epochs}/variables/variables')

    def _get_layers(layer_type):
        return [l for l in model.layers if layer_type in str(type(l))]

    # Extract encoder from graph
    encoder_inputs = model.input[0]
    enc_outs, state_h_enc_forward, state_c_enc_forward, state_h_enc_backward, state_c_enc_backward = model.layers[
        5].output
    # return
    state_h_enc = keras.layers.Concatenate()(
        [state_h_enc_forward, state_h_enc_backward])
    state_c_enc = keras.layers.Concatenate()(
        [state_c_enc_forward, state_c_enc_backward])
    encoder_states = [state_h_enc, state_c_enc]
    encoder_model = keras.Model(encoder_inputs, [encoder_states, enc_outs])

    # Extract decoder from graph
    decoder_inputs = keras.Input(shape=(1, constants['Y_DIM']),
                                 name='rna_input_inference')
    dense_out_inputs = keras.Input(shape=(constants['Y_DIM']),
                                   name='dense_out_inputs_inference')

    decoder_state_inputs = [
        keras.Input(shape=(LATENT_DIM * 2, ),
                    name='decoder_state_h_1_inference'),
        keras.Input(shape=(LATENT_DIM * 2, ),
                    name='decoder_state_h_2_inference'),
        keras.Input(shape=(LATENT_DIM * 2, ),
                    name='decoder_state_h_3_inference'),
        keras.Input(shape=(LATENT_DIM * 2, ),
                    name='decoder_state_c_1_inference'),
        keras.Input(shape=(LATENT_DIM * 2, ),
                    name='decoder_state_c_2_inference'),
        keras.Input(shape=(LATENT_DIM * 2, ),
                    name='decoder_state_c_3_inference'),
    ]
    decoder_state_input_attn_energies = keras.Input(
        shape=(constants['MAX_CHORALE_LENGTH']),
        name='decoder_state_input_attn_energies')
    encoder_output_input = keras.Input(shape=(constants['MAX_CHORALE_LENGTH'],
                                              2 * LATENT_DIM),
                                       name="encoder_output_input")

    decoder_recurrent = _get_layers('recurrent.RNN')[0]

    decoder_states_inputs = [
        dense_out_inputs, decoder_state_inputs,
        decoder_state_input_attn_energies, encoder_output_input
    ]

    dense_outputs, _, _, decoder_state, attention_energies, _ = decoder_recurrent(
        decoder_inputs, initial_state=decoder_states_inputs)

    output_components = _convert_dense_to_output_components(dense_outputs)

    ins = [decoder_inputs]
    ins.extend(decoder_states_inputs)
    outs = [output_components, dense_outputs]
    outs.extend([decoder_state, attention_energies])
    decoder_model = keras.Model(ins, outs)

    # Retruns true when the "is_terminal" output is set to -1, meaning the RNA
    # is finished.
    def _terminate(toks):
        return (toks[-1] <= -0.5)

    # Cut off analysis at minimum value of is_terminal
    def _cut_off(result, attn_energies, output_tokens):
        is_terms = [r[0][0][-1] for r in output_tokens]
        terminal_chord_ind = np.argmin(is_terms)
        for i, term in enumerate(is_terms):
            if term < 0:
                terminal_chord_ind = i
                break
        return result[:terminal_chord_ind +
                      1], attn_energies[:terminal_chord_ind + 1]

    # Decode a single chorale.
    def decode(input_seq, decoder_input=None):
        (prev_decoded_state_h,
         prev_decoded_state_c), encoder_output_values = encoder_model.predict(
             np.array([input_seq]))
        prev_decoder_states = [
            prev_decoded_state_h, prev_decoded_state_h, prev_decoded_state_h,
            prev_decoded_state_c, prev_decoded_state_c, prev_decoded_state_c
        ]
        prev_attention_energies = tf.zeros_like(
            tf.reduce_sum(encoder_output_values, axis=-1))

        target_seq = np.ones((1, 1, constants['Y_DIM'])) * -5.
        prev_dense = np.ones((1, constants['Y_DIM'])) * -5.

        result = []
        output_tokens_list = []
        attn_energies = []
        for k in range(constants['MAX_ANALYSIS_LENGTH'])[:-1]:
            ins = [target_seq]
            ins.extend([
                prev_dense, prev_decoder_states, prev_attention_energies,
                encoder_output_values
            ])
            ins.append(encoder_output_values)
            output_components, dense_outs, prev_decoder_states, prev_attention_energies = decoder_model.predict(
                ins)
            attn_energies.append(prev_attention_energies)
            output_tokens_after_softmax = np.concatenate(
                [output_components[key] for key in COMPONENTS_IN_ORDER],
                axis=-1)

            # translate to RNAChord during decoding
            rna_chord = RNAChord(encoding=output_tokens_after_softmax[0][0])

            output_tokens_from_chord = [[rna_chord.encode()]]
            output_tokens_list.append(output_tokens_after_softmax)
            result.append(output_tokens_from_chord)

            target_seq = np.ones((1, 1, constants['Y_DIM'])) * dense_outs
            prev_dense = np.ones((1, constants['Y_DIM'])) * dense_outs
            prev_dense = np.squeeze(prev_dense, axis=1)

        result, attn_energies = _cut_off(result, attn_energies,
                                         output_tokens_list)
        return result, attn_energies

    def cut_off_ground_truth(ground_truth):
        res = []
        for g in ground_truth:
            if _terminate(g):
                return res
            res.append(g)
        print("Ground truth does not terminate! Returning large RNA.")

    err_rates = []
    len_diffs = []
    attn_energy_matrixes = []
    chorale_inds = list(range(len(encoder_input_data)))
    random.shuffle(chorale_inds)
    for chorale_ind in chorale_inds:
        print("Eval for chorale " + str(chorale_ind))

        decoded, atnn_energies = decode(encoder_input_data[chorale_ind],
                                        decoder_input_data[chorale_ind])
        attn_energy_matrixes.append(atnn_energies)

        decoded_rna_chords = [
            feature_extractors.RNAChord(encoding=decoded[i][0][0])
            for i in range(len(decoded))
        ]

        ground_truth = cut_off_ground_truth(decoder_target_data[chorale_ind])
        ground_truth_chords = [
            feature_extractors.RNAChord(encoding=ground_truth[i])
            for i in range(len(ground_truth))
        ]

        err_rates = collections.defaultdict(list)
        for fn_name in EQUALITY_FNS.keys():
            errs = levenshtein(ground_truth_chords,
                               decoded_rna_chords,
                               equality_fn=EQUALITY_FNS[fn_name],
                               substitution_cost=1,
                               left_deletion_cost=1,
                               right_deletion_cost=1)
            err_rates[fn_name].append(float(errs / len(decoded_rna_chords)))
        print("Ground Truth: %s, Decoded: %s" %
              (len(ground_truth_chords), len(decoded_rna_chords)))

    for fn_name in EQUALITY_FNS.keys():
        print("Error Name: " + fn_name + " Error Rate: " +
              str(np.mean(err_rates[fn_name])))
    return attn_energy_matrixes
示例#9
0
def predict(epochs):
    train_data, test_data, constants = feature_extractors.load_dataset()
    encoder_input_data, decoder_input_data, decoder_target_data = test_data

    model = keras.models.load_model(
        f'thrush_attn_bidir_{LATENT_DIM}_{epochs}',
        custom_objects={'AttentionLayer': AttentionLayer},
        compile=False)

    def _get_layers(layer_type):
        return [l for l in model.layers if layer_type in str(type(l))]

    # Extract encoder from graph
    encoder_inputs = model.input[0]

    enc_outs, state_h_enc_forward, state_c_enc_forward, state_h_enc_backward, state_c_enc_backward = model.layers[
        2].output
    # return
    state_h_enc = keras.layers.Concatenate()(
        [state_h_enc_forward, state_h_enc_backward])
    state_c_enc = keras.layers.Concatenate()(
        [state_c_enc_forward, state_c_enc_backward])
    encoder_states = [state_h_enc, state_c_enc]
    encoder_model = keras.Model(encoder_inputs, [encoder_states, enc_outs])

    # Extract decoder from graph
    decoder_inputs = keras.Input(shape=(1, constants['Y_DIM']),
                                 name='rna_input_inference')
    decoder_state_input_h = keras.Input(shape=(LATENT_DIM * 2, ),
                                        name="decoder_state_h_inference")
    decoder_state_input_c = keras.Input(shape=(LATENT_DIM * 2, ),
                                        name="decoder_state_c_inference")
    decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
    encoder_output_input = keras.Input(shape=(constants['MAX_CHORALE_LENGTH'],
                                              LATENT_DIM * 2),
                                       name="encoder_output_input")

    decoder_lstm = model.layers[6]
    decoder_attn = model.layers[7]
    decoder_concat = model.layers[8]
    decoder_dense = model.layers[9]

    decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
        decoder_inputs, initial_state=decoder_states_inputs)
    decoder_states = [state_h_dec, state_c_dec]
    attention_outputs, attention_energies = decoder_attn(
        [encoder_output_input, decoder_outputs])
    dense_outputs = decoder_dense(
        decoder_concat([decoder_outputs, attention_outputs]))
    output_components = _convert_dense_to_output_components(dense_outputs)

    ins = [decoder_inputs]
    ins.extend(decoder_states_inputs)
    ins.append(encoder_output_input)
    outs = [output_components]
    outs.extend(decoder_states)
    outs.append(attention_energies)
    decoder_model = keras.Model(ins, outs)

    # Retruns true when the "is_terminal" output is set to -1, meaning the RNA
    # is finished.
    def _terminate(toks):
        return (toks[-1] < 0)

    # Decode a single chorale.
    def decode(input_seq):
        states_value, encoder_output_values = encoder_model.predict(
            np.array([input_seq]))
        target_seq = np.ones((1, 1, constants['Y_DIM'])) * -5.

        result = []
        attn_energies = []
        for k in range(constants['MAX_ANALYSIS_LENGTH']):
            ins = [target_seq]
            ins.extend(states_value)
            ins.append(encoder_output_values)
            output_components, h, c, attn_energy = decoder_model.predict(ins)
            attn_energies.append(attn_energy)
            output_tokens = np.concatenate(
                [output_components[key] for key in COMPONENTS_IN_ORDER],
                axis=-1)
            result.append(output_tokens)
            if _terminate(output_tokens[0][0]):
                return result, attn_energies

            target_seq = np.ones((1, 1, constants['Y_DIM'])) * output_tokens
            states_value = [h, c]
        print("Decoding did not terminate! Returning large RNA.")
        return result, attn_energies

    def cut_off_ground_truth(ground_truth):
        res = []
        for g in ground_truth:
            if _terminate(g):
                return res
            res.append(g)
        print("Ground truth does not terminate! Returning large RNA.")

    err_rates = []
    len_diffs = []
    attn_energy_matrixes = []
    chorale_inds = list(range(len(encoder_input_data)))
    random.shuffle(chorale_inds)
    for chorale_ind in chorale_inds[:20]:
        print("Eval for chorale " + str(chorale_ind))
        # c = encoder_input_data[chorale_ind]
        # c = [i for i in c if i[-1] != -1]
        # print(len(c))

        decoded, attn_energies = decode(encoder_input_data[chorale_ind])
        attn_energy_matrixes.append(attn_energies)

        decoded_rna_chords = [
            feature_extractors.RNAChord(encoding=decoded[i][0][0])
            for i in range(len(decoded))
        ]

        ground_truth = cut_off_ground_truth(decoder_target_data[chorale_ind])
        ground_truth_chords = [
            feature_extractors.RNAChord(encoding=ground_truth[i])
            for i in range(len(ground_truth))
        ]

        errs = scoring.levenshtein(
            ground_truth_chords,
            decoded_rna_chords,
            equality_fn=scoring.EQUALITY_FNS['key_enharmonic_and_parallel'],
            left_deletion_cost=0)
        print(len(ground_truth_chords) - len(decoded_rna_chords))
        len_diffs.append((len(ground_truth_chords) - len(decoded_rna_chords)))
        err_rates.append(float(errs / len(ground_truth_chords)))
        # Uncomment these lines to see the ground truth RNA sequence together
        # with the decoded prediction.
        # print("--------------------- GROUND TRUTH  ------------------")
        # for c in ground_truth_chords:
        #     print(c)
        # print("---------------------  PREDICTION  -------------------")
        # for c in decoded_rna_chords:
        #     print(c)

    print("Error rate: " + str(np.mean(err_rates)))
    print("Len diff: " + str(np.mean(len_diffs)))
    return attn_energy_matrixes