Beispiel #1
0
def register_vta_tuning_tasks():
    from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args

    @tvm.tag_scope(tag=topi.tag.ELEMWISE)
    def my_clip(x, a_min, a_max):
        """Unlike topi's current clip, put min and max into two stages."""
        const_min = tvm.const(a_min, x.dtype)
        const_max = tvm.const(a_max, x.dtype)
        x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
        x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
        return x

    # init autotvm env to register VTA operator
    TaskExtractEnv()

    @autotvm.task.register("topi_nn_conv2d", override=True)
    def _topi_nn_conv2d(*args, **kwargs):
        assert not kwargs, "Do not support kwargs in template function call"
        args = deserialize_args(args)
        A, W = args[:2]

        with tvm.target.vta():
            res = topi.nn.conv2d(*args, **kwargs)
            res = topi.right_shift(res, 8)
            res = my_clip(res, 0, 127)
            res = topi.cast(res, "int8")

        if tvm.target.current_target().device_name == 'vta':
            s = topi.generic.schedule_conv2d_nchw([res])
        else:
            s = tvm.create_schedule([res.op])
        return s, [A, W, res]
Beispiel #2
0
        results = verify_bitserial_conv2d_nhwc(log_file, batch, in_size, ic,
                                               oc, k, stride, padding,
                                               activation_bits, weight_bits,
                                               in_dtype, pack_dtype, out_dtype,
                                               parallel)
        mean_ms = np.mean(results) * 1000
        std_dev_ms = np.std(results) * 1000
        print("Workload", w, " Average time", mean_ms, "ms", "std deviation",
              std_dev_ms, "ms")
        all_results[itr] = results
    np.savetxt(raw_data, all_results)


if __name__ == "__main__":
    # Convolutions layers of resent minus the first which is traditionally not binarized
    TaskExtractEnv.get()
    resnet_workload = [(1, 56, 64, 64, 3, 1, 'SAME'),
                       (1, 56, 64, 64, 1, 1, 'VALID'),
                       (1, 56, 64, 128, 3, 2, 'SAME'),
                       (1, 56, 64, 128, 1, 2, 'VALID'),
                       (1, 28, 128, 128, 3, 1, 'SAME'),
                       (1, 28, 128, 256, 3, 2, 'SAME'),
                       (1, 28, 128, 256, 1, 2, 'VALID'),
                       (1, 14, 256, 256, 3, 1, 'SAME'),
                       (1, 14, 256, 512, 3, 2, 'SAME'),
                       (1, 14, 256, 512, 1, 2, 'VALID'),
                       (1, 7, 512, 512, 3, 1, 'SAME')]
    if args.first:
        resnet_workload = resnet_workload = [resnet_workload[0]]

    print("A", args.activation_bits, "W", args.weight_bits, sep='')