def parallel_cast(input, name=None, distribute=None, gradient_distribute=None): assert not oneflow.eager_execution_enabled() op_conf = op_conf_util.OperatorConf() setattr( op_conf, "name", name if name is not None else id_util.UniqueStr("ParallelCast_"), ) op_conf.parallel_cast_conf.out = "out" setattr(op_conf.parallel_cast_conf, "in", input.unique_name) def to_split_axis(dist): split_axis = data_type_util.OptInt64() if type(dist) is distribute_util.SplitDistribute: split_axis.value = dist.axis elif type(dist) is distribute_util.BroadcastDistribute: split_axis.ClearField("value") else: raise NotImplementedError return split_axis if distribute is not None: op_conf.parallel_cast_conf.split_axis.CopyFrom(to_split_axis(distribute)) if gradient_distribute is not None: op_conf.parallel_cast_conf.gradient_split_axis.CopyFrom( to_split_axis(gradient_distribute) ) compile_context.CurJobAddOp(op_conf) lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = op_conf.name lbi.blob_name = "out" return remote_blob_util.RemoteBlob(lbi)
def LazyReturnRemoteBlob(remote_blob, allow_cpu_return_op=True): assert isinstance( remote_blob, (oneflow_api.LazyMirroredBlob, oneflow_api.LazyConsistentBlob), ) op_conf, lbi, scope = _GetReturnOpConfAndOutLbiAndScope( remote_blob, allow_cpu_return_op) compile_context.CurJobAddOp(op_conf, scope) return remote_blob_util.RemoteBlob(lbi)
def LazyConsistentWatch(blob_watched, handler): handler_uuid = str(uuid.uuid1()) op_conf = op_conf_util.OperatorConf() op_conf.name = id_util.UniqueStr("ForeignWatch_") setattr(op_conf.foreign_watch_conf, "in", blob_watched.unique_name) op_conf.foreign_watch_conf.handler_uuid = handler_uuid device_name = blob_watched.parallel_conf.device_name(0) with oneflow.scope.placement("cpu", "0:0"): compile_context.CurJobAddOp(op_conf) watcher_util.BindUuidAndHandler(handler_uuid, blob_watched, handler)
def LazyConsistentWatch(blob_watched, handler): handler_uuid = str(uuid.uuid1()) op_conf = op_conf_util.OperatorConf() op_conf.name = id_util.UniqueStr("ForeignWatch_") setattr(op_conf.foreign_watch_conf, "in", blob_watched.unique_name) op_conf.foreign_watch_conf.handler_uuid = handler_uuid device_name = blob_watched.parallel_conf.device_name[0] tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds( blob_watched.parallel_conf) with oneflow.scope.placement(*tag_and_dev_ids): compile_context.CurJobAddOp(op_conf) watcher_util.BindUuidAndHandler(handler_uuid, blob_watched, handler)
def unpack(input, unpack_num, name=None): assert not oneflow.eager_execution_enabled() op_conf = op_conf_util.OperatorConf() setattr( op_conf, "name", name if name is not None else id_util.UniqueStr("Unpack_"), ) setattr(op_conf.unpack_conf, "in", input.unique_name) op_conf.unpack_conf.out = "out" op_conf.unpack_conf.unpack_num = unpack_num compile_context.CurJobAddOp(op_conf) lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = op_conf.name lbi.blob_name = "out" return remote_blob_util.RemoteBlob(lbi)
def acc(one, max_acc_num, name=None): assert not oneflow.eager_execution_enabled() op_conf = op_conf_util.OperatorConf() setattr( op_conf, "name", name if name is not None else id_util.UniqueStr("Acc_"), ) op_conf.acc_conf.one = one.unique_name op_conf.acc_conf.acc = "acc" op_conf.acc_conf.max_acc_num = max_acc_num compile_context.CurJobAddOp(op_conf) lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = op_conf.name lbi.blob_name = "acc" return remote_blob_util.RemoteBlob(lbi)
def compile(self, op_list): self._check_status(self.SessionStatus.OPEN) scope = flow.current_scope() device_tag = scope.device_parallel_desc_symbol.device_tag for op_conf in op_list: if _need_check_device_tag( op_conf) and op_conf.device_tag != device_tag: print( "WARNING: the device_tag of op {} is not equal to the device_tag of seesion's current scope" " ({} vs. {})" ", which may cause the op graph to be incompatible".format( op_conf.name, op_conf.device_tag, device_tag)) compile_ctx.CurJobAddOp(op_conf) oneflow_api.CurJobBuildAndInferCtx_Complete() oneflow_api.CurJobBuildAndInferCtx_Rebuild()
def InferAndTryRun(self): assert hob.in_global_mode(None) compile_context.CurJobAddOp(self.op_conf_) return self
def InferAndTryRun(self): compile_context.CurJobAddOp(self.op_conf_) return self