コード例 #1
0
ファイル: keras.py プロジェクト: LiuCKind/tensorflow
def _save_first_checkpoint(keras_model, estimator, custom_objects,
                           keras_weights):
  """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    estimator: keras estimator.
    custom_objects: Dictionary for custom objects.
    keras_weights: A flat list of Numpy arrays for weights of given keras_model.

  Returns:
    The model_fn for a keras Estimator.
  """
  # Load weights and save to checkpoint if there is no checkpoint
  latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
  if not latest_path:
    with ops.Graph().as_default():
      random_seed.set_random_seed(estimator.config.tf_random_seed)
      training_util.create_global_step()
      model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
                                     custom_objects)
      # save to checkpoint
      with session.Session(config=estimator._session_config) as sess:
        if keras_weights:
          model.set_weights(keras_weights)
        # Make update ops and initialize all variables.
        if not model.train_function:
          # pylint: disable=protected-access
          model._make_train_function()
          K._initialize_variables(sess)
          # pylint: enable=protected-access
        saver = saver_lib.Saver()
        saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
コード例 #2
0
 def _assert_all_close(self, expected, actual):
     if not context.executing_eagerly():
         with self.cached_session() as sess:
             K._initialize_variables(sess)
             self.assertAllClose(expected, actual)
     else:
         self.assertAllClose(expected, actual)
コード例 #3
0
  def test_bad_kernel_approximation(self, initializer, scale, exact_kernel_fn):
    """Approximation is bad when output dimension is small."""
    # Two distinct inputs.
    x = constant_op.constant([[1.0, -1.0, 0.5]])
    y = constant_op.constant([[-1.0, 1.0, 1.0]])

    small_output_dim = 10
    random_seed.set_random_seed(1234)
    # Initialize layer.
    rff_layer = kernel_layers.RandomFourierFeatures(
        output_dim=small_output_dim,
        kernel_initializer=initializer,
        scale=scale,
        name='random_fourier_features')

    # Apply layer to both inputs.
    output_x = math.sqrt(2.0 / small_output_dim) * rff_layer.apply(x)
    output_y = math.sqrt(2.0 / small_output_dim) * rff_layer.apply(y)

    # The inner products of the outputs (on inputs x and y) approximates the
    # real value of the RBF kernel but poorly since the output dimension of the
    # layer is small.
    exact_kernel_value = exact_kernel_fn(x, y)
    approx_kernel_value = kernelized_utils.inner_product(output_x, output_y)
    abs_error = math_ops.abs(exact_kernel_value - approx_kernel_value)
    if not context.executing_eagerly():
      with self.cached_session() as sess:
        keras_backend._initialize_variables(sess)
        abs_error_eval = sess.run([abs_error])
        self.assertGreater(abs_error_eval[0][0], 0.05)
        self.assertLess(abs_error_eval[0][0], 0.5)
    else:
      self.assertGreater(abs_error, 0.05)
      self.assertLess(abs_error, 0.5)
コード例 #4
0
  def test_bad_kernel_approximation(self, initializer, scale, exact_kernel_fn):
    """Approximation is bad when output dimension is small."""
    # Two distinct inputs.
    x = constant_op.constant([[1.0, -1.0, 0.5]])
    y = constant_op.constant([[-1.0, 1.0, 1.0]])

    small_output_dim = 10
    random_seed.set_random_seed(1234)
    # Initialize layer.
    rff_layer = kernel_layers.RandomFourierFeatures(
        output_dim=small_output_dim,
        kernel_initializer=initializer,
        scale=scale,
        name='random_fourier_features')

    # Apply layer to both inputs.
    output_x = math.sqrt(2.0 / small_output_dim) * rff_layer(x)
    output_y = math.sqrt(2.0 / small_output_dim) * rff_layer(y)

    # The inner products of the outputs (on inputs x and y) approximates the
    # real value of the RBF kernel but poorly since the output dimension of the
    # layer is small.
    exact_kernel_value = exact_kernel_fn(x, y)
    approx_kernel_value = kernelized_utils.inner_product(output_x, output_y)
    abs_error = math_ops.abs(exact_kernel_value - approx_kernel_value)
    if not context.executing_eagerly():
      with self.cached_session() as sess:
        keras_backend._initialize_variables(sess)
        abs_error_eval = sess.run([abs_error])
        self.assertGreater(abs_error_eval[0][0], 0.05)
        self.assertLess(abs_error_eval[0][0], 0.5)
    else:
      self.assertGreater(abs_error, 0.05)
      self.assertLess(abs_error, 0.5)
