Exemplo n.º 1
0
  def load_layers(self, compile=True):  # pylint: disable=redefined-builtin
    """Load all layer nodes from the metadata."""
    # Load metrics after models and layers, since it's likely that models
    # and layers will create the metric when initialized (this avoids wasting
    # time by creating objects multiple times).
    metric_list = []
    for node_metadata in self._metadata.values():
      if node_metadata.identifier == constants.METRIC_IDENTIFIER:
        metric_list.append(node_metadata)
        continue

      self.loaded_nodes[node_metadata.node_id] = self._load_layer(
          node_metadata.node_id, node_metadata.identifier,
          node_metadata.metadata)

    for node_metadata in metric_list:
      try:
        self.loaded_nodes[node_metadata.node_id] = self._load_layer(
            node_metadata.node_id, node_metadata.identifier,
            node_metadata.metadata)
      except ValueError as e:
        # Metrics are only needed when the model is compiled later. We ignore
        # errors when trying to load custom metrics when `compile=False` until
        # custom metrics are serialized properly (b/135550038).
        if compile:
          raise e
        logging.warning('Unable to restore custom metric. Please ensure that '
                        'the layer implements `get_config` and `from_config` '
                        'when saving. In addition, please use the '
                        '`custom_objects` arg when calling `load_model()`.')
Exemplo n.º 2
0
def should_skip_serialization(layer):
    """Skip serializing extra objects and functions if layer inputs aren't set."""
    saved_model_input_spec_set = (
        isinstance(layer, training_lib.Model)
        and layer._saved_model_inputs_spec is not None)  # pylint: disable=protected-access
    if not layer.built and not saved_model_input_spec_set:
        logging.warning(
            "Skipping full serialization of Keras layer {}, because "
            "it is not built.".format(layer))
        return True
    return False
Exemplo n.º 3
0
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.reader = config.get('reader')

        # training hyperparameters
        self.batch_size = config.get('batch_size')
        if not self.batch_size:
            self.batch_size = 10
            tf.logging.warn(
                "No 'batch_size' parameter provided. Using default value of %d",
                self.batch_size)
        self.buffer_size = config.get(
            'buffer_size',
            999999)  # in other words, read full dataset into memory by default
        self.batch_buffer_size = config.get(
            'batch_buffer_size',
            512)  # number of consecutive batches to shuffle
        self.dataset_caching = config.get('dataset_caching', True)

        self.max_steps = config.get('max_steps')
        self.steps_per_epoch = config.get('steps_per_epoch')

        self.max_epochs = config.get('max_epochs')
        self.patience_epochs = config.get('patience_epochs')
        self.checkpoint_epochs = config.get('checkpoint_epochs')

        self.exports_to_keep = config.get('exports_to_keep', 1)
        self.keep_checkpoints = config.get('checkpoints_to_keep', 1)

        # feature/input settings
        self.features = config.get('features')
        self.bucket_sizes = config.get('bucket_sizes')
        self.max_length = config.get('max_length', 100)
        self.duplicate_uncased = config.get('duplicate_uncased', 0)
        self.filter_key = config.get('filter_key')
        self.include = config.get('include')
        self.exclude = config.get('exclude')
        # Decay for exponential moving average (EMA) of parameters -- 0.998 or 0.999 is standard
        # "Temporal averaging for semi-supervised learning", Laine and Aila 2017. https://arxiv.org/abs/1610.02242
        self.ema_decay = config.get('ema_decay', 0)

        # encoder settings
        self.encoders = [
            EncoderConfig(val) for val in config.get('encoders', [])
        ]
        if not self.encoders:
            raise ValueError('Must have at least one encoder')

        # head configuration validation
        self.heads = [HeadConfig(head) for head in config.get('heads', [])]
        targets = {}
        for target in self.features.targets:
            if target.name not in {head.name for head in self.heads}:
                logging.warning("Missing head configuration for target '%s'" %
                                target.name)
            targets[target.name] = target
        for head in self.heads:
            if head.name not in targets:
                raise ValueError(
                    "Missing feature configuration for target '%s'" %
                    head.name)
        if len(self.heads) == 0:
            raise ValueError(
                "Must have at least one head/target in configuration")

        self.metric = config.get('metric')
        if not self.metric:
            metrics = [
                append_label(head.metric, head.name) for head in self.heads
            ]
            self.metric = metrics[0]

        optimizer_config = config.get('optimizer')
        if optimizer_config:
            self.optimizer = OptimizerConfig(optimizer_config)
