def gen_chxvm_codegen_cc(): lines = [] for op in CHX_ALL_OPS: signature = make_codegen_signature(op.name, op.inputs, op.outputs) lines.append(signature + ' {') lines.append( 'ChxVMInstructionProto* inst = program->add_instructions();') lines.append('inst->set_op(ChxVMInstructionProto::%s);' % op.name) for inp in op.inputs: lines.append('{') lines.append('ChxVMValueProto* input_proto = inst->add_inputs();') enum = inp.typ.replace('OPTIONAL_', '') lines.append('input_proto->set_type(ChxVMValueProto::%s);' % enum) pfn = inp.proto_field_name() name = inp.name if inp.is_repeated(): lines.append('for (auto v : %s) ' % name + 'input_proto->add_%s(v);' % pfn) else: lines.append('input_proto->set_%s(%s);' % (pfn, name)) lines.append('}') for typ, name in op.outputs: if typ == ARRAY_LIST: lines.append( 'for (const ChxVMValue& v : %s) v.AddOutput(inst);' % name) else: lines.append('%s.AddOutput(inst);' % name) lines.append('}') with open(args.output_dir + '/gen_chxvm_codegen.cc', 'w') as f: f.write(r'''// Auto-generated by gen_chxvm.py #include <compiler/gen_chxvm_codegen.h> #include <runtime/chxvm.pb.h> namespace chainer_compiler { namespace chxvm { using runtime::ChxVMInstructionProto; using runtime::ChxVMValueProto; ''') f.writelines(codegen_util.format_code(lines)) f.write(r''' } // namespace chxvm } // namespace chainer_compiler ''')
def gen_chxvm_codegen_h(): lines = [] for op in CHX_ALL_OPS: signature = make_codegen_signature(op.name, op.inputs, op.outputs) lines.append(signature + ';') with open(args.output_dir + '/gen_chxvm_codegen.h', 'w') as f: f.write(r'''// Auto-generated by gen_chxvm.py #pragma once #include <compiler/chxvm/chxvm_value.h> #include <runtime/chxvm.pb.h> namespace chainer_compiler { namespace chxvm { ''') f.writelines(codegen_util.format_code(lines)) f.write(r''' } // namespace chxvm } // namespace chainer_compiler ''')
def gen_gen_node_base_cc(): lines = [] lines.append('NodeBase::OpType NodeBase::StringToOpType' '(const std::string& str) {') conds = [] bodies = [] for i, node in enumerate(NODES): conds.append('str == "%s"' % node.op_type) bodies.append(['return k%s;' % node.op_type]) bodies.append(['CHECK(false) << "Unsupported op_type: " << str;']) if os.name == 'posix': lines.extend(codegen_util.cond(conds, bodies)) else: lines.extend(codegen_util.cond_msvc('goto_flag_type', conds, bodies)) lines.append('}') lines.append('NodeBase::NodeBase(OpType op_type) : op_type_(op_type) {}') lines.append('NodeBase::NodeBase(const onnx::NodeProto& xnode, ' 'const std::vector<Value*>& inputs, ' 'const std::vector<Value*>& outputs) {') lines.append('op_type_ = StringToOpType(xnode.op_type());') lines.append('SetDefaultAttributeValues();') lines.append('ValidateNumInputsOutputs(inputs, outputs);') lines.append('') lines.append('// Validate attributes.') lines.append('switch (op_type_) {') for node in NODES: op = node.op_type lines.append('case k%s: ' % (node.op_type) + '{') lines.append('for (const onnx::AttributeProto& xattr : ' 'xnode.attribute()) {') conds = [] bodies = [] for _, attr in sorted(node.attr_defs.items()): conds.append('xattr.name() == "%s"' % (attr.onnx_name)) blines = [] blines.append( 'if (!g_permissive) ' 'CHECK_EQ(xattr.type(), %s) << xnode.DebugString();' % attr.onnx_type()) if attr.type == int: blines.append('set_%s(xattr.i());' % (attr.c_name)) elif attr.type == bool: blines.append('set_%s(xattr.i() != 0);' % (attr.c_name)) elif attr.type == float: blines.append('set_%s(xattr.f());' % (attr.c_name)) elif attr.type == str: blines.append('set_%s(xattr.s());' % (attr.c_name)) elif attr.type == [Tensor]: blines.append('for (const auto& t : xattr.tensors()) ' '%s_.emplace_back(new Tensor(t));' % attr.c_name) elif isinstance(attr.type, list): fs = attr.onnx_field() blines.append('%s_.assign(xattr.%s().begin(), ' % (attr.c_name, fs) + 'xattr.%s().end());' % (fs)) elif attr.type == Dtype: blines.append( 'set_%s(Dtype(onnx::TensorProto::DataType(xattr.i())));' % (attr.c_name)) elif attr.type == Tensor: blines.append('set_%s(new Tensor(xattr.t()));' % (attr.c_name)) elif attr.type == Graph: blines.append('set_%s(new Graph(xattr.g()));' % (attr.c_name)) else: raise RuntimeError('Unknown attribute type: %s' % attr.type) blines.append('was_%s_set_ = true;' % (attr.c_name)) bodies.append(blines) bodies.append([ 'if (!g_permissive) CHECK(false) << "Invalid attribute `"' '<< xattr.name() << "\' for " << OpTypeToString(op_type_);', 'unknown_attributes_.push_back(xattr);' ]) if os.name == 'posix': lines += codegen_util.cond(conds, bodies) else: lines.extend( codegen_util.cond_msvc('goto_flag_{}'.format(op), conds, bodies)) lines.append('}') lines.append('break;') lines.append('}') lines.append('}') lines.append('ValidateAttributes();') lines.append('}') lines.append('const char* NodeBase::OpTypeToString(OpType op_type) {') lines.append('switch (op_type) {') for node in NODES: lines.append('case NodeBase::k%s: ' % (node.op_type) + 'return "%s";' % (node.op_type)) lines.append('default: CHECK(false) << "Unknown op_type: " << ' 'static_cast<int>(op_type);') lines.append('}') lines.append('}') lines.append('void NodeBase::FillONNXAttributes(onnx::NodeProto* xnode) ' 'const {') lines.append(r''' auto add_int_attr = [&xnode](const std::string& name, int v) { onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::INT); xattr->set_i(v); }; auto add_float_attr = [&xnode](const std::string& name, float v) { onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::FLOAT); xattr->set_f(v); }; auto add_string_attr = [&xnode](const std::string& name, const std::string& v) { onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::STRING); xattr->set_s(v); }; auto add_tensor_attr = [&xnode](const std::string& name, const std::unique_ptr<Tensor>& v) { if (!v.get()) return; onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::TENSOR); v->ToONNX(xattr->mutable_t()); }; auto add_tensors_attr = [&xnode](const std::string& name, const std::vector<std::unique_ptr<Tensor>>& vec) { if (vec.empty()) return; onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::TENSORS); for (const std::unique_ptr<Tensor>& t : vec) t->ToONNX(xattr->add_tensors()); }; auto add_graph_attr = [&xnode](const std::string& name, const std::unique_ptr<Graph>& v) { if (!v.get()) return; onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::GRAPH); v->ToONNX(xattr->mutable_g()); }; auto add_ints_attr = [&xnode](const std::string& name, const std::vector<int64_t>& ints) { if (ints.empty()) return; onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::INTS); for (int s : ints) xattr->add_ints(s); }; auto add_floats_attr = [&xnode](const std::string& name, const std::vector<float>& floats) { if (floats.empty()) return; onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::FLOATS); for (float s : floats) xattr->add_floats(s); }; auto add_strings_attr = [&xnode](const std::string& name, const std::vector<std::string>& strings) { if (strings.empty()) return; onnx::AttributeProto* xattr = xnode->add_attribute(); xattr->set_name(name); xattr->set_type(onnx::AttributeProto::STRINGS); for (const std::string& s : strings) xattr->add_strings(s); }; auto add_dtype_attr = [&xnode, add_int_attr](const std::string& name, Dtype v) { add_int_attr(name, static_cast<int>(v.ToONNX())); }; ''') lines.append('switch (op_type_) {') for node in NODES: lines.append('case k%s: ' % (node.op_type) + '{') for _, attr in sorted(node.attr_defs.items()): lines.append('if (was_%s_set_)' % (attr.c_name)) lines.append(' %s("%s",' % (attr.add_func(), attr.onnx_name) + ' %s_);' % (attr.c_name)) lines.append('break;') lines.append('}') lines.append('}') lines.append( 'for (const onnx::AttributeProto& xattr : unknown_attributes_) {') lines.append('*xnode->add_attribute() = xattr;') lines.append('}') lines.append('}') lines.append('void NodeBase::SetDefaultAttributeValues() {') lines.append('const float inf = std::numeric_limits<float>::infinity();') lines.append('switch (op_type_) {') for node in NODES: lines.append('case k%s: ' % (node.op_type) + '{') for _, attr in sorted(node.attr_defs.items()): if attr.value is None: continue if attr.type == str: lines.append('%s_ = "%s";' % (attr.c_name, attr.value)) elif attr.type == bool: lines.append('%s_ = %s;' % (attr.c_name, str(attr.value).lower())) else: lines.append('%s_ = %s;' % (attr.c_name, attr.value)) lines.append('break;') lines.append('}') lines.append('}') lines.append('}') lines.append('void NodeBase::ValidateNumInputsOutputs(' 'const std::vector<Value*>& inputs, ' 'const std::vector<Value*>& outputs) const {') lines.append('switch (op_type_) {') for node in NODES: op = node.op_type lines.append('case k%s: ' % (op) + '{') for sym, num in [('inputs', node.num_inputs), ('outputs', node.num_outputs)]: if isinstance(num, tuple): conds = ['%d == %s.size()' % (n, sym) for n in num] cond = ' || '.join(conds) lines.append('CHECK(%s) << ' % (cond) + '"Unexpected number of %s for %s (" << ' % (sym, op) + '%s.size() << ")";' % (sym)) elif num is not None: lines.append('CHECK_EQ(%d, %s.size()) << ' % (num, sym) + '"Unexpected number of %s for %s";' % (sym, op)) lines.append('break;') lines.append('}') lines.append('}') lines.append('}') lines.append('void NodeBase::ValidateAttributes() const {') lines.append('switch (op_type_) {') for node in NODES: op = node.op_type lines.append('case k%s: ' % (op) + '{') for key, value in node.attributes.items(): if isinstance(value, Required): lines.append( 'CHECK(was_%s_set_) << "%s is mandatory for %s";' % (key, key, op)) lines.append('break;') lines.append('}') lines.append('}') lines.append('}') for attr in ATTRS: name = attr.c_name arg = attr.c_setter_arg_type() lines.append('NodeBase* NodeBase::set_%s(%s %s) ' % (name, arg, name) + '{') cond = ' || '.join('op_type_ == k%s' % (t) for t in attr.op_types) lines.append('CHECK(%s) << "Invalid attribute `%s\' for " ' % (cond, name) + '<< OpTypeToString(op_type_);') if attr.type in [Tensor, Graph]: lines.append('%s_.reset(%s);' % (name, name)) elif attr.type == [Tensor]: lines.append('%s_ = std::move(%s);' % (name, name)) else: lines.append('%s_ = %s;' % (name, name)) lines.append('was_%s_set_ = true;' % (name)) lines.append('return this;') lines.append('}') with open('gen_node_base.cc', 'w') as f: f.write(r'''// Auto-generated by gen_node.py #include "gen_node_base.h" #include <limits> #include <string> #include <vector> #include <common/log.h> #include <compiler/flags.h> #include <compiler/graph.h> #include <compiler/onnx.h> #include <compiler/tensor.h> namespace chainer_compiler { ''') f.writelines(codegen_util.format_code(lines)) f.write(r''' } // namespace chainer_compiler ''')
def gen_gen_node_base_h(): public_lines = [] private_lines = [] public_lines.append('enum OpType {') for node in NODES: public_lines.append('k%s,' % (node.op_type)) public_lines.append('};') public_lines.append( 'static OpType StringToOpType(const std::string& str);') public_lines.append('static const char* OpTypeToString(OpType op_type);') public_lines.append('OpType op_type() const {') public_lines.append('return op_type_;') public_lines.append('}') for attr in ATTRS: name = attr.c_name arg = attr.c_arg_type() typ = attr.c_type() public_lines.append('%s %s() const ' % (arg, name) + '{') public_lines.append('return %s_;' % (name)) public_lines.append('}') if attr.type == Graph: public_lines.append('Graph* release_%s() ' % (name) + '{') public_lines.append('return %s_.release();' % (name)) public_lines.append('}') sarg = attr.c_setter_arg_type() public_lines.append('NodeBase* set_%s(%s %s);' % (name, sarg, name)) private_lines.append('%s %s_;' % (typ, name)) private_lines.append('bool was_%s_set_ = false;' % (name)) lines = public_lines + ['protected:'] + private_lines with open('gen_node_base.h', 'w') as f: f.write(r'''// Auto-generated by gen_node.py #pragma once #include <memory> #include <string> #include <vector> #include <compiler/dtype.h> #include <compiler/onnx.h> namespace chainer_compiler { class Graph; class Tensor; class Value; class NodeBase { public: void FillONNXAttributes(onnx::NodeProto* xnode) const; void SetDefaultAttributeValues(); void ValidateNumInputsOutputs(const std::vector<Value*>& inputs, const std::vector<Value*>& outputs) const; void ValidateAttributes() const; ''') f.writelines(codegen_util.format_code(lines, num_indents=4)) f.write(r''' OpType op_type_; std::vector<onnx::AttributeProto> unknown_attributes_; explicit NodeBase(OpType op_type); NodeBase(const onnx::NodeProto& xnode, const std::vector<Value*>& inputs, const std::vector<Value*>& outputs); }; } // namespace chainer_compiler ''')
def gen_gen_xcvm_ops_h(): lines = [] for op in XC_ALL_OPS: lines.append('class %sOp : public XCVMOp {' % op.name) lines.append('public:') lines.append('explicit %sOp(const XCInstructionProto& inst);' % op.name) args = ['XCVMState* st'] if op.typed: for typ, name in op.inputs: if typ == ARRAY: args.append('const chainerx::Array& %s' % name) elif typ == OPTIONAL_ARRAY: args.append('const nonstd::optional<chainerx::Array>& %s' % name) elif typ == ARRAY_LIST: args.append('const std::vector<chainerx::Array>& %s' % name) elif typ == SEQUENCE: args.append('const XCVMSequence& %s' % name) elif typ == OPAQUE: args.append('const XCVMOpaque& %s' % name) else: assert typ in FIELD_TYPES, 'Unknown type: %s' % typ output_ctypes = [] for typ, name in op.outputs: if typ == ARRAY_LIST: output_ctypes.append('std::vector<chainerx::Array>') elif typ == SEQUENCE: args.append('XCVMSequence* %s' % name) elif typ == OPAQUE: output_ctypes.append('XCVMOpaque*') else: output_ctypes.append('chainerx::Array') if len(output_ctypes) == 0: rettype = 'void' elif len(output_ctypes) == 1: rettype = output_ctypes[0] else: rettype = 'std::tuple<' + ', '.join(output_ctypes) + '>' else: rettype = 'void' lines.append('%s RunImpl(%s);' % (rettype, ', '.join(args))) lines.append('virtual void Run(XCVMState* st);') lines.append('private:') for inp in op.inputs: ctype = inp.c_storage_type() lines.append('%s %s;' % (ctype, inp.name)) for out in op.outputs: ctype = out.c_storage_type() lines.append('%s %s;' % (ctype, out.name)) if op.has_custom_field: lines.append('~%sOp() override;' % op.name) lines.append('void InitImpl();') lines.append('class %sImpl;' % op.name) lines.append('%sImpl* impl_{nullptr};' % op.name) lines.append('};') with open(output_dir + '/gen_xcvm_ops.h', 'w') as f: f.write(r'''// Auto-generated by gen_xcvm.py #pragma once #include <memory> #include <string> #include <chainerx/stack_vector.h> #include <runtime/xcvm_op.h> #include <runtime/xcvm_state.h> #include <runtime/xcvm.pb.h> namespace chainer_compiler { namespace runtime { ''') f.writelines(codegen_util.format_code(lines)) f.write(r''' } // namespace runtime } // namespace chainer_compiler ''')
def gen_gen_xcvm_ops_cc(): lines = [] for op in XC_ALL_OPS: # Emit constructor. lines.append('%sOp::%sOp(const XCInstructionProto& inst)' ': XCVMOp(inst) {' % (op.name, op.name)) for i, inp in enumerate(op.inputs): enum = inp.typ.replace('OPTIONAL_', '') lines.append('CHECK_EQ(XCValueProto::%s, ' % (enum) + 'inst.inputs(%d).type()) ' % (i) + '<< "Unexpected type for input#%d of %s";' % (i, op.name)) pfn = inp.proto_field_name() name = inp.name if not inp.is_repeated(): lines.append('%s = inst.inputs(%d).%s();' % (name, i, pfn)) elif inp.typ == INTS: lines.append('%s = %s(' % (name, STACK_VECTOR) + 'inst.inputs(%d).ints().begin(), ' % (i) + 'inst.inputs(%d).ints().end());' % (i)) else: lines.append('%s.assign(inst.inputs(%d).%s().begin(),' % (name, i, pfn) + 'inst.inputs(%d).%s().end());' % (i, pfn)) for i, (typ, name) in enumerate(op.outputs): if typ == ARRAY_LIST: lines.append('%s.assign(inst.outputs().begin(), ' 'inst.outputs().end());' % name) else: lines.append('%s = inst.outputs(%d);' % (name, i)) if op.has_custom_field: lines.append('InitImpl();') lines.append('}') # Emit Run. lines.append('void %sOp::Run(XCVMState* st) {' % op.name) lines.append('if (st->trace_level() && !debug_info().empty()) ' 'std::cerr << "# " << debug_info() << std::endl;') line = 'if (st->trace_level()) std::cerr' if op.outputs: for typ, name in op.outputs: if typ == ARRAY_LIST: line += ' << ArrayListToString(%s)' % name else: line += ' << "%s" << %s' % (sigil(typ), name) line += ' << " = "' line += ' << "%s("' % (op.name) for i, (typ, name) in enumerate(op.inputs): if i: line += ' << ", "' if typ in [ARRAY, OPTIONAL_ARRAY, SEQUENCE, OPAQUE]: line += ' << "%s" << %s' % (sigil(typ), name) elif typ in (INT, FLOAT): line += ' << %s' % name elif typ in [STRING, LONGS, DOUBLES]: line += ' << "%s"' % name elif typ == INTS: line += ' << StackVectorToString(%s)' % name elif typ == ARRAY_LIST: line += ' << ArrayListToString(%s)' % name else: raise RuntimeError('Unknown type: %s' % typ) line += ' << ")"' line += ' << std::endl;' lines.append(line) line = 'if (st->trace_level()) std::cerr' for typ, name in op.inputs: if typ in [ARRAY, OPTIONAL_ARRAY, SEQUENCE]: line += ' << " %s" << %s << "="' % (sigil(typ), name) line += ' << st->GetVarString(%s)' % name elif typ == ARRAY_LIST: line += ' << st->GetVarListString(%s)' % name if op.outputs: line += ' << " ->"' if not line.endswith('std::cerr'): line += ';' lines.append(line) if op.typed: args = ['st'] # TODO(hamaji): Remove this code by removing null gradients. conds = [] for typ, name in op.inputs: if typ in ARG_TYPES and typ != ARRAY_LIST: conds.append('(%s >= 0 && st->GetVar(%s)->IsNull())' % (name, name)) if conds: lines.append('if (%s) {' % (' || '.join(conds))) lines.append('WARN_ONCE("%s skipped\\n");' % op.name) for typ, oname in op.outputs: if typ in ARG_TYPES and typ != ARRAY_LIST: lines.append('st->SetVar(%s, XCVMVar());' % oname) lines.append('return;') lines.append('}') for typ, name in op.inputs: if typ == ARRAY: args.append('st->GetArray(%s)' % name) elif typ == OPTIONAL_ARRAY: args.append('st->GetOptionalArray(%s)' % name) elif typ == ARRAY_LIST: args.append('st->GetArrayList(%s)' % name) elif typ == SEQUENCE: args.append('*st->GetSequence(%s)' % name) elif typ == OPAQUE: args.append('st->GetOpaque(%s)' % name) outputs = [] for output in op.outputs: typ, name = output if typ == SEQUENCE: args.append('st->CreateSequence(%s)' % name) else: outputs.append(output) call = 'RunImpl(%s)' % ', '.join(args) if len(outputs) == 1: typ, name = outputs[0] if typ == ARRAY_LIST: lines.append('st->SetArrayList(%s, %s);' % (name, call)) elif typ == OPAQUE: lines.append('st->SetOpaque(%s, %s);' % (name, call)) else: lines.append('st->SetArray(%s, %s);' % (name, call)) elif outputs: lines.append('auto r_ = ' + call + ';') for i, (typ, output) in enumerate(outputs): # TODO(hamaji): Revisit optional outputs. if typ == OPAQUE: lines.append( 'if (%s >= 0) st->SetOpaque(%s, std::get<%d>(r_));' % (output, output, i)) lines.append('else delete std::get<%d>(r_);' % i) else: lines.append( 'if (%s >= 0) st->SetArray(%s, std::get<%d>(r_));' % (output, output, i)) lines.append(line) else: lines.append(call + ';') else: lines.append('RunImpl(st);') line = 'if (st->trace_level()) std::cerr' for typ, name in op.outputs: if typ in [ARRAY, OPTIONAL_ARRAY, SEQUENCE, OPAQUE]: line += ' << " %s" << %s << "="' % (sigil(typ), name) line += ' << st->GetVarString(%s)' % name elif typ == ARRAY_LIST: line += ' << st->GetVarListString(%s)' % name else: raise RuntimeError('Unknown output type: %s' % typ) line += ' << std::endl;' lines.append(line) if op.outputs: inputs_str = ', '.join([ name for typ, name in op.inputs if typ == ARRAY or typ == OPTIONAL_ARRAY ]) outputs_str = ', '.join(op.output_names) lines.append('if (st->check_infs()) st->CheckInfs({%s}, {%s});' % (inputs_str, outputs_str)) lines.append('if (st->check_nans()) st->CheckNans({%s}, {%s});' % (inputs_str, outputs_str)) lines.append('}') lines.append('XCVMOp* MakeXCVMOp(const XCInstructionProto& inst) {') lines.append('switch (inst.op()) {') for op in XC_ALL_OPS: lines.append('case XCInstructionProto::%s:' % (op.name)) lines.append('return new %sOp(inst);' % (op.name)) lines.append('default:') lines.append('CHECK(false) << "Unknown op: " ' + '<< static_cast<int>(inst.op());') lines.append('}') lines.append('}') with open(output_dir + '/gen_xcvm_ops.cc', 'w') as f: f.write(r'''// Auto-generated by gen_xcvm.py #include <string> #include <sstream> #include <common/log.h> #include <runtime/gen_xcvm_ops.h> namespace chainer_compiler { namespace runtime { std::string StackVectorToString(const chainerx::StackVector<int64_t, chainerx::kMaxNdim>& s) { std::ostringstream oss; for (int v : s) { oss << (oss.str().empty() ? '(' : ','); oss << v; } oss << ')'; return oss.str(); } std::string ArrayListToString(const std::vector<int>& s) { std::ostringstream oss; for (int v : s) { oss << (oss.str().empty() ? '(' : ','); oss << '$' << v; } oss << ')'; return oss.str(); } ''') f.writelines(codegen_util.format_code(lines)) f.write(r''' } // namespace runtime } // namespace chainer_compiler ''')