Пример #1
0
    def test_multiprocess(self):
        accuracy = use_environment(
            Accuracy(), in_keys=["actual", ["ground_truth", "expected"]],
            value_key="actual"
        )
        dataset = ListDataset([
            Environment(
                {"query": "query0", "ground_truth": "c0"},
                set(["ground_truth"])
            ),
            Environment(
                {"query": "query1", "ground_truth": "c0"},
                set(["ground_truth"])
            ),
            Environment(
                {"query": "query2", "ground_truth": "c0"},
                set(["ground_truth"])
            ),
        ])

        with tempfile.TemporaryDirectory() as init_dir:
            with context.Pool(2) as pool:
                procs = []
                for i in range(2):
                    p = pool.apply_async(
                        self._run,
                        args=(init_dir, dataset, {"accuracy": accuracy}, i),
                    )
                    procs.append(p)
                out = [p.get() for p in procs]
        r0 = out[0]
        r1 = out[1]

        assert r0 == r1

        results = r0
        assert results.metrics == {1: {"accuracy": 1.0 / 3},
                                   3: {"accuracy": 2.0 / 3}}
        assert 3 == len(results.results)
        results.results[0].time = 0.0
        results.results[1].time = 0.0
        results.results[2].time = 0.0
        results.results.sort(key=lambda x: x.sample["query"])
        assert Result({"query": "query0",
                       "ground_truth": "c0"},
                      ["c0", "c1", "c2"],
                      {1: {"accuracy": 1.0}, 3: {"accuracy": 1.0}},
                      True, 0.0) == results.results[0]
        assert Result({"query": "query1",
                       "ground_truth": "c0"},
                      ["c2", "c3", "c0"],
                      {1: {"accuracy": 0.0}, 3: {"accuracy": 1.0}},
                      True, 0.0) == results.results[1]
        assert Result({"query": "query2",
                       "ground_truth": "c0"},
                      ["c2", "c3", "c5"],
                      {1: {"accuracy": 0.0}, 3: {"accuracy": 0.0}},
                      True, 0.0) == results.results[2]
Пример #2
0
 def test_happy_path(self):
     f = UpdateInput()
     entry = f(
         Environment(
             {"interpreter_state": BatchedState({}, {}, [], ["xxx\nyyy"])}))
     assert entry["code"] == "xxx\nyyy"
     state = BatchedState({}, {Diff([]): ["foo"]}, [Diff([])], ["foo"])
     entry = f(Environment({"interpreter_state": state}))
     assert entry["code"] == "foo"
Пример #3
0
    def split(self, values: Environment) -> Sequence[Environment]:
        retval: List[Environment] = []
        B = None
        for key, t in values.items():
            if t is None:
                continue
            if key in self.options:
                option = self.options[key]
                if option.use_pad_sequence:
                    B = t.data.shape[1]
                else:
                    B = t.data.shape[option.dim]
                break
            elif isinstance(t, list):
                B = len(t)
                break

        assert B is not None

        for _ in range(B):
            retval.append(Environment())

        for key, t in values.items():
            if t is None:
                for b in range(B):
                    retval[b][key] = None
                continue
            if key in self.options:
                option = self.options[key]
                if option.use_pad_sequence:
                    for b in range(B):
                        inds = torch.nonzero(t.mask[:, b], as_tuple=False)
                        data = t.data[:, b]
                        shape = data.shape[1:]
                        data = data[inds]
                        data = data.reshape(-1, *shape)
                        retval[b][key] = data
                else:
                    shape = list(t.data.shape)
                    del shape[option.dim]
                    data = torch.split(t, 1, dim=option.dim)
                    for b in range(B):
                        d = data[b]
                        if len(shape) == 0:
                            d = d.reshape(())
                        else:
                            d = d.reshape(*shape)
                        retval[b][key] = d
            elif isinstance(t, list):
                for b in range(B):
                    retval[b][key] = t[b]
            else:
                logger.debug(f"{key} is invalid type: {type(t)}")

        return retval
Пример #4
0
    def __call__(self, entry: Environment) -> Environment:
        entry = cast(Environment, entry.clone())
        train = True
        if "train" in entry:
            train = entry["train"]
        if "action_sequence" in entry:
            train = entry.is_supervision("action_sequence")
        if train or self.key not in entry:
            entry[self.key] = self.initial

        return entry
