Exemplo n.º 1
0
def InsertRemoveForeignCallbackInstruction(self, object_id, callback):
    unique_callback_id = python_callback.GetIdForRegisteredCallback(callback)
    instruction = instr_cfg.InstructionProto()
    instruction.set_instr_type_name("RemoveForeignCallback")
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.DelObjectOperand(object_id))
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.Int64Operand(unique_callback_id))
    self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)
Exemplo n.º 2
0
def _InitOpConfSymbol(self, symbol_id, op_conf):
    instruction = instr_cfg.InstructionProto()
    instruction.set_instr_type_name("InitOperatorConfSymbol")
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.InitSymbolOperand(symbol_id)
    )
    self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)
    eager_symbol = eager_symbol_pb.EagerSymbol()
    eager_symbol.symbol_id = symbol_id
    eager_symbol.op_conf_symbol.CopyFrom(op_conf)
    eager_symbol = oneflow_api.deprecated.MakeEagerSymbolByString(str(eager_symbol))
    self.eager_symbol_list().mutable_eager_symbol().Add().CopyFrom(eager_symbol)
Exemplo n.º 3
0
def FeedBlob(self, blob_object, feeder):
    unique_callback_id = python_callback.GetIdForRegisteredCallback(feeder)
    instruction = instr_cfg.InstructionProto()
    device_tag = blob_object.parallel_desc_symbol.device_tag
    instruction.set_instr_type_name("%s.%s" % (device_tag, "FeedBlob"))
    instruction.set_parallel_desc_symbol_id(
        blob_object.parallel_desc_symbol.symbol_id)
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.Mut2Operand(blob_object.object_id))
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.Int64Operand(unique_callback_id))
    self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)
Exemplo n.º 4
0
def _NewOpKernelObject(self, parallel_desc_symbol, job_desc_sym, op_conf_sym):
    object_id = self.NewObjectId(parallel_desc_symbol)
    instruction = instr_cfg.InstructionProto()
    instruction.set_instr_type_name("InitOpKernelObject")
    instruction.set_parallel_desc_symbol_id(parallel_desc_symbol.symbol_id)
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.SymbolOperand(job_desc_sym.symbol_id)
    )
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.SymbolOperand(op_conf_sym.symbol_id)
    )
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.MutOperand(object_id)
    )
    self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)
    return object_id
Exemplo n.º 5
0
def _StatefulCallOpKernel(
    self,
    instr_name,
    parallel_desc_sym,
    opkernel_object,
    op_node_signature_sym,
    const_input_operand_blob_objects,
    mutable_input_operand_blob_objects,
    mut1_operand_blob_objects,
    mut2_operand_blob_objects,
):
    instruction = instr_cfg.InstructionProto()
    instruction.set_instr_type_name(
        "%s.%s" % (parallel_desc_sym.device_tag, instr_name,)
    )
    instruction.set_parallel_desc_symbol_id(parallel_desc_sym.symbol_id)
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.MutOperand(opkernel_object.object_id)
    )
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.SymbolOperand(op_node_signature_sym.symbol_id)
    )
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.OperandSeparator()
    )
    for ibn_sym, _ in const_input_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.SymbolOperand(ibn_sym.symbol_id)
        )
    for _, blob_object in const_input_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.ConstOperand(blob_object.object_id)
        )
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.OperandSeparator()
    )
    for ibn_sym, _ in mutable_input_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.SymbolOperand(ibn_sym.symbol_id)
        )
    for _, blob_object in mutable_input_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.MutOperand(blob_object.object_id)
        )
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.OperandSeparator()
    )
    for obn_sym, _ in mut1_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.SymbolOperand(obn_sym.symbol_id)
        )
    for _, blob_object in mut1_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.MutOperand(blob_object.object_id)
        )
    instruction.mutable_operand().Add().CopyFrom(
        oneflow_api.deprecated.vm.OperandSeparator()
    )
    for obn_sym, _ in mut2_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.SymbolOperand(obn_sym.symbol_id)
        )
    for _, blob_object in mut2_operand_blob_objects:
        instruction.mutable_operand().Add().CopyFrom(
            oneflow_api.deprecated.vm.Mut2Operand(blob_object.object_id)
        )
    self.instruction_list().mutable_instruction().Add().CopyFrom(instruction)