def visualize(summary_dir, global_step, episodes, predictions, scalars): """Visualizes the episodes in TensorBoard.""" if tf.executing_eagerly(): writer = tfs.create_file_writer(summary_dir) with writer.as_default(): videos = np.stack([e["image"] for e in episodes]) video_summary = visualization.py_gif_summary(tag="episodes/video", images=videos, max_outputs=20, fps=20) tfs.experimental.write_raw_pb(video_summary, step=global_step) for k in scalars: tfs.scalar(name="episodes/%s" % k, data=scalars[k], step=global_step) if "image" in predictions[0]: videos = np.stack([e["image"] for e in predictions]) video_summary = visualization.py_gif_summary( tag="episodes/video_prediction", images=videos, max_outputs=6, fps=20) tfs.experimental.write_raw_pb(video_summary, step=global_step) if "reward" in predictions[0]: rewards = np.stack([e["reward"][1:] for e in episodes]) predicted_rewards = np.stack( [p["reward"] for p in predictions]) signals = np.stack([rewards, predicted_rewards], axis=1) signals = signals[:, :, :, 0] visualization.py_plot_1d_signal( name="episodes/reward", signals=signals, labels=["reward", "prediction"], max_outputs=6, step=global_step) reward_dir = os.path.join(summary_dir, "rewards") rewards_to_save = {"true": rewards, "pred": predicted_rewards} npz.save_dictionary(rewards_to_save, reward_dir) else: summary_writer = tf.summary.FileWriter(summary_dir) for k in scalars: s = tf.Summary() s.value.add(tag="episodes/" + k, simple_value=scalars[k]) summary_writer.add_summary(s, global_step) videos = np.stack([e["image"] for e in episodes]) video_summary = visualization.py_gif_summary(tag="episodes/video", images=videos, max_outputs=20, fps=30) summary_writer.add_summary(video_summary, global_step) summary_writer.flush()
def get_summary_writer(summary_dir: str) -> SummaryWriter: """Context manager around Tensorflow's SummaryWriter.""" if jax.process_index() == 0: logging.info('Opening SummaryWriter `%s`...', summary_dir) summary_writer = tf_summary.create_file_writer(summary_dir) else: # We create a dummy tf.summary.SummaryWriter() on non-zero tasks. This will # return a mock object, which acts like a summary writer, but does nothing, # such as writing event to disk. logging.info('Opening a mock-like SummaryWriter.') summary_writer = tf_summary.create_noop_writer() try: yield summary_writer finally: summary_writer.close() if jax.process_index() == 0: logging.info('Closed SummaryWriter `%s`.', summary_dir) else: logging.info('Closed a mock-like SummaryWriter.')
def host_call_fn(**kwargs): """Host_call_fn. Args: **kwargs: dict of summary name to tf.Tensor mapping. The value we see here is the tensor across all cores, concatenated along axis 0. This function will take make a scalar summary that is the mean of the whole tensor (as all the values are the same - the mean, trait of tpu.CrossShardOptimizer). Returns: A merged summary op. """ gs = kwargs.pop('global_step')[0] with tf_summary.create_file_writer(model_dir).as_default(): with tf_summary.record_if(tf.equal(gs % 10, 0)): for name, tensor in kwargs.items(): # Take the mean across cores. tensor = tf.reduce_mean(tensor) tf_summary.scalar(name, tensor, step=gs) return tf.summary.all_v2_summary_ops()
def host_call_fn(model_dir, **kwargs): """host_call function used for creating training summaries when using TPU. Args: model_dir: String indicating the output_dir to save summaries in. **kwargs: Set of metric names and tensor values for all desired summaries. Returns: Summary op to be passed to the host_call arg of the estimator function. """ gs = kwargs.pop('global_step')[0] with summary.create_file_writer(model_dir).as_default(): # Always record summaries. with summary.record_if(True): for name, tensor in kwargs.items(): if name.startswith(IMG_SUMMARY_PREFIX): summary.image(name.replace(IMG_SUMMARY_PREFIX, ''), tensor, max_images=1) else: summary.scalar(name, tensor[0], step=gs) # Following function is under tf:1x, so we use it. return tf.summary.all_v2_summary_ops()
def train_fn(data_path): """A train_fn to train the planet model.""" nonlocal iterator nonlocal optimizer with strategy.scope(): global_step = tf.Variable(0, dtype=tf.int64, trainable=False) checkpoint = tf.train.Checkpoint(global_step=global_step, optimizer=optimizer, **model.get_trackables()) manager = tf.train.CheckpointManager(checkpoint, model_dir, max_to_keep=1) checkpoint.restore(manager.latest_checkpoint) if iterator is None: dataset = npz.load_dataset_from_directory(data_path, duration, batch) dataset = strategy.experimental_distribute_dataset(dataset) iterator = dataset writer = tfs.create_file_writer(model_dir) tfs.experimental.set_step(global_step) true_rewards, pred_rewards = None, None with writer.as_default(): for step, obs in enumerate(iterator): if step > train_steps: if save_rewards: # We are only saving the last training batch. reward_dir = os.path.join(model_dir, 'train_rewards') true_rewards = strategy.experimental_local_results( true_rewards) pred_reward = strategy.experimental_local_results( pred_rewards) true_rewards = np.concatenate( [x.numpy() for x in true_rewards]) pred_reward = np.concatenate( [x.numpy() for x in pred_reward]) rewards_to_save = { 'true': true_rewards, 'pred': pred_reward } npz.save_dictionary(rewards_to_save, reward_dir) break (loss, reward_loss, divergence, frames, pred_rewards, true_rewards, frame_loss) = train_step(obs) if step % 100 == 0: loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss) reward_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, reward_loss) divergence = strategy.reduce(tf.distribute.ReduceOp.MEAN, divergence) frame_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, frame_loss) frames = strategy.experimental_local_results(frames) frames = tf.concat(frames, axis=0) pred_reward = strategy.experimental_local_results( pred_rewards) pred_reward = tf.concat(pred_reward, axis=0) tf.logging.info('loss at step %d: %f', step, loss) tfs.scalar('loss/total', loss) tfs.scalar('loss/reward', reward_loss) tfs.scalar('loss/divergence', divergence) tfs.scalar('loss/frames', frame_loss) tfs.experimental.write_raw_pb( visualization.py_gif_summary(tag='predictions/frames', images=frames.numpy(), max_outputs=6, fps=20)) ground_truth_rewards = (tf.concat( strategy.experimental_local_results(obs['reward']), axis=0)[:, :, 0]) rewards = pred_reward[:, :, 0] signals = tf.stack([ground_truth_rewards, rewards], axis=1) visualization.py_plot_1d_signal( name='predictions/reward', signals=signals.numpy(), labels=['ground_truth', 'prediction'], max_outputs=6) global_step.assign_add(1) manager.save(global_step)