コード例 #1
0
    # VTA target and execution context
    target = env.target if opt.device == "vta" else env.target_vta_cpu
    ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0)
    
    # Compile Relay program
    print("Initial compile...")
    relay_prog, params = compile_network(opt, env, target)

    # Register VTA tuning tasks
    register_vta_tuning_tasks()

    # Perform task extraction on Relay program
    print("Extracting tasks...")
    tasks = extract_from_program(func=relay_prog,
                                 params=params,
                                 ops=(tvm.relay.op.nn.conv2d,),
                                 target=target,
                                 target_host=env.target_host)

    # Perform Autotuning
    print("Tuning...")
    tuning_opt = {
        'log_filename': opt.log_filename,
        'tuner': opt.tuner,
        'n_trial': 1e9,
        'early_stopping': None,
        'measure_option': autotvm.measure_option(
                builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
                runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port,
                    number=4, min_repeat_ms=150, repeat=opt.measurements, timeout=60,
                    check_correctness=True))
コード例 #2
0
    target = env.target if opt.device == "vta" else env.target_vta_cpu
    ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0)

    # Compile Relay program
    print("Initial compile...")
    relay_prog, params = compile_network(opt, env, target)

    # Register VTA tuning tasks
    register_vta_tuning_tasks()

    # Perform task extraction on Relay program
    print("Extracting tasks...")
    tasks = extract_from_program(
        func=relay_prog,
        params=params,
        ops=(relay.op.get("nn.conv2d"),),
        target=target,
        target_host=env.target_host,
    )

    # Perform Autotuning
    print("Tuning...")
    tuning_opt = {
        "log_filename": opt.log_filename,
        "tuner": opt.tuner,
        "n_trial": 1e9,
        "early_stopping": None,
        "measure_option": autotvm.measure_option(
            builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
            runner=autotvm.RPCRunner(
                env.TARGET,