def lazy_system_assign(ref, value, validate_shape=None, use_locking=None, name=None): op_conf = _SystemAssignOpConf(ref, value, name=name) device_tag, machine_device_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds( ref.parallel_conf) with oneflow.scope.placement(device_tag, machine_device_ids): interpret_util.Forward(op_conf) return ref
def LazyConsistentWatch(blob_watched, handler): handler_uuid = str(uuid.uuid1()) op_conf = op_conf_util.OperatorConf() op_conf.name = id_util.UniqueStr("ForeignWatch_") setattr(op_conf.foreign_watch_conf, "in", blob_watched.unique_name) op_conf.foreign_watch_conf.handler_uuid = handler_uuid device_name = blob_watched.parallel_conf.device_name[0] tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds( blob_watched.parallel_conf) with oneflow.scope.placement(*tag_and_dev_ids): compile_context.CurJobAddOp(op_conf) watcher_util.BindUuidAndHandler(handler_uuid, blob_watched, handler)
def InterpretScope(session, function_desc, config_proto): job_conf = function_desc.job_config_proto job_conf.job_name = function_desc.job_func.__name__ placement_scope = function_desc.function_attribute.default_placement_scope if placement_scope is None: tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds( oneflow.env.current_resource()) placement_scope = placement_util.GetPlacementScope(*tag_and_dev_ids) distribute_strategy = function_desc.function_attribute.default_distribute_strategy if distribute_strategy is None: distribute_strategy = distribute_util.DistributeConsistentStrategy() is_mirrored = isinstance(distribute_strategy, distribute_util.DistributeMirroredStrategy) tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds( placement_scope.default_parallel_conf) scope = MakeInitialScope(job_conf, *tag_and_dev_ids, is_mirrored) with _JobBuildAndInferCtx( job_conf.job_name), placement_scope, distribute_strategy: c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE): with _SessionInitialScope(session, scope): yield
def BuildWithNewParallelConf(self, instruction_builder, parallel_conf): tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds( parallel_conf ) return self.BuildWithNewParallelDesc(instruction_builder, *tag_and_dev_ids)