def test_long_formatted_with_diff(self): int32 = computation_types.TensorType(tf.int32) first = computation_types.StructType([(None, int32)] * 20) second = computation_types.StructType([(None, int32)] * 21) actual = computation_types.type_mismatch_error_message( first, second, computation_types.TypeRelation.EQUIVALENT) golden.check_string('long_formatted_with_diff.expected', actual)
def check_computations( filename: str, computations: OrderedDict[str, building_blocks.ComputationBuildingBlock] ) -> None: """Check the AST of computations matches the contents of the golden file. Args: filename: String filename of the golden file. computations: An OrderedDict of computation names to `building_blocks.ComputationBuildingBlock`. Raises: TypeError: If any argument type mismatches. """ py_typecheck.check_type(filename, str) py_typecheck.check_type(computations, collections.OrderedDict, 'computations') values = [] for name, computation in computations.items(): py_typecheck.check_type(computation, building_blocks.ComputationBuildingBlock, name) computation_ast = _name_compiled_computations(computation) values.append( f'{name}:\n\n{computation_ast.formatted_representation()}\n\n') golden.check_string(filename, ''.join(values))
def test_stackframes_in_errors(self): class DummyError(RuntimeError): pass try: @computation_wrapper_instances.federated_computation_wrapper def _(): raise DummyError() self.fail('Tracing should throw `DummyError`') except DummyError: golden.check_string('federated_computation_wrapper_traceback.expected', traceback_string())
def assert_transforms(self, comp, file, changes_type=False, unmodified=False): # NOTE: A `transform` method must be present on inheritors. after, modified = self.transform(comp) golden.check_string( file, f'Before transformation:\n\n{comp.formatted_representation()}\n\n' f'After transformation:\n\n{after.formatted_representation()}') if not changes_type: type_test_utils.assert_types_identical(comp.type_signature, after.type_signature) if unmodified: self.assertFalse(modified) else: self.assertTrue(modified) return after
def test_check_string_updates(self): filename = 'test_check_string_updates.expected' golden_path = golden._filename_to_golden_path(filename) old_contents = 'old\ndata\n' new_contents = 'new\ndata\n' # Attempt to reset the contents of the file to their checked-in state. try: with open(golden_path, 'w') as f: f.write(old_contents) except (OSError, PermissionError): # We're running without `--test_strategy=local`, and so can't test # updates properly because these files are read-only. return # Check for a mismatch when `--update_goldens` isn't passed. with self.assertRaises(golden.MismatchedGoldenError): golden.check_string(filename, new_contents) # Rerun with `--update_goldens`. with flagsaver.flagsaver(update_goldens=True): golden.check_string(filename, new_contents) # Check again without `--update_goldens` now that they have been updated. try: golden.check_string(filename, new_contents) except golden.MismatchedGoldenError as e: self.fail(f'Unexpected mismatch after update: {e}') # Reset the contents of the file to their checked-in state. with open(golden_path, 'w') as f: f.write(old_contents)
def test_check_string_fails(self): with self.assertRaises(golden.MismatchedGoldenError): golden.check_string('test_check_string_fails.expected', 'not\nwhat\nyou\nexpected')
def test_check_string_succeeds(self): golden.check_string('test_check_string_succeeds.expected', 'foo\nbar\nbaz\nfizzbuzz')
def test_container_types_full_repr(self): first = computation_types.StructWithPythonType([], list) second = computation_types.StructWithPythonType([], tuple) actual = computation_types.type_mismatch_error_message( first, second, computation_types.TypeRelation.EQUIVALENT) golden.check_string('container_types_full_repr.expected', actual)
def test_short_compact_repr(self): first = computation_types.TensorType(tf.int32) second = computation_types.TensorType(tf.bool) actual = computation_types.type_mismatch_error_message( first, second, computation_types.TypeRelation.EQUIVALENT) golden.check_string('short_compact_repr.expected', actual)