def test_stochastic_descent_maximize(self):
     # Every 2-SAT with just one clause will not be satisfied by 1/4 of all
     # possible literal assignments.
     # With only one bang, the only way to do better than random guessing is to
     # apply the DNF hamiltonian.
     dnf = dnf_lib.DNF(2, [dnf_lib.Clause(0, 1, True, True)])
     circuit = dnf_circuit_lib.BangBangProtocolCircuit(1, dnf)
     random_protocol = stochastic_descent_lib.get_random_protocol(2)
     random_eval = circuit.get_constraint_expectation(
         circuit.get_wavefunction(random_protocol))
     protocol, evaluation, num_epoch = stochastic_descent_lib.stochastic_descent(
         circuit=circuit,
         max_num_flips=2,
         initial_protocol=stochastic_descent_lib.get_random_protocol(5),
         minimize=False)
     self.assertLen(protocol, 5)
     self.assertIsInstance(evaluation, float)
     self.assertGreaterEqual(evaluation, random_eval)
     # Contain at least 1 epoch of stochastic descent.
     self.assertGreaterEqual(num_epoch, 1)
 def test_stochastic_descent_neg_max_num_flips(self):
     # Every 2-SAT with just one clause will not be satisfied by 1/4 of all
     # possible literal assignments.
     # With only one bang, the only way to do better than random guessing is to
     # apply the DNF hamiltonian.
     dnf = dnf_lib.DNF(3, [dnf_lib.Clause(0, 1, False, True)])
     circuit = dnf_circuit_lib.BangBangProtocolCircuit(1, dnf)
     with self.assertRaisesRegex(
             ValueError, 'max_num_flips should be positive, not -10'):
         stochastic_descent_lib.stochastic_descent(
             circuit=circuit,
             max_num_flips=-10,
             initial_protocol=stochastic_descent_lib.get_random_protocol(5),
             minimize=False)
 def test_stochastic_descent_skip_search(self):
     dnf = dnf_lib.DNF(2, [dnf_lib.Clause(0, 1, True, True)])
     circuit = dnf_circuit_lib.BangBangProtocolCircuit(1, dnf)
     random_protocol = stochastic_descent_lib.get_random_protocol(2)
     random_eval = circuit.get_constraint_expectation(
         circuit.get_wavefunction(random_protocol))
     protocol, evaluation, num_epoch = stochastic_descent_lib.stochastic_descent(
         circuit=circuit,
         max_num_flips=1,
         initial_protocol=random_protocol,
         minimize=False,
         skip_search=True)
     self.assertListEqual(protocol, random_protocol)
     self.assertIsInstance(evaluation, float)
     self.assertAlmostEqual(evaluation, random_eval)
     # Zero epoch of stochastic descent.
     self.assertEqual(num_epoch, 0)
 def test_get_random_protocol_neg_chunks(self):
     with self.assertRaisesRegex(ValueError,
                                 'num_chunks should be positive, not -4'):
         stochastic_descent_lib.get_random_protocol(-4)
 def test_get_random_protocol(self, seed, expected_protocol):
     self.assertListEqual(
         stochastic_descent_lib.get_random_protocol(
             num_chunks=5, random_state=np.random.RandomState(seed)),
         expected_protocol)