def test_undefined_operations(self): undefined_symbol = special_values.Undefined('name') self.assertTrue(special_values.is_undefined(undefined_symbol.foo)) self.assertTrue(special_values.is_undefined(undefined_symbol[0])) self.assertFalse( special_values.is_undefined(undefined_symbol.__class__))
def test_undefined(self): undefined_symbol = special_values.Undefined('name') self.assertEqual(undefined_symbol.symbol_name, 'name') undefined_symbol2 = special_values.Undefined('name') self.assertNotEqual(undefined_symbol, undefined_symbol2) self.assertTrue(special_values.is_undefined(undefined_symbol)) self.assertTrue(special_values.is_undefined(undefined_symbol2))
def _filter_undefined(all_symbols): """Returns the names of undefined symbols contained in all_symbols.""" undefined_symbols = [ s.symbol_name for s in all_symbols if special_values.is_undefined(s) ] return undefined_symbols
def _verify_loop_init_vars(values, symbol_names): """Ensures that all values in the state are defined when entering a loop.""" for name, value in zip(symbol_names, values): if value is None: raise ValueError( '"{}" may not be None before the loop.'.format(name)) if special_values.is_undefined_return(value): # Assumption: the loop will only capture the variable which tracks the # return value if the loop contained a return statement. # TODO(mdan): This should be checked at the place where return occurs. raise ValueError( 'return statements are not supported within a TensorFlow loop.' ) if special_values.is_undefined(value): raise ValueError( '"{}" must be defined before the loop.'.format(name))
def protected_func(): """Calls function and raises an error if undefined symbols are returned.""" results = func() undefined_symbols = None if isinstance(results, tuple): undefined_symbols = _filter_undefined(results) elif special_values.is_undefined(results): # Single return value undefined_symbols = results.symbol_name if undefined_symbols: message = ('The following symbols must also be initialized in the %s ' 'branch: {}. Alternatively, you may initialize them before ' 'the if statement.') % branch_name message = message.format(undefined_symbols) raise ValueError(message) return results