Ejemplo n.º 1
0
  def losses(self):
    """Retrieve the network's losses.

    Will only include losses that are either
    unconditional, or conditional on inputs to this model
    (e.g. will not include losses that depend on tensors
    that aren't inputs to this model).

    Returns:
        A list of loss tensors.
    """
    losses = []
    for layer in self.layers:
      losses += layer.losses
    if context.in_eager_mode():
      return losses

    relevant_inputs = self.inputs or []
    for i in range(1, len(self._inbound_nodes)):
      inputs = self.get_input_at(i)
      if isinstance(inputs, list):
        relevant_inputs += inputs
      else:
        relevant_inputs.append(inputs)
    reachable = layers_util.get_reachable_from_inputs(relevant_inputs, losses)
    relevant_conditional_losses = [x for x in losses if x in reachable]
    unconditional_losses = [
        x for x in losses if x._unconditional_loss]  # pylint: disable=protected-access
    return list(set(
        relevant_conditional_losses + unconditional_losses + self._losses))
Ejemplo n.º 2
0
  def losses(self):
    """Retrieve the network's losses.

    Will only include losses that are either
    unconditional, or conditional on inputs to this model
    (e.g. will not include losses that depend on tensors
    that aren't inputs to this model).

    Returns:
        A list of loss tensors.
    """
    losses = []
    if context.in_eager_mode():
      for layer in self.layers:
        losses += layer.losses
      return losses

    for layer in self.layers:
      losses += layer.losses

    relevant_inputs = []
    for i in range(len(self._inbound_nodes)):
      inputs = self.get_input_at(i)
      if isinstance(inputs, list):
        relevant_inputs += inputs
      else:
        relevant_inputs.append(inputs)
    reachable = layers_util.get_reachable_from_inputs(relevant_inputs, losses)
    relevant_conditional_losses = [x for x in losses if x in reachable]
    unconditional_losses = [
        x for x in losses if x._unconditional_loss]  # pylint: disable=protected-access
    return list(set(
        relevant_conditional_losses + unconditional_losses + self._losses))
Ejemplo n.º 3
0
  def testGetReachableFromInputs(self):

    pl_1 = array_ops.placeholder(shape=None, dtype='float32')
    pl_2 = array_ops.placeholder(shape=None, dtype='float32')
    pl_3 = array_ops.placeholder(shape=None, dtype='float32')
    x_1 = pl_1 + pl_2
    x_2 = pl_2 * 2
    x_3 = pl_3 + 1
    x_4 = x_1 + x_2
    x_5 = x_3 * pl_1

    self.assertEqual({pl_1, x_1, x_4, x_5},
                     utils.get_reachable_from_inputs([pl_1]))
    self.assertEqual({pl_1, pl_2, x_1, x_2, x_4, x_5},
                     utils.get_reachable_from_inputs([pl_1, pl_2]))
    self.assertEqual({pl_3, x_3, x_5}, utils.get_reachable_from_inputs([pl_3]))
    self.assertEqual({x_3, x_5}, utils.get_reachable_from_inputs([x_3]))
Ejemplo n.º 4
0
    def testGetReachableFromInputs(self):

        pl_1 = array_ops.placeholder(shape=None, dtype='float32')
        pl_2 = array_ops.placeholder(shape=None, dtype='float32')
        pl_3 = array_ops.placeholder(shape=None, dtype='float32')
        x_1 = pl_1 + pl_2
        x_2 = pl_2 * 2
        x_3 = pl_3 + 1
        x_4 = x_1 + x_2
        x_5 = x_3 * pl_1

        self.assertEqual({pl_1, x_1, x_4, x_5},
                         utils.get_reachable_from_inputs([pl_1]))
        self.assertEqual({pl_1, pl_2, x_1, x_2, x_4, x_5},
                         utils.get_reachable_from_inputs([pl_1, pl_2]))
        self.assertEqual({pl_3, x_3, x_5},
                         utils.get_reachable_from_inputs([pl_3]))
        self.assertEqual({x_3, x_5}, utils.get_reachable_from_inputs([x_3]))
