예제 #1
0
    def test_programs(self):
        program1 = dsl.Concat(dsl.GetToken(dsl.Type.ALPHANUM, 3),
                              dsl.GetFrom(':'), dsl.GetFirst(dsl.Type.CHAR, 4))
        self.assertEqual(program1('Ud 9:25,JV3 Obb'), '2525,JV3 ObbUd 9')
        self.assertEqual(program1('zLny xmHg 8:43 A44q'), '843 A44qzLny')

        program2 = dsl.Concat(
            dsl.Compose(
                dsl.Replace(' ', ','),
                dsl.GetSpan(dsl.Type.PROP_CASE, 1, dsl.Boundary.START,
                            dsl.Type.PROP_CASE, 4, dsl.Boundary.END)),
            dsl.ConstStr('.'), dsl.GetToken(dsl.Type.PROP_CASE, -1))
        self.assertEqual(program2('Jacob Ethan James Alexander Michael'),
                         'Jacob,Ethan,James,Alexander.Michael')
        self.assertEqual(program2('Earth Fire Wind Water Pluto Sun'),
                         'Earth,Fire,Wind,Water.Sun')
예제 #2
0
def random_task_switch_concept_order(
    max_expressions,
    max_k,
    max_input_tokens,
    max_input_length,
    num_examples,
    is_train,
    min_expressions = 1):
  """Returns a sampled program and IO examples satisfying the given constraints.

  Args:
    max_expressions: Maximum number of concatenated expressions in the program.
    max_k: Maximum number of times a generated token can be repeated.
    max_input_tokens: Maximum number of unique tokens in the inputs. A token is
        either a constant string, or a sample from a regular expression.
    max_input_length: Maximum length of inputs to generate.
    num_examples: Number of input-output examples to generate.
    is_train: Whether to generate a task for train or test / finetune.
    min_expressions: Minimum number of concatenated expressions in the program.
  Returns:
    Input strings, output strings, and a program expression.
  """
  max_output_length = max_input_length * max_expressions

  # Sample inputs.
  inputs, delimiter_dict, type_dict = sample_inputs(
      num_examples, max_input_tokens, max_k, max_input_length)

  # Sample program.
  assert min_expressions >= 2
  n_expressions = random.randint(min_expressions, max_expressions)
  n_first_half_expressions = n_expressions // 2
  n_second_half_expressions = n_expressions - n_first_half_expressions
  first_half_sampler_pool = (
      ALL_SUBSTRING if is_train else SAMPLER_POOL_MODIFY_OR_CONST)
  second_half_sampler_pool = (
      SAMPLER_POOL_MODIFY_OR_CONST if is_train else ALL_SUBSTRING)

  expression_list = [
      random_expression(inputs, delimiter_dict, type_dict,
                        sampler_pool=first_half_sampler_pool)
      for _ in range(n_first_half_expressions)
  ] + [
      random_expression(inputs, delimiter_dict, type_dict,
                        sampler_pool=second_half_sampler_pool)
      for _ in range(n_second_half_expressions)
  ]
  program = dsl.Concat(*expression_list)

  outputs = [program(inp) for inp in inputs]
  # Assert output lengths are ok.
  assert all(0 < len(out) <= max_output_length for out in outputs)

  return dsl.ProgramTask(program, inputs, outputs)
예제 #3
0
def random_task(
    max_expressions,
    max_k,
    max_input_tokens,
    max_input_length,
    max_output_length,
    num_examples,
    min_expressions=1,
    n_expressions=None,
):
    """Returns a sampled program and IO examples satisfying the given constraints.

  Args:
    max_expressions: Maximum number of concatenated expressions in the program.
    max_k: Maximum number of times a generated token can be repeated.
    max_input_tokens: Maximum number of unique tokens in the inputs. A token is
        either a constant string, or a sample from a regular expression.
    max_input_length: Maximum length of inputs to generate.
    max_output_length: Maximum length of outputs to generate.
    num_examples: Number of input-output examples to generate.
    min_expressions: Minimum number of concatenated expressions in the program.
    n_expressions: Fixed number of concatenated expressions (if provided)
  Returns:
    Input strings, output strings, and a program expression.
  """

    # Sample inputs.
    inputs, delimiter_dict, type_dict = sample_inputs(num_examples,
                                                      max_input_tokens, max_k,
                                                      max_input_length)

    # Sample program.
    if not n_expressions:
        n_expressions = random.randint(min_expressions, max_expressions)
    while True:
        program = dsl.Concat(*[
            random_expression(inputs, delimiter_dict, type_dict)
            for _ in range(n_expressions)
        ])

        outputs = [program(inp) for inp in inputs]
        # Rejection step on output lengths.
        if ((max(len(out) for out in outputs) <= max_output_length)
                and (min(len(out) for out in outputs) > 0)):
            return dsl.ProgramTask(program, inputs, outputs)