Пример #5
0
    def test_to(self) -> None:
        class X:
            def to(self, *args, **kwargs):
                self.args = (args, kwargs)
                return self

        e = Environment()
        e["key"] = X()
        e["x"] = torch.tensor(0)
        e["y"] = PaddedSequenceWithMask(torch.tensor(0.0), torch.tensor(True))
        e["z"] = 10
        e.to(device=torch.device("cpu"))
        assert e["key"].args == ((), {"device": torch.device("cpu")})
Пример #6
0
    def test_download(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            cache_path = os.path.join(tmpdir, "cache.pt")
            sqlitefile = os.path.join(tmpdir, "dataset.db")
            conn = sqlite3.connect(sqlitefile)
            c = conn.cursor()
            c.execute(
                "CREATE TABLE Code(code text, error text, errorcount int)")
            c.execute("INSERT INTO Code VALUES ('foo', 'bar', 1)")
            c.execute("INSERT INTO Code VALUES ('foo', '', 0)")
            conn.commit()
            conn.close()

            gzipfile = os.path.join(tmpdir, "dataset.gz")
            with gzip.open(gzipfile, "wb") as file, \
                    open(sqlitefile, "rb") as src_file:
                copyfileobj(src_file, file)
            path = os.path.join(tmpdir, "dataset.zip")
            with zipfile.ZipFile(path, "w") as z:
                with z.open(os.path.join("prutor-deepfix-09-12-2017",
                                         "prutor-deepfix-09-12-2017.db.gz"),
                            "w") as dst_file, \
                        open(gzipfile, "rb") as src_file:
                    copyfileobj(src_file, dst_file)

            def get(src, dst):
                copyfile(src, dst)

            dataset0 = download(cache_path=cache_path, path=path, get=get)

            def get2(src, dst):
                raise NotImplementedError

            dataset1 = download(cache_path=cache_path, path=path, get=get2)

        assert 2 == len(dataset0)
        assert dataset0[0] == Environment(
            {
                "code": "foo",
                "error": "bar",
                "n_error": 1
            }, set(["error", "n_error"]))
        assert dataset0[1] == Environment(
            {
                "code": "foo",
                "error": "",
                "n_error": 0
            }, set(["error", "n_error"]))

        assert list(dataset0) == list(dataset1)
Пример #7
0
    def test_download(self):
        values = [
            "line0\n",
            "x = 10\n",
            "line1\n",
            "if True:§  pass\n",
            "line2\n",
            "if True and \\True:§  pass\n",
        ]

        with tempfile.TemporaryDirectory() as tmpdir:
            cache_path = os.path.join(tmpdir, "cache.pt")

            def get(path):
                return values.pop(0)

            dataset0 = download(get=get, cache_path=cache_path)

            def get2(path):
                raise NotImplementedError

            dataset1 = download(get=get2, cache_path=cache_path)
        train_dataset = dataset0["train"]
        test_dataset = dataset0["test"]
        valid_dataset = dataset0["valid"]

        assert 1 == len(train_dataset)
        assert list(train_dataset) == list(dataset1["train"])
        assert train_dataset[0] == Environment(
            {
                "text_query": "line0",
                "ground_truth": "x = 10"
            }, set(["ground_truth"]))

        assert 1 == len(test_dataset)
        assert list(test_dataset) == list(dataset1["test"])
        assert test_dataset[0] == Environment(
            {
                "text_query": "line1",
                "ground_truth": "if True:\n  pass"
            }, set(["ground_truth"]))

        assert 1 == len(valid_dataset)
        assert list(valid_dataset) == list(dataset1["valid"])
        assert valid_dataset[0] == Environment(
            {
                "text_query": "line2",
                "ground_truth": "if True and True:\n  pass"
            }, set(["ground_truth"]))
Пример #8
0
    def test_propagate_supervision(self):
        apply = Apply(["x", "y"], "out", MockModule(1))
        output = apply(
            Environment({
                "x": torch.arange(3).reshape(-1, 1),
                "y": 10
            }))
        assert not output.is_supervision("out")

        output = apply(
            Environment({
                "x": torch.arange(3).reshape(-1, 1),
                "y": 10
            }, set(["x"])))
        assert output.is_supervision("out")
Пример #9
0
    def test_eval(self):
        entries = [Environment(
            {"text_query": "foo bar", "ground_truth": "y = x + 1"},
            set(["ground_truth"])
        )]
        dataset = ListDataset(entries)
        d = get_samples(dataset, MockParser())
        aencoder = ActionSequenceEncoder(d, 0)
        action_sequence = GroundTruthToActionSequence(MockParser())(
            "y = x + 1"
        )
        transform = AddActions(aencoder)
        action_tensor = transform(
            reference=[Token(None, "foo", "foo"), Token(None, "bar", "bar")],
            action_sequence=action_sequence,
            train=False
        )

        assert np.array_equal(
            [
                [2, 2, 0], [4, 3, 1], [6, 4, 2], [6, 4, 2], [5, 3, 1],
                [6, 5, 5], [6, 5, 5], [5, 5, 5], [6, 4, 8], [6, 4, 8],
                [5, 5, 5], [9, 6, 11], [9, 6, 11], [-1, -1, -1]
            ],
            action_tensor.numpy()
        )
Пример #10
0
 def __next__(self) -> Environment:
     self.n += 1
     return Environment(
         {
             "value": torch.tensor(self.n),
             "ground_truth": torch.tensor(self.n)
         }, set(["ground_truth"]))
Пример #11
0
 def test_constants(self):
     apply = Apply([],
                   "out",
                   MockModule(1),
                   constants={"x": torch.arange(3).reshape(-1, 1)})
     output = apply(Environment())
     assert np.array_equal([[1], [2], [3]], output["out"].detach().numpy())
Пример #12
0
 def test_simple_case(self):
     entries = [Environment(
         {"ground_truth": "y = x + 1"},
         set(["ground_truth"])
     )]
     dataset = ListDataset(entries)
     d = get_samples(dataset, MockParser())
     aencoder = ActionSequenceEncoder(d, 0)
     action_sequence = GroundTruthToActionSequence(MockParser())(
         ground_truth="y = x + 1"
     )
     transform = EncodeActionSequence(aencoder)
     ground_truth = transform(
         action_sequence=action_sequence,
         reference=[Token(None, "foo", "foo"), Token(None, "bar", "bar")],
     )
     assert np.array_equal(
         [
             [3, -1, -1], [4, -1, -1], [-1, 1, -1], [1, -1, -1],
             [5, -1, -1], [-1, 2, -1], [1, -1, -1], [4, -1, -1],
             [-1, 3, -1], [1, -1, -1], [6, -1, -1], [-1, 4, -1],
             [1, -1, -1]
         ],
         ground_truth.numpy()
     )
Пример #13
0
 def test_eval(self):
     entries = [Environment(
         {"text_query": "ab test", "ground_truth": "y = x + 1"},
         set(["ground_truth"])
     )]
     dataset = ListDataset(entries)
     d = get_samples(dataset, MockParserWithoutVariadicArgs())
     aencoder = ActionSequenceEncoder(d, 0)
     action_sequence = GroundTruthToActionSequence(MockParserWithoutVariadicArgs())(
         "y = x + 1"
     )
     transform = AddQueryForTreeGenDecoder(aencoder, 3,)
     query = transform(
         reference=[Token(None, "ab", "ab"), Token(None, "test", "test")],
         action_sequence=action_sequence,
         train=False
     )
     assert np.array_equal(
         [
             [-1, -1, -1], [2, -1, -1], [3, 2, -1], [4, 3, 2],
             [3, 2, -1], [5, 3, 2], [5, 3, 2], [4, 5, 3],
             [5, 3, 2], [6, 5, 3]
         ],
         query.numpy()
     )
Пример #14
0
 def test_eval(self):
     entries = [Environment(
         {"text_query": "ab test", "ground_truth": "y = x + 1"},
         set(["ground_truth"])
     )]
     dataset = ListDataset(entries)
     d = get_samples(dataset, MockParserWithoutVariadicArgs())
     aencoder = ActionSequenceEncoder(d, 0)
     action_sequence = GroundTruthToActionSequence(MockParserWithoutVariadicArgs())(
         "y = x + 1"
     )
     transform = AddActionSequenceAsTree(aencoder,)
     matrix, depth = transform(
         reference=[Token(None, "ab", "ab"), Token(None, "test", "test")],
         action_sequence=action_sequence,
         train=False
     )
     assert np.array_equal(
         [0, 1, 2, 3, 2, 3, 3, 4, 3, 4],
         depth.numpy()
     )
     assert np.array_equal(
         [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 1, 0, 1, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 1, 1, 0, 1, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
         matrix.numpy()
     )
Пример #15
0
 def test_n_dependent(self):
     entries = [Environment(
         {"text_query": "ab test", "ground_truth": "y = x + 1"},
         set(["ground_truth"])
     )]
     dataset = ListDataset(entries)
     d = get_samples(dataset, MockParserWithoutVariadicArgs())
     aencoder = ActionSequenceEncoder(d, 0)
     action_sequence = GroundTruthToActionSequence(MockParserWithoutVariadicArgs())(
         "y = x + 1"
     )
     transform = AddPreviousActionRules(aencoder, 2, n_dependent=3)
     prev_rule_action = transform(
         reference=[Token(None, "ab", "ab"), Token(None, "test", "test")],
         action_sequence=action_sequence,
         train=False,
     )
     assert np.array_equal(
         [
             # str -> "y"
             [[-1, -1, -1], [-1, 3, -1], [-1, -1, -1]],
             # Number -> number
             [[8, -1, -1], [9, -1, -1], [-1, -1, -1]],
             [[-1, -1, -1], [-1, 4, -1], [-1, -1, -1]],
         ],
         prev_rule_action.numpy()
     )
Пример #16
0
    def _download():
        logger.info("Download DeepFix dataset")
        logger.debug(f"Dataset path: {path}")
        samples = []
        with tempfile.TemporaryDirectory() as tmpdir:
            dst = os.path.join(tmpdir, "dataset.zip")
            get(path, dst)

            gzipfile = os.path.join(tmpdir, "dataset.gz")
            with zipfile.ZipFile(dst) as z:
                with z.open(os.path.join("prutor-deepfix-09-12-2017",
                                         "prutor-deepfix-09-12-2017.db.gz"),
                            "r") as file, \
                        open(gzipfile, "wb") as dst_file:
                    copyfileobj(file, dst_file)
            sqlitefile = os.path.join(tmpdir, "dataset.db")
            with gzip.open(gzipfile, "rb") as src_file, \
                    open(sqlitefile, "wb") as dst_file:
                copyfileobj(src_file, dst_file)

            conn = sqlite3.connect(sqlitefile)
            c = conn.cursor()
            for code, error, errorcount in \
                    c.execute("SELECT code, error, errorcount FROM Code"):
                samples.append(
                    Environment(
                        {
                            "code": code,
                            "error": error,
                            "n_error": errorcount,
                        }, set(["error", "n_error"])))
        return samples
Пример #17
0
def dataset():
    return ListDataset([
        Environment({
            "query": "query",
            "ground_truth": "name0"
        }, set(["ground_truth"]))
    ])
Пример #18
0
    def test_eval(self):
        entries = [Environment(
            {"ground_truth": "y = x + 1"},
            set(["ground_truth"])
        )]
        dataset = ListDataset(entries)
        d = get_samples(dataset, MockParser())
        aencoder = ActionSequenceEncoder(d, 0)
        action_sequence = GroundTruthToActionSequence(MockParser())(
            "y = x + 1"
        )
        transform = AddPreviousActions(aencoder)
        prev_action_tensor = transform(
            reference=[Token(None, "foo", "foo"), Token(None, "bar", "bar")],
            action_sequence=action_sequence,
            train=False
        )

        assert np.array_equal(
            [
                [2, -1, -1], [3, -1, -1], [4, -1, -1], [-1, 1, -1],
                [1, -1, -1], [5, -1, -1], [-1, 2, -1], [1, -1, -1],
                [4, -1, -1], [-1, 3, -1], [1, -1, -1], [6, -1, -1],
                [-1, 4, -1], [1, -1, -1]
            ],
            prev_action_tensor.numpy()
        )
Пример #19
0
 def test_ast_set_sample(self):
     asts = ["c0", "c1", "c2"]
     sampler = SequentialProgramSampler(MockSynthesizer(asts),
                                        transform_input, Collate(),
                                        MockEncoder(), MockExpander(),
                                        MockInterpreter())
     zero = SamplerState(0, sampler.initialize([(None, None)]))
     samples = list(sampler.batch_k_samples([zero], [3]))
     samples.sort(key=lambda x: -x.state.score)
     assert 3 == len(samples)
     assert samples[0] == DuplicatedSamplerState(
         SamplerState(
             1,
             Environment({
                 "test_cases": [(None, None)],
                 "reference": [Token(None, str(asts[0]), str(asts[0]))],
                 "variables": [["#" + str(asts[0])]],
                 "interpreter_state":
                 BatchedState({str(asts[0]): None},
                              {str(asts[0]): ["#" + str(asts[0])]},
                              [str(asts[0])], ["#" + str(asts[0])])
             })), 1)
     assert DuplicatedSamplerState(
         SamplerState(
             0.5,
             Environment({
                 "test_cases": [(None, None)],
                 "reference": [Token(None, str(asts[1]), str(asts[1]))],
                 "variables": [["#" + str(asts[1])]],
                 "interpreter_state":
                 BatchedState({str(asts[1]): None},
                              {str(asts[1]): ["#" + str(asts[1])]},
                              [str(asts[1])], ["#" + str(asts[1])])
             })), 1) == samples[1]
     assert DuplicatedSamplerState(
         SamplerState(
             1.0 / 3,
             Environment({
                 "test_cases": [(None, None)],
                 "reference": [Token(None, str(asts[2]), str(asts[2]))],
                 "variables": [["#" + str(asts[2])]],
                 "interpreter_state":
                 BatchedState({str(asts[2]): None},
                              {str(asts[2]): ["#" + str(asts[2])]},
                              [str(asts[2])], ["#" + str(asts[2])])
             })), 1) == samples[2]
Пример #20
0
def create_dummy_nl2bash(path):
    data = [
        Environment({
            "text_query": "foo",
            "ground_truth": "echoa"
        }, set(["ground_truth"]))
    ]
    torch.save({"train": data, "test": data, "valid": data}, path)
Пример #21
0
def create_dummy_hearthstone(path):
    data = [
        Environment({
            "text_query": "foo",
            "ground_truth": "1 + 1"
        }, set(["ground_truth"]))
    ]
    torch.save({"train": data, "test": data, "valid": data}, path)
Пример #22
0
 def to_sample(elem: Tuple[str, str]) -> Environment:
     anno, code = elem
     return Environment(
         {
             "text_query": anno,
             "ground_truth": code
         },
         set(["ground_truth"])
     )
Пример #23
0
 def test_multiple_inputs(self):
     apply = Apply(["x", "y"], "out", MockModule(1))
     output = apply(
         Environment({
             "x": torch.arange(3).reshape(-1, 1),
             "y": 10
         }))
     assert np.array_equal([[11], [12], [13]],
                           output["out"].detach().numpy())
Пример #24
0
    def test_simple_case(self):
        accuracy = use_environment(
            Accuracy(), in_keys=["actual", ["ground_truth", "expected"]],
            value_key="actual"
        )
        dataset = ListDataset([
            Environment(
                {"query": "query0", "ground_truth": "c0"},
                set(["ground_truth"])
            ),
            Environment(
                {"query": "query1", "ground_truth": "c0"},
                set(["ground_truth"])
            ),
            Environment(
                {"query": "query2", "ground_truth": "c0"},
                set(["ground_truth"])
            ),
        ])
        results = EvaluateSynthesizer(dataset, synthesize,
                                      metrics={"accuracy": accuracy})()

        assert results.metrics == \
            {1: {"accuracy": 1.0 / 3.0}, 3: {"accuracy": 2.0 / 3.0}}
        assert 3 == len(results.results)
        results.results[0].time = 0.0
        results.results[1].time = 0.0
        results.results[2].time = 0.0
        assert Result({"query": "query0",
                       "ground_truth": "c0"},
                      ["c0", "c1", "c2"],
                      {1: {"accuracy": 1.0}, 3: {"accuracy": 1.0}},
                      True, 0.0) == results.results[0]
        assert Result({"query": "query1",
                       "ground_truth": "c0"},
                      ["c2", "c3", "c0"],
                      {1: {"accuracy": 0.0}, 3: {"accuracy": 1.0}},
                      True, 0.0) == results.results[1]
        assert Result({"query": "query2",
                       "ground_truth": "c0"},
                      ["c2", "c3", "c5"],
                      {1: {"accuracy": 0.0}, 3: {"accuracy": 0.0}},
                      True, 0.0) == results.results[2]
Пример #25
0
    def forward(self, test_cases: List[Tuple[Input, Kind]],
                actual: Code) -> float:
        inputs = [input for input, _ in test_cases]
        outputs = [output for _, output in test_cases]

        # calc. metric
        m = 0.0  # TODO reduction function is required
        for actual, expected in zip(self.interpreter.eval(actual, inputs),
                                    outputs):
            m += self.metric(Environment({"expected": expected}), actual)
        return m / len(outputs)
Пример #26
0
 def test_happy_path(self):
     to_episode = ToEpisode(Interpreter(), Expander())
     episode = to_episode(
         Environment(
             {
                 "test_cases": [("xxx\nyyy", None)],
                 "ground_truth": Diff([Replace(0, "zzz"),
                                       Remove(1)])
             }, set(["ground_truth"])))
     assert len(episode) == 2
     assert episode[0]["interpreter_state"].context == ["xxx\nyyy"]
     assert episode[1]["interpreter_state"].context == ["zzz\nyyy"]
Пример #27
0
 def test_create_output(self):
     sampler = SequentialProgramSampler(MockSynthesizer([]),
                                        transform_input, Collate(),
                                        MockEncoder(), MockExpander(),
                                        MockInterpreter())
     assert sampler.create_output(
         None,
         Environment({
             "interpreter_state": BatchedState({}, {}, [], [None])
         })) is None
     assert sampler.create_output(
         None,
         Environment(
             {"interpreter_state": BatchedState({}, {}, ["tmp"],
                                                [None])})) == ("tmp", False)
     assert sampler.create_output(
         None,
         Environment({
             "interpreter_state":
             BatchedState({}, {}, ["line0", "line1"], [None])
         })) == ("line0\nline1", False)
    def test_rule(self):
        rule_prob = torch.tensor([
            [[
                1.0,  # unknown
                1.0,  # close variadic field
                0.2,  # Root2X
                0.1,  # Root2Y
                1.0,  # X2Y_list
                1.0,  # Ysub2Str
            ]],
            [[
                1.0,  # unknown
                1.0,  # close variadic field
                1.0,  # Root2X
                1.0,  # Root2Y
                0.5,  # X2Y_list
                1.0,  # Ysub2Str
            ]]])
        token_prob = torch.tensor([[[]], [[]]])
        reference_prob = torch.tensor([[[]], [[]]])
        sampler = ActionSequenceSampler(
            create_encoder(),
            is_subtype,
            create_transform_input([]), transform_action_sequence,
            collate,
            Module(encoder_module,
                   DecoderModule(rule_prob, token_prob, reference_prob))
        )
        s = SamplerState(0.0, sampler.initialize(Environment()))
        topk_results = list(sampler.top_k_samples([s], 1))
        assert 1 == len(topk_results)
        assert 1 == topk_results[0].state.state["length"].item()
        assert np.allclose(log(0.2), topk_results[0].state.score)
        random_results = list(sampler.batch_k_samples([s], [1]))
        assert 1 == len(random_results)
        assert 1 == random_results[0].state.state["length"].item()
        assert \
            log(0.1) - 1e-5 <= random_results[0].state.score <= log(0.2) + 1e-5
        all_results = list(sampler.all_samples([s]))
        assert 2 == len(all_results)
        assert 1 == all_results[0].state.state["length"].item()
        assert np.allclose(log(0.2), all_results[0].state.score)
        assert np.allclose(log(0.1), all_results[1].state.score)
        all_results = list(sampler.all_samples([s], sorted=False))
        assert 1 == all_results[0].state.state["length"].item()
        assert \
            log(0.1) - 1e-5 <= all_results[0].state.score <= log(0.2) + 1e-5

        next = list(sampler.top_k_samples(
            [s.state for s in topk_results], 1))[0]
        assert 2 == next.state.state["length"].item()
        assert np.allclose(log(0.2) + log(0.5), next.state.score)