axis=1) # zero pad the end but we don't actually use it z_mask = 1.0 - raw_d z_targ = tf.concat([z_targ, z_mask], axis=2) # use a signal to not pass grad if args.env_name == 'CarRacing-v0': outputs = z_targ else: d_mask = tf.concat([ tf.ones([args.rnn_batch_size, 1, 1], dtype=tf.float16), 1.0 - raw_d[:, :-1, :] ], axis=1) d_targ = tf.concat([raw_d, d_mask], axis=2) outputs = [z_targ, d_targ] loss = rnn.train_on_batch(x=inputs, y=outputs) if (step % 20 == 0 and step > 0): end = time.time() time_taken = end - start start = time.time() if args.env_name == 'CarRacing-v0': output_log = "step: %d, lr: %.6f, loss: %.4f, train_time_taken: %.4f" % ( step, curr_learning_rate, loss, time_taken) else: output_log = "step: %d, lr: %.6f, loss: %.4f, z_loss: %.4f, d_loss: %.4f, train_time_taken: %.4f" % ( step, curr_learning_rate, loss[0], loss[1], loss[2], time_taken) print(output_log) tf.keras.models.save_model(rnn,
rnn._set_inputs(inputs) dummy_zero = tf.zeros([raw_z.shape[0], 1, raw_z.shape[2]], dtype=tf.float16) z_targ = tf.concat([raw_z[:, 1:, :], dummy_zero], axis=1) # zero pad the end but we don't actually use it z_mask = 1.0 - raw_d z_targ = tf.concat([z_targ, z_mask], axis=2) # use a signal to not pass grad outputs = {'MDN': z_targ} if args.rnn_r_pred == 1: r_mask = tf.concat([tf.ones([args.rnn_batch_size, 1, 1], dtype=tf.float16), 1.0 - raw_d[:, :-1, :]], axis=1) r_targ = tf.concat([raw_r, r_mask], axis=2) outputs['r'] = r_targ if args.rnn_d_pred == 1: d_mask = tf.concat([tf.ones([args.rnn_batch_size, 1, 1], dtype=tf.float16), 1.0 - raw_d[:, :-1, :]], axis=1) d_targ = tf.concat([raw_d, d_mask], axis=2) outputs['d'] = d_targ loss = rnn.train_on_batch(x=inputs, y=outputs, return_dict=True) [tf.summary.scalar(loss_key, loss_val, step=step) for loss_key, loss_val in loss.items()] if (step%20==0 and step > 0): end = time.time() time_taken = end-start start = time.time() output_log = "step: %d, train_time_taken: %.4f, lr: %.6f" % (step, time_taken, curr_learning_rate) for loss_key, loss_val in loss.items(): output_log += ', {}: {:.4f}'.format(loss_key, loss_val) print(output_log) if (step%1000==0 and step > 0): tf.keras.models.save_model(rnn, model_save_path, include_optimizer=True, save_format='tf') step += 1