def sample_messy_power(variable, entropy):
  """Returns unsimplified power expression like ((x**2)**3/x**4)**2/x**3."""
  if entropy <= 0:
    return variable

  which = random.choice([1, 2, 3])

  if which == 1:
    exponent_entropy = min(2, entropy)
    entropy -= exponent_entropy
    exponent = number.integer_or_rational(exponent_entropy, signed=True)
    left = sample_messy_power(variable, entropy)
    return ops.Pow(left, exponent)

  entropy_left = entropy / 2
  if entropy_left < 1:
    entropy_left = 0
  entropy_right = entropy - entropy_left
  if random.choice([False, True]):
    entropy_left, entropy_right = entropy_right, entropy_left

  left = sample_messy_power(variable, entropy_left)
  right = sample_messy_power(variable, entropy_right)
  if which == 2:
    return ops.Mul(left, right)
  else:
    return ops.Div(left, right)
예제 #2
0
  def testDescendants(self):
    constants = [ops.Constant(i) for i in range(6)]

    # (1 + 2*3**4) / 5 - 6
    expression = ops.Sub(
        ops.Div(
            ops.Add(
                constants[0],
                ops.Mul(
                    constants[1],
                    ops.Pow(
                        constants[2],
                        constants[3]))),
            constants[4]),
        constants[5])
    descendants = expression.descendants()
    descendants = ops._flatten(descendants)

    for constant in constants:
      self.assertIn(constant, descendants)
      self.assertEqual(descendants.count(constant), 1)

    # Also test top-level.
    self.assertEqual(constants[0].descendants(), [constants[0]])

    # Also general structure.
    constant = ops.Constant(3)
    expression = ops.Neg(constant)
    self.assertEqual(set(expression.descendants()), set([constant, expression]))
예제 #3
0
 def div_by_sqrt_k():
   """Do sqrt(k * base) / sqrt(k)."""
   entropy_k = min(1, entropy)
   k = number.integer(entropy_k, signed=False, min_abs=2)
   entropy_left, entropy_right = _surd_split_entropy_two(entropy - entropy_k)
   k_base_expr = _sample_surd(k * base, entropy_left, max_power, True)
   while True:
     k_expr = _sample_surd(k, entropy_right, max_power, True)
     if k_expr.sympy() != 0:
       break
   return ops.Div(k_base_expr, k_expr)
예제 #4
0
  def testDiv(self):
    div = ops.Div(2, 3)
    self.assertEqual(str(div), '2/3')
    self.assertEqual(div.sympy(), sympy.Rational(2, 3))

    div = ops.Div(2, sympy.Rational(4, 5))
    self.assertEqual(str(div), '2/(4/5)')
    self.assertEqual(div.sympy(), sympy.Rational(5, 2))

    div = ops.Div(1, ops.Div(2, 3))
    self.assertEqual(str(div), '1/(2/3)')
    self.assertEqual(div.sympy(), sympy.Rational(3, 2))

    div = ops.Div(ops.Div(2, 3), 4)
    self.assertEqual(str(div), '(2/3)/4')
    self.assertEqual(div.sympy(), sympy.Rational(1, 6))

    div = ops.Div(2, ops.Mul(3, 4))
    self.assertEqual(str(div), '2/(3*4)')

    div = ops.Div(2, sympy.Function('f')(sympy.Symbol('x')))
    self.assertEqual(str(div), '2/f(x)')