コード例 #5
0
 def _assert_all_close(self, expected, actual):
   if not context.executing_eagerly():
     with self.cached_session() as sess:
       K._initialize_variables(sess)
       self.assertAllClose(expected, actual)
   else:
     self.assertAllClose(expected, actual)
コード例 #6
0
 def _assert_all_close(self, expected, actual, atol=0.001):
   if not context.executing_eagerly():
     with self.cached_session() as sess:
       keras_backend._initialize_variables(sess)
       self.assertAllClose(expected, actual, atol=atol)
   else:
     self.assertAllClose(expected, actual, atol=atol)
コード例 #7
0
def _save_first_checkpoint(keras_model, estimator, custom_objects,
                           keras_weights):
    """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    estimator: keras estimator.
    custom_objects: Dictionary for custom objects.
    keras_weights: A flat list of Numpy arrays for weights of given keras_model.

  Returns:
    The model_fn for a keras Estimator.
  """
    # Load weights and save to checkpoint if there is no checkpoint
    latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
    if not latest_path:
        with ops.Graph().as_default():
            random_seed.set_random_seed(estimator.config.tf_random_seed)
            training_util.create_global_step()
            model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN,
                                           keras_model, custom_objects)
            # save to checkpoint
            with session.Session(config=estimator._session_config) as sess:
                if keras_weights:
                    model.set_weights(keras_weights)
                # Make update ops and initialize all variables.
                if not model.train_function:
                    # pylint: disable=protected-access
                    model._make_train_function()
                    K._initialize_variables(sess)
                    # pylint: enable=protected-access
                saver = saver_lib.Saver()
                saver.save(
                    sess, os.path.join(estimator.model_dir,
                                       'keras_model.ckpt'))
コード例 #8
0
 def _assert_all_close(self, expected, actual, atol=0.001):
     if not context.executing_eagerly():
         with self.cached_session() as sess:
             keras_backend._initialize_variables(sess)
             self.assertAllClose(expected, actual, atol=atol)
     else:
         self.assertAllClose(expected, actual, atol=atol)
コード例 #9
0
def _save_first_checkpoint(keras_model, custom_objects, config,
                           save_object_ckpt):
    """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    config: Estimator config.
    save_object_ckpt: Whether to save an object-based checkpoint.

  Returns:
    The path where keras model checkpoint is saved.
  """
    # save checkpoint into subdirectory to allow warm start
    keras_model_dir = os.path.join(config.model_dir, 'keras')
    # Load weights and save to checkpoint if there is no checkpoint
    latest_path = tf.train.latest_checkpoint(keras_model_dir)
    if not latest_path:
        keras_weights = None
        if _any_weight_initialized(keras_model):
            keras_weights = keras_model.get_weights()
        if not tf.compat.v1.gfile.IsDirectory(keras_model_dir):
            tf.compat.v1.gfile.MakeDirs(keras_model_dir)
        with tf.Graph().as_default():
            tf.compat.v1.random.set_random_seed(config.tf_random_seed)
            tf.compat.v1.train.create_global_step()
            model = _clone_and_build_model(ModeKeys.TRAIN, keras_model,
                                           custom_objects)

            # Init the train_function outside of the context of session. This is due
            # to the fact that train function will update the graph by adding backprop
            # parts. This will potentially trying to update the node in forward graph
            # which will fail if it is done within same session.
            # Always create the train_function here since the model is just cloned.
            # See https://github.com/tensorflow/tensorflow/issues/27750 for details.
            model._make_train_function()  # pylint: disable=protected-access

            # save to checkpoint
            with tf.compat.v1.Session(config=config.session_config) as sess:
                if keras_weights:
                    model.set_weights(keras_weights)
                # model._make_train_function() will potentially create the optimizer
                # variable, which will require another variable initialization.
                K._initialize_variables(sess)  # pylint: disable=protected-access

                if save_object_ckpt:
                    model._track_trackable(  # pylint: disable=protected-access
                        tf.compat.v1.train.get_global_step(),
                        'estimator_global_step')
                    latest_path = os.path.join(keras_model_dir,
                                               'keras_model.ckpt')
                    model.save_weights(latest_path)
                else:
                    saver = tf.compat.v1.train.Saver()
                    latest_path = os.path.join(keras_model_dir,
                                               'keras_model.ckpt')
                    saver.save(sess, latest_path)

    return latest_path
