예제 #1
0
def run_model(m, input, randomize):
    if randomize:
        torch_glow.enable_randomize_constants()
    else:
        torch_glow.disable_randomize_constants()

    torch_glow.disableFusionPass()
    traced_m = torch.jit.trace(m, input)

    spec = torch.classes.glow.GlowCompileSpec()
    spec.setBackend("Interpreter")
    sim = torch.classes.glow.SpecInputMeta()
    sim.setSameAs(input)
    spec.addInputs([sim])

    glow_m = torch_glow.to_glow(traced_m, {"forward": spec})
    return glow_m.forward(input)
예제 #2
0
def run_model(m, input, randomize):
    torch_glow.disableFusionPass()
    traced_m = torch.jit.trace(m, input)

    if randomize:
        torch_glow.enable_randomize_constants()
    else:
        torch_glow.disable_randomize_constants()

    spec = torch_glow.CompilationSpec()
    spec.get_settings().set_glow_backend("Interpreter")

    compilation_group = torch_glow.CompilationGroup()
    spec.compilation_groups_append(compilation_group)

    input_spec = torch_glow.InputSpec()
    input_spec.set_same_as(input)

    compilation_group.input_sets_append([input_spec])

    glow_m = torch_glow.to_glow(traced_m, {"forward": spec})
    return glow_m(input)