def from_network_parameters(cls, context_size, melody_bar_len, meta_embed_size, bar_embed_params, rhythm_net_params, melody_net_params, meta_predictor, generation=False, compile_now=True): bar_embedder = BarEmbedding(*bar_embed_params, compile_now=False) rhythm_net = RhythmNetwork(bar_embedder, *rhythm_net_params, compile_now=False) melody_net = MelodyNetwork(*melody_net_params, compile_now=False) combined_net = cls(context_size, melody_bar_len, meta_embed_size, bar_embedder, rhythm_net, melody_net, meta_predictor, generation=generation, compile_now=compile_now) return combined_net
def from_saved_custom(cls, save_dir, meta_predictor, generation=False, compile_now=True): with open(save_dir + "/parameters.json", "r") as handle: param_dict = json.load(handle) bar_embedder = BarEmbedding(*param_dict["bar_embed_params"], compile_now=False) rhythm_net = RhythmNetwork.init_with_Encoder( bar_embedder, *param_dict["rhythm_net_params"], compile_now=False) melody_net = MelodyNetwork.init_with_Encoder( *param_dict["melody_net_params"], compile_now=False) bar_embedder.load_weights(save_dir + "/bar_embedding_weights") rhythm_net.load_weights(save_dir + "/rhythm_net_weights") melody_net.load_weights(save_dir + "/melody_net_weights") # print(param_dict["context_size"], # param_dict["melody_bar_len"], # param_dict["meta_len"], # bar_embedder, # rhythm_net, # melody_net, # meta_predictor, # generation, # compile_now) combined_net = cls(param_dict["context_size"], param_dict["melody_bar_len"], param_dict["meta_embed_size"], bar_embedder, rhythm_net, melody_net, meta_predictor, generation=generation, compile_now=compile_now) return combined_net
context_size = rc_size rhythm_enc_lstm_size = 32 rhythm_dec_lstm_size = 28 #melody params m = 48 V_melody = cg.melody_V conv_f = 4 conv_win_size = 3 melody_enc_lstm_size = 52 melody_dec_lstm_size = 32 # INDIVIDUAL NETS bar_embedder = BarEmbedding(V=V_rhythm, beat_embed_size=beat_embed_size, embed_lstm_size=embed_lstm_size, out_size=out_size) rhythm_encoder = RhythmEncoder(bar_embedder=bar_embedder, context_size=rc_size, lstm_size=rhythm_enc_lstm_size) rhythm_net = RhythmNetwork(rhythm_encoder=rhythm_encoder, dec_lstm_size=rhythm_dec_lstm_size, V=V_rhythm, dec_use_meta=True, compile_now=True) # ATTENTION: conv_win_size must not be greater than context size! melody_encoder = MelodyEncoder(m=m, conv_f=conv_f, conv_win_size=min(mc_size, conv_win_size), enc_lstm_size=melody_enc_lstm_size)