def test_location_type_error(self): circ = circuit.Circuit(5, None) with self.assertRaisesRegex( TypeError, r'location is not integer-like \(found type: float\)'): transform.focus_single_operation(circ, 47.11)
def test_location_out_of_bounds_error(self, location): circ = circuit.Circuit(3, [ _random_operation(0), _random_operation(0, 1), _random_operation(1), _random_operation(1, 2), _random_operation(2), ]) with self.assertRaisesRegex( IndexError, r'location %d out of bounds for a Circuit of length 5'%location): transform.focus_single_operation(circ, location)
def test_successful(self, location): # preparation work operation0 = _random_operation(0) operation1 = _random_operation(0, 1) operation2 = _random_operation(1) operation3 = _random_operation(1, 2) operation4 = _random_operation(2) circ = circuit.Circuit( 5, [operation0, operation1, operation2, operation3, operation4]) # call the function to be tested attention_circ = transform.focus_single_operation(circ, location) # check type of attention_circ self.assertIs(type(attention_circ), transform.AttentionCircuit) # check the focus of attention_circ self.assertLen(attention_circ, 1) self.assertTrue(_elementwise_is(attention_circ.focus(), [operation3])) # check the context of attention_circ context = attention_circ.context() self.assertTrue( _elementwise_is(context.before().get_operation_sequence(), [operation0, operation1, operation2])) self.assertEmpty(context.between()) self.assertTrue( _elementwise_is(context.after().get_operation_sequence(), [operation4])) # check the locations of attention_circ self.assertTupleEqual(attention_circ.locations(), (3, ))
def scan_for_single_operations(circ): """Iterates over all attention circuits with exactly one operation in the focus. There will be such an AttentionCircuit for each operation in the circuit. This function is lazy, i.e. items are not computed before they are actually requested. Args: circ: the circuit to be scanned. Yields: all attention circuits with exactly one operation in the focus. Raises: TypeError: if circ is not a Circuit. """ if not isinstance(circ, circuit.Circuit): raise TypeError('circ is not a Circuit (found type: %s)' % type(circ).__name__) for location, _ in enumerate(circ): yield transform.focus_single_operation(circ, location)
def test_circ_type_error(self): with self.assertRaisesRegex( TypeError, r'circ is not a Circuit \(found type: range\)'): transform.focus_single_operation(range(10), 3)