def testAlgebraInverse(self):
    dataset_objects = algorithmic_math.math_dataset_init(26)
    counter = 0
    for d in algorithmic_math.algebra_inverse(26, 0, 3, 10):
      counter += 1
      decoded_input = dataset_objects.int_decoder(d["inputs"])
      solve_var, expression = decoded_input.split(":")
      lhs, rhs = expression.split("=")

      # Solve for the solve-var.
      result = sympy.solve("%s-(%s)" % (lhs, rhs), solve_var)
      target_expression = dataset_objects.int_decoder(d["targets"])

      # Check that the target and sympy's solutions are equivalent.
      self.assertEqual(
          0, sympy.simplify(str(result[0]) + "-(%s)" % target_expression))
    self.assertEqual(counter, 10)
  def testAlgebraInverse(self):
    dataset_objects = algorithmic_math.math_dataset_init(26)
    counter = 0
    for d in algorithmic_math.algebra_inverse(26, 0, 3, 10):
      counter += 1
      decoded_input = dataset_objects.int_decoder(d["inputs"])
      solve_var, expression = decoded_input.split(":")
      lhs, rhs = expression.split("=")

      # Solve for the solve-var.
      result = sympy.solve("%s-(%s)" % (lhs, rhs), solve_var)
      target_expression = dataset_objects.int_decoder(d["targets"])

      # Check that the target and sympy's solutions are equivalent.
      self.assertEqual(
          0, sympy.simplify(str(result[0]) + "-(%s)" % target_expression))
    self.assertEqual(counter, 10)
예제 #3
0
flags.DEFINE_integer("task_id_end", -1, "For distributed data generation.")
flags.DEFINE_integer(
    "num_concurrent_processes", None,
    "Applies only to problems for which multiprocess_generate=True.")
flags.DEFINE_string(
    "t2t_usr_dir", "", "Path to a Python module that will be imported. The "
    "__init__.py file should include the necessary imports. "
    "The imported files should contain registrations, "
    "e.g. @registry.register_problem calls, that will then be "
    "available to t2t-datagen.")

# Mapping from problems that we can generate data for to their generators.
# pylint: disable=g-long-lambda
_SUPPORTED_PROBLEM_GENERATORS = {
    "algorithmic_algebra_inverse":
    (lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
     lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000),
     lambda: None),  # test set
    "parsing_english_ptb8k":
    (lambda: wsj_parsing.parsing_token_generator(FLAGS.data_dir, FLAGS.tmp_dir,
                                                 True, 2**13, 2**9),
     lambda: wsj_parsing.parsing_token_generator(FLAGS.data_dir, FLAGS.tmp_dir,
                                                 False, 2**13, 2**9),
     lambda: None),  # test set
    "parsing_english_ptb16k":
    (lambda: wsj_parsing.parsing_token_generator(FLAGS.data_dir, FLAGS.tmp_dir,
                                                 True, 2**14, 2**9),
     lambda: wsj_parsing.parsing_token_generator(FLAGS.data_dir, FLAGS.tmp_dir,
                                                 False, 2**14, 2**9),
     lambda: None),  # test set
    "inference_snli32k":
예제 #4
0
flags.DEFINE_integer("task_id_end", -1, "For distributed data generation.")
flags.DEFINE_integer(
    "num_concurrent_processes", None,
    "Applies only to problems for which multiprocess_generate=True.")
flags.DEFINE_string("t2t_usr_dir", "",
                    "Path to a Python module that will be imported. The "
                    "__init__.py file should include the necessary imports. "
                    "The imported files should contain registrations, "
                    "e.g. @registry.register_problem calls, that will then be "
                    "available to t2t-datagen.")

# Mapping from problems that we can generate data for to their generators.
# pylint: disable=g-long-lambda
_SUPPORTED_PROBLEM_GENERATORS = {
    "algorithmic_algebra_inverse": (
        lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
        lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
    "parsing_english_ptb8k": (
        lambda: wsj_parsing.parsing_token_generator(
            FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13, 2**9),
        lambda: wsj_parsing.parsing_token_generator(
            FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13, 2**9)),
    "parsing_english_ptb16k": (
        lambda: wsj_parsing.parsing_token_generator(
            FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
        lambda: wsj_parsing.parsing_token_generator(
            FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)),
    "inference_snli32k": (
        lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
        lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
    ),