Пример #1
0
def parallel_cast(input, name=None, distribute=None, gradient_distribute=None):
    assert not oneflow.eager_execution_enabled()
    op_conf = op_conf_util.OperatorConf()
    setattr(
        op_conf,
        "name",
        name if name is not None else id_util.UniqueStr("ParallelCast_"),
    )
    op_conf.parallel_cast_conf.out = "out"
    setattr(op_conf.parallel_cast_conf, "in", input.unique_name)

    def to_split_axis(dist):
        split_axis = data_type_util.OptInt64()
        if type(dist) is distribute_util.SplitDistribute:
            split_axis.value = dist.axis
        elif type(dist) is distribute_util.BroadcastDistribute:
            split_axis.ClearField("value")
        else:
            raise NotImplementedError
        return split_axis

    if distribute is not None:
        op_conf.parallel_cast_conf.split_axis.CopyFrom(to_split_axis(distribute))
    if gradient_distribute is not None:
        op_conf.parallel_cast_conf.gradient_split_axis.CopyFrom(
            to_split_axis(gradient_distribute)
        )
    compile_context.CurJobAddOp(op_conf)
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = "out"
    return remote_blob_util.RemoteBlob(lbi)
Пример #2
0
def LazyReturnRemoteBlob(remote_blob, allow_cpu_return_op=True):
    assert isinstance(
        remote_blob,
        (oneflow_api.LazyMirroredBlob, oneflow_api.LazyConsistentBlob),
    )
    op_conf, lbi, scope = _GetReturnOpConfAndOutLbiAndScope(
        remote_blob, allow_cpu_return_op)
    compile_context.CurJobAddOp(op_conf, scope)
    return remote_blob_util.RemoteBlob(lbi)
Пример #3
0
def LazyConsistentWatch(blob_watched, handler):
    handler_uuid = str(uuid.uuid1())
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = id_util.UniqueStr("ForeignWatch_")
    setattr(op_conf.foreign_watch_conf, "in", blob_watched.unique_name)
    op_conf.foreign_watch_conf.handler_uuid = handler_uuid
    device_name = blob_watched.parallel_conf.device_name(0)
    with oneflow.scope.placement("cpu", "0:0"):
        compile_context.CurJobAddOp(op_conf)
    watcher_util.BindUuidAndHandler(handler_uuid, blob_watched, handler)
Пример #4
0
def LazyConsistentWatch(blob_watched, handler):
    handler_uuid = str(uuid.uuid1())
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = id_util.UniqueStr("ForeignWatch_")
    setattr(op_conf.foreign_watch_conf, "in", blob_watched.unique_name)
    op_conf.foreign_watch_conf.handler_uuid = handler_uuid
    device_name = blob_watched.parallel_conf.device_name[0]
    tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds(
        blob_watched.parallel_conf)
    with oneflow.scope.placement(*tag_and_dev_ids):
        compile_context.CurJobAddOp(op_conf)
    watcher_util.BindUuidAndHandler(handler_uuid, blob_watched, handler)
Пример #5
0
def unpack(input, unpack_num, name=None):
    assert not oneflow.eager_execution_enabled()
    op_conf = op_conf_util.OperatorConf()
    setattr(
        op_conf, "name", name if name is not None else id_util.UniqueStr("Unpack_"),
    )
    setattr(op_conf.unpack_conf, "in", input.unique_name)
    op_conf.unpack_conf.out = "out"
    op_conf.unpack_conf.unpack_num = unpack_num
    compile_context.CurJobAddOp(op_conf)
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = "out"
    return remote_blob_util.RemoteBlob(lbi)
Пример #6
0
def acc(one, max_acc_num, name=None):
    assert not oneflow.eager_execution_enabled()
    op_conf = op_conf_util.OperatorConf()
    setattr(
        op_conf,
        "name",
        name if name is not None else id_util.UniqueStr("Acc_"),
    )
    op_conf.acc_conf.one = one.unique_name
    op_conf.acc_conf.acc = "acc"
    op_conf.acc_conf.max_acc_num = max_acc_num
    compile_context.CurJobAddOp(op_conf)
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = "acc"
    return remote_blob_util.RemoteBlob(lbi)
Пример #7
0
    def compile(self, op_list):
        self._check_status(self.SessionStatus.OPEN)
        scope = flow.current_scope()
        device_tag = scope.device_parallel_desc_symbol.device_tag
        for op_conf in op_list:
            if _need_check_device_tag(
                    op_conf) and op_conf.device_tag != device_tag:
                print(
                    "WARNING: the device_tag of op {} is not equal to the device_tag of seesion's current scope"
                    " ({} vs. {})"
                    ", which may cause the op graph to be incompatible".format(
                        op_conf.name, op_conf.device_tag, device_tag))

            compile_ctx.CurJobAddOp(op_conf)

        oneflow_api.CurJobBuildAndInferCtx_Complete()
        oneflow_api.CurJobBuildAndInferCtx_Rebuild()
Пример #8
0
 def InferAndTryRun(self):
     assert hob.in_global_mode(None)
     compile_context.CurJobAddOp(self.op_conf_)
     return self
Пример #9
0
 def InferAndTryRun(self):
     compile_context.CurJobAddOp(self.op_conf_)
     return self