Exemple #1
0
 def testHillClimbingTask(self):
     task = test_tasks.BasicTaskManager(test_tasks.HillClimbingTask())
     reward_fns = task.rl_batch(1)
     reward_fn = reward_fns[0]
     self.assertTrue(np.isclose(get_reward(reward_fn, [1, 2, 0]), 8 / 12.))
     self.assertTrue(
         np.isclose(get_reward(reward_fn, [1, 2, 2, 0]), 11 / 12.))
     self.assertTrue(np.isclose(get_reward(reward_fn, [1, 2, 3, 0]), 1.0))
     self.assertTrue(
         np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 2, 0]),
                    1. + 8 / 12.))
     self.assertTrue(
         np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 0]), 2.0))
     self.assertTrue(
         np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 1, 8, 3, 0]),
                    3.0))
     self.assertTrue(
         np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 7, 8, 7, 0]),
                    3.0))
     self.assertTrue(
         np.isclose(
             get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 1, 8, 3, 1, 0]),
             3.0 - 4 / 12.))
     self.assertTrue(
         np.isclose(
             get_reward(reward_fn,
                        [1, 2, 3, 4, 5, 6, 1, 8, 3, 1, 1, 1, 1, 0]), 2.0))
     self.assertTrue(
         np.isclose(
             get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 7, 8, 7, 3, 0]),
             3.0 + 1 / 12.))
     self.assertTrue(
         np.isclose(
             get_reward(reward_fn, [
                 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 8, 5, 1, 6, 4,
                 2, 1, 8, 3, 0
             ]), 8.0))
     self.assertTrue(
         np.isclose(
             get_reward(reward_fn, [
                 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 8, 5, 1, 6, 4,
                 2, 1, 8, 3, 1, 1, 0
             ]), 8.0 - 8 / 12.))
     self.assertTrue(
         np.isclose(
             get_reward(reward_fn, [
                 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 8, 5, 1, 6, 4,
                 2, 1, 8, 3, 1, 1, 1, 1, 1, 1, 1, 0
             ]), 7.0))
def make_task(task_name, override_kwargs=None, max_code_length=100,
              require_correct_syntax=False,
              do_code_simplification=False,
              correct_bonus=2.0, code_length_bonus=1.0):
  """Make tasks with setting from paper."""
  logging.info('Making paper-config task.')
  n = 16  # Number of test cases.
  task_mapping = {
      'print-hello': (
          PrintTask, dict(base=27, fixed_string=[8, 5, 12, 12, 15])),
      'print': (PrintIntTask, dict(base=256, fixed_string=[1, 2, 3, 4, 5])),
      'echo': (EchoTask, dict(base=27, min_length=1, max_length=6)),
      'remove-char': (
          RemoveCharTask, dict(base=256, n=n, min_len=1, max_len=6)),
      'reverse': (
          ReverseTask, dict(base=256, n=n, min_len=1, max_len=6)),
      'reverse-tune': (
          ReverseTaskV2, dict(base=256, reward_type='static-bylen')),
      'remove-char-tune': (RemoveCharTaskV2, dict(base=27)),
      'prefix': (CommonPrefixTask, dict(base=27)),
      'find': (FindSubStrTask, dict(base=27)),
      'sort3': (SortFixedTaskV2, dict(base=27, n=150, length=3)),
      'count-char': (CountCharTaskV2, dict(n=n, max_len=6)),
      'bool-logic': (BooleanLogicTask, dict()),
      'add': (AddTask, dict(n=9)),
      'echo-twice': (EchoTwiceTask, dict(n=n)),
      'echo-thrice': (EchoThriceTask, dict(n=n)),
      'copy-reverse': (CopyReverseTask, dict(n=n)),
      'zero-cascade': (EchoZeroCascadeTask, dict(n=n)),
      'cascade': (EchoCascadeTask, dict(n=n)),
      'shift-left': (ShiftLeftTask, dict(n=n)),
      'shift-right': (ShiftRightTask, dict(n=n)),
      'riffle': (RiffleTask, dict(n=n)),
      'unriffle': (UnriffleTask, dict(n=n)),
      'middle-char': (MiddleCharTask, dict(n=n)),
      'remove-last': (RemoveLastTask, dict(n=n)),
      'remove-last-two': (RemoveLastTwoTask, dict(n=n)),
      'echo-alternating': (EchoAlternatingTask, dict(n=n)),
      'echo-half': (EchoHalfTask, dict(n=n)),
      'length': (LengthTask, dict(n=n)),
      'echo-second-seq': (EchoSecondSequenceTask, dict(n=n)),
      'echo-nth-seq': (EchoNthSequenceTask, dict(n=n)),
      'substring': (SubstringTask, dict(n=n)),
      'divide-2': (Divide2Task, dict(n=n)),
      'dedup': (DedupTask, dict(n=n)),
      'remove-target-char': (RemoveTargetCharTask, dict(n=n)),
      'list-index': (ListIndexTask, dict(n=n)),
      'fib': (FibonacciTask, dict()),
      'count-down': (BottlesOfBeerTask, dict()),
      'split': (SplitTask, dict()),
      'trim-left': (TrimLeftTask, dict()),
      'circle-route': (
          JudgeRouteCircleTask, dict(n=100, max_len=32)),
      'multiply': (MultiplyTask, dict(n=100)),
      'divmod': (DivModTask, dict(n=100)),
  }

  if task_name not in task_mapping:
    # Test tasks.
    if task_name == 'test-hill-climb':
      return test_tasks.BasicTaskManager(test_tasks.HillClimbingTask())
    raise ValueError('Unknown task type "%s"' % task_name)
  task_cls, kwargs = task_mapping[task_name]

  if override_kwargs:
    if not isinstance(override_kwargs, dict):
      raise ValueError(
          'override_kwargs must be a dict, got: %s', override_kwargs)
    kwargs.update(override_kwargs)

  task = task_cls(**kwargs)

  reward_fn = r.absolute_distance_reward
  # reward_fn = r.absolute_mod_distance_reward
  # reward_fn = r.absolute_log_distance_reward
  logging.info('Using reward function: %s', reward_fn.__name__)

  # We want reward with and without code simplification to be scaled the same
  # way. Without code simplification, give the maximum code length bonus
  # every time.
  min_code_length = 0.0 if do_code_simplification else max_code_length

  return MultiIOTaskManager(
      task=task, correct_bonus=correct_bonus,
      code_length_bonus=code_length_bonus,
      max_code_length=max_code_length, min_code_length=min_code_length,
      reward_fn=reward_fn, require_correct_syntax=require_correct_syntax)