def cached_function(inputs, outputs): import theano with Message("Hashing theano fn"): if hasattr(outputs, '__len__'): hash_content = tuple(map(theano.pp, outputs)) else: hash_content = theano.pp(outputs) cache_key = hex(hash(hash_content) & (2**64 - 1))[:-1] cache_dir = Path('~/.hierctrl_cache') cache_dir = cache_dir.expanduser() cache_dir.mkdir_p() cache_file = cache_dir / ('%s.pkl' % cache_key) if cache_file.exists(): with Message("unpickling"): with open(cache_file, "rb") as f: try: return pickle.load(f) except Exception: pass with Message("compiling"): fun = compile_function(inputs, outputs) with Message("picking"): with open(cache_file, "wb") as f: pickle.dump(fun, f, protocol=pickle.HIGHEST_PROTOCOL) return fun
def compile_function(inputs=None, outputs=None, updates=None, givens=None, log_name=None, **kwargs): import theano if log_name: msg = Message("Compiling function %s" % log_name) msg.__enter__() ret = theano.function( inputs=inputs, outputs=outputs, updates=updates, givens=givens, on_unused_input='ignore', allow_input_downcast=True, **kwargs ) if log_name: msg.__exit__(None, None, None) return ret