def FeedValueToEagerBlob(blob_object, blob_def, ndarray): physical_blob_objects = _GetPhysicalBlobObjects(blob_object, None) feed_ctx = FeedContext(blob_object.op_arg_parallel_attr, ndarray) for i, physical_blob_object in enumerate(physical_blob_objects): feed_ctx.set_rank(i) _FeedValueToInputPhysicalBlob(feed_ctx, blob_def, physical_blob_object) oneflow_api.TryDisableBlobCache(blob_object)
def _LogicalSliceAssign( ref_blob: EagerBlobTrait, value_blob: EagerBlobTrait, start: Sequence[int], stop: Sequence[int], ) -> None: """ Construct a logical_slice_assign op and run it by oneflow eager """ ref_blob_object = ref_blob.blob_object value_blob_object = value_blob.blob_object def BuildAssignInstruction(builder): op_conf = op_conf_pb.OperatorConf() # device_tag doesn't matter for logical_slice_assign op device_tag = oneflow.current_scope( ).device_parallel_desc_symbol.device_tag op_conf.device_tag = device_tag op_name = id_util.UniqueStr(OP_PREFIX) op_conf.name = op_name op_conf.user_conf.op_type_name = "logical_slice_assign" op_conf.user_conf.input["value"].s.append("{}/value_0".format(op_name)) op_conf.user_conf.input["ref"].s.append("{}/ref_0".format(op_name)) parallel_conf = ref_blob_object.parallel_desc_symbol.parallel_conf op_conf.user_conf.attr["parallel_conf"].at_string = str(parallel_conf) op_conf.user_conf.attr["start"].at_list_int64.val[:] = start op_conf.user_conf.attr["stop"].at_list_int64.val[:] = stop op_conf.user_conf.attr["step"].at_list_int64.val[:] = [1] * len(start) bn_in_op2blob_object = oneflow_api.deprecated.BnInOp2BlobObject() bn_in_op2blob_object["ref_0"] = ref_blob_object bn_in_op2blob_object["value_0"] = value_blob_object scope_symbol_id = _GetScopeSymbolIdFromEagerBlob(ref_blob) op_attribute = op_infer_util.Infer(op_conf, bn_in_op2blob_object, scope_symbol_id) cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString( str(op_attribute)) builder.StatelessCall( cfg_op_attribute, parallel_conf, bn_in_op2blob_object, boxing_util.BoxingTo, ) oneflow_api.deprecated.LogicalRun(BuildAssignInstruction) oneflow_api.TryDisableBlobCache(ref_blob_object)
def BuildAssignInstruction(builder, ref_blob_object, value_blob_object, op_conf): oneflow_api.TryDisableBlobCache(ref_blob_object) ref_parallel_conf = ref_blob_object.parallel_desc_symbol.parallel_conf ref_devices = ref_blob_object.parallel_desc_symbol.machine_id2device_id_list value_devices = value_blob_object.parallel_desc_symbol.machine_id2device_id_list assert ref_devices == value_devices, "\nref_devices: %s\nvalue_devices: %s" % ( ref_devices, value_devices, ) ref_device_tag = ref_blob_object.parallel_desc_symbol.device_tag value_device_tag = value_blob_object.parallel_desc_symbol.device_tag bn_in_op2blob_object = oneflow_api.deprecated.BnInOp2BlobObject() bn_in_op2blob_object["ref"] = ref_blob_object bn_in_op2blob_object["value"] = value_blob_object op_attribute = op_infer_util.Infer(op_conf, bn_in_op2blob_object) cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString( str(op_attribute)) if ref_device_tag == value_device_tag: builder.NoBoxingStatelessCall( cfg_op_attribute, ref_parallel_conf, bn_in_op2blob_object, ) elif ref_device_tag == "cpu" and value_device_tag == "gpu": value_parallel_conf = value_blob_object.parallel_desc_symbol.parallel_conf builder.NoBoxingCudaD2HStatelessCall( cfg_op_attribute, value_parallel_conf, bn_in_op2blob_object, TryReplaceDeviceTag, ) elif ref_device_tag == "gpu" and value_device_tag == "cpu": with _CudaHostPinBlob(builder, value_blob_object): builder.NoBoxingCudaH2DStatelessCall( cfg_op_attribute, ref_parallel_conf, bn_in_op2blob_object, ) else: raise NotImplementedError( "invalid device found. ref_device_tag: %s, value_device_tag: %s" % (ref_device_tag, value_device_tag))