def testHillClimbingTask(self): task = test_tasks.BasicTaskManager(test_tasks.HillClimbingTask()) reward_fns = task.rl_batch(1) reward_fn = reward_fns[0] self.assertTrue(np.isclose(get_reward(reward_fn, [1, 2, 0]), 8 / 12.)) self.assertTrue( np.isclose(get_reward(reward_fn, [1, 2, 2, 0]), 11 / 12.)) self.assertTrue(np.isclose(get_reward(reward_fn, [1, 2, 3, 0]), 1.0)) self.assertTrue( np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 2, 0]), 1. + 8 / 12.)) self.assertTrue( np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 0]), 2.0)) self.assertTrue( np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 1, 8, 3, 0]), 3.0)) self.assertTrue( np.isclose(get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 7, 8, 7, 0]), 3.0)) self.assertTrue( np.isclose( get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 1, 8, 3, 1, 0]), 3.0 - 4 / 12.)) self.assertTrue( np.isclose( get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 1, 8, 3, 1, 1, 1, 1, 0]), 2.0)) self.assertTrue( np.isclose( get_reward(reward_fn, [1, 2, 3, 4, 5, 6, 7, 8, 7, 3, 0]), 3.0 + 1 / 12.)) self.assertTrue( np.isclose( get_reward(reward_fn, [ 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 8, 5, 1, 6, 4, 2, 1, 8, 3, 0 ]), 8.0)) self.assertTrue( np.isclose( get_reward(reward_fn, [ 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 8, 5, 1, 6, 4, 2, 1, 8, 3, 1, 1, 0 ]), 8.0 - 8 / 12.)) self.assertTrue( np.isclose( get_reward(reward_fn, [ 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 8, 5, 1, 6, 4, 2, 1, 8, 3, 1, 1, 1, 1, 1, 1, 1, 0 ]), 7.0))
def make_task(task_name, override_kwargs=None, max_code_length=100, require_correct_syntax=False, do_code_simplification=False, correct_bonus=2.0, code_length_bonus=1.0): """Make tasks with setting from paper.""" logging.info('Making paper-config task.') n = 16 # Number of test cases. task_mapping = { 'print-hello': ( PrintTask, dict(base=27, fixed_string=[8, 5, 12, 12, 15])), 'print': (PrintIntTask, dict(base=256, fixed_string=[1, 2, 3, 4, 5])), 'echo': (EchoTask, dict(base=27, min_length=1, max_length=6)), 'remove-char': ( RemoveCharTask, dict(base=256, n=n, min_len=1, max_len=6)), 'reverse': ( ReverseTask, dict(base=256, n=n, min_len=1, max_len=6)), 'reverse-tune': ( ReverseTaskV2, dict(base=256, reward_type='static-bylen')), 'remove-char-tune': (RemoveCharTaskV2, dict(base=27)), 'prefix': (CommonPrefixTask, dict(base=27)), 'find': (FindSubStrTask, dict(base=27)), 'sort3': (SortFixedTaskV2, dict(base=27, n=150, length=3)), 'count-char': (CountCharTaskV2, dict(n=n, max_len=6)), 'bool-logic': (BooleanLogicTask, dict()), 'add': (AddTask, dict(n=9)), 'echo-twice': (EchoTwiceTask, dict(n=n)), 'echo-thrice': (EchoThriceTask, dict(n=n)), 'copy-reverse': (CopyReverseTask, dict(n=n)), 'zero-cascade': (EchoZeroCascadeTask, dict(n=n)), 'cascade': (EchoCascadeTask, dict(n=n)), 'shift-left': (ShiftLeftTask, dict(n=n)), 'shift-right': (ShiftRightTask, dict(n=n)), 'riffle': (RiffleTask, dict(n=n)), 'unriffle': (UnriffleTask, dict(n=n)), 'middle-char': (MiddleCharTask, dict(n=n)), 'remove-last': (RemoveLastTask, dict(n=n)), 'remove-last-two': (RemoveLastTwoTask, dict(n=n)), 'echo-alternating': (EchoAlternatingTask, dict(n=n)), 'echo-half': (EchoHalfTask, dict(n=n)), 'length': (LengthTask, dict(n=n)), 'echo-second-seq': (EchoSecondSequenceTask, dict(n=n)), 'echo-nth-seq': (EchoNthSequenceTask, dict(n=n)), 'substring': (SubstringTask, dict(n=n)), 'divide-2': (Divide2Task, dict(n=n)), 'dedup': (DedupTask, dict(n=n)), 'remove-target-char': (RemoveTargetCharTask, dict(n=n)), 'list-index': (ListIndexTask, dict(n=n)), 'fib': (FibonacciTask, dict()), 'count-down': (BottlesOfBeerTask, dict()), 'split': (SplitTask, dict()), 'trim-left': (TrimLeftTask, dict()), 'circle-route': ( JudgeRouteCircleTask, dict(n=100, max_len=32)), 'multiply': (MultiplyTask, dict(n=100)), 'divmod': (DivModTask, dict(n=100)), } if task_name not in task_mapping: # Test tasks. if task_name == 'test-hill-climb': return test_tasks.BasicTaskManager(test_tasks.HillClimbingTask()) raise ValueError('Unknown task type "%s"' % task_name) task_cls, kwargs = task_mapping[task_name] if override_kwargs: if not isinstance(override_kwargs, dict): raise ValueError( 'override_kwargs must be a dict, got: %s', override_kwargs) kwargs.update(override_kwargs) task = task_cls(**kwargs) reward_fn = r.absolute_distance_reward # reward_fn = r.absolute_mod_distance_reward # reward_fn = r.absolute_log_distance_reward logging.info('Using reward function: %s', reward_fn.__name__) # We want reward with and without code simplification to be scaled the same # way. Without code simplification, give the maximum code length bonus # every time. min_code_length = 0.0 if do_code_simplification else max_code_length return MultiIOTaskManager( task=task, correct_bonus=correct_bonus, code_length_bonus=code_length_bonus, max_code_length=max_code_length, min_code_length=min_code_length, reward_fn=reward_fn, require_correct_syntax=require_correct_syntax)