def testPiecewiseConstantEdgeCases(self):
    x_int = resource_variable_ops.ResourceVariable(
        0, dtype=variables.dtypes.int32)
    boundaries, values = [-1.0, 1.0], [1, 2, 3]
    with self.assertRaises(ValueError):
      decayed_lr = learning_rate_decay.piecewise_constant(
          x_int, boundaries, values)
      if context.executing_eagerly():
        decayed_lr()

    x = resource_variable_ops.ResourceVariable(0.0)
    boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
    with self.assertRaises(ValueError):
      decayed_lr = learning_rate_decay.piecewise_constant(
          x, boundaries, values)
      if context.executing_eagerly():
        decayed_lr()

    # Test that ref types are valid.
    if not context.executing_eagerly():
      x = variables.VariableV1(0.0)
      x_ref = x.op.outputs[0]   # float32_ref tensor should be accepted
      boundaries, values = [1.0, 2.0], [1, 2, 3]
      learning_rate_decay.piecewise_constant(x_ref, boundaries, values)

    # Test casting boundaries from int32 to int64.
    x_int64 = resource_variable_ops.ResourceVariable(
        0, dtype=variables.dtypes.int64)
    boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
    decayed_lr = learning_rate_decay.piecewise_constant(
        x_int64, boundaries, values)

    self.evaluate(variables.global_variables_initializer())
    self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6)
    self.evaluate(x_int64.assign(1))
    self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6)
    self.evaluate(x_int64.assign(2))
    self.assertAllClose(self.evaluate(decayed_lr), 0.5, 1e-6)
    self.evaluate(x_int64.assign(3))
    self.assertAllClose(self.evaluate(decayed_lr), 0.6, 1e-6)
    self.evaluate(x_int64.assign(4))
    self.assertAllClose(self.evaluate(decayed_lr), 0.7, 1e-6)
  def testPiecewiseConstant(self):
    x = resource_variable_ops.ResourceVariable(-999)
    decayed_lr = learning_rate_decay.piecewise_constant(
        x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001])

    self.evaluate(variables.global_variables_initializer())

    self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6)
    self.evaluate(x.assign(100))
    self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6)
    self.evaluate(x.assign(105))
    self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6)
    self.evaluate(x.assign(110))
    self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6)
    self.evaluate(x.assign(120))
    self.assertAllClose(self.evaluate(decayed_lr), 0.01, 1e-6)
    self.evaluate(x.assign(999))
    self.assertAllClose(self.evaluate(decayed_lr), 0.001, 1e-6)