Example #1
0
def test_makedirs(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.join(path, "x", "x", "x")
        bf.makedirs(dirpath)
        assert bf.exists(dirpath)
        _write_contents(bf.join(dirpath, "testfile"), contents)
Example #2
0
def test_glob(ctx, parallel):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "ab")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "bb")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)

        def assert_listing_equal(path, desired):
            desired = sorted([bf.join(dirpath, p) for p in desired])
            actual = sorted(list(bf.glob(path, parallel=parallel)))
            assert actual == desired, f"{actual} != {desired}"

        assert_listing_equal(bf.join(dirpath, "*b"), ["ab", "bb"])
        assert_listing_equal(bf.join(dirpath, "a*"), ["ab"])
        assert_listing_equal(bf.join(dirpath, "ab*"), ["ab"])
        assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb"])
        assert_listing_equal(bf.join(dirpath, "bb"), ["bb"])

        path = bf.join(dirpath, "test.txt")
        with bf.BlobFile(path, "wb") as w:
            w.write(contents)
        path = bf.join(dirpath, "subdir", "test.txt")
        bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb") as f:
            f.write(contents)
        path = bf.join(dirpath, "subdir", "subsubdir", "test.txt")
        if "://" not in path:
            # implicit directory
            bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb") as f:
            f.write(contents)

        assert_listing_equal(bf.join(dirpath, "*/test.txt"), ["subdir/test.txt"])
        assert_listing_equal(bf.join(dirpath, "*/*.txt"), ["subdir/test.txt"])
        if "://" in path:
            # local glob doesn't handle ** the same way as remote glob
            assert_listing_equal(
                bf.join(dirpath, "**.txt"),
                ["test.txt", "subdir/test.txt", "subdir/subsubdir/test.txt"],
            )
        else:
            assert_listing_equal(bf.join(dirpath, "**.txt"), ["test.txt"])
        assert_listing_equal(bf.join(dirpath, "*/test"), [])
        assert_listing_equal(bf.join(dirpath, "subdir/test.txt"), ["subdir/test.txt"])

        # directories
        assert_listing_equal(bf.join(dirpath, "*"), ["ab", "bb", "subdir", "test.txt"])
        assert_listing_equal(bf.join(dirpath, "subdir"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "subdir/"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "*/"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "*dir"), ["subdir"])
        assert_listing_equal(bf.join(dirpath, "subdir/*dir"), ["subdir/subsubdir"])
        assert_listing_equal(bf.join(dirpath, "subdir/*dir/"), ["subdir/subsubdir"])
        assert_listing_equal(bf.join(dirpath, "su*ir/*dir/"), ["subdir/subsubdir"])
Example #3
0
def test_isdir(ctx):
    contents = b"meow!"
    with ctx() as path:
        assert not bf.isdir(path)
        _write_contents(path, contents)
        assert not bf.isdir(path)
        dirpath = path + ".dir"
        bf.makedirs(dirpath)
        assert bf.isdir(dirpath)
        assert not bf.isdir(dirpath[:-1])
Example #4
0
def test_listdir(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "a")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "b")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c"))
        assert sorted(list(bf.listdir(dirpath))) == ["a", "b", "c"]
Example #5
0
def tmp_google_dir():
    random_id = "".join(
        random.choice(string.ascii_lowercase) for _ in range(16))
    try:
        tmpdir = GOOGLE_TEST_BASE / random_id
        blobfile.makedirs(str(tmpdir))
        yield tmpdir
    finally:
        try:
            blobfile.rmtree(str(tmpdir))
        except NotADirectoryError:
            pass
Example #6
0
def test_read_write(ctx, streaming):
    contents = b"meow!\npurr\n"
    with ctx() as path:
        path = bf.join(path, "a folder", "a.file")
        bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb", streaming=streaming) as w:
            w.write(contents)
        with bf.BlobFile(path, "rb", streaming=streaming) as r:
            assert r.read() == contents
        with bf.BlobFile(path, "rb", streaming=streaming) as r:
            lines = list(r)
            assert b"".join(lines) == contents
