示例#1
0
def create_model_and_train(train_segments,
                           test_segments,
                           x_depth,
                           batch_size,
                           enc_rnn_dim,
                           dec_rnn_dim,
                           enc_dropout,
                           dec_dropout,
                           cont_dim,
                           cat_dim,
                           mu_force,
                           t_gumbel,
                           style_embed_dim,
                           kl_reg,
                           beta_anneal_steps,
                           rnn_type,
                           attention,
                           save_path,
                           start_epoch,
                           final_epoch,
                           weights=None):

    train_iterator = load_noteseqs(train_segments, x_depth,
                                   batch_size).get_iterator()
    test_iterator = load_noteseqs(test_segments, x_depth,
                                  batch_size).get_iterator()

    vae = MVAE(x_depth=x_depth,
               enc_rnn_dim=enc_rnn_dim,
               enc_dropout=enc_dropout,
               dec_rnn_dim=dec_rnn_dim,
               dec_dropout=dec_dropout,
               cont_dim=cont_dim,
               cat_dim=cat_dim,
               mu_force=mu_force,
               t_gumbel=t_gumbel,
               style_embed_dim=style_embed_dim,
               kl_reg=kl_reg,
               beta_anneal_steps=beta_anneal_steps,
               rnn_type=rnn_type,
               attention=attention)

    optimizer = tfk.optimizers.Adam(learning_rate=5e-4)
    vae.compile(optimizer=optimizer)
    vae.run_eagerly = True

    save_path = save_path

    if (os.path.exists(save_path) == False):
        os.makedirs(save_path)

    callbacks = [
        tfk.callbacks.LambdaCallback(
            on_epoch_end=lambda epoch, _: generate_and_save_samples(
                vae, epoch, save_path, cat_dim)),
        tfk.callbacks.LambdaCallback(
            on_epoch_start=lambda epoch, _: vae.reset_trackers()),
        tfk.callbacks.EarlyStopping(monitor='val_p_acc',
                                    min_delta=0.01,
                                    patience=5,
                                    mode='max'),
        tfk.callbacks.CSVLogger(save_path + 'log.csv', append=True),
        tfk.callbacks.ModelCheckpoint(save_path + 'weights/' +
                                      '/weights.{epoch:02d}',
                                      monitor='val_p_acc',
                                      save_weights_only=True,
                                      save_best_only=True,
                                      mode='max'),
        tfk.callbacks.TensorBoard(log_dir=save_path,
                                  write_graph=True,
                                  update_freq='epoch',
                                  histogram_freq=40,
                                  profile_batch='10,20')
    ]

    if weights != None:
        vae.load_weights(save_path + weights)

    history = vae.fit(train_iterator,
                      epochs=final_epoch,
                      initial_epoch=start_epoch,
                      callbacks=callbacks,
                      validation_data=test_iterator)
    vae.save_weights(save_path + 'weights/weights-final')

    return history
示例#2
0
with open(save_path + 'model.txt', 'w') as f:
    vae.model().summary(print_fn=lambda x: f.write(x + '\n'))

# register handler for Ctrl+C in order to save final weights
signal.signal(signal.SIGINT, signal_handler)

# print('loading weights')
# vae.load_weights('gru_jsb_nmd/weights/weights-final')

callbacks = [
    tfk.callbacks.LambdaCallback(
        on_epoch_end=lambda epoch, _: generate_and_save_samples(
            vae, epoch, save_path, int(args["cat_dim"]))),
    tfk.callbacks.LambdaCallback(
        on_epoch_start=lambda epoch, _: vae.reset_trackers()),
    tfk.callbacks.CSVLogger(save_path + 'log.csv', append=True),
    tfk.callbacks.ModelCheckpoint(save_path + 'weights/' +
                                  '/weights.{epoch:02d}',
                                  monitor='val_p_acc',
                                  save_weights_only=True,
                                  save_best_only=True,
                                  mode='max'),
    tfk.callbacks.TensorBoard(log_dir=save_path,
                              write_graph=True,
                              update_freq='epoch',
                              histogram_freq=40,
                              profile_batch='10,20')
]

history = vae.fit(train_iterator,