コード例 #10
0
def init_restore_or_wait_for_variables():
  """Initialize or restore variables or wait for variables to be initialized."""
  session = K._get_session()  # pylint: disable=protected-access
  worker_context = dc_context.get_current_worker_context()
  if not worker_context or worker_context.experimental_should_init:
    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
    K._initialize_variables(session)  # pylint: disable=protected-access
  else:
    _wait_for_variable_initialization(session)
コード例 #11
0
def init_restore_or_wait_for_variables():
  """Initialize or restore variables or wait for variables to be initialized."""
  session = K._get_session()  # pylint: disable=protected-access
  if not multi_worker_util.has_worker_context(
  ) or multi_worker_util.should_load_checkpoint():
    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
    K._initialize_variables(session)  # pylint: disable=protected-access
  else:
    _wait_for_variable_initialization(session)
コード例 #12
0
def init_restore_or_wait_for_variables():
    """Initialize or restore variables or wait for variables to be initialized."""
    session = K._get_session()  # pylint: disable=protected-access
    worker_context = dc_context.get_current_worker_context()
    if not worker_context or worker_context.should_init:
        # TODO(yuefengz): if checkpoints exit, restore from checkpoint.
        K._initialize_variables(session)  # pylint: disable=protected-access
    else:
        _wait_for_variable_initialization(session)
コード例 #13
0
def _save_first_checkpoint(keras_model, custom_objects, config,
                           save_object_ckpt):
    """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    config: Estimator config.
    save_object_ckpt: Whether to save an object-based checkpoint.

  Returns:
    The path where keras model checkpoint is saved.
  """
    # save checkpoint into subdirectory to allow warm start
    keras_model_dir = os.path.join(config.model_dir, 'keras')
    # Load weights and save to checkpoint if there is no checkpoint
    latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
    if not latest_path:
        keras_weights = None
        if _any_weight_initialized(keras_model):
            keras_weights = keras_model.get_weights()
        if not gfile.IsDirectory(keras_model_dir):
            gfile.MakeDirs(keras_model_dir)
        with ops.Graph().as_default():
            random_seed.set_random_seed(config.tf_random_seed)
            training_util.create_global_step()
            model = _clone_and_build_model(ModeKeys.TRAIN, keras_model,
                                           custom_objects)
            # save to checkpoint
            with session.Session(config=config.session_config) as sess:
                if keras_weights:
                    model.set_weights(keras_weights)
                # Make update ops and initialize all variables.
                if not model.train_function:
                    # pylint: disable=protected-access
                    model._make_train_function()
                    K._initialize_variables(sess)
                    # pylint: enable=protected-access

                if save_object_ckpt:
                    model._track_trackable(  # pylint: disable=protected-access
                        training_util.get_global_step(),
                        'estimator_global_step')
                    latest_path = os.path.join(keras_model_dir,
                                               'keras_model.ckpt')
                    model.save_weights(latest_path)
                else:
                    saver = saver_lib.Saver()
                    latest_path = os.path.join(keras_model_dir,
                                               'keras_model.ckpt')
                    saver.save(sess, latest_path)

    return latest_path
