def test_stochastic_descent_maximize(self): # Every 2-SAT with just one clause will not be satisfied by 1/4 of all # possible literal assignments. # With only one bang, the only way to do better than random guessing is to # apply the DNF hamiltonian. dnf = dnf_lib.DNF(2, [dnf_lib.Clause(0, 1, True, True)]) circuit = dnf_circuit_lib.BangBangProtocolCircuit(1, dnf) random_protocol = stochastic_descent_lib.get_random_protocol(2) random_eval = circuit.get_constraint_expectation( circuit.get_wavefunction(random_protocol)) protocol, evaluation, num_epoch = stochastic_descent_lib.stochastic_descent( circuit=circuit, max_num_flips=2, initial_protocol=stochastic_descent_lib.get_random_protocol(5), minimize=False) self.assertLen(protocol, 5) self.assertIsInstance(evaluation, float) self.assertGreaterEqual(evaluation, random_eval) # Contain at least 1 epoch of stochastic descent. self.assertGreaterEqual(num_epoch, 1)
def test_stochastic_descent_neg_max_num_flips(self): # Every 2-SAT with just one clause will not be satisfied by 1/4 of all # possible literal assignments. # With only one bang, the only way to do better than random guessing is to # apply the DNF hamiltonian. dnf = dnf_lib.DNF(3, [dnf_lib.Clause(0, 1, False, True)]) circuit = dnf_circuit_lib.BangBangProtocolCircuit(1, dnf) with self.assertRaisesRegex( ValueError, 'max_num_flips should be positive, not -10'): stochastic_descent_lib.stochastic_descent( circuit=circuit, max_num_flips=-10, initial_protocol=stochastic_descent_lib.get_random_protocol(5), minimize=False)
def test_stochastic_descent_skip_search(self): dnf = dnf_lib.DNF(2, [dnf_lib.Clause(0, 1, True, True)]) circuit = dnf_circuit_lib.BangBangProtocolCircuit(1, dnf) random_protocol = stochastic_descent_lib.get_random_protocol(2) random_eval = circuit.get_constraint_expectation( circuit.get_wavefunction(random_protocol)) protocol, evaluation, num_epoch = stochastic_descent_lib.stochastic_descent( circuit=circuit, max_num_flips=1, initial_protocol=random_protocol, minimize=False, skip_search=True) self.assertListEqual(protocol, random_protocol) self.assertIsInstance(evaluation, float) self.assertAlmostEqual(evaluation, random_eval) # Zero epoch of stochastic descent. self.assertEqual(num_epoch, 0)
def test_get_random_protocol_neg_chunks(self): with self.assertRaisesRegex(ValueError, 'num_chunks should be positive, not -4'): stochastic_descent_lib.get_random_protocol(-4)
def test_get_random_protocol(self, seed, expected_protocol): self.assertListEqual( stochastic_descent_lib.get_random_protocol( num_chunks=5, random_state=np.random.RandomState(seed)), expected_protocol)