def test_rule(self):
        rule_prob = torch.tensor([
            [[
                1.0,  # unknown
                1.0,  # close variadic field
                0.2,  # Root2X
                0.1,  # Root2Y
                1.0,  # X2Y_list
                1.0,  # Ysub2Str
            ]],
            [[
                1.0,  # unknown
                1.0,  # close variadic field
                1.0,  # Root2X
                1.0,  # Root2Y
                0.5,  # X2Y_list
                1.0,  # Ysub2Str
            ]]])
        token_prob = torch.tensor([[[]], [[]]])
        reference_prob = torch.tensor([[[]], [[]]])
        sampler = ActionSequenceSampler(
            create_encoder(),
            is_subtype,
            create_transform_input([]), transform_action_sequence,
            collate,
            Module(encoder_module,
                   DecoderModule(rule_prob, token_prob, reference_prob))
        )
        s = SamplerState(0.0, sampler.initialize(Environment()))
        topk_results = list(sampler.top_k_samples([s], 1))
        assert 1 == len(topk_results)
        assert 1 == topk_results[0].state.state["length"].item()
        assert np.allclose(log(0.2), topk_results[0].state.score)
        random_results = list(sampler.batch_k_samples([s], [1]))
        assert 1 == len(random_results)
        assert 1 == random_results[0].state.state["length"].item()
        assert \
            log(0.1) - 1e-5 <= random_results[0].state.score <= log(0.2) + 1e-5
        all_results = list(sampler.all_samples([s]))
        assert 2 == len(all_results)
        assert 1 == all_results[0].state.state["length"].item()
        assert np.allclose(log(0.2), all_results[0].state.score)
        assert np.allclose(log(0.1), all_results[1].state.score)
        all_results = list(sampler.all_samples([s], sorted=False))
        assert 1 == all_results[0].state.state["length"].item()
        assert \
            log(0.1) - 1e-5 <= all_results[0].state.score <= log(0.2) + 1e-5

        next = list(sampler.top_k_samples(
            [s.state for s in topk_results], 1))[0]
        assert 2 == next.state.state["length"].item()
        assert np.allclose(log(0.2) + log(0.5), next.state.score)
 def test_initialize(self):
     sampler = ActionSequenceSampler(
         create_encoder(),
         is_subtype,
         create_transform_input([]), transform_action_sequence,
         collate,
         Module(encoder_module,
                DecoderModule([], [], []))
     )
     s = sampler.initialize(Environment())
     assert 1 == len(s["action_sequence"].action_sequence)
     assert s["input"] == torch.zeros((1,))
     assert s["reference"] == []
 def test_reference(self):
     torch.manual_seed(0)
     rule_prob = torch.tensor([
         [[
             1.0,  # unknown
             1.0,  # close variadic field
             0.1,  # Root2X
             0.2,  # Root2Y
             1.0,  # X2Y_list
             1.0,  # Ysub2Str
         ]],
         [[
             1.0,  # unknown
             1.0,  # close variadic field
             1.0,  # Root2X
             1.0,  # Root2Y
             1.0,  # X2Y_list
             1.0,  # Ysub2Str
         ]],
         [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]])
     token_prob = torch.tensor([
         [[0.0, 0.0, 0.0]],
         [[0.0, 0.0, 0.0]],
         [[
             1.0,  # Unknown
             0.8,  # x
             0.2,  # 1
         ]]])
     reference_prob = torch.tensor(
         [[[0.0, 0.0]], [[0.0, 0.0]], [[0.1, 0.1]]])
     sampler = ActionSequenceSampler(
         create_encoder(),
         is_subtype,
         create_transform_input([Token("Str", "x", "x"),
                                 Token(None, "x", "x")]),
         transform_action_sequence,
         collate,
         Module(encoder_module,
                DecoderModule(rule_prob, token_prob, reference_prob)),
         rng=np.random.RandomState(0)
     )
     s = SamplerState(0.0, sampler.initialize(Environment()))
     results = [s.state for s in sampler.top_k_samples([s], 1)]
     results = [s.state for s in sampler.top_k_samples(results, 1)]
     topk_results = list(sampler.top_k_samples(results, 1))
     assert 1 == len(topk_results)
     assert 3 == topk_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.),
                        topk_results[0].state.score)
     random_results = list(sampler.batch_k_samples(results[:1], [1]))
     assert 1 == len(random_results)
     assert 3 == random_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.0),
                        random_results[0].state.score)
     all_results = list(sampler.all_samples(results))
     assert 1 == len(all_results)
     assert 3 == all_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.),
                        all_results[0].state.score)
     all_results = list(sampler.all_samples(results, sorted=False))
     assert 1 == len(all_results)
     assert 3 == all_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.),
                        all_results[0].state.score)