Example #7
0
def test_rmtree(ctx):
    contents = b"meow!"
    with ctx() as path:
        root = bf.dirname(path)
        destroy_path = bf.join(root, "destroy")
        bf.makedirs(destroy_path)
        save_path = bf.join(root, "save")
        bf.makedirs(save_path)

        # implicit dir
        if not "://" in path:
            bf.makedirs(bf.join(destroy_path, "adir"))
        with bf.BlobFile(bf.join(destroy_path, "adir/b"), "wb") as w:
            w.write(contents)

        # explicit dir
        bf.makedirs(bf.join(destroy_path, "bdir"))
        with bf.BlobFile(bf.join(destroy_path, "bdir/b"), "wb") as w:
            w.write(contents)

        bf.makedirs(bf.join(save_path, "somedir"))
        with bf.BlobFile(bf.join(save_path, "somefile"), "wb") as w:
            w.write(contents)

        def assert_listing_equal(path, desired):
            actual = list(bf.walk(path))
            # ordering of os walk is weird, only compare sorted order
            assert sorted(actual) == sorted(desired), f"{actual} != {desired}"

        assert_listing_equal(
            root,
            [
                (root, ["destroy", "save"], []),
                (destroy_path, ["adir", "bdir"], []),
                (bf.join(destroy_path, "adir"), [], ["b"]),
                (bf.join(destroy_path, "bdir"), [], ["b"]),
                (save_path, ["somedir"], ["somefile"]),
                (bf.join(save_path, "somedir"), [], []),
            ],
        )

        bf.rmtree(destroy_path)

        assert_listing_equal(
            root,
            [
                (root, ["save"], []),
                (save_path, ["somedir"], ["somefile"]),
                (bf.join(save_path, "somedir"), [], []),
            ],
        )
Example #8
0
def test_az_path():
    contents = b"meow!\npurr\n"
    with _get_temp_as_path() as path:
        path = _convert_https_to_az(path)
        path = bf.join(path, "a folder", "a.file")
        path = _convert_https_to_az(path)
        bf.makedirs(_convert_https_to_az(bf.dirname(path)))
        with bf.BlobFile(path, "wb") as w:
            w.write(contents)
        with bf.BlobFile(path, "rb") as r:
            assert r.read() == contents
        with bf.BlobFile(path, "rb") as r:
            lines = list(r)
            assert b"".join(lines) == contents
Example #9
0
def test_walk(ctx, topdown):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "a")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c/d"))
        b_path = bf.join(dirpath, "c/d/b")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        expected = [
            (dirpath, ["c"], ["a"]),
            (bf.join(dirpath, "c"), ["d"], []),
            (bf.join(dirpath, "c", "d"), [], ["b"]),
        ]
        if not topdown:
            expected = list(reversed(expected))
        assert list(bf.walk(dirpath, topdown=topdown)) == expected
Example #10
0
def test_scandir(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "a")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "b")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c"))
        entries = sorted(list(bf.scandir(dirpath)))
        assert [e.name for e in entries] == ["a", "b", "c"]
        assert [e.path for e in entries
                ] == [bf.join(dirpath, name) for name in ["a", "b", "c"]]
        assert [e.is_dir for e in entries] == [False, False, True]
        assert [e.is_file for e in entries] == [True, True, False]
        assert entries[0].stat.size == len(contents)
        assert entries[1].stat.size == len(contents)
        assert entries[2].stat is None
Example #11
0
def test_isdir(ctx):
    contents = b"meow!"
    with ctx() as path:
        assert not bf.isdir(path)
        _write_contents(path, contents)
        assert not bf.isdir(path)

        dirpath = path + ".dir"
        bf.makedirs(dirpath)
        assert bf.isdir(dirpath)
        assert not bf.isdir(dirpath[:-1])

        filepath = bf.join(path + ".otherdir", "subdir", "file.name")
        if "://" not in path:
            # implicit directory
            bf.makedirs(bf.dirname(filepath))
        dirpath = bf.dirname(bf.dirname(filepath))
        _write_contents(filepath, contents)
        assert bf.isdir(dirpath)
        assert not bf.isdir(dirpath[:-1])
