예제 #1
0
def periodic_target_update(target_variables,
                           source_variables,
                           update_period,
                           tau=1.0,
                           use_locking=False,
                           name="periodic_target_update"):
  """Returns an op to periodically update a list of target variables.

  The `update_target_variables` op is executed every `update_period`
  executions of the `periodic_target_update` op.

  The update rule is:
  `target_variable = (1 - tau) * target_variable + tau * source_variable`.

  Args:
    target_variables: a list of the variables to be updated.
    source_variables: a list of the variables used for the update.
    update_period: inverse frequency with which to apply the update.
    tau: weight used to gate the update. The permitted range is 0 < tau <= 1,
      with small tau representing an incremental update, and tau == 1
      representing a full update (that is, a straight copy).
    use_locking: use `tf.variable.Assign`'s locking option when assigning
      source variable values to target variables.
    name: sets the `name_scope` for this op.

  Returns:
    An op that periodically updates `target_variables` with `source_variables`.
  """

  def update_op():
    return update_target_variables(
        target_variables, source_variables, tau, use_locking)

  with tf.name_scope(name, values=target_variables + source_variables):
    return periodic_ops.periodically(update_op, update_period)
예제 #2
0
def periodic_target_update(target_variables,
                           source_variables,
                           update_period,
                           tau=1.0,
                           use_locking=False,
                           name="periodic_target_update"):
    """Returns an op to periodically update a list of target variables.

  The `update_target_variables` op is executed every `update_period`
  executions of the `periodic_target_update` op.

  The update rule is:
  `target_variable = (1 - tau) * target_variable + tau * source_variable`.

  Args:
    target_variables: a list of the variables to be updated.
    source_variables: a list of the variables used for the update.
    update_period: inverse frequency with which to apply the update.
    tau: weight used to gate the update. The permitted range is 0 < tau <= 1,
      with small tau representing an incremental update, and tau == 1
      representing a full update (that is, a straight copy).
    use_locking: use `tf.variable.Assign`'s locking option when assigning
      source variable values to target variables.
    name: sets the `name_scope` for this op.

  Returns:
    An op that periodically updates `target_variables` with `source_variables`.
  """
    def update_op():
        return update_target_variables(target_variables, source_variables, tau,
                                       use_locking)

    with tf.name_scope(name, values=target_variables + source_variables):
        return periodic_ops.periodically(update_op, update_period)
예제 #3
0
    def testPeriodOne(self):
        """Tests that the function is called every time if period == 1."""
        target = tf.Variable(0)

        periodic_update = periodic_ops.periodically(
            body=lambda: target.assign_add(1).op, period=1)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            for desired_value in range(1, 11):
                _, result = sess.run([periodic_update, target])
                self.assertEqual(desired_value, result)
예제 #4
0
  def testPeriodOne(self):
    """Tests that the function is called every time if period == 1."""
    target = tf.Variable(0)

    periodic_update = periodic_ops.periodically(
        body=lambda: target.assign_add(1).op, period=1)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      for desired_value in range(1, 11):
        _, result = sess.run([periodic_update, target])
        self.assertEqual(desired_value, result)
예제 #5
0
    def testPeriodically(self):
        """Tests that a function is called exactly every `period` steps."""
        target = tf.Variable(0)
        period = 3

        periodic_update = periodic_ops.periodically(
            body=lambda: target.assign_add(1).op, period=period)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            desired_values = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4]
            for desired_value in desired_values:
                sess.run(periodic_update)
                result = sess.run(target)
                self.assertEqual(desired_value, result)
예제 #6
0
  def testPeriodically(self):
    """Tests that a function is called exactly every `period` steps."""
    target = tf.Variable(0)
    period = 3

    periodic_update = periodic_ops.periodically(
        body=lambda: target.assign_add(1).op, period=period)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      desired_values = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4]
      for desired_value in desired_values:
        sess.run(periodic_update)
        result = sess.run(target)
        self.assertEqual(desired_value, result)