def testBranchingSingleBeamEntry(self):
    sequence, state, score = beam_search.beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
        branch_factor=32, steps_per_iteration=1)

    # Here the beam search should greedily choose ones.
    self.assertEqual(sequence, [1, 1, 1, 1, 1])
    self.assertEqual(state, 1)
    self.assertEqual(score, 5)
  def testNoBranchingMultipleBeamEntries(self):
    sequence, state, score = beam_search.beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=32,
        branch_factor=1, steps_per_iteration=1)

    # Here the beam has enough capacity to find the optimal solution without
    # branching.
    self.assertEqual(sequence, [0, 0, 0, 0, 1])
    self.assertEqual(state, 1)
    self.assertEqual(score, 16)
  def testNoBranchingMultipleStepsPerIteration(self):
    sequence, state, score = beam_search.beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
        branch_factor=1, steps_per_iteration=2)

    # Like the above case, the counter should never reach one as only a single
    # sequence is ever considered.
    self.assertEqual(sequence, [0, 0, 0, 0, 0])
    self.assertEqual(state, 32)
    self.assertEqual(score, 0)
  def testNoBranchingSingleStepPerIteration(self):
    sequence, state, score = beam_search.beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
        branch_factor=1, steps_per_iteration=1)

    # The generator should emit all zeros, as only a single sequence is ever
    # considered so the counter doesn't reach one.
    self.assertEqual(sequence, [0, 0, 0, 0, 0])
    self.assertEqual(state, 32)
    self.assertEqual(score, 0)
    def testBranchingSingleBeamEntry(self):
        sequence, state, score = beam_search.beam_search(
            initial_sequence=[],
            initial_state=1,
            generate_step_fn=self._generate_step_fn,
            num_steps=5,
            beam_size=1,
            branch_factor=32,
            steps_per_iteration=1)

        # Here the beam search should greedily choose ones.
        self.assertEqual(sequence, [1, 1, 1, 1, 1])
        self.assertEqual(state, 1)
        self.assertEqual(score, 5)
    def testNoBranchingMultipleBeamEntries(self):
        sequence, state, score = beam_search.beam_search(
            initial_sequence=[],
            initial_state=1,
            generate_step_fn=self._generate_step_fn,
            num_steps=5,
            beam_size=32,
            branch_factor=1,
            steps_per_iteration=1)

        # Here the beam has enough capacity to find the optimal solution without
        # branching.
        self.assertEqual(sequence, [0, 0, 0, 0, 1])
        self.assertEqual(state, 1)
        self.assertEqual(score, 16)
    def testNoBranchingMultipleStepsPerIteration(self):
        sequence, state, score = beam_search.beam_search(
            initial_sequence=[],
            initial_state=1,
            generate_step_fn=self._generate_step_fn,
            num_steps=5,
            beam_size=1,
            branch_factor=1,
            steps_per_iteration=2)

        # Like the above case, the counter should never reach one as only a single
        # sequence is ever considered.
        self.assertEqual(sequence, [0, 0, 0, 0, 0])
        self.assertEqual(state, 32)
        self.assertEqual(score, 0)
    def testNoBranchingSingleStepPerIteration(self):
        sequence, state, score = beam_search.beam_search(
            initial_sequence=[],
            initial_state=1,
            generate_step_fn=self._generate_step_fn,
            num_steps=5,
            beam_size=1,
            branch_factor=1,
            steps_per_iteration=1)

        # The generator should emit all zeros, as only a single sequence is ever
        # considered so the counter doesn't reach one.
        self.assertEqual(sequence, [0, 0, 0, 0, 0])
        self.assertEqual(state, 32)
        self.assertEqual(score, 0)