def InsertRemoveForeignCallbackInstruction(self, object_id, callback): unique_callback_id = python_callback.GetIdForRegisteredCallback(callback) instruction = instr_cfg.InstructionProto() instruction.set_instr_type_name("RemoveForeignCallback") instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.DelObjectOperand(object_id)) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.Int64Operand(unique_callback_id)) self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)
def _InitOpConfSymbol(self, symbol_id, op_conf): instruction = instr_cfg.InstructionProto() instruction.set_instr_type_name("InitOperatorConfSymbol") instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.InitSymbolOperand(symbol_id) ) self.instruction_list().mutable_instruction().Add().CopyFrom(instruction) eager_symbol = eager_symbol_pb.EagerSymbol() eager_symbol.symbol_id = symbol_id eager_symbol.op_conf_symbol.CopyFrom(op_conf) eager_symbol = oneflow_api.deprecated.MakeEagerSymbolByString(str(eager_symbol)) self.eager_symbol_list().mutable_eager_symbol().Add().CopyFrom(eager_symbol)
def FeedBlob(self, blob_object, feeder): unique_callback_id = python_callback.GetIdForRegisteredCallback(feeder) instruction = instr_cfg.InstructionProto() device_tag = blob_object.parallel_desc_symbol.device_tag instruction.set_instr_type_name("%s.%s" % (device_tag, "FeedBlob")) instruction.set_parallel_desc_symbol_id( blob_object.parallel_desc_symbol.symbol_id) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.Mut2Operand(blob_object.object_id)) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.Int64Operand(unique_callback_id)) self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)
def _NewOpKernelObject(self, parallel_desc_symbol, job_desc_sym, op_conf_sym): object_id = self.NewObjectId(parallel_desc_symbol) instruction = instr_cfg.InstructionProto() instruction.set_instr_type_name("InitOpKernelObject") instruction.set_parallel_desc_symbol_id(parallel_desc_symbol.symbol_id) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.SymbolOperand(job_desc_sym.symbol_id) ) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.SymbolOperand(op_conf_sym.symbol_id) ) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.MutOperand(object_id) ) self.instruction_list().mutable_instruction().Add().CopyFrom(instruction) return object_id
def _StatefulCallOpKernel( self, instr_name, parallel_desc_sym, opkernel_object, op_node_signature_sym, const_input_operand_blob_objects, mutable_input_operand_blob_objects, mut1_operand_blob_objects, mut2_operand_blob_objects, ): instruction = instr_cfg.InstructionProto() instruction.set_instr_type_name( "%s.%s" % (parallel_desc_sym.device_tag, instr_name,) ) instruction.set_parallel_desc_symbol_id(parallel_desc_sym.symbol_id) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.MutOperand(opkernel_object.object_id) ) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.SymbolOperand(op_node_signature_sym.symbol_id) ) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.OperandSeparator() ) for ibn_sym, _ in const_input_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.SymbolOperand(ibn_sym.symbol_id) ) for _, blob_object in const_input_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.ConstOperand(blob_object.object_id) ) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.OperandSeparator() ) for ibn_sym, _ in mutable_input_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.SymbolOperand(ibn_sym.symbol_id) ) for _, blob_object in mutable_input_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.MutOperand(blob_object.object_id) ) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.OperandSeparator() ) for obn_sym, _ in mut1_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.SymbolOperand(obn_sym.symbol_id) ) for _, blob_object in mut1_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.MutOperand(blob_object.object_id) ) instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.OperandSeparator() ) for obn_sym, _ in mut2_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.SymbolOperand(obn_sym.symbol_id) ) for _, blob_object in mut2_operand_blob_objects: instruction.mutable_operand().Add().CopyFrom( oneflow_api.deprecated.vm.Mut2Operand(blob_object.object_id) ) self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)