コード例 #14
0
def _save_first_checkpoint(keras_model, custom_objects, config):
    """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    config: Estimator config.

  Returns:
    The path where keras model checkpoint is saved.
  """
    # save checkpoint into subdirectory to allow warm start
    keras_model_dir = os.path.join(config.model_dir, 'keras')
    # Load weights and save to checkpoint if there is no checkpoint
    latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
    if not latest_path:
        keras_weights = None
        if _any_weight_initialized(keras_model):
            keras_weights = keras_model.get_weights()
        if not gfile.IsDirectory(keras_model_dir):
            gfile.MakeDirs(keras_model_dir)
        with ops.Graph().as_default():
            random_seed.set_random_seed(config.tf_random_seed)
            training_util.create_global_step()
            model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN,
                                           keras_model, custom_objects)
            # save to checkpoint
            with session.Session(config=config.session_config) as sess:
                if keras_weights:
                    model.set_weights(keras_weights)
                # Make update ops and initialize all variables.
                if not model.train_function:
                    # pylint: disable=protected-access
                    model._make_train_function()
                    # We are using global variables collection here because:
                    # estimator runs eager mode under context.graph_mode() context manager
                    # When we try to get all the TF optimizer variables using
                    # optimizer.variables() we try to return variables that belong to the
                    # current graph. This check (variable.op.graph is current_graph) will
                    # error as the context is graph mode but variables are eager.
                    # TODO(psv): investigate this and see if we can remove the usage of
                    # collection here.
                    K._initialize_variables(
                        sess, variables_module.global_variables())
                    # pylint: enable=protected-access
                saver = saver_lib.Saver()
                latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
                saver.save(sess, latest_path)
    return latest_path
コード例 #15
0
ファイル: keras.py プロジェクト: AnishShah/tensorflow
def _save_first_checkpoint(keras_model, custom_objects, config):
  """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    config: Estimator config.

  Returns:
    The path where keras model checkpoint is saved.
  """
  # save checkpoint into subdirectory to allow warm start
  keras_model_dir = os.path.join(config.model_dir, 'keras')
  # Load weights and save to checkpoint if there is no checkpoint
  latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
  if not latest_path:
    keras_weights = None
    if _any_weight_initialized(keras_model):
      keras_weights = keras_model.get_weights()
    if not gfile.IsDirectory(keras_model_dir):
      gfile.MakeDirs(keras_model_dir)
    with ops.Graph().as_default():
      random_seed.set_random_seed(config.tf_random_seed)
      training_util.create_global_step()
      model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
                                     custom_objects)
      # save to checkpoint
      with session.Session(config=config.session_config) as sess:
        if keras_weights:
          model.set_weights(keras_weights)
        # Make update ops and initialize all variables.
        if not model.train_function:
          # pylint: disable=protected-access
          model._make_train_function()
          K._initialize_variables(sess)
          # pylint: enable=protected-access
        saver = saver_lib.Saver()
        latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
        saver.save(sess, latest_path)
  return latest_path
コード例 #16
0
    def _build(self, shape):
        """Initialize TP, FP, TN, and FN tensors, given the shape of the data."""
        if self.multi_label:
            if shape.ndims != 2:
                raise ValueError(
                    '`y_true` must have rank=2 when `multi_label` is '
                    'True. Found rank %s.' % shape.ndims)
            variable_shape = tensor_shape.TensorShape(
                [tensor_shape.Dimension(self.num_thresholds), shape[1]])
        else:
            variable_shape = tensor_shape.TensorShape(
                [tensor_shape.Dimension(self.num_thresholds)])

        # Create metric variables
        self.true_positives = self.add_weight(
            'true_positives',
            shape=variable_shape,
            initializer=init_ops.zeros_initializer)
        self.true_negatives = self.add_weight(
            'true_negatives',
            shape=variable_shape,
            initializer=init_ops.zeros_initializer)
        self.false_positives = self.add_weight(
            'false_positives',
            shape=variable_shape,
            initializer=init_ops.zeros_initializer)
        self.false_negatives = self.add_weight(
            'false_negatives',
            shape=variable_shape,
            initializer=init_ops.zeros_initializer)

        if self.multi_label:
            with ops.init_scope():
                # This should only be necessary for handling v1 behavior. In v2, AUC
                # should be initialized outside of any tf.functions, and therefore in
                # eager mode.
                if not context.executing_eagerly():
                    K._initialize_variables(K._get_session())  # pylint: disable=protected-access

        self._built = True
コード例 #17
0
def init_restore_or_wait_for_variables():
    """Initialize or restore variables or wait for variables to be initialized."""
    backend._initialize_variables(backend._get_session())  # pylint: disable=protected-access