Example #12
0
def test_scanglob(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        a_path = bf.join(dirpath, "ab")
        with bf.BlobFile(a_path, "wb") as w:
            w.write(contents)
        b_path = bf.join(dirpath, "bb")
        with bf.BlobFile(b_path, "wb") as w:
            w.write(contents)
        path = bf.join(dirpath, "test.txt")
        with bf.BlobFile(path, "wb") as w:
            w.write(contents)
        path = bf.join(dirpath, "subdir", "test.txt")
        bf.makedirs(bf.dirname(path))
        with bf.BlobFile(path, "wb") as f:
            f.write(contents)

        entries = sorted(list(bf.scanglob(bf.join(dirpath, "*b*"))))
        assert entries[0].name == "ab" and entries[0].is_file
        assert entries[1].name == "bb" and entries[1].is_file
        assert entries[2].name == "subdir" and entries[2].is_dir
Example #13
0
def test_listdir_sharded(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        with bf.BlobFile(bf.join(dirpath, "a"), "wb") as w:
            w.write(contents)
        with bf.BlobFile(bf.join(dirpath, "aa"), "wb") as w:
            w.write(contents)
        with bf.BlobFile(bf.join(dirpath, "b"), "wb") as w:
            w.write(contents)
        with bf.BlobFile(bf.join(dirpath, "ca"), "wb") as w:
            w.write(contents)
        bf.makedirs(bf.join(dirpath, "c"))
        with bf.BlobFile(bf.join(dirpath, "c/a"), "wb") as w:
            w.write(contents)
        # this should also test shard_prefix_length=2 but that takes too long
        assert sorted(list(bf.listdir(dirpath, shard_prefix_length=1))) == [
            "a",
            "aa",
            "b",
            "c",
            "ca",
        ]
Example #14
0
def test_rmdir(ctx):
    contents = b"meow!"
    with ctx() as path:
        dirpath = bf.dirname(path)
        # this is an error for a local path but not for a blob path
        bf.rmdir(bf.join(dirpath, "fakedirname"))
        new_dirpath = bf.join(dirpath, "dirname")
        bf.makedirs(new_dirpath)
        assert bf.exists(new_dirpath)
        bf.rmdir(new_dirpath)
        assert not bf.exists(new_dirpath)

        # double delete is fine
        bf.rmdir(new_dirpath)

        # implicit dir
        new_filepath = bf.join(dirpath, "dirname", "name")
        _write_contents(new_filepath, contents)
        with pytest.raises(OSError):
            # not empty dir
            bf.rmdir(new_dirpath)
        bf.remove(new_filepath)
        bf.rmdir(new_dirpath)
Example #15
0
def main(H: HParams):
    layout = H.model_spec.run_params.all_gpu_layout()

    # Instantiate policy
    policy = Policy(task_hparams=H.task, spec=H.model_spec, layout=layout)

    if H.orig_model_spec:
        assert H.orig_model_spec.run_params.n_shards == H.model_spec.run_params.n_shards
        orig_policy = Policy(task_hparams=H.task,
                             spec=H.orig_model_spec,
                             layout=layout)
    else:
        orig_policy = None

    encoder = policy.encoder
    response_encoder = ResponseEncoder(H.task.response, encoder)
    setup_logging_with_pacific_tz()

    act_dtype = torch.float16 if H.fp16_activations else torch.float32

    is_logging_rank = layout.is_logging_rank

    total_queries_per_replica = exact_div(H.num_queries, layout.n_replicas)
    num_runs = exact_div(total_queries_per_replica,
                         H.queries_per_run_per_replica)

    input_iter = task_data.get_iter_for_task(
        H.task,
        encoder=encoder,
        dataset_split=H.query_dataset_split,
        batch_size=H.queries_per_run_per_replica,
        layout=layout,
        seed=H.seed,
        all_fields=True,
    )

    log_dir = os.getenv("OUTPUT_DIR") or os.path.join("/tmp/jobs",
                                                      os.getenv("JOB_NAME"))

    results_dir = os.path.join(log_dir, "results")
    bf.makedirs(results_dir)
    with open(os.path.join(log_dir, "hparams.json"), "w") as f:
        json.dump(H.to_json(), f, indent=2)

    if is_logging_rank:
        with open(os.path.join(results_dir, "task_hparams.json"), "w") as f:
            json.dump(H.task.to_json(), f)
        with open(os.path.join(results_dir, "hparams.json"), "w") as f:
            json.dump(H.to_json(), f)

    # Creates files for printing. Only the replica root prints the files
    local_file_name = os.devnull
    if layout.is_replica_root:
        fname = f"samples.{layout.replica_idx}.jsonl"
        local_file_name = os.path.join(results_dir, fname)
        print(f"Samples will be written to {local_file_name}")

    def prepare_eval_fn_and_inputs(tokens):
        def eval_fn(outputs_mb, eval_inputs_mb):
            logprobs = label_logprobs(logits=outputs_mb["logits"],
                                      labels=eval_inputs_mb["labels"])
            logprobs = torch.masked_fill(logprobs, eval_inputs_mb["mask"],
                                         INVALID_LOGPROB)
            return dict(logprobs=logprobs)

        mask = tokens == PADDING_TOKEN
        return eval_fn, dict(labels=torch.masked_fill(tokens, mask, 0),
                             mask=mask)

    runs_per_query = exact_div(H.responses_per_query,
                               H.responses_per_query_per_batch)
    with open(local_file_name, "w") as f:
        for run_idx in range(num_runs):
            with Timer() as timer:
                input = next(input_iter)
                context_tokens = input["context"]["tokens"]
                assert_shape_eq(
                    context_tokens,
                    (H.queries_per_run_per_replica, H.task.query.length),
                    "Context tokens shape mismatch",
                )
                ref_tokens = input["reference"]["tokens"].unsqueeze(1)
                assert_shape_eq(
                    ref_tokens,
                    (H.queries_per_run_per_replica, 1, H.task.response.length),
                    "Ref tokens shape mismatch",
                )

                # Sample from policy
                all_sample_results = []
                for _ in range(runs_per_query):
                    sample_results = policy.sample(
                        context_tokens,
                        responses_per_query=H.responses_per_query_per_batch,
                        sample_H=H.sample,
                        act_dtype=act_dtype,
                    )
                    assert_shape_eq(
                        sample_results["samples"],
                        (
                            H.queries_per_run_per_replica,
                            H.responses_per_query_per_batch,
                            H.task.response.length,
                        ),
                        "Samples size mismatch",
                    )

                    processed_samples = response_encoder.process_responses(
                        sample_results["samples"])

                    sample_results["processed_samples"] = processed_samples
                    assert_shape_eq(
                        processed_samples,
                        (
                            H.queries_per_run_per_replica,
                            H.responses_per_query_per_batch,
                            H.task.response.length,
                        ),
                        "Samples size mismatch",
                    )
                    sample_results["logprobs"] = torch.masked_fill(
                        sample_results["logprobs"],
                        processed_samples == PADDING_TOKEN,
                        INVALID_LOGPROB,
                    )
                    if orig_policy is not None:
                        eval_fn, eval_inputs = prepare_eval_fn_and_inputs(
                            processed_samples)
                        orig_eval_results = orig_policy.eval(
                            context_tokens,
                            processed_samples,
                            eval_fn=eval_fn,
                            eval_inputs=eval_inputs,
                            act_dtype=act_dtype,
                        )
                        sample_results["orig_eval_results"] = orig_eval_results

                    sample_results = map_nested(sample_results,
                                                lambda x: x.cpu().numpy())
                    all_sample_results.append(sample_results)

                eval_fn, eval_inputs = prepare_eval_fn_and_inputs(ref_tokens)
                ref_eval_results = policy.eval(
                    context_tokens,
                    ref_tokens,
                    eval_fn=eval_fn,
                    eval_inputs=eval_inputs,
                    act_dtype=act_dtype,
                )

                if orig_policy is not None:
                    orig_ref_eval_results = orig_policy.eval(
                        context_tokens,
                        ref_tokens,
                        eval_fn=eval_fn,
                        eval_inputs=eval_inputs,
                        act_dtype=act_dtype,
                    )

                if layout.is_replica_root:
                    for batch_idx in range(H.queries_per_run_per_replica):
                        context_tokens = sample_results["contexts"][batch_idx]
                        context = encoder.decode(context_tokens)

                        # Dump to a file so that we can use things in downstream tasks
                        # samples (written to file) is now a list of strings
                        d = dict(context=context,
                                 context_tokens=context_tokens)
                        d["sample_tokens"] = np.concatenate(
                            [
                                sample_results["processed_samples"][batch_idx]
                                for sample_results in all_sample_results
                            ],
                            axis=0,
                        )
                        assert_shape_eq(
                            d["sample_tokens"],
                            (H.responses_per_query, H.task.response.length),
                            "Sample tokens shape mismatch",
                        )
                        d["samples"] = response_encoder.decode_responses(
                            d["sample_tokens"])
                        d["logprobs"] = np.concatenate(
                            [
                                sample_results["logprobs"][batch_idx]
                                for sample_results in all_sample_results
                            ],
                            axis=0,
                        )
                        assert_shape_eq(
                            d["logprobs"],
                            (H.responses_per_query, H.task.response.length),
                            "Logprobs shape mismatch",
                        )
                        if orig_policy is not None:
                            d["orig_logprobs"] = np.concatenate(
                                [
                                    sample_results["orig_eval_results"]
                                    ["eval_stats"]["logprobs"][batch_idx]
                                    for sample_results in all_sample_results
                                ],
                                axis=0,
                            )
                            assert_shape_eq(
                                d["orig_logprobs"],
                                (H.responses_per_query,
                                 H.task.response.length),
                                "Orig logprobs shape mismatch",
                            )

                        # Process ref_tokens (H.queries_per_run_per_replica, H.task.response.length)
                        d["ref_tokens"] = ref_tokens[batch_idx].squeeze(
                            0).cpu().numpy()
                        d["ref"] = response_encoder.decode_response(
                            d["ref_tokens"])
                        assert_eq(
                            len(d["ref_tokens"]),
                            H.task.response.length,
                            "Ref tokens shape mismatch",
                        )
                        d["ref_logprobs"] = (
                            ref_eval_results["eval_stats"]["logprobs"]
                            [batch_idx].squeeze(0).cpu().numpy())
                        assert_eq(
                            len(d["ref_logprobs"]),
                            H.task.response.length,
                            "Ref logprobs shape mismatch",
                        )
                        if orig_policy is not None:
                            d["orig_ref_logprobs"] = (
                                orig_ref_eval_results["eval_stats"]["logprobs"]
                                [batch_idx].squeeze(0).cpu().numpy())
                            assert_eq(
                                len(d["orig_ref_logprobs"]),
                                H.task.response.length,
                                "Orig ref Logprobs shape mismatch",
                            )
                        if "extra_fields" in input:
                            d["extra_fields"] = input["extra_fields"][
                                batch_idx]

                        print("=" * 80)
                        replica_sample_idx = run_idx * H.queries_per_run_per_replica + batch_idx
                        print(
                            f"RESULT {replica_sample_idx} of {total_queries_per_replica}"
                        )
                        print(f"CONTEXT:")
                        print(context)
                        print(f"REF:")
                        print(d["ref"])
                        print("avg logprob", avg_negative(d["ref_logprobs"]))
                        if orig_policy is not None:
                            print("avg orig logprob",
                                  avg_negative(d["orig_ref_logprobs"]))
                        for sample_idx in range(H.responses_per_query):
                            print(f"SAMPLE {sample_idx}:")
                            print(d["samples"][sample_idx])
                            print("avg logprob",
                                  avg_negative(d["logprobs"][sample_idx]))
                            if orig_policy is not None:
                                print(
                                    "avg orig logprob",
                                    avg_negative(
                                        d["orig_logprobs"][sample_idx]))

                        f.write((json.dumps(jsonl_encoding.encode_example(d)) +
                                 "\n"))
            if layout.is_replica_root:
                print(
                    f"Batch {run_idx+1} of {num_runs}.  Took {timer.interval} seconds"
                )

    return dict(output_path=results_dir)
Example #16
0
def test_invalid_paths(base_path):
    for suffix in ["", "/", "//", "/invalid.file", "/invalid/dir/"]:
        path = base_path + suffix
        print(path)
        if path.endswith("/"):
            expected_error = IsADirectoryError
        else:
            expected_error = FileNotFoundError
        list(bf.glob(path))
        if suffix == "":
            for pattern in ["*", "**"]:
                try:
                    list(bf.glob(path + pattern))
                except bf.Error as e:
                    assert "Wildcards cannot be used" in e.message
        else:
            for pattern in ["*", "**"]:
                list(bf.glob(path + pattern))
        with pytest.raises(FileNotFoundError):
            list(bf.listdir(path))
        assert not bf.exists(path)
        assert not bf.isdir(path)
        with pytest.raises(expected_error):
            bf.remove(path)
        if suffix in ("", "/"):
            try:
                bf.rmdir(path)
            except bf.Error as e:
                assert "Cannot delete bucket" in e.message
        else:
            bf.rmdir(path)
        with pytest.raises(NotADirectoryError):
            bf.rmtree(path)
        with pytest.raises(FileNotFoundError):
            bf.stat(path)

        if base_path == AZURE_INVALID_CONTAINER_NO_ACCOUNT:
            with pytest.raises(bf.Error):
                bf.get_url(path)
        else:
            bf.get_url(path)

        with pytest.raises(FileNotFoundError):
            bf.md5(path)
        with pytest.raises(bf.Error):
            bf.makedirs(path)
        list(bf.walk(path))
        with tempfile.TemporaryDirectory() as tmpdir:
            local_path = os.path.join(tmpdir, "test.txt")
            with pytest.raises(expected_error):
                bf.copy(path, local_path)
            with open(local_path, "w") as f:
                f.write("meow")
            with pytest.raises(expected_error):
                bf.copy(local_path, path)
        for streaming in [False, True]:
            with pytest.raises(expected_error):
                with bf.BlobFile(path, "rb", streaming=streaming) as f:
                    f.read()
            with pytest.raises(expected_error):
                with bf.BlobFile(path, "wb", streaming=streaming) as f:
                    f.write(b"meow")
def main(H: HParams):
    layout = H.reward_model_spec.run_params.all_gpu_layout()

    reward_model = RewardModel(task_hparams=H.task,
                               spec=H.reward_model_spec,
                               layout=layout)

    setup_logging_with_pacific_tz()

    act_dtype = torch.float16 if H.fp16_activations else torch.float32

    results_dir = bf.join(
        os.environ.get("OUTPUT_DIR",
                       os.path.join("/tmp/jobs", os.getenv("JOB_NAME"))),
        "results")
    bf.makedirs(results_dir)

    if layout.is_logging_rank:
        with open(bf.join(results_dir, "task_hparams.json"), "w") as f:
            json.dump(H.task.to_json(), f)
        with open(bf.join(results_dir, "hparams.json"), "w") as f:
            json.dump(H.to_json(), f)

    # Creates files for printing. Only the replica root prints the files
    output_file_name = os.devnull
    if layout.is_replica_root:
        fname = f"samples.{layout.replica_idx}.jsonl"
        output_file_name = bf.join(results_dir, fname)
        print(f"Outputs will be written to {output_file_name}")

    input_iter = make_jsonl_samples_iter(H.input_path, layout=layout)

    replica_rewards = []

    with open(output_file_name, "a") as out_f:
        input_idx = 0
        for input in input_iter:
            with Timer() as timer:
                query_tokens = torch.tensor(input["context_tokens"])
                assert_shape_eq(query_tokens, (H.task.query.length, ),
                                "Context tokens shape mismatch")
                response_tokens = torch.tensor(input["sample_tokens"])
                assert_eq(response_tokens.dim(), 2)

                n_responses = response_tokens.size(0)

                results = reward_model.reward(
                    query_tokens=query_tokens.unsqueeze(0),
                    response_tokens=response_tokens.unsqueeze(0),
                    act_dtype=act_dtype,
                )

                rewards = to_numpy(results["reward"].reshape((n_responses, )))

                if layout.is_replica_root:

                    replica_rewards.append(rewards)

                    output = {**input, H.output_key: rewards}
                    out_f.write(
                        (json.dumps(jsonl_encoding.encode_example(output)) +
                         "\n"))
            input_idx += 1
            if layout.is_replica_root:
                print(f"Batch {input_idx}.  Took {timer.interval} seconds")

        if layout.is_replica_root:
            print(f"Wrote {input_idx} batches to {output_file_name}")

            replica_rewards = np.stack(replica_rewards, axis=0)
            all_rewards = reward_model.dp_comm.mpi_all_gather(
                replica_rewards, "rewards")
            if layout.replica_idx == 0:
                all_rewards = np.concatenate(all_rewards, axis=0)
                print(f"Mean reward: {all_rewards.mean():.3f}")
                if all_rewards.shape[1] > 1:
                    print(
                        f"Stddev within a query: {all_rewards.std(axis=1, ddof=1).mean():.3}"
                    )
                print(
                    f"Stddev across queries: {all_rewards.std(axis=0, ddof=1).mean():.3}"
                )

    return dict(output_path=results_dir)