def test_linear_combination_instantiate(self): """Test LinearCombinationDifferentiator type checking.""" linear_combination.LinearCombination([1, 1], [1, 0]) with self.assertRaisesRegex(TypeError, expected_regex="weights must be"): linear_combination.LinearCombination("junk", [1, 0]) with self.assertRaisesRegex(TypeError, expected_regex="perturbations must be"): linear_combination.LinearCombination([1, 1], "junk") with self.assertRaisesRegex(TypeError, expected_regex="weight in weights"): linear_combination.LinearCombination([1, "junk"], [1, 0]) with self.assertRaisesRegex( TypeError, expected_regex="perturbation in perturbations"): linear_combination.LinearCombination([1, 1], [1, "junk"]) with self.assertRaisesRegex(ValueError, expected_regex="length"): linear_combination.LinearCombination([1, 1, 1], [1, 0]) with self.assertRaisesRegex(ValueError, expected_regex="at least two"): linear_combination.LinearCombination([1], [1]) with self.assertRaisesRegex(ValueError, expected_regex="unique"): linear_combination.LinearCombination([1, 1], [1, 1])
def test_no_gradient_circuits(self): """Confirm LinearCombination differentiator has no gradient circuits.""" dif = linear_combination.LinearCombination([1, 1], [1, 0]) with self.assertRaisesRegex(NotImplementedError, expected_regex="not currently available"): _ = dif.get_gradient_circuits(None, None, None)
def test_get_gradient_circuits(self): """Test that the correct objects are returned.""" # Minimal linear combination. input_weights = [1.0, -0.5] input_perturbations = [1.0, -1.5] diff = linear_combination.LinearCombination(input_weights, input_perturbations) # Circuits to differentiate. symbols = [sympy.Symbol("s0"), sympy.Symbol("s1")] q0 = cirq.GridQubit(0, 0) q1 = cirq.GridQubit(1, 2) input_programs = util.convert_to_tensor([ cirq.Circuit(cirq.X(q0)**symbols[0], cirq.ry(symbols[1])(q1)), cirq.Circuit(cirq.rx(symbols[0])(q0), cirq.Y(q1)**symbols[1]), ]) input_symbol_names = tf.constant([str(s) for s in symbols]) input_symbol_values = tf.constant([[1.5, -2.7], [-0.3, 0.9]]) # For each program in the input batch: LinearCombination creates a copy # of that program for each symbol in the batch; then for each symbol, # the program is copied for each non-zero perturbation; finally, a # single copy is added for the zero perturbation (no zero pert here). expected_batch_programs = tf.stack([[input_programs[0]] * 4, [input_programs[1]] * 4]) expected_new_symbol_names = input_symbol_names # For each program in the input batch: first, the input symbol_values # for the program are tiled to the number of copies in the output. tiled_symbol_values = tf.stack([[input_symbol_values[0]] * 4, [input_symbol_values[1]] * 4]) # Then we create the tensor of perturbations to apply to these symbol # values: for each symbol we tile out the non-zero perturbations at that # symbol's index, keeping all the other symbol perturbations at zero. # Perturbations are the same for each program. single_program_perturbations = tf.stack([[input_perturbations[0], 0.0], [input_perturbations[1], 0.0], [0.0, input_perturbations[0]], [0.0, input_perturbations[1]]]) tiled_perturbations = tf.stack( [single_program_perturbations, single_program_perturbations]) # Finally we add the perturbations to the original symbol values. expected_batch_symbol_values = tiled_symbol_values + tiled_perturbations # The weights for LinearCombination is the same for every program. individual_batch_weights = tf.stack( [[input_weights[0], input_weights[1]], [input_weights[0], input_weights[1]]]) expected_batch_weights = tf.stack( [individual_batch_weights, individual_batch_weights]) # The mapper selects the expectations. single_program_mapper = tf.constant([[0, 1], [2, 3]]) expected_batch_mapper = tf.tile( tf.expand_dims(single_program_mapper, 0), [2, 1, 1]) (test_batch_programs, test_new_symbol_names, test_batch_symbol_values, test_batch_weights, test_batch_mapper) = diff.get_gradient_circuits( input_programs, input_symbol_names, input_symbol_values) self.assertAllEqual(expected_batch_programs, test_batch_programs) self.assertAllEqual(expected_new_symbol_names, test_new_symbol_names) self.assertAllClose(expected_batch_symbol_values, test_batch_symbol_values, atol=1e-6) self.assertAllClose(expected_batch_weights, test_batch_weights, atol=1e-6) self.assertAllEqual(expected_batch_mapper, test_batch_mapper)