def _add_weight(self, name, initial_value, dtype=None): """Adds a weight to this loss scale. Args: name: Variable name. initial_value: The variable's initial value. dtype: The type of the variable. Returns: A variable. Raises: RuntimeError: If a weight with `name` has already been added. """ variable = tf.Variable( initial_value=initial_value, name=name, dtype=dtype, trainable=False, synchronization=tf.VariableSynchronization.AUTO, # Set aggregation to NONE, as loss scaling variables should never be # aggregated. aggregation=tf.VariableAggregation.NONE) if tf.executing_eagerly(): graph_key = None else: graph = tf.compat.v1.get_default_graph() graph_key = graph._graph_key # pylint: disable=protected-access key = (name, graph_key) self._weights[key] = variable self._handle_deferred_dependencies(name=name, trackable=variable) backend.track_variable(variable) return variable
def _finalize_config_layers(layers): """Runs the final steps of loading Keras Layers from config.""" for layer in layers: # It is assumed that layers define their unconditional losses after being # recreated from the config and built. The exceptions to this # are Functional and Sequential models, which only store conditional losses # (losses dependent on the inputs) in the config. Unconditional losses like # weight regularization must be revived from the SavedModel. if _is_graph_network(layer): _restore_layer_unconditional_losses(layer) # Some layers, like Dense, record their activation loss function in the # config. However, not all layers do this, so the activation loss may be # missing when restored from the config/hdf5. # TODO(kathywu): Investigate ways to improve the config to ensure consistent # loading behavior between HDF5 and SavedModel. _restore_layer_activation_loss(layer) # Restore metrics list. _restore_layer_metrics(layer) # Restore RNN layer states if (isinstance(layer, recurrent.RNN) and layer.stateful and hasattr(_get_keras_attr(layer), 'states')): layer.states = getattr(_get_keras_attr(layer), 'states', None) for variable in tf.nest.flatten(layer.states): backend.track_variable(variable)