def _MakeEagerLogicalBlob(op_attribute, obn, blob_register): lbi = op_attribute.arg_signature.bn_in_op2lbi[obn] blob_object = blob_register.GetObject4BlobName( "%s/%s" % (lbi.op_name, lbi.blob_name) ) mirrored_sig_map = op_attribute.mirrored_signature.bn_in_op2opt_mirrored_parallel if not isinstance(lbi, lbi_util.LogicalBlobId): cfg_lbi = lbi_util.LogicalBlobId() cfg_lbi.set_op_name(lbi.op_name) cfg_lbi.set_blob_name(lbi.blob_name) lbi = cfg_lbi if mirrored_sig_map[obn].HasField("mirrored_parallel"): return oneflow_api.EagerMirroredBlob(lbi, blob_object, default_blob_register) else: return oneflow_api.EagerConsistentBlob(lbi, blob_object, default_blob_register)
def Build(builder, Yield): blob_object = _GetInterfaceBlobObject(builder, op_name) lbi = lbi_util.LogicalBlobId() lbi.set_op_name(op_name) op_attribute = sess.OpAttribute4InterfaceOpName(op_name) assert len(op_attribute.output_bns) == 1 lbi.set_blob_name(op_attribute.output_bns[0]) if blob_object.op_arg_parallel_attr.is_mirrored(): remote_blob = oneflow_api.EagerMirroredBlob( lbi, blob_object, blob_register, job_name) else: remote_blob = oneflow_api.EagerConsistentBlob( lbi, blob_object, blob_register, job_name) Yield(remote_blob)
def _Watch(op_attribute, parallel_conf, blob_register): lbi = op_attribute.arg_signature.bn_in_op2lbi["in"] uuid = op_attribute.op_conf.foreign_watch_conf.handler_uuid lbn = "%s/%s" % (lbi.op_name, lbi.blob_name) in_blob_object = blob_register.GetObject4BlobName(lbn) if not isinstance(lbi, lbi_util.LogicalBlobId): cfg_lbi = lbi_util.LogicalBlobId() cfg_lbi.set_op_name(lbi.op_name) cfg_lbi.set_blob_name(lbi.blob_name) lbi = cfg_lbi if in_blob_object.op_arg_parallel_attr.is_mirrored(): blob = oneflow_api.EagerMirroredBlob(lbi, in_blob_object, default_blob_register) else: blob = oneflow_api.EagerConsistentBlob(lbi, in_blob_object, default_blob_register) uuid2watch_handler = session_ctx.GetDefaultSession().uuid2watch_handler assert uuid in uuid2watch_handler uuid2watch_handler[uuid](blob) del uuid2watch_handler[uuid]
def build(builder): blob_object = GetEagerInterfaceBlob(op_name).blob_object lbi = lbi_util.LogicalBlobId() lbi.set_op_name(op_name) op_attribute = sess.OpAttribute4InterfaceOpName(op_name) assert len(op_attribute.output_bns) == 1 lbi.set_blob_name(op_attribute.output_bns[0]) if not isinstance(lbi, lbi_util.LogicalBlobId): cfg_lbi = lbi_util.LogicalBlobId() cfg_lbi.set_op_name(lbi.op_name) cfg_lbi.set_blob_name(lbi.blob_name) lbi = cfg_lbi if blob_object.op_arg_parallel_attr.is_mirrored(): remote_blob = oneflow_api.EagerMirroredBlob( lbi, blob_object, blob_register, job_name) else: remote_blob = oneflow_api.EagerConsistentBlob( lbi, blob_object, blob_register, job_name) value = remote_blob.numpy() Yield(value)