def from_saved_custom(cls, save_dir, generation=False, compile_now=True): with open(save_dir + "/parameters", "r") as handle: param_dict = json.load(handle) bar_embedder = BarEmbedding(*param_dict["bar_embed_params"], compile_now=False) rhythm_net = RhythmNetwork(bar_embedder, *param_dict["rhythm_net_params"], compile_now=False) melody_net = MelodyNetwork(*param_dict["melody_net_params"], compile_now=False) bar_embedder.load_weights(save_dir + "/bar_embedding_weights") rhythm_net.load_weights(save_dir + "/rhythm_net_weights") melody_net.load_weights(save_dir + "/melody_net_weights") combined_net = cls(param_dict["context_size"], param_dict["melody_bar_len"], param_dict["meta_len"], bar_embedder, rhythm_net, melody_net, generation=generation, compile_now=compile_now) return combined_net
def from_network_parameters(cls, context_size, melody_bar_len, meta_len, bar_embed_params, rhythm_net_params, melody_net_params, generation=False, compile_now=True): bar_embedder = BarEmbedding(*bar_embed_params, compile_now=False) rhythm_net = RhythmNetwork(bar_embedder, *rhythm_net_params, compile_now=False) melody_net = MelodyNetwork(*melody_net_params, compile_now=False) combined_net = cls(context_size, melody_bar_len, meta_len, bar_embedder, rhythm_net, melody_net, generation=generation, compile_now=compile_now) return combined_net
conv_win_size = 3 melody_enc_lstm_size = 52 melody_dec_lstm_1_size = 32 melody_dec_lstm_2_size = 32 meta_data_len = 10 # INDIVIDUAL NETS be = BarEmbedding(V=V_rhythm, beat_embed_size=beat_embed_size, embed_lstm_size=embed_lstm_size, out_size=out_size) rhythm_net = RhythmNetwork(bar_embedder=be, context_size=context_size, enc_lstm_size=rhythm_enc_lstm_size, dec_lstm_size=rhythm_dec_lstm_size, enc_use_meta=False, dec_use_meta=True) melody_net = MelodyNetwork(m=m, V=V_melody, rhythm_embed_size=out_size, conv_f=conv_f, conv_win_size=conv_win_size, enc_lstm_size=melody_enc_lstm_size, dec_lstm_1_size=melody_dec_lstm_1_size, enc_use_meta=False, dec_use_meta=True) print("Individual networks set up...\n")
conv_f = 4 conv_win_size = 3 melody_enc_lstm_size = 52 melody_dec_lstm_size = 32 # INDIVIDUAL NETS bar_embedder = BarEmbedding(V=V_rhythm, beat_embed_size=beat_embed_size, embed_lstm_size=embed_lstm_size, out_size=out_size) rhythm_encoder = RhythmEncoder(bar_embedder=bar_embedder, context_size=rc_size, lstm_size=rhythm_enc_lstm_size) rhythm_net = RhythmNetwork(rhythm_encoder=rhythm_encoder, dec_lstm_size=rhythm_dec_lstm_size, V=V_rhythm, dec_use_meta=True, compile_now=True) # ATTENTION: conv_win_size must not be greater than context size! melody_encoder = MelodyEncoder(m=m, conv_f=conv_f, conv_win_size=min(mc_size, conv_win_size), enc_lstm_size=melody_enc_lstm_size) melody_net = MelodyNetwork(melody_encoder=melody_encoder, rhythm_embed_size=out_size, dec_lstm_size=melody_dec_lstm_size, V=V_melody, dec_use_meta=True, compile_now=True)
#%% with open("bach_beats.pkl", "rb") as handle: data = pickle.load(handle) data = rand.permutation(data) parts = [list(map(tuple, part)) for score in data for part in score] _, label_d = label([b for p in parts for b in p], start=1) pad_symb = "<s>" label_d[pad_symb] = 0 #%% r = RhythmGenerator(4, label_d, "<s>") r_gen = r.generate_data(parts, shuffle=False) #%% rnet = RhythmNetwork(num_categories=len(label_d), embed_size=32, lstm_size=64) #%% rnet.fit_generator(r_gen, steps_per_epoch=len(parts), epochs=1) #%% rnet.save_weights("rnet_weights.h5") #%% rnet2 = RhythmNetwork.from_weights("rnet_weights.h5", num_categories=len(label_d), embed_size=32, lstm_size=64) #%% rnet2.fit_generator(r_gen, steps_per_epoch=len(parts), epochs=1)