def write_epoch_models(self, mode: str) -> None: with self.tf_summary_writers[mode].as_default(), summary_ops_v2.always_record_summaries(): summary_ops_v2.graph(backend.get_graph(), step=0) for model in self.network.epoch_models: summary_writable = (model.__class__.__name__ == 'Sequential' or (hasattr(model, '_is_graph_network') and model._is_graph_network)) if summary_writable: summary_ops_v2.keras_model(model.model_name, model, step=0)
def keras_model(self, *args, **kwargs): logdir = self.get_temp_dir() writer = summary_ops.create_file_writer(logdir) with writer.as_default(): summary_ops.keras_model(*args, **kwargs) writer.close() events = events_from_logdir(logdir) # The first event contains no summary values. The written content goes to # the second event. return events[1]
def write_graph(self, model: k.Model): """Sets Keras model and writes graph if specified.""" if model and self.is_write_graph: with self.writer.as_default(), summary_ops_v2.always_record_summaries(): if not model.run_eagerly: summary_ops_v2.graph(get_graph(), step=0) summary_writable = ( model._is_graph_network or # pylint: disable=protected-access model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access if summary_writable: summary_ops_v2.keras_model('keras', model, step=0)
def on_begin(self, state): if self.write_graph: with self.summary_writers['train'].as_default(): with summary_ops_v2.always_record_summaries(): summary_ops_v2.graph(backend.get_graph(), step=0) for name, model in self.network.model.items(): summary_writable = (model._is_graph_network or model.__class__.__name__ == 'Sequential') if summary_writable: summary_ops_v2.keras_model(name, model, step=0) if self.embeddings_freq: self._configure_embeddings()
def write_keras_graph(self, model: Union[Model, Sequential], step: int = 0, name: str = "keras"): r""" Writes Keras graph networks to TensorBoard. """ with self.summary_writer.as_default(): with summary_ops_v2.always_record_summaries(): if not model.run_eagerly: summary_ops_v2.graph(keras.backend.get_graph(), step=step) summary_writable = (model._is_graph_network or model.__class__.__name__ == 'Sequential') if summary_writable: summary_ops_v2.keras_model(name=str(name), data=model, step=step) return self
def write_model_to_tensorboard(self, model: Model): """ Write the given model as a graph in tensorboard. :param model: The model to write to tensorboard. """ with self._file_writer.as_default(): if tf.__version__ == "2.4.1": with summary_ops_v2.always_record_summaries(): summary_ops_v2.keras_model(name=model.name, data=model, step=0) elif tf.__version__ == "2.5.0": from tensorflow.python.keras.callbacks import keras_model_summary with summary_ops_v2.record_if(True): keras_model_summary("keras", model, step=0)
def testKerasModel_otherExceptions(self): model = Sequential() with test.mock.patch.object(model, 'to_json') as mock_to_json: with test.mock.patch.object(logging, 'warn') as mock_log: mock_to_json.side_effect = Exception('oops') self.assertFalse( summary_ops.keras_model(name='my_name', data=model, step=1)) self.assertRegexpMatches( str(mock_log.call_args), 'Model failed to serialize as JSON. Ignoring... oops')
def testKerasModel_otherExceptions(self): model = Sequential() with test.mock.patch.object(model, 'to_json') as mock_to_json: with test.mock.patch.object(logging, 'warn') as mock_log: mock_to_json.side_effect = Exception('oops') self.assertFalse( summary_ops.keras_model(name='my_name', data=model, step=1)) self.assertRegex( str(mock_log.call_args), 'Model failed to serialize as JSON. Ignoring... oops')
def set_model(self, model): """Sets Keras model and writes graph if specified.""" self.model = model.model with context.eager_mode(): self._close_writers() if self.write_graph: with self._get_writer(self._train_run_name).as_default(): with summary_ops_v2.always_record_summaries(): if not self.model.run_eagerly: summary_ops_v2.graph(K.get_graph(), step=0) summary_writable = ( self.model._is_graph_network or # pylint: disable=protected-access self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access if summary_writable: summary_ops_v2.keras_model('keras', self.model, step=0) if self.embeddings_freq: self._configure_embeddings()
def testKerasModel_subclass(self): class SimpleSubclass(Model): def __init__(self): super(SimpleSubclass, self).__init__(name='subclass') self.dense = Dense(10, input_shape=(100, )) self.activation = Activation('relu', name='my_relu') def call(self, inputs): x = self.dense(inputs) return self.activation(x) model = SimpleSubclass() with test.mock.patch.object(logging, 'warn') as mock_log: self.assertFalse( summary_ops.keras_model(name='my_name', data=model, step=1)) self.assertRegexpMatches(str(mock_log.call_args), 'Model failed to serialize as JSON.')
def testKerasModel_subclass(self): class SimpleSubclass(Model): def __init__(self): super(SimpleSubclass, self).__init__(name='subclass') self.dense = Dense(10, input_shape=(100,)) self.activation = Activation('relu', name='my_relu') def call(self, inputs): x = self.dense(inputs) return self.activation(x) model = SimpleSubclass() with test.mock.patch.object(logging, 'warn') as mock_log: self.assertFalse( summary_ops.keras_model(name='my_name', data=model, step=1)) self.assertRegexpMatches( str(mock_log.call_args), 'Model failed to serialize as JSON.')