Esempio n. 1
0
def _MakeEagerLogicalBlob(op_attribute, obn, blob_register):
    lbi = op_attribute.arg_signature.bn_in_op2lbi[obn]
    blob_object = blob_register.GetObject4BlobName(
        "%s/%s" % (lbi.op_name, lbi.blob_name)
    )
    mirrored_sig_map = op_attribute.mirrored_signature.bn_in_op2opt_mirrored_parallel
    if not isinstance(lbi, lbi_util.LogicalBlobId):
        cfg_lbi = lbi_util.LogicalBlobId()
        cfg_lbi.set_op_name(lbi.op_name)
        cfg_lbi.set_blob_name(lbi.blob_name)
        lbi = cfg_lbi
    if mirrored_sig_map[obn].HasField("mirrored_parallel"):
        return oneflow_api.EagerMirroredBlob(lbi, blob_object, default_blob_register)
    else:
        return oneflow_api.EagerConsistentBlob(lbi, blob_object, default_blob_register)
        def Build(builder, Yield):
            blob_object = _GetInterfaceBlobObject(builder, op_name)
            lbi = lbi_util.LogicalBlobId()
            lbi.set_op_name(op_name)
            op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
            assert len(op_attribute.output_bns) == 1
            lbi.set_blob_name(op_attribute.output_bns[0])
            if blob_object.op_arg_parallel_attr.is_mirrored():
                remote_blob = oneflow_api.EagerMirroredBlob(
                    lbi, blob_object, blob_register, job_name)
            else:
                remote_blob = oneflow_api.EagerConsistentBlob(
                    lbi, blob_object, blob_register, job_name)

            Yield(remote_blob)
Esempio n. 3
0
def _Watch(op_attribute, parallel_conf, blob_register):
    lbi = op_attribute.arg_signature.bn_in_op2lbi["in"]
    uuid = op_attribute.op_conf.foreign_watch_conf.handler_uuid
    lbn = "%s/%s" % (lbi.op_name, lbi.blob_name)
    in_blob_object = blob_register.GetObject4BlobName(lbn)
    if not isinstance(lbi, lbi_util.LogicalBlobId):
        cfg_lbi = lbi_util.LogicalBlobId()
        cfg_lbi.set_op_name(lbi.op_name)
        cfg_lbi.set_blob_name(lbi.blob_name)
        lbi = cfg_lbi
    if in_blob_object.op_arg_parallel_attr.is_mirrored():
        blob = oneflow_api.EagerMirroredBlob(lbi, in_blob_object,
                                             default_blob_register)
    else:
        blob = oneflow_api.EagerConsistentBlob(lbi, in_blob_object,
                                               default_blob_register)
    uuid2watch_handler = session_ctx.GetDefaultSession().uuid2watch_handler
    assert uuid in uuid2watch_handler
    uuid2watch_handler[uuid](blob)
    del uuid2watch_handler[uuid]
 def build(builder):
     blob_object = GetEagerInterfaceBlob(op_name).blob_object
     lbi = lbi_util.LogicalBlobId()
     lbi.set_op_name(op_name)
     op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
     assert len(op_attribute.output_bns) == 1
     lbi.set_blob_name(op_attribute.output_bns[0])
     if not isinstance(lbi, lbi_util.LogicalBlobId):
         cfg_lbi = lbi_util.LogicalBlobId()
         cfg_lbi.set_op_name(lbi.op_name)
         cfg_lbi.set_blob_name(lbi.blob_name)
         lbi = cfg_lbi
     if blob_object.op_arg_parallel_attr.is_mirrored():
         remote_blob = oneflow_api.EagerMirroredBlob(
             lbi, blob_object, blob_register, job_name)
     else:
         remote_blob = oneflow_api.EagerConsistentBlob(
             lbi, blob_object, blob_register, job_name)
     value = remote_blob.numpy()
     Yield(value)