def host_call_fn(gs, g_loss, d_loss, real_audio, generated_audio): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step g_loss: `Tensor` with shape `[batch]` for the generator loss. d_loss: `Tensor` with shape `[batch]` for the discriminator loss. real_audio: `Tensor` with shape `[batch, 8192, 1]` generated_audio: `Tensor` with shape `[batch, 8192, 1]` Returns: List of summary ops to run on the CPU host. """ gs = gs[0] with summary.create_file_writer(FLAGS.model_dir).as_default(): with summary.always_record_summaries(): summary.scalar('g_loss', g_loss, step=gs) summary.scalar('d_loss', d_loss, step=gs) summary.audio('real_audio', real_audio, sample_rate=_FS, max_outputs=10, step=gs) summary.audio('generated_audio', generated_audio, sample_rate=_FS, max_outputs=10, step=gs) return summary.all_summary_ops()
def host_call_fn(gs, loss, lr, mix=None, gt_sources=None, est_sources=None): """Training host call. Creates scalar summaries for training metrics. This function is executed on the CPU and should not directly reference any Tensors in the rest of the `model_fn`. To pass Tensors from the model to the `metric_fn`, provide as part of the `host_call`. See https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec for more information. Arguments should match the list of `Tensor` objects passed as the second element in the tuple passed to `host_call`. Args: gs: `Tensor with shape `[batch]` for the global_step loss: `Tensor` with shape `[batch]` for the training loss. lr: `Tensor` with shape `[batch]` for the learning_rate. input: `Tensor` with shape `[batch, mix_samples, 1]` gt_sources: `Tensor` with shape `[batch, sources_n, output_samples, 1]` est_sources: `Tensor` with shape `[batch, sources_n, output_samples, 1]` Returns: List of summary ops to run on the CPU host. """ gs = gs[0] with summary.create_file_writer( model_config["model_base_dir"] + os.path.sep + str(model_config["experiment_id"])).as_default(): with summary.always_record_summaries(): summary.scalar('loss', loss[0], step=gs) summary.scalar('learning_rate', lr[0], step=gs) if gs % 10000 == 0: with summary.record_summaries_every_n_global_steps( model_config["audio_summaries_every_n_steps"]): summary.audio('mix', mix, model_config['expected_sr'], max_outputs=model_config["num_sources"]) for source_id in range(gt_sources.shape[1].value): summary.audio('gt_sources_{source_id}'.format( source_id=source_id), gt_sources[:, source_id, :, :], model_config['expected_sr'], max_outputs=model_config["num_sources"]) summary.audio('est_sources_{source_id}'.format( source_id=source_id), est_sources[:, source_id, :, :], model_config['expected_sr'], max_outputs=model_config["num_sources"]) return summary.all_summary_ops()