def generate_fully_serial_mlir(num_kernels):
    """Generate a fully serial DAG for benchmarking BEFExecutor."""

    body = """
  // The pseudo-code for this mlir function is as follows:
  //
  // a = 1
  // c0 = 1
  // c1 = c0 + a
  // c2 = c1 + a
  // c3 = c2 + a
  // ...
  // Since each c_i depends on c_{i-1}, all c_i's need to be computed in serial.

  %a = tfrt.constant.i32 1
  %c0 = tfrt.constant.i32 1
"""

    def gen_line(c):
        return '  %c{} = "tfrt.add.i32"(%c{}, %a) : (i32, i32) -> i32'.format(
            c + 1, c)

    body += '\n'.join([gen_line(c) for c in range(num_kernels)])
    # Add return statement.
    body += '\n  tfrt.return %c{} : i32'.format(num_kernels)

    return generate_benchmark_mlir('BM_full_serial_{}'.format(num_kernels),
                                   body)
def generate_fully_parallel_mlir(num_kernels):
    """Generate a fully parallel DAG for benchmarking BEFExecutor."""

    body = """
  // The pseudo-code for this mlir function is as follows:
  //
  // a = 1
  // c0 = 1
  // c1 = c0 + a
  // c2 = c0 + a
  // c3 = c0 + a
  // ...
  //
  // Since c_i's have no dependency with each other, they can be computed in
  // parallel.

  %a = tfrt.constant.i32 1
  %c0 = tfrt.constant.i32 1
"""

    def gen_line(c):
        return '  %c{} = "tfrt_test.async_add.i32"(%c0, %a) : (i32, i32) -> i32'.format(
            c + 1)

    body += '\n'.join([gen_line(c) for c in range(num_kernels)])
    # Add return statement.
    body += '\n  tfrt.return %c{} : i32'.format(num_kernels)

    return generate_benchmark_mlir('BM_full_parallel_{}'.format(num_kernels),
                                   body)
def generate_host_tensor(num_kernels):
    """Benchmark DHTIndexableView overhead.

  Generate a no-op host tensor program for benchmarking the overhead of
  DHTIndexableView.
  """

    body = """
  // The pseudo-code for this mlir function is as follows:
  //
  // t = dense_host_tensor
  // c0 = tfrt.chain
  // c1 = dht.no_op_ht(t, c0)
  // c2 = dht.no_op_ht(t, c1)
  // c3 = dht.no_op_ht(t, c2)
  // ...
  // return cn

  %t = dht.create_uninitialized_tensor.i32.2 [3 : i32, 2 : i32]
  %c0 = tfrt.new.chain
"""

    def gen_line(c):
        return ('  %c{} = "dht.no_op_ht"(%t, %c{}) : '
                '(!dht.dense_host_tensor.i32.2, !tfrt.chain) -> !tfrt.chain'
                ).format(c + 1, c)

    body += '\n'.join([gen_line(c) for c in range(num_kernels)])
    # Add return statement.
    body += '\n  tfrt.return %c{num_kernels} : !tfrt.chain'

    return generate_benchmark_mlir('BM_HostTensor_{}'.format(num_kernels),
                                   body)
def generate_star_mlir(num_kernels):
    """Generate a fully parallel DAG for benchmarking BEFExecutor."""

    body = """
  // The pseudo-code for this mlir function is as follows:
  //
  // a = 1
  // c0 = 1
  // c1 = c0 + a
  // c2 = c0 + a
  // c3 = c0 + a
  // ...
  // s = sum(c0, c1, c2 ...)
  //
  // Since c_i's have no dependency with each other, they can be computed in
  // parallel. s depends on all of c_i's, thus can only be computed after
  // the computation for all c_i's is done.

  %a = tfrt.constant.i32 1
  %c0 = tfrt.constant.i32 1
"""

    def gen_line(c):
        return '  %c{} = "tfrt_test.async_add.i32"(%c0, %a) : (i32, i32) -> i32'.format(
            c + 1)

    # Construct sum statement:
    # %s = "tfrt_test.sum100"(%c1, %c2, ...) : (i32, ..., i32) -> i32
    sum_line = '  %s = "tfrt_test.sum"({args}) : ({arg_types}) -> i32'.format(
        args=', '.join(['%c{}'.format(i + 1) for i in range(num_kernels)]),
        arg_types=', '.join(['i32' for i in range(num_kernels)]),
    )

    body += '\n'.join([gen_line(c) for c in range(num_kernels)])
    body += ('\n' + sum_line)
    # Add return statement.
    body += '\n  tfrt.return %s : i32'

    return generate_benchmark_mlir('BM_star_{}'.format(num_kernels), body)