def test_programs(self): program1 = dsl.Concat(dsl.GetToken(dsl.Type.ALPHANUM, 3), dsl.GetFrom(':'), dsl.GetFirst(dsl.Type.CHAR, 4)) self.assertEqual(program1('Ud 9:25,JV3 Obb'), '2525,JV3 ObbUd 9') self.assertEqual(program1('zLny xmHg 8:43 A44q'), '843 A44qzLny') program2 = dsl.Concat( dsl.Compose( dsl.Replace(' ', ','), dsl.GetSpan(dsl.Type.PROP_CASE, 1, dsl.Boundary.START, dsl.Type.PROP_CASE, 4, dsl.Boundary.END)), dsl.ConstStr('.'), dsl.GetToken(dsl.Type.PROP_CASE, -1)) self.assertEqual(program2('Jacob Ethan James Alexander Michael'), 'Jacob,Ethan,James,Alexander.Michael') self.assertEqual(program2('Earth Fire Wind Water Pluto Sun'), 'Earth,Fire,Wind,Water.Sun')
def random_task_switch_concept_order( max_expressions, max_k, max_input_tokens, max_input_length, num_examples, is_train, min_expressions = 1): """Returns a sampled program and IO examples satisfying the given constraints. Args: max_expressions: Maximum number of concatenated expressions in the program. max_k: Maximum number of times a generated token can be repeated. max_input_tokens: Maximum number of unique tokens in the inputs. A token is either a constant string, or a sample from a regular expression. max_input_length: Maximum length of inputs to generate. num_examples: Number of input-output examples to generate. is_train: Whether to generate a task for train or test / finetune. min_expressions: Minimum number of concatenated expressions in the program. Returns: Input strings, output strings, and a program expression. """ max_output_length = max_input_length * max_expressions # Sample inputs. inputs, delimiter_dict, type_dict = sample_inputs( num_examples, max_input_tokens, max_k, max_input_length) # Sample program. assert min_expressions >= 2 n_expressions = random.randint(min_expressions, max_expressions) n_first_half_expressions = n_expressions // 2 n_second_half_expressions = n_expressions - n_first_half_expressions first_half_sampler_pool = ( ALL_SUBSTRING if is_train else SAMPLER_POOL_MODIFY_OR_CONST) second_half_sampler_pool = ( SAMPLER_POOL_MODIFY_OR_CONST if is_train else ALL_SUBSTRING) expression_list = [ random_expression(inputs, delimiter_dict, type_dict, sampler_pool=first_half_sampler_pool) for _ in range(n_first_half_expressions) ] + [ random_expression(inputs, delimiter_dict, type_dict, sampler_pool=second_half_sampler_pool) for _ in range(n_second_half_expressions) ] program = dsl.Concat(*expression_list) outputs = [program(inp) for inp in inputs] # Assert output lengths are ok. assert all(0 < len(out) <= max_output_length for out in outputs) return dsl.ProgramTask(program, inputs, outputs)
def random_task( max_expressions, max_k, max_input_tokens, max_input_length, max_output_length, num_examples, min_expressions=1, n_expressions=None, ): """Returns a sampled program and IO examples satisfying the given constraints. Args: max_expressions: Maximum number of concatenated expressions in the program. max_k: Maximum number of times a generated token can be repeated. max_input_tokens: Maximum number of unique tokens in the inputs. A token is either a constant string, or a sample from a regular expression. max_input_length: Maximum length of inputs to generate. max_output_length: Maximum length of outputs to generate. num_examples: Number of input-output examples to generate. min_expressions: Minimum number of concatenated expressions in the program. n_expressions: Fixed number of concatenated expressions (if provided) Returns: Input strings, output strings, and a program expression. """ # Sample inputs. inputs, delimiter_dict, type_dict = sample_inputs(num_examples, max_input_tokens, max_k, max_input_length) # Sample program. if not n_expressions: n_expressions = random.randint(min_expressions, max_expressions) while True: program = dsl.Concat(*[ random_expression(inputs, delimiter_dict, type_dict) for _ in range(n_expressions) ]) outputs = [program(inp) for inp in inputs] # Rejection step on output lengths. if ((max(len(out) for out in outputs) <= max_output_length) and (min(len(out) for out in outputs) > 0)): return dsl.ProgramTask(program, inputs, outputs)
def test_decode(self): id_token_table, token_id_table = tokens.build_token_tables() self.assertEqual(len(token_id_table), len(id_token_table)) program = dsl.Concat( dsl.Compose( dsl.Replace(' ', ','), dsl.GetSpan(dsl.Type.PROP_CASE, 1, dsl.Boundary.START, dsl.Type.PROP_CASE, 4, dsl.Boundary.END)), dsl.ConstStr('.'), dsl.GetToken(dsl.Type.PROP_CASE, -1)) encoding = program.encode(token_id_table) self.assertEqual(encoding[-1], token_id_table[dsl.EOS]) decoded_program = dsl.decode_program(encoding, id_token_table) self.assertEqual( decoded_program('Jacob Ethan James Alexander Michael'), 'Jacob,Ethan,James,Alexander.Michael') self.assertEqual(decoded_program('Earth Fire Wind Water Pluto Sun'), 'Earth,Fire,Wind,Water.Sun')
def random_task(max_expressions, max_k, max_input_tokens, max_input_length, num_examples, min_expressions = 1, n_expressions = None, sampler_pool=None, valid_num_expressions_fn=None, keep_fn=None): """Returns a sampled program and IO examples satisfying the given constraints. Args: max_expressions: Maximum number of concatenated expressions in the program. max_k: Maximum number of times a generated token can be repeated. max_input_tokens: Maximum number of unique tokens in the inputs. A token is either a constant string, or a sample from a regular expression. max_input_length: Maximum length of inputs to generate. num_examples: Number of input-output examples to generate. min_expressions: Minimum number of concatenated expressions in the program. n_expressions: Fixed number of concatenated expressions (if provided) sampler_pool: Pool of expression to sampled from (if None, all expressions are allowed). valid_num_expressions_fn: A function that returns True if the number of expressions is ok, or False if it should be rejected and re-sampled. keep_fn: A function that returns True if the Concat should be kept, or False if it should be rejected and re-sampled. Returns: Input strings, output strings, and a program expression. """ max_output_length = max_input_length * max_expressions # Sample inputs. inputs, delimiter_dict, type_dict = sample_inputs( num_examples, max_input_tokens, max_k, max_input_length) # Sample program. if not n_expressions: while True: n_expressions = random.randint(min_expressions, max_expressions) if (valid_num_expressions_fn is None or valid_num_expressions_fn(n_expressions)): break while True: program = dsl.Concat( *[random_expression(inputs, delimiter_dict, type_dict, sampler_pool=sampler_pool) for _ in range(n_expressions)]) outputs = [program(inp) for inp in inputs] # Assert output lengths are ok. if not all(0 < len(out) <= max_output_length for out in outputs): logging.error('Output length not ok') logging.error('program: %s', program) logging.error('inputs: %s', inputs) logging.error('outputs: %s', outputs) raise ValueError('Output lengths not ok') # Rejection step. if keep_fn is not None and not keep_fn(program): continue return dsl.ProgramTask(program, inputs, outputs)