Exemple #1
0
 def write_epoch_models(self, mode: str) -> None:
     with self.tf_summary_writers[mode].as_default(), summary_ops_v2.always_record_summaries():
         summary_ops_v2.graph(backend.get_graph(), step=0)
         for model in self.network.epoch_models:
             summary_writable = (model.__class__.__name__ == 'Sequential'
                                 or (hasattr(model, '_is_graph_network') and model._is_graph_network))
             if summary_writable:
                 summary_ops_v2.keras_model(model.model_name, model, step=0)
 def keras_model(self, *args, **kwargs):
     logdir = self.get_temp_dir()
     writer = summary_ops.create_file_writer(logdir)
     with writer.as_default():
         summary_ops.keras_model(*args, **kwargs)
     writer.close()
     events = events_from_logdir(logdir)
     # The first event contains no summary values. The written content goes to
     # the second event.
     return events[1]
 def keras_model(self, *args, **kwargs):
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(logdir)
   with writer.as_default():
     summary_ops.keras_model(*args, **kwargs)
   writer.close()
   events = events_from_logdir(logdir)
   # The first event contains no summary values. The written content goes to
   # the second event.
   return events[1]
  def write_graph(self, model: k.Model):
    """Sets Keras model and writes graph if specified."""
    if model and self.is_write_graph:
      with self.writer.as_default(), summary_ops_v2.always_record_summaries():
        if not model.run_eagerly:
          summary_ops_v2.graph(get_graph(), step=0)

        summary_writable = (
            model._is_graph_network or  # pylint: disable=protected-access
            model.__class__.__name__ == 'Sequential')  # pylint: disable=protected-access
        if summary_writable:
          summary_ops_v2.keras_model('keras', model, step=0)
Exemple #5
0
 def on_begin(self, state):
     if self.write_graph:
         with self.summary_writers['train'].as_default():
             with summary_ops_v2.always_record_summaries():
                 summary_ops_v2.graph(backend.get_graph(), step=0)
                 for name, model in self.network.model.items():
                     summary_writable = (model._is_graph_network
                                         or model.__class__.__name__
                                         == 'Sequential')
                     if summary_writable:
                         summary_ops_v2.keras_model(name, model, step=0)
     if self.embeddings_freq:
         self._configure_embeddings()
Exemple #6
0
 def write_keras_graph(self,
                       model: Union[Model, Sequential],
                       step: int = 0,
                       name: str = "keras"):
   r""" Writes Keras graph networks to TensorBoard. """
   with self.summary_writer.as_default():
     with summary_ops_v2.always_record_summaries():
       if not model.run_eagerly:
         summary_ops_v2.graph(keras.backend.get_graph(), step=step)
       summary_writable = (model._is_graph_network or
                           model.__class__.__name__ == 'Sequential')
       if summary_writable:
         summary_ops_v2.keras_model(name=str(name), data=model, step=step)
   return self
    def write_model_to_tensorboard(self, model: Model):
        """
        Write the given model as a graph in tensorboard.

        :param model: The model to write to tensorboard.
        """
        with self._file_writer.as_default():
            if tf.__version__ == "2.4.1":
                with summary_ops_v2.always_record_summaries():
                    summary_ops_v2.keras_model(name=model.name,
                                               data=model,
                                               step=0)
            elif tf.__version__ == "2.5.0":
                from tensorflow.python.keras.callbacks import keras_model_summary

                with summary_ops_v2.record_if(True):
                    keras_model_summary("keras", model, step=0)
  def testKerasModel_otherExceptions(self):
    model = Sequential()

    with test.mock.patch.object(model, 'to_json') as mock_to_json:
      with test.mock.patch.object(logging, 'warn') as mock_log:
        mock_to_json.side_effect = Exception('oops')
        self.assertFalse(
            summary_ops.keras_model(name='my_name', data=model, step=1))
        self.assertRegexpMatches(
            str(mock_log.call_args),
            'Model failed to serialize as JSON. Ignoring... oops')
Exemple #9
0
  def testKerasModel_otherExceptions(self):
    model = Sequential()

    with test.mock.patch.object(model, 'to_json') as mock_to_json:
      with test.mock.patch.object(logging, 'warn') as mock_log:
        mock_to_json.side_effect = Exception('oops')
        self.assertFalse(
            summary_ops.keras_model(name='my_name', data=model, step=1))
        self.assertRegex(
            str(mock_log.call_args),
            'Model failed to serialize as JSON. Ignoring... oops')
    def set_model(self, model):
        """Sets Keras model and writes graph if specified."""
        self.model = model.model
        with context.eager_mode():
            self._close_writers()
            if self.write_graph:
                with self._get_writer(self._train_run_name).as_default():
                    with summary_ops_v2.always_record_summaries():
                        if not self.model.run_eagerly:
                            summary_ops_v2.graph(K.get_graph(), step=0)

                        summary_writable = (
                            self.model._is_graph_network or  # pylint: disable=protected-access
                            self.model.__class__.__name__ == 'Sequential')  # pylint: disable=protected-access
                        if summary_writable:
                            summary_ops_v2.keras_model('keras',
                                                       self.model,
                                                       step=0)

        if self.embeddings_freq:
            self._configure_embeddings()
    def testKerasModel_subclass(self):
        class SimpleSubclass(Model):
            def __init__(self):
                super(SimpleSubclass, self).__init__(name='subclass')
                self.dense = Dense(10, input_shape=(100, ))
                self.activation = Activation('relu', name='my_relu')

            def call(self, inputs):
                x = self.dense(inputs)
                return self.activation(x)

        model = SimpleSubclass()
        with test.mock.patch.object(logging, 'warn') as mock_log:
            self.assertFalse(
                summary_ops.keras_model(name='my_name', data=model, step=1))
            self.assertRegexpMatches(str(mock_log.call_args),
                                     'Model failed to serialize as JSON.')
  def testKerasModel_subclass(self):

    class SimpleSubclass(Model):

      def __init__(self):
        super(SimpleSubclass, self).__init__(name='subclass')
        self.dense = Dense(10, input_shape=(100,))
        self.activation = Activation('relu', name='my_relu')

      def call(self, inputs):
        x = self.dense(inputs)
        return self.activation(x)

    model = SimpleSubclass()
    with test.mock.patch.object(logging, 'warn') as mock_log:
      self.assertFalse(
          summary_ops.keras_model(name='my_name', data=model, step=1))
      self.assertRegexpMatches(
          str(mock_log.call_args), 'Model failed to serialize as JSON.')