Beispiel #1
0
def train_rnn(args, train_dataset, validation_dataset):
    model_save_path = get_path(args, "tf_rnn", create=True)

    rnn = MDNRNN(args=args)
    rnn.compile(optimizer=rnn.optimizer,
                loss=rnn.loss_fn,
                metrics=rnn.get_metrics())

    print("Start training")

    current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_dir = model_save_path / "tensorboard" / current_time

    rnn.fit(train_dataset,
            validation_data=validation_dataset,
            steps_per_epoch=args.rnn_epoch_steps,
            epochs=args.rnn_num_steps // args.rnn_epoch_steps,
            callbacks=[
                tf.keras.callbacks.TensorBoard(log_dir=str(tensorboard_dir),
                                               update_freq=20,
                                               histogram_freq=1,
                                               profile_batch=0),
                tf.keras.callbacks.ModelCheckpoint(str(model_save_path /
                                                       "ckpt-e{epoch:03d}"),
                                                   verbose=1),
            ])

    rnn.save(str(model_save_path))
    print(f"Model saved to {model_save_path}")
Beispiel #2
0
              separators=(',', ': '))


def random_batch():
    indices = np.random.permutation(N_data)[0:args.rnn_batch_size]
    # suboptimal b/c we are always only taking first set of steps
    mu = data_mu[indices][:, :args.rnn_max_seq_len]
    logvar = data_logvar[indices][:, :args.rnn_max_seq_len]
    action = data_action[indices][:, :args.rnn_max_seq_len]
    z = sample_vae(mu, logvar)
    d = tf.cast(data_d[indices], tf.float16)[:, :args.rnn_max_seq_len]
    return z, action, d


rnn = MDNRNN(args=args)
rnn.compile(optimizer=rnn.optimizer, loss=rnn.get_loss())

# train loop:
start = time.time()
step = 0
input_spec = tf.TensorSpec([1, args.max_frames, args.rnn_input_seq_width],
                           tf.float32)
rnn._set_inputs(input_spec)
tf.keras.models.save_model(rnn,
                           model_save_path,
                           include_optimizer=True,
                           save_format='tf')

for step in range(args.rnn_num_steps):
    curr_learning_rate = (
        args.rnn_learning_rate - args.rnn_min_learning_rate) * (