def testAddSubFallback(self): torch._lazy.config.set_force_fallback("aten::sub") device = get_test_device() x = torch.randn(2, 3, 4, device=device) y = torch.randn(2, 3, 4, device=device) z = torch.zeros(2, 3, 4, device=device) device = 'lazy' x_lazy = x.detach().clone().to(device=device) y_lazy = y.detach().clone().to(device=device) z_lazy = z.detach().clone().to(device=device) for i in range(10): if i < 5: z += (x + y) else: z += (x - y) for i in range(10): if i < 5: z_lazy += (x_lazy + y_lazy) else: z_lazy += (x_lazy - y_lazy) torch._lazy.mark_step() torch.testing.assert_close(z.cpu(), z_lazy.cpu()) assert metrics.counter_value( "IrNodeReused_torch::lazy::AddTensor") >= 8 metrics.reset() torch._lazy.ir_cache.reset() torch._lazy.config.set_force_fallback("")
def testBatchNorm(self): device = get_test_device() x = torch.randn(16, 3, 224, 224, device=device) weight = torch.randn(3, device=device) bias = torch.randn(3, device=device) for i in range(10): # BatchNorm2d does extra checks on dimensions which SymInts don't support yet # so we call `torch.ops.aten.native_batch_norm` to bypass the checks. z, _, _ = torch.ops.aten.native_batch_norm(x, weight, bias, None, None, True, 0.1, 1e-5) device = "lazy" x_lazy = x.detach().clone().to(device=device) weight_lazy = weight.detach().clone().to(device=device) bias_lazy = bias.detach().clone().to(device=device) for i in range(10): z_lazy, _, _ = torch.ops.aten.native_batch_norm( x_lazy, weight_lazy, bias_lazy, None, None, True, 0.1, 1e-5) torch._lazy.mark_step() torch.testing.assert_close(z.cpu(), z_lazy.cpu()) assert metrics.counter_value( "IrNodeReused_torch::lazy::TSNativeBatchNormForward") >= 7 metrics.reset() torch._lazy.ir_cache.reset()
def testBatchNorm(self): device = get_test_device() x = torch.randn(16, 3, 224, 224, device=device) bn = torch.nn.BatchNorm2d(3).to(device=device) for i in range(10): z = bn(x) device = "lazy" x_lazy = x.detach().clone().to(device=device) bn = bn.to(device=device) for i in range(10): z_lazy = bn(x_lazy) torch._lazy.mark_step() torch.testing.assert_close(z.cpu(), z_lazy.cpu()) assert metrics.counter_value( "IrNodeReused_torch::lazy::TSNativeBatchNormForward") >= 7 metrics.reset() torch._lazy.ir_cache.reset()
def testAdd(self): device = get_test_device() x = torch.randn(2, 3, 4, device=device) y = torch.randn(2, 3, 4, device=device) z = torch.zeros(2, 3, 4, device=device) device = 'lazy' x_lazy = x.detach().clone().to(device=device) y_lazy = y.detach().clone().to(device=device) z_lazy = z.detach().clone().to(device=device) for i in range(10): z += (x + y) for i in range(10): z_lazy += (x_lazy + y_lazy) torch._lazy.mark_step() torch.testing.assert_close(z.cpu(), z_lazy.cpu()) assert metrics.counter_value( "IrNodeReused_torch::lazy::AddTensor") >= 14 metrics.reset() torch._lazy.ir_cache.reset()
def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: """ Optimize an eager model with LTC and returns a wrapper to execute the compiled graph directly without retracing. It depends on other mechanisms like TorchDynamo guards to guarantee the returned wrapper is only called when it's safe. """ lazy_args = [arg.to(device="lazy") for arg in example_inputs] args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] tensor_id_to_arg_idx = { tensor_id: i for i, tensor_id in enumerate(args_tensor_ids) } lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) force_lazy_device(lazy_model) # This line executes lazy tracing and enable us extracting compiled graph later metrics.reset() lazy_out = lazy_model(*lazy_args) fallback_ops = get_fallback_ops() metrics.reset() if len(fallback_ops) > 0: raise RuntimeError( f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" ) if not isinstance(lazy_out, (tuple, list)): lazy_out = (lazy_out, ) args_and_out = tuple(lazy_args) + tuple(lazy_out) return_value_handler = ReturnValueHandler(args_and_out) if debug: print("Fx code:\n", model.code) print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) # TODO: this part is TS backend specific for now and will be generalized to # support XLA ( graph_input_tensor_ids, graph_input_ivalues, ) = computation.get_tensors_ts_device_data_node(args_and_out) assert len(graph_input_tensor_ids) == len(graph_input_ivalues) graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues) graph_hash = computation.get_graph_hash(args_and_out) if debug: print("graph_hash", graph_hash) print(f"args_tensor_ids {args_tensor_ids}") print("tensor ids from device data:", graph_input_tensor_ids) # sync the list of output tensors so the computation graph for these # tensors will be cached. Those computation graphs can be retrieved # by graph hash later. lazy.sync_multi(args_and_out, []) def optimized_mod(*args): if len(args_and_out) == 0: return () graph_input = graph_input_matcher(args) res = return_value_handler.duplicate_eager_tensors( computation.run_cached_graph(graph_hash, graph_input)) assert len(res) == len(args_and_out) for i, arg in enumerate(args): # only copy those tensors that get inplace updated if arg is not res[i]: arg.copy_(res[i]) # skip the args return res[len(args):] return optimized_mod