예제 #1
0
    def from_saved_custom(cls, save_dir, generation=False, compile_now=True):
        with open(save_dir + "/parameters", "r") as handle:
            param_dict = json.load(handle)

        bar_embedder = BarEmbedding(*param_dict["bar_embed_params"],
                                    compile_now=False)
        rhythm_net = RhythmNetwork(bar_embedder,
                                   *param_dict["rhythm_net_params"],
                                   compile_now=False)
        melody_net = MelodyNetwork(*param_dict["melody_net_params"],
                                   compile_now=False)

        bar_embedder.load_weights(save_dir + "/bar_embedding_weights")
        rhythm_net.load_weights(save_dir + "/rhythm_net_weights")
        melody_net.load_weights(save_dir + "/melody_net_weights")

        combined_net = cls(param_dict["context_size"],
                           param_dict["melody_bar_len"],
                           param_dict["meta_len"],
                           bar_embedder,
                           rhythm_net,
                           melody_net,
                           generation=generation,
                           compile_now=compile_now)

        return combined_net
예제 #2
0
    def from_network_parameters(cls,
                                context_size,
                                melody_bar_len,
                                meta_len,
                                bar_embed_params,
                                rhythm_net_params,
                                melody_net_params,
                                generation=False,
                                compile_now=True):

        bar_embedder = BarEmbedding(*bar_embed_params, compile_now=False)
        rhythm_net = RhythmNetwork(bar_embedder,
                                   *rhythm_net_params,
                                   compile_now=False)
        melody_net = MelodyNetwork(*melody_net_params, compile_now=False)

        combined_net = cls(context_size,
                           melody_bar_len,
                           meta_len,
                           bar_embedder,
                           rhythm_net,
                           melody_net,
                           generation=generation,
                           compile_now=compile_now)

        return combined_net
예제 #3
0
    conv_win_size = 3
    melody_enc_lstm_size = 52
    melody_dec_lstm_1_size = 32
    melody_dec_lstm_2_size = 32

    meta_data_len = 10

    # INDIVIDUAL NETS
    be = BarEmbedding(V=V_rhythm,
                      beat_embed_size=beat_embed_size,
                      embed_lstm_size=embed_lstm_size,
                      out_size=out_size)

    rhythm_net = RhythmNetwork(bar_embedder=be,
                               context_size=context_size,
                               enc_lstm_size=rhythm_enc_lstm_size,
                               dec_lstm_size=rhythm_dec_lstm_size,
                               enc_use_meta=False,
                               dec_use_meta=True)

    melody_net = MelodyNetwork(m=m,
                               V=V_melody,
                               rhythm_embed_size=out_size,
                               conv_f=conv_f,
                               conv_win_size=conv_win_size,
                               enc_lstm_size=melody_enc_lstm_size,
                               dec_lstm_1_size=melody_dec_lstm_1_size,
                               enc_use_meta=False,
                               dec_use_meta=True)

    print("Individual networks set up...\n")
예제 #4
0
    conv_f = 4
    conv_win_size = 3
    melody_enc_lstm_size = 52
    melody_dec_lstm_size = 32

    # INDIVIDUAL NETS
    bar_embedder = BarEmbedding(V=V_rhythm,
                                beat_embed_size=beat_embed_size,
                                embed_lstm_size=embed_lstm_size,
                                out_size=out_size)
    rhythm_encoder = RhythmEncoder(bar_embedder=bar_embedder,
                                   context_size=rc_size,
                                   lstm_size=rhythm_enc_lstm_size)
    rhythm_net = RhythmNetwork(rhythm_encoder=rhythm_encoder,
                               dec_lstm_size=rhythm_dec_lstm_size,
                               V=V_rhythm,
                               dec_use_meta=True,
                               compile_now=True)

    # ATTENTION: conv_win_size must not be greater than context size!
    melody_encoder = MelodyEncoder(m=m,
                                   conv_f=conv_f,
                                   conv_win_size=min(mc_size, conv_win_size),
                                   enc_lstm_size=melody_enc_lstm_size)
    melody_net = MelodyNetwork(melody_encoder=melody_encoder,
                               rhythm_embed_size=out_size,
                               dec_lstm_size=melody_dec_lstm_size,
                               V=V_melody,
                               dec_use_meta=True,
                               compile_now=True)
예제 #5
0
#%%
        
with open("bach_beats.pkl", "rb") as handle:
    data = pickle.load(handle)

data = rand.permutation(data)
parts = [list(map(tuple, part)) for score in data for part in score]
_, label_d = label([b for p in parts for b in p], start=1)
pad_symb = "<s>"
label_d[pad_symb] = 0

#%%
r = RhythmGenerator(4, label_d, "<s>")

r_gen = r.generate_data(parts, shuffle=False)

#%%
rnet = RhythmNetwork(num_categories=len(label_d),
                     embed_size=32,
                     lstm_size=64)
#%%
rnet.fit_generator(r_gen, steps_per_epoch=len(parts), epochs=1)

#%%
rnet.save_weights("rnet_weights.h5")
#%%
rnet2 = RhythmNetwork.from_weights("rnet_weights.h5", num_categories=len(label_d),
                     embed_size=32,
                     lstm_size=64)
#%%
rnet2.fit_generator(r_gen, steps_per_epoch=len(parts), epochs=1)