def run_model(m, input, randomize): if randomize: torch_glow.enable_randomize_constants() else: torch_glow.disable_randomize_constants() torch_glow.disableFusionPass() traced_m = torch.jit.trace(m, input) spec = torch.classes.glow.GlowCompileSpec() spec.setBackend("Interpreter") sim = torch.classes.glow.SpecInputMeta() sim.setSameAs(input) spec.addInputs([sim]) glow_m = torch_glow.to_glow(traced_m, {"forward": spec}) return glow_m.forward(input)
def run_model(m, input, randomize): torch_glow.disableFusionPass() traced_m = torch.jit.trace(m, input) if randomize: torch_glow.enable_randomize_constants() else: torch_glow.disable_randomize_constants() spec = torch_glow.CompilationSpec() spec.get_settings().set_glow_backend("Interpreter") compilation_group = torch_glow.CompilationGroup() spec.compilation_groups_append(compilation_group) input_spec = torch_glow.InputSpec() input_spec.set_same_as(input) compilation_group.input_sets_append([input_spec]) glow_m = torch_glow.to_glow(traced_m, {"forward": spec}) return glow_m(input)