enc_dropout=params['enc_dropout'], dec_rnn_dim=params['dec_rnn_dim'], dec_dropout=params['dec_dropout'], cont_dim=params['cont_dim'], cat_dim=args.cat_dim, mu_force=1.3, t_gumbel=params['t_gumbel'], style_embed_dim=params['style_dim'], kl_reg=params['kl_reg'], beta_anneal_steps=params['beta_anneal_steps'], rnn_type=params['rnn_type'], attention=params['attention']) if args.restore_path: print('restoring weights from ' + args.restore_path) vae.load_weights(args.restore_path) train_segments = [] test_segments = [] datasets = [ "datasets/dataset-8/JSB-8-raw.pickle", "datasets/dataset-8/NMD-8-raw.pickle" ] keep_pcts = [1, 0.23] master_pct = 1 for (dataset, pct) in zip(datasets, keep_pcts): segments = joblib.load(dataset) test_size = 0.1 * master_pct train_size = (1 - test_size) * master_pct Xtr, Xte = train_test_split(segments, train_size=train_size * pct,
def create_model_and_train(train_segments, test_segments, x_depth, batch_size, enc_rnn_dim, dec_rnn_dim, enc_dropout, dec_dropout, cont_dim, cat_dim, mu_force, t_gumbel, style_embed_dim, kl_reg, beta_anneal_steps, rnn_type, attention, save_path, start_epoch, final_epoch, weights=None): train_iterator = load_noteseqs(train_segments, x_depth, batch_size).get_iterator() test_iterator = load_noteseqs(test_segments, x_depth, batch_size).get_iterator() vae = MVAE(x_depth=x_depth, enc_rnn_dim=enc_rnn_dim, enc_dropout=enc_dropout, dec_rnn_dim=dec_rnn_dim, dec_dropout=dec_dropout, cont_dim=cont_dim, cat_dim=cat_dim, mu_force=mu_force, t_gumbel=t_gumbel, style_embed_dim=style_embed_dim, kl_reg=kl_reg, beta_anneal_steps=beta_anneal_steps, rnn_type=rnn_type, attention=attention) optimizer = tfk.optimizers.Adam(learning_rate=5e-4) vae.compile(optimizer=optimizer) vae.run_eagerly = True save_path = save_path if (os.path.exists(save_path) == False): os.makedirs(save_path) callbacks = [ tfk.callbacks.LambdaCallback( on_epoch_end=lambda epoch, _: generate_and_save_samples( vae, epoch, save_path, cat_dim)), tfk.callbacks.LambdaCallback( on_epoch_start=lambda epoch, _: vae.reset_trackers()), tfk.callbacks.EarlyStopping(monitor='val_p_acc', min_delta=0.01, patience=5, mode='max'), tfk.callbacks.CSVLogger(save_path + 'log.csv', append=True), tfk.callbacks.ModelCheckpoint(save_path + 'weights/' + '/weights.{epoch:02d}', monitor='val_p_acc', save_weights_only=True, save_best_only=True, mode='max'), tfk.callbacks.TensorBoard(log_dir=save_path, write_graph=True, update_freq='epoch', histogram_freq=40, profile_batch='10,20') ] if weights != None: vae.load_weights(save_path + weights) history = vae.fit(train_iterator, epochs=final_epoch, initial_epoch=start_epoch, callbacks=callbacks, validation_data=test_iterator) vae.save_weights(save_path + 'weights/weights-final') return history