예제 #4
0
    def test_decode(self):
        id_token_table, token_id_table = tokens.build_token_tables()
        self.assertEqual(len(token_id_table), len(id_token_table))
        program = dsl.Concat(
            dsl.Compose(
                dsl.Replace(' ', ','),
                dsl.GetSpan(dsl.Type.PROP_CASE, 1, dsl.Boundary.START,
                            dsl.Type.PROP_CASE, 4, dsl.Boundary.END)),
            dsl.ConstStr('.'), dsl.GetToken(dsl.Type.PROP_CASE, -1))
        encoding = program.encode(token_id_table)
        self.assertEqual(encoding[-1], token_id_table[dsl.EOS])

        decoded_program = dsl.decode_program(encoding, id_token_table)
        self.assertEqual(
            decoded_program('Jacob Ethan James Alexander Michael'),
            'Jacob,Ethan,James,Alexander.Michael')
        self.assertEqual(decoded_program('Earth Fire Wind Water Pluto Sun'),
                         'Earth,Fire,Wind,Water.Sun')
예제 #5
0
def random_task(max_expressions,
                max_k,
                max_input_tokens,
                max_input_length,
                num_examples,
                min_expressions = 1,
                n_expressions = None,
                sampler_pool=None,
                valid_num_expressions_fn=None,
                keep_fn=None):
  """Returns a sampled program and IO examples satisfying the given constraints.

  Args:
    max_expressions: Maximum number of concatenated expressions in the program.
    max_k: Maximum number of times a generated token can be repeated.
    max_input_tokens: Maximum number of unique tokens in the inputs. A token is
        either a constant string, or a sample from a regular expression.
    max_input_length: Maximum length of inputs to generate.
    num_examples: Number of input-output examples to generate.
    min_expressions: Minimum number of concatenated expressions in the program.
    n_expressions: Fixed number of concatenated expressions (if provided)
    sampler_pool: Pool of expression to sampled from (if None, all expressions
        are allowed).
    valid_num_expressions_fn: A function that returns True if the number of
        expressions is ok, or False if it should be rejected and re-sampled.
    keep_fn: A function that returns True if the Concat should be kept, or False
        if it should be rejected and re-sampled.
  Returns:
    Input strings, output strings, and a program expression.
  """
  max_output_length = max_input_length * max_expressions

  # Sample inputs.
  inputs, delimiter_dict, type_dict = sample_inputs(
      num_examples, max_input_tokens, max_k, max_input_length)

  # Sample program.
  if not n_expressions:
    while True:
      n_expressions = random.randint(min_expressions, max_expressions)
      if (valid_num_expressions_fn is None
          or valid_num_expressions_fn(n_expressions)):
        break

  while True:
    program = dsl.Concat(
        *[random_expression(inputs, delimiter_dict, type_dict,
                            sampler_pool=sampler_pool)
          for _ in range(n_expressions)])

    outputs = [program(inp) for inp in inputs]
    # Assert output lengths are ok.
    if not all(0 < len(out) <= max_output_length for out in outputs):
      logging.error('Output length not ok')
      logging.error('program: %s', program)
      logging.error('inputs: %s', inputs)
      logging.error('outputs: %s', outputs)
      raise ValueError('Output lengths not ok')

    # Rejection step.
    if keep_fn is not None and not keep_fn(program):
      continue

    return dsl.ProgramTask(program, inputs, outputs)