Exemple #1
0
def main(_):
    tf.enable_v2_behavior()

    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    _, token_id_table = dsl_tokens.build_token_tables()

    if not gfile.isdir(FLAGS.save_dir):
        gfile.mkdir(FLAGS.save_dir)

    worker_fname = os.path.join(FLAGS.save_dir,
                                'program_tasks.tf_records-00000-of-00001')

    # Write the `tf.Example` observations to the file.
    with tf.io.TFRecordWriter(worker_fname) as writer:
        for _ in range(FLAGS.num_tasks):
            task = sample_random.random_task(
                max_expressions=FLAGS.max_expressions,
                min_expressions=FLAGS.min_expressions,
                max_k=5,
                max_input_tokens=10,
                max_input_length=FLAGS.max_characters,
                max_output_length=FLAGS.max_characters,
                num_examples=FLAGS.num_strings_per_task,
            )
            example = serialize_example(task, token_id_table)
            writer.write(example)
    def test_sample_random(self):
        for _ in range(10):
            example = sample_random.random_task(max_expressions=10,
                                                max_k=5,
                                                max_input_tokens=10,
                                                max_input_length=20,
                                                num_examples=4)
            self.assertGreater(min(len(out) for out in example.outputs), 0)

            outputs = [example.program(inp) for inp in example.inputs]
            self.assertListEqual(outputs, example.outputs)
def main(_):
    tf.enable_v2_behavior()

    if FLAGS.seed is not None:
        tf.random.set_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        random.seed(FLAGS.seed)

    _, token_id_table = dsl_tokens.build_token_tables()

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    shard_id = 0
    total_shards = 1

    entire_programs_fname = os.path.join(
        FLAGS.save_dir,
        'entire_programs_{}.tf_records-{:05d}-of-{:05d}'.format(
            FLAGS.split, shard_id, total_shards))
    decomposition_data_fname = os.path.join(
        FLAGS.save_dir,
        'decomposition_data_{}.tf_records-{:05d}-of-{:05d}'.format(
            FLAGS.split, shard_id, total_shards))

    # Write the `tf.Example` observations to the file.
    with tf.io.TFRecordWriter(entire_programs_fname) as entire_programs_writer, \
        tf.io.TFRecordWriter(decomposition_data_fname) as decomposition_data_writer:
        for i in range(FLAGS.num_tasks):
            if FLAGS.experiment == exp_module.Experiment.NONE.name:
                task = sample_random.random_task(
                    max_expressions=FLAGS.max_expressions,
                    min_expressions=FLAGS.min_expressions,
                    max_k=3,
                    max_input_tokens=5,
                    max_input_length=FLAGS.max_input_length,
                    num_examples=FLAGS.num_strings_per_task)
            else:
                if FLAGS.split in ['train', 'valid']:
                    is_train = True
                elif FLAGS.split == 'test':
                    is_train = False
                elif FLAGS.split == 'finetune':
                    is_train = bool(i % 2)
                else:
                    raise ValueError('Unhandled split: {}'.format(FLAGS.split))
                task = generate_task_for_experiment(FLAGS.experiment, is_train)

            entire_programs_writer.write(
                serialize_entire_program_example(task, token_id_table))
            for example in serialize_decomposition_examples(
                    task, token_id_table):
                decomposition_data_writer.write(example)
def main(_):
    tf.enable_v2_behavior()

    if FLAGS.seed is not None:
        tf.random.set_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        random.seed(FLAGS.seed)

    _, token_id_table = dsl_tokens.build_token_tables()

    if not gfile.isdir(FLAGS.save_dir):
        gfile.makedirs(FLAGS.save_dir)

    worker_fname = os.path.join(
        FLAGS.save_dir,
        'program_tasks_{}.tf_records-00000-of-00001'.format(FLAGS.split))

    # Write the `tf.Example` observations to the file.
    with tf.io.TFRecordWriter(worker_fname) as writer:
        for i in range(FLAGS.num_tasks):
            if FLAGS.experiment == exp_module.Experiment.NONE:
                task = sample_random.random_task(
                    max_expressions=FLAGS.max_expressions,
                    min_expressions=FLAGS.min_expressions,
                    max_k=3,
                    max_input_tokens=5,
                    max_input_length=FLAGS.max_input_length,
                    num_examples=FLAGS.num_strings_per_task)
            else:
                if FLAGS.split in ['train', 'valid']:
                    is_train = True
                elif FLAGS.split == 'test':
                    is_train = False
                elif FLAGS.split == 'finetune':
                    is_train = bool(i % 2)
                else:
                    raise ValueError('Unhandled split: {}'.format(FLAGS.split))
                task = generate_task_for_experiment(FLAGS.experiment, is_train)

            example = serialize_example(task, token_id_table)
            writer.write(example)