Ejemplo n.º 5
0
  def updates(self):
    """Retrieve the network's updates.

    Will only include updates that are either
    unconditional, or conditional on inputs to this model
    (e.g. will not include updates that were created by layers of this model
    outside of the model).

    Effectively, `network.updates` behaves like `layer.updates`.

    Concrete example:

    ```python
      bn = keras.layers.BatchNormalization()
      x1 = keras.layers.Input(shape=(10,))
      _ = bn(x1)  # This creates 2 updates.

      x2 = keras.layers.Input(shape=(10,))
      y2 = bn(x2)  # This creates 2 more updates.

      # The BN layer has now 4 updates.
      self.assertEqual(len(bn.updates), 4)

      # Let's create a model from x2 to y2.
      model = keras.models.Model(x2, y2)

      # The model does not list all updates from its underlying layers,
      # but only the updates that are relevant to it. Updates created by layers
      # outside of the model are discarded.
      self.assertEqual(len(model.updates), 2)

      # If you keep calling the model, you append to its updates, just like
      # what happens for a layer.
      x3 = keras.layers.Input(shape=(10,))
      y3 = model(x3)
      self.assertEqual(len(model.updates), 4)

      # But if you call the inner BN layer independently, you don't affect
      # the model's updates.
      x4 = keras.layers.Input(shape=(10,))
      _ = bn(x4)
      self.assertEqual(len(model.updates), 4)
    ```

    Returns:
        A list of update ops.
    """
    if context.in_eager_mode():
      return []

    if not self.trainable and not self.stateful:
      return []

    updates = []
    for layer in self.layers:
      updates += layer.updates

    # `updates` might contain irrelevant updates, so it needs to be filtered
    # with respect to inputs the model has been called on.
    relevant_inputs = self.inputs or []
    for i in range(1, len(self._inbound_nodes)):
      inputs = self.get_input_at(i)
      if isinstance(inputs, list):
        relevant_inputs += inputs
      else:
        relevant_inputs.append(inputs)
    reachable = layers_util.get_reachable_from_inputs(relevant_inputs, updates)
    relevant_conditional_updates = [x for x in updates if x in reachable]
    unconditional_updates = [
        x for x in updates if x._unconditional_update]  # pylint: disable=protected-access
    # A layer could be used multiple times in a nested structure,
    # so the updates list must be de-duped.
    return list(set(
        relevant_conditional_updates + unconditional_updates + self._updates))
Ejemplo n.º 6
0
  def updates(self):
    """Retrieve the network's updates.

    Will only include updates that are either
    unconditional, or conditional on inputs to this model
    (e.g. will not include updates that were created by layers of this model
    outside of the model).

    Effectively, `network.updates` behaves like `layer.updates`.

    Concrete example:

    ```python
      bn = keras.layers.BatchNormalization()
      x1 = keras.layers.Input(shape=(10,))
      _ = bn(x1)  # This creates 2 updates.

      x2 = keras.layers.Input(shape=(10,))
      y2 = bn(x2)  # This creates 2 more updates.

      # The BN layer has now 4 updates.
      self.assertEqual(len(bn.updates), 4)

      # Let's create a model from x2 to y2.
      model = keras.models.Model(x2, y2)

      # The model does not list all updates from its underlying layers,
      # but only the updates that are relevant to it. Updates created by layers
      # outside of the model are discarded.
      self.assertEqual(len(model.updates), 2)

      # If you keep calling the model, you append to its updates, just like
      # what happens for a layer.
      x3 = keras.layers.Input(shape=(10,))
      y3 = model(x3)
      self.assertEqual(len(model.updates), 4)

      # But if you call the inner BN layer independently, you don't affect
      # the model's updates.
      x4 = keras.layers.Input(shape=(10,))
      _ = bn(x4)
      self.assertEqual(len(model.updates), 4)
    ```

    Returns:
        A list of update ops.
    """
    if not self.trainable and not self.stateful:
      return []

    updates = []
    for layer in self.layers:
      updates += layer.updates

    # `updates` might contain irrelevant updates, so it needs to be filtered
    # with respect to inputs the model has been called on.
    relevant_inputs = []
    for i in range(len(self._inbound_nodes)):
      inputs = self.get_input_at(i)
      if isinstance(inputs, list):
        relevant_inputs += inputs
      else:
        relevant_inputs.append(inputs)
    reachable = layers_util.get_reachable_from_inputs(relevant_inputs, updates)
    relevant_conditional_updates = [x for x in updates if x in reachable]
    unconditional_updates = [
        x for x in updates if x._unconditional_update]  # pylint: disable=protected-access
    # A layer could be used multiple times in a nested structure,
    # so the updates list must be de-duped.
    return list(set(
        relevant_conditional_updates + unconditional_updates + self._updates))