class MultiTcBuilder(): def __init__(self, tc="", forward_names=(), forward_input_indices=(()), forward_force_reinforcement_tunings=(), backward_names=(), backward_input_indices=(()), backward_force_reinforcement_tunings=(), check_output_shapes=True, tuner_cache_file="", tuner_config=TunerConfig(), debug=False): if debug: assert isinstance(tc, str), type(tc) assert isinstance(forward_names, tuple), type(forward_names) assert isinstance(forward_input_indices, tuple), type(forward_input_indices) assert isinstance(forward_force_reinforcement_tunings, tuple), type(forward_force_reinforcement_tunings) assert isinstance(backward_names, tuple), type(backward_names) assert isinstance(backward_input_indices, tuple), type(backward_input_indices) assert isinstance( backward_force_reinforcement_tunings, tuple), type(backward_force_reinforcement_tunings) assert isinstance(check_output_shapes, bool), type(tuner_cache_file) assert isinstance(tuner_cache_file, str), type(tuner_cache_file) assert isinstance(tuner_config, TunerConfig), type(tuner_config) self.tc = tc self.forward_names = forward_names self.forward_input_indices = forward_input_indices self.forward_force_reinforcement_tunings = forward_force_reinforcement_tunings self.backward_names = backward_names self.backward_input_indices = backward_input_indices self.backward_force_reinforcement_tunings = backward_force_reinforcement_tunings self.check_output_shapes = check_output_shapes self.tuner_cache_file = tuner_cache_file self.tuner_config = tuner_config self.debug = debug self.compilation_cache = CompilationCache(self.tc) def compileOrTune(self, name="", force_reinforcement_tuning=False, inputs=()): if self.debug: print( "On Tc: {}\ncompile def {}, force_reinforcement_tuning {}, inputs: {}" .format( self.tc, name, force_reinforcement_tuning, "".join("{}/{}, ".format(t.size().__str__(), t.stride().__str__()) for t in inputs))) if not self.compilation_cache.is_compiled(name, inputs): cache = MappingOptionsCache(self.tuner_cache_file) mapping_options = None base_options_list = cache.load(self.tc, name, inputs, 1) if len(base_options_list) > 0 and not force_reinforcement_tuning: mapping_options = base_options_list[0] if self.debug: print("Found best options in {}:\n{}".format( self.tuner_cache_file, mapping_options)) else: if self.debug: print( "########################################################" "########################################################" ) print( "force_reinforcement_tuning = {} was specified, {} options loaded from " "{}".format(force_reinforcement_tuning, len(base_options_list), self.tuner_cache_file)) print( "Starting a tuning run (abort it with Ctrl+C when " "performance is satisfactory.\nYou can always reinforce " "the results later by passing a proper tuner cache file " "and specifying force_reinforcement_tuning=True)") print( "########################################################" "########################################################" ) if len(base_options_list) == 0: mapping_options = MappingOptions() else: mapping_options = base_options_list[0] tuner = Tuner(self.tc, self.tuner_cache_file) mapping_options = tuner.tune(name, inputs, mapping_options, self.tuner_config) self.compilation_cache.compile(name, inputs, mapping_options)
################################################################################ executor = compile(mm, "matmul", (A, B), MappingOptions('naive')) C = executor.run((A, B)) time_tc(100, "simple API (in place)\t", lambda name, ins: executor.unchecked_run(ins, (C, )), "matmul", (A, B)) time_tc(100, "simple API (with allocation overhead)\t", lambda name, ins: executor.unchecked_run(ins), "matmul", (A, B)) ################################################################################ # 2. Use the C++ API to build a low-overhead compilation cache and time it ################################################################################ # Compilation returns an allocated tuple of outputs with the proper shapes. # Allocation overhead is negligible compared to compilation overhead. compilation_cache.compile("matmul", (A, B), MappingOptions('naive')) # Run once without timing compilation_cache.unchecked_run("matmul", (A, B)) # unchecked_run on tensors time_tc(100, "raw unchecked_run naive options\t", lambda name, ins: compilation_cache.unchecked_run(name, ins), "matmul", (A, B)) ################################################################################ # 3. Short tuning run saving to file then load the best option to create a # compilation cache ################################################################################ with tempfile.NamedTemporaryFile() as cache_file: tuner = Tuner(mm, cache_file.name)
time_tc(100, "simple API\t", lambda name, ins: executor.unchecked_run(ins, tuple(outputs)), "matmul", (mat1, mat2)) time_tc(100, "simple API (with allocation overhead)\t", lambda name, ins: executor.unchecked_run(ins, ()), "matmul", (mat1, mat2)) ################################################################################ # 2. Use the C++ API to build a low-overhead compilation cache and time it ################################################################################ from tensor_comprehensions.tclib import CompilationCache compilation_cache = CompilationCache(mm) # Compilation returns an allocated tuple of outputs with the proper shapes. # Allocation overhead is negligible compared to compilation overhead. compilation_cache.compile("matmul", (mat1, mat2), MappingOptions()) # Run once without timing compilation_cache.unchecked_run("matmul", (mat1, mat2), ()) # unchecked_run on tensors time_tc(100, "raw unchecked_run naive options\t", lambda name, ins: compilation_cache.unchecked_run(name, ins, ()), "matmul", (mat1, mat2)) ################################################################################ # 3. Short tuning run saving to file then load the best option to create a # compilation cache ################################################################################ from tensor_comprehensions.tclib import Tuner from tensor_comprehensions.tclib import MappingOptionsCache from tensor_comprehensions.tclib import TunerConfig