示例#1
0
    logvar = data_logvar[indices][:, :args.rnn_max_seq_len]
    action = data_action[indices][:, :args.rnn_max_seq_len]
    z = sample_vae(mu, logvar)
    d = tf.cast(data_d[indices], tf.float16)[:, :args.rnn_max_seq_len]
    return z, action, d


rnn = MDNRNN(args=args)
rnn.compile(optimizer=rnn.optimizer, loss=rnn.get_loss())

# train loop:
start = time.time()
step = 0
input_spec = tf.TensorSpec([1, args.max_frames, args.rnn_input_seq_width],
                           tf.float32)
rnn._set_inputs(input_spec)
tf.keras.models.save_model(rnn,
                           model_save_path,
                           include_optimizer=True,
                           save_format='tf')

for step in range(args.rnn_num_steps):
    curr_learning_rate = (
        args.rnn_learning_rate - args.rnn_min_learning_rate) * (
            args.rnn_decay_rate)**step + args.rnn_min_learning_rate
    rnn.optimizer.learning_rate = curr_learning_rate

    raw_z, raw_a, raw_d = random_batch()
    inputs = tf.concat([raw_z, raw_a], axis=2)

    dummy_zero = tf.zeros([raw_z.shape[0], 1, raw_z.shape[2]],
rnn = MDNRNN(args=args)
rnn.compile(optimizer=rnn.optimizer, loss=rnn.get_loss())
tensorboard_callback.set_model(rnn)

# train loop:
start = time.time()
step = 0
for raw_z, raw_a, raw_r, raw_d, raw_N in dataset:
    curr_learning_rate = (args.rnn_learning_rate-args.rnn_min_learning_rate) * (args.rnn_decay_rate) ** step + args.rnn_min_learning_rate
    rnn.optimizer.learning_rate = curr_learning_rate
    
    inputs = tf.concat([raw_z, raw_a], axis=2)

    if step == 0:
        rnn._set_inputs(inputs)

    dummy_zero = tf.zeros([raw_z.shape[0], 1, raw_z.shape[2]], dtype=tf.float16)
    z_targ = tf.concat([raw_z[:, 1:, :], dummy_zero], axis=1) # zero pad the end but we don't actually use it
    z_mask = 1.0 - raw_d
    z_targ = tf.concat([z_targ, z_mask], axis=2) # use a signal to not pass grad

    outputs = {'MDN': z_targ}
    if args.rnn_r_pred == 1:
        r_mask = tf.concat([tf.ones([args.rnn_batch_size, 1, 1], dtype=tf.float16), 1.0 - raw_d[:, :-1, :]], axis=1)
        r_targ = tf.concat([raw_r, r_mask], axis=2)
        outputs['r'] = r_targ
    if args.rnn_d_pred == 1:
        d_mask = tf.concat([tf.ones([args.rnn_batch_size, 1, 1], dtype=tf.float16), 1.0 - raw_d[:, :-1, :]], axis=1)
        d_targ = tf.concat([raw_d, d_mask], axis=2)
        outputs['d'] = d_targ