def create_model_and_train(train_segments, test_segments, x_depth, batch_size, enc_rnn_dim, dec_rnn_dim, enc_dropout, dec_dropout, cont_dim, cat_dim, mu_force, t_gumbel, style_embed_dim, kl_reg, beta_anneal_steps, rnn_type, attention, save_path, start_epoch, final_epoch, weights=None): train_iterator = load_noteseqs(train_segments, x_depth, batch_size).get_iterator() test_iterator = load_noteseqs(test_segments, x_depth, batch_size).get_iterator() vae = MVAE(x_depth=x_depth, enc_rnn_dim=enc_rnn_dim, enc_dropout=enc_dropout, dec_rnn_dim=dec_rnn_dim, dec_dropout=dec_dropout, cont_dim=cont_dim, cat_dim=cat_dim, mu_force=mu_force, t_gumbel=t_gumbel, style_embed_dim=style_embed_dim, kl_reg=kl_reg, beta_anneal_steps=beta_anneal_steps, rnn_type=rnn_type, attention=attention) optimizer = tfk.optimizers.Adam(learning_rate=5e-4) vae.compile(optimizer=optimizer) vae.run_eagerly = True save_path = save_path if (os.path.exists(save_path) == False): os.makedirs(save_path) callbacks = [ tfk.callbacks.LambdaCallback( on_epoch_end=lambda epoch, _: generate_and_save_samples( vae, epoch, save_path, cat_dim)), tfk.callbacks.LambdaCallback( on_epoch_start=lambda epoch, _: vae.reset_trackers()), tfk.callbacks.EarlyStopping(monitor='val_p_acc', min_delta=0.01, patience=5, mode='max'), tfk.callbacks.CSVLogger(save_path + 'log.csv', append=True), tfk.callbacks.ModelCheckpoint(save_path + 'weights/' + '/weights.{epoch:02d}', monitor='val_p_acc', save_weights_only=True, save_best_only=True, mode='max'), tfk.callbacks.TensorBoard(log_dir=save_path, write_graph=True, update_freq='epoch', histogram_freq=40, profile_batch='10,20') ] if weights != None: vae.load_weights(save_path + weights) history = vae.fit(train_iterator, epochs=final_epoch, initial_epoch=start_epoch, callbacks=callbacks, validation_data=test_iterator) vae.save_weights(save_path + 'weights/weights-final') return history
# print('loading weights') # vae.load_weights('gru_jsb_nmd/weights/weights-final') callbacks = [ tfk.callbacks.LambdaCallback( on_epoch_end=lambda epoch, _: generate_and_save_samples( vae, epoch, save_path, int(args["cat_dim"]))), tfk.callbacks.LambdaCallback( on_epoch_start=lambda epoch, _: vae.reset_trackers()), tfk.callbacks.CSVLogger(save_path + 'log.csv', append=True), tfk.callbacks.ModelCheckpoint(save_path + 'weights/' + '/weights.{epoch:02d}', monitor='val_p_acc', save_weights_only=True, save_best_only=True, mode='max'), tfk.callbacks.TensorBoard(log_dir=save_path, write_graph=True, update_freq='epoch', histogram_freq=40, profile_batch='10,20') ] history = vae.fit(train_iterator, epochs=int(args["epochs"]), callbacks=callbacks, validation_data=test_iterator) # history = vae.fit(train_iterator, epochs=1001, initial_epoch=800, callbacks=callbacks, validation_data=test_iterator) vae.save_weights(save_path + 'weights/weights-final')