Exemplo n.º 1
0
def main(_):
  if FLAGS.gen_register_op:
    assert FLAGS.output.endswith('.cc')
    generated_code = gen_register_op(sys.modules[__name__], '_override_')
  else:
    assert FLAGS.output.endswith('.mlir')
    generated_code = tfr_gen_from_module(sys.modules[__name__], '_override_')

  dirname = os.path.dirname(FLAGS.output)
  if not os.path.exists(dirname):
    os.makedirs(dirname)
  with open(FLAGS.output, 'w') as f:
    f.write(generated_code)
Exemplo n.º 2
0
 def test_op_reg_gen(self):
     cxx_code = gen_register_op(sys.modules[__name__])
     cxx_code_exp = r"""
   CHECK: #include "tensorflow/core/framework/op.h"
   CHECK-EMPTY
   CHECK: namespace tensorflow {
   CHECK-EMPTY
   CHECK-LABEL: REGISTER_OP("TestNoOp")
   CHECK-NEXT:      .Attr("T: numbertype")
   CHECK-NEXT:      .Output("o1: T");
   CHECK-EMPTY
   CHECK-LABEL: REGISTER_OP("TestCompositeOp")
   CHECK-NEXT:      .Input("x: T")
   CHECK-NEXT:      .Input("y: T")
   CHECK-NEXT:      .Attr("act: {'', 'relu'}")
   CHECK-NEXT:      .Attr("trans: bool = true")
   CHECK-NEXT:      .Attr("T: numbertype")
   CHECK-NEXT:      .Output("o1: T")
   CHECK-NEXT:      .Output("o2: T");
   CHECK-EMPTY
   CHECK:  }  // namespace tensorflow
 """
     self.assertTrue(fw.check(str(cxx_code), cxx_code_exp), str(cxx_code))