예제 #1
0
파일: bias_main.py 프로젝트: wurde/cwavegan
 def host_call_fn(gs, g_loss, d_loss, real_audio, generated_audio):
   """Training host call. Creates scalar summaries for training metrics.
   This function is executed on the CPU and should not directly reference
   any Tensors in the rest of the `model_fn`. To pass Tensors from the
   model to the `metric_fn`, provide as part of the `host_call`. See
   https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
   for more information.
   Arguments should match the list of `Tensor` objects passed as the second
   element in the tuple passed to `host_call`.
   Args:
     gs: `Tensor with shape `[batch]` for the global_step
     g_loss: `Tensor` with shape `[batch]` for the generator loss.
     d_loss: `Tensor` with shape `[batch]` for the discriminator loss.
     real_audio: `Tensor` with shape `[batch, 8192, 1]`
     generated_audio: `Tensor` with shape `[batch, 8192, 1]`
   Returns:
     List of summary ops to run on the CPU host.
   """
   gs = gs[0]
   with summary.create_file_writer(FLAGS.model_dir).as_default():
       with summary.always_record_summaries():
           summary.scalar('g_loss', g_loss, step=gs)
           summary.scalar('d_loss', d_loss, step=gs)
           summary.audio('real_audio', real_audio, sample_rate=_FS, max_outputs=10, step=gs)
           summary.audio('generated_audio', generated_audio, sample_rate=_FS, max_outputs=10, step=gs)
   return summary.all_summary_ops()
예제 #2
0
 def host_call_fn(gs,
                  loss,
                  lr,
                  mix=None,
                  gt_sources=None,
                  est_sources=None):
     """Training host call. Creates scalar summaries for training metrics.
         This function is executed on the CPU and should not directly reference
         any Tensors in the rest of the `model_fn`. To pass Tensors from the
         model to the `metric_fn`, provide as part of the `host_call`. See
         https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
         for more information.
         Arguments should match the list of `Tensor` objects passed as the second
         element in the tuple passed to `host_call`.
         Args:
           gs: `Tensor with shape `[batch]` for the global_step
           loss: `Tensor` with shape `[batch]` for the training loss.
           lr: `Tensor` with shape `[batch]` for the learning_rate.
           input: `Tensor` with shape `[batch, mix_samples, 1]`
           gt_sources: `Tensor` with shape `[batch, sources_n, output_samples, 1]`
           est_sources: `Tensor` with shape `[batch, sources_n, output_samples, 1]`
         Returns:
           List of summary ops to run on the CPU host.
         """
     gs = gs[0]
     with summary.create_file_writer(
             model_config["model_base_dir"] + os.path.sep +
             str(model_config["experiment_id"])).as_default():
         with summary.always_record_summaries():
             summary.scalar('loss', loss[0], step=gs)
             summary.scalar('learning_rate', lr[0], step=gs)
         if gs % 10000 == 0:
             with summary.record_summaries_every_n_global_steps(
                     model_config["audio_summaries_every_n_steps"]):
                 summary.audio('mix',
                               mix,
                               model_config['expected_sr'],
                               max_outputs=model_config["num_sources"])
                 for source_id in range(gt_sources.shape[1].value):
                     summary.audio('gt_sources_{source_id}'.format(
                         source_id=source_id),
                                   gt_sources[:, source_id, :, :],
                                   model_config['expected_sr'],
                                   max_outputs=model_config["num_sources"])
                     summary.audio('est_sources_{source_id}'.format(
                         source_id=source_id),
                                   est_sources[:, source_id, :, :],
                                   model_config['expected_sr'],
                                   max_outputs=model_config["num_sources"])
     return summary.all_summary_ops()