示例#1
0
def _MakeEagerLogicalBlob(op_attribute, obn, blob_register):
    lbi = op_attribute.arg_signature.bn_in_op2lbi[obn]
    blob_object = blob_register.GetObject4BlobName(
        "%s/%s" % (lbi.op_name, lbi.blob_name))
    mirrored_sig_map = op_attribute.mirrored_signature.bn_in_op2opt_mirrored_parallel
    if mirrored_sig_map[obn].HasField("mirrored_parallel"):
        return remote_blob_util.EagerMirroredBlob(lbi, blob_object)
    else:
        return remote_blob_util.EagerConsistentBlob(lbi, blob_object)
示例#2
0
def _Watch(op_attribute, parallel_conf, blob_register):
    lbi = op_attribute.arg_signature.bn_in_op2lbi["in"]
    uuid = op_attribute.op_conf.foreign_watch_conf.handler_uuid
    lbn = "%s/%s" % (lbi.op_name, lbi.blob_name)
    in_blob_object = blob_register.GetObject4BlobName(lbn)
    if in_blob_object.op_arg_parallel_attr.is_mirrored():
        blob = remote_blob_util.EagerMirroredBlob(lbi, in_blob_object)
    else:
        blob = remote_blob_util.EagerConsistentBlob(lbi, in_blob_object)
    uuid2watch_handler = session_ctx.GetDefaultSession().uuid2watch_handler
    assert uuid in uuid2watch_handler
    uuid2watch_handler[uuid](blob)
    del uuid2watch_handler[uuid]
示例#3
0
        def Build(builder, Yield):
            blob_object = _GetInterfaceBlobObject(builder, op_name)
            lbi = logical_blob_id_util.LogicalBlobId()
            lbi.op_name = op_name
            op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
            assert len(op_attribute.output_bns) == 1
            lbi.blob_name = op_attribute.output_bns[0]
            if blob_object.op_arg_parallel_attr.is_mirrored():
                remote_blob = remote_blob_util.EagerMirroredBlob(
                    lbi, blob_object, job_name)
            else:
                remote_blob = remote_blob_util.EagerConsistentBlob(
                    lbi, blob_object, job_name)

            Yield(remote_blob)
示例#4
0
 def build(builder):
     blob_object = GetEagerInterfaceBlob(op_name).blob_object
     lbi = logical_blob_id_util.LogicalBlobId()
     lbi.op_name = op_name
     op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
     assert len(op_attribute.output_bns) == 1
     lbi.blob_name = op_attribute.output_bns[0]
     if blob_object.op_arg_parallel_attr.is_mirrored():
         remote_blob = remote_blob_util.EagerMirroredBlob(
             lbi, blob_object, job_name)
     else:
         remote_blob = remote_blob_util.EagerConsistentBlob(
             lbi, blob_object, job_name)
     if blob_object.op_arg_blob_attr.is_tensor_list:
         value = remote_blob.numpy_list()
     else:
         value = remote_blob.numpy()
     Yield(value)