Exemplo n.º 4
0
#!/usr/bin/env python
"""
Training GNN in HPC using Horovod
"""

import tensorflow as tf
from tensorflow.compat.v1 import logging
logging.info("TF Version:{}".format(tf.__version__))
try:
    import horovod.tensorflow as hvd
    no_horovod = False
except ModuleNotFoundError:
    logging.warning("No horvod module, cannot perform distributed training")
    no_horovod = True


# tf.config.optimizer.set_jit(True)
# tf.debugging.set_log_device_placement(True)

import os
import sys
import argparse
import glob
import re
import time
import random
import functools
from types import SimpleNamespace

import numpy as np
import sklearn.metrics
Exemplo n.º 5
0
def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
  """Loads Keras objects from a SavedModel.

  Any Keras layer or model saved to the SavedModel will be loaded back
  as Keras objects. Other objects are loaded as regular trackable objects (same
  as `tf.saved_model.load`).

  Currently, Keras saving/loading only retains the Keras object's weights,
  losses, and call function.

  The loaded model can be re-compiled, but the original optimizer, compiled loss
  functions, and metrics are not retained. This is temporary, and `model.save`
  will soon be able to serialize compiled models.

  Args:
    path: Path to SavedModel.
    compile: If true, compile the model after loading it.
    options: Optional `tf.saved_model.LoadOptions` object that specifies
      options for loading from SavedModel.


  Returns:
    Object loaded from SavedModel.
  """
  # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
  # TODO(kathywu): Add code to load from objects that contain all endpoints

  # Look for metadata file or parse the SavedModel
  metadata = saved_metadata_pb2.SavedMetadata()
  meta_graph_def = tf.__internal__.saved_model.parse_saved_model(
      path).meta_graphs[0]
  object_graph_def = meta_graph_def.object_graph_def
  path_to_metadata_pb = tf.io.gfile.join(path, constants.SAVED_METADATA_PATH)
  if tf.compat.v1.gfile.Exists(path_to_metadata_pb):
    try:
      with tf.io.gfile.GFile(path_to_metadata_pb, 'rb') as f:
        file_content = f.read()
      metadata.ParseFromString(file_content)
    except message.DecodeError as e:
      raise IOError(
          f'Cannot parse keras metadata at path {path_to_metadata_pb}: '
          f'Received error: {e}')
  else:
    logging.warning('SavedModel saved prior to TF 2.5 detected when loading '
                    'Keras model. Please ensure that you are saving the model '
                    'with model.save() or tf.keras.models.save_model(), *NOT* '
                    'tf.saved_model.save(). To confirm, there should be a file '
                    'named "keras_metadata.pb" in the SavedModel directory.')
    _read_legacy_metadata(object_graph_def, metadata, path)

  if not metadata.nodes:
    # When there are no Keras objects, return the results from the core loader
    return tf.saved_model.load(path, options=options)

  metadata = _update_to_current_version(metadata)
  # Recreate layers and metrics using the info stored in the metadata.
  keras_loader = KerasObjectLoader(metadata, object_graph_def)
  keras_loader.load_layers(compile=compile)

  # Generate a dictionary of all loaded nodes.
  nodes_to_load = {'root': None}
  for node_id, loaded_node in keras_loader.loaded_nodes.items():
    nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
  loaded = tf.__internal__.saved_model.load_partial(
      path, nodes_to_load, options=options)

  # Finalize the loaded layers and remove the extra tracked dependencies.
  keras_loader.finalize_objects()
  keras_loader.del_tracking()

  model = loaded['root']

  # pylint: disable=protected-access
  if isinstance(model, training_lib.Model) and compile:
    # TODO(kathywu): Use compiled objects from SavedModel, instead of
    # creating new objects from the training config.
    training_config = model._serialized_attributes['metadata'].get(
        'training_config', None)
    if training_config is not None:
      model.compile(**saving_utils.compile_args_from_training_config(
          training_config), from_serialized=True)
      saving_utils.try_build_compiled_arguments(model)
      if isinstance(model.optimizer, optimizer_v2.OptimizerV2):
        if model.optimizer.get_slot_names():
          logging.warning('Your optimizer uses slots. '
                          'Slots cannot be restored from saved_model, '
                          'as a result, your model is starting with  '
                          'a new initialized optimizer.')
    else:
      logging.warning('No training configuration found in save file, so the '
                      'model was *not* compiled. Compile it manually.')
  # pylint: enable=protected-access

  # Force variables and resources to initialize.
  if not tf.executing_eagerly():
    sess = backend.get_session()  # Variables are initialized by this call.
    sess.run(
        tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))

  return model
Exemplo n.º 6
0
  def _add_children_recreated_from_config(self, obj, proto, node_id):
    """Recursively records objects recreated from config."""
    # pylint: disable=protected-access
    if node_id in self._traversed_nodes_from_config:
      return

    parent_path = self._node_paths[node_id]
    self._traversed_nodes_from_config.add(node_id)
    obj._maybe_initialize_trackable()
    if isinstance(obj, base_layer.Layer) and not obj.built:
      metadata = json_utils.decode(self._metadata[node_id].metadata)
      self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))

    # Create list of all possible children
    children = []
    # Look for direct children
    for reference in proto.children:
      obj_child = obj._lookup_dependency(reference.local_name)
      children.append((obj_child, reference.node_id, reference.local_name))

    # Add metrics that may have been added to the layer._metrics list.
    # This is stored in the SavedModel as layer.keras_api.layer_metrics in
    # SavedModels created after Tf 2.2.
    metric_list_node_id = self._search_for_child_node(
        node_id, [constants.KERAS_ATTR, 'layer_metrics'])
    if metric_list_node_id is not None and hasattr(obj, '_metrics'):
      obj_metrics = {m.name: m for m in obj._metrics}
      for reference in self._proto.nodes[metric_list_node_id].children:
        metric = obj_metrics.get(reference.local_name)
        if metric is not None:
          metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
                                                     reference.local_name)
          children.append((metric, reference.node_id, metric_path))

    for (obj_child, child_id, child_name) in children:
      child_proto = self._proto.nodes[child_id]

      if not isinstance(obj_child, tf.__internal__.tracking.Trackable):
        continue
      if (child_proto.user_object.identifier in
          tf.__internal__.saved_model.load.registered_identifiers()):
        setter = tf.__internal__.saved_model.load.get_setter(
            child_proto.user_object)
      elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
        setter = _revive_setter
      else:
        setter = setattr
        # pylint: enable=protected-access

      if child_id in self.loaded_nodes:
        if self.loaded_nodes[child_id][0] is not obj_child:
          # This means that the same trackable object is referenced by two
          # different objects that were recreated from the config.
          logging.warning(
              'Looks like there is an object (perhaps variable or '
              'layer) that is shared between different layers/models. '
              'This may cause issues when restoring the variable '
              'values. Object: {}'.format(obj_child))
        continue

      # Overwrite variable names with the ones saved in the SavedModel.
      if (child_proto.WhichOneof('kind') == 'variable' and
          child_proto.variable.name):
        obj_child._handle_name = child_proto.variable.name + ':0'  # pylint: disable=protected-access

      if isinstance(obj_child, tf.__internal__.tracking.TrackableDataStructure):
        setter = lambda *args: None

      child_path = '{}.{}'.format(parent_path, child_name)
      self._node_paths[child_id] = child_path
      self._add_children_recreated_from_config(
          obj_child, child_proto, child_id)
      self.loaded_nodes[child_id] = obj_child, setter