def test_reader_with_limit(self): ws = workspace.C.Workspace() session = LocalSession(ws) """ 1. feed full dataset """ src_init = core.Net('src_init') src_values = Struct(('label', np.array(range(100)))) src_blobs = NewRecord(src_init, src_values) src_ds = Dataset(src_blobs) FeedRecord(src_blobs, src_values, ws) ws.run(src_init) """ 2. Read with limit smaller than size of dataset """ dst_init = core.Net('dst_init') dst_ds = Dataset(src_values.clone_schema()) dst_ds.init_empty(dst_init) ws.run(dst_init) with TaskGroup() as tg: reader = ReaderWithLimit(src_ds.reader(), num_iter=10) pipe(reader, dst_ds.writer(), num_threads=8) session.run(tg) self.assertFalse(ws.blobs[str(reader.data_finished())].fetch()) self.assertEquals( sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(10)) """ 3. Read with limit larger than size of dataset """ ws.run(dst_init) with TaskGroup() as tg: reader = ReaderWithLimit(src_ds.reader(), num_iter=110) pipe(reader, dst_ds.writer(), num_threads=8) session.run(tg) self.assertEquals( sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(100)) self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())
def test_reader_with_limit(self): ws = workspace.C.Workspace() session = LocalSession(ws) """ 1. feed full dataset """ src_ds = init_dataset(ws) """ 2. Read with limit smaller than size of dataset """ dst_init = core.Net('dst_init') with core.NameScope('dst'): dst_ds = Dataset(src_ds.content().clone_schema()) dst_ds.init_empty(dst_init) ws.run(dst_init) # WorkspaceType.GLOBAL is required because we are fetching # reader.data_finished() after the TaskGroup finishes. with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: reader = ReaderWithLimit(src_ds.reader(), num_iter=10) pipe(reader, dst_ds.writer(), num_threads=8) session.run(tg) self.assertFalse(ws.blobs[str(reader.data_finished())].fetch()) self.assertEquals( sorted(ws.blobs[str(dst_ds.content().label())].fetch()), list(range(10)) ) """ 3. Read with limit larger than size of dataset """ ws.run(dst_init) with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: reader = ReaderWithLimit(src_ds.reader(), num_iter=110) pipe(reader, dst_ds.writer(), num_runtime_threads=8) session.run(tg) self.assertEquals( sorted(ws.blobs[str(dst_ds.content().label())].fetch()), list(range(100)) ) self.assertTrue(ws.blobs[str(reader.data_finished())].fetch()) """ 4. Read without counter """ ws.run(dst_init) with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: reader = ReaderWithLimit(src_ds.reader(), num_iter=None) pipe(reader, dst_ds.writer(), num_threads=8) session.run(tg) self.assertEquals( sorted(ws.blobs[str(dst_ds.content().label())].fetch()), list(range(100)) ) self.assertTrue(ws.blobs[str(reader.data_finished())].fetch()) """ 5. Read using the same reader without resetting workspace """ session.run(tg) self.assertEquals( sorted(ws.blobs[str(dst_ds.content().label())].fetch()), sorted(list(range(100)) * 2) )
def __init__(self, init_group=None, epoch_group=None, exit_group=None, stop_signals=None, nodes_to_checkpoint=None): self.init_group = init_group or TaskGroup( workspace_type=WorkspaceType.GLOBAL) self.epoch_group = epoch_group or TaskGroup() self.exit_group = exit_group or TaskGroup() self.stop_signals = stop_signals or [] self._nodes_to_checkpoint = nodes_to_checkpoint
def __exit__(self, etype, *args): if etype is None: step = core.to_execution_step(self) step.RunEveryMillis(self.interval_ms) if self._net: self._net.add_attribute(Task.REPORT_STEP, step) else: TaskGroup.current().report_step( step, interval_ms=self.interval_ms) NetBuilder.__exit__(self, etype, *args)
def __init__(self, init_group=None, epoch_group=None, download_group=None, exit_group=None, stop_conditions=None, nodes_to_checkpoint=None): self.init_group = init_group or TaskGroup( workspace_type=WorkspaceType.GLOBAL) self.epoch_group = epoch_group or TaskGroup() self.download_group = download_group or TaskGroup() self.exit_group = exit_group or TaskGroup() self.stop_conditions = stop_conditions or [] self._nodes_to_checkpoint = nodes_to_checkpoint
def compile(cls, runnable): if isinstance(runnable, CompiledRunnable): assert cls == runnable.session_class, ( 'Runnable was compiled for different session type. ' + 'Need: %s, got: %s' % (cls.__name__, runnable.session_class.__name__)) return runnable if runnable in cls._compiled_cache: return cls._compiled_cache[runnable] if isinstance(runnable, TaskGroup): tg = runnable else: tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL) if isinstance(runnable, Task): tg.add(runnable) elif isinstance(runnable, core.ExecutionStep): tg.add(Task(step=runnable)) else: step = core.execution_step('runnable', runnable) tg.add(Task(step=step)) compiled = CompiledRunnable(cls._compile_task_group(tg), session_class=cls) cls._compiled_cache[runnable] = compiled return compiled
def test_runtime_threads(self): ws = workspace.C.Workspace() session = LocalSession(ws) src_ds = init_dataset(ws) totals = [None] * 3 def proc(rec): # executed once with ops.task_init(): counter1 = ops.CreateCounter([], ['global_counter']) counter2 = ops.CreateCounter([], ['global_counter2']) counter3 = ops.CreateCounter([], ['global_counter3']) # executed once per thread with ops.task_instance_init(): task_counter = ops.CreateCounter([], ['task_counter']) # executed on each iteration ops.CountUp(counter1) ops.CountUp(task_counter) # executed once per thread with ops.task_instance_exit(): with ops.loop(ops.RetrieveCount(task_counter)): ops.CountUp(counter2) ops.CountUp(counter3) # executed once with ops.task_exit(): totals[0] = final_output(ops.RetrieveCount(counter1)) totals[1] = final_output(ops.RetrieveCount(counter2)) totals[2] = final_output(ops.RetrieveCount(counter3)) return rec """ 1. Feed full dataset """ with TaskGroup() as tg: pipe(src_ds.reader(), num_runtime_threads=8, processor=proc) session.run(tg) self.assertEquals(totals[0].fetch(), 100) self.assertEquals(totals[1].fetch(), 100) self.assertEquals(totals[2].fetch(), 8) """ 2. Add a few steps in between """ with TaskGroup() as tg: q1 = pipe(src_ds.reader(), num_runtime_threads=2) q2 = pipe( ReaderWithLimit(q1.reader(), num_iter=25), num_runtime_threads=3) pipe(q2, processor=proc, num_runtime_threads=6) session.run(tg) self.assertEquals(totals[0].fetch(), 25) self.assertEquals(totals[1].fetch(), 25) self.assertEquals(totals[2].fetch(), 6)
def build_cache_step(self, overwrite=False): """Build a step for generating cache DB file. If self.db_path exists and not overwritting, build an empty step. Overwise, build a step as follows. Pipe original reader to the _DatasetWriter, so that dataset field blobs are populated. Then save these blobs into a file. Args: overwrite: bool. If true, ignore the existing file and build a new one overwritting the existing one anyway. Returns: build_cache_step: ExecutionStep. The step to be run for building a cache DB file. """ if os.path.exists(self.db_path) and not overwrite: # cache already exists, no need to rebuild it return core.execution_step('build_step', []) init_net = core.Net('init') self._init_field_blobs_as_empty(init_net) with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg: pipe(self.original_reader, self.ds.writer(), num_threads=16) copy_step = copy_tg.to_task().get_step() save_net = core.Net('save') self._save_field_blobs_to_db_file(save_net) return core.execution_step('build_cache', [init_net, copy_step, save_net])
def _task_group(self, func, *args, **kw): assert self._node_managers is not None, 'init must be called first.' with TaskGroup(WorkspaceType.GLOBAL) as task_group: for node, manager in self._node_managers: with Node(node): func(manager, *args, **kw) return task_group
def test_composite_reader(self): ws = workspace.C.Workspace() session = LocalSession(ws) num_srcs = 3 names = ["src_{}".format(i) for i in range(num_srcs)] size = 100 offsets = [i * size for i in range(num_srcs)] src_dses = [ make_source_dataset(ws, offset=offset, size=size, name=name) for (name, offset) in zip(names, offsets) ] data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses] # Sanity check we didn't overwrite anything for d, offset in zip(data, offsets): npt.assert_array_equal(d, range(offset, offset + size)) # Make an identically-sized empty destnation dataset dst_ds_schema = schema.Struct( *[(name, src_ds.content().clone_schema()) for name, src_ds in zip(names, src_dses)]) dst_ds = make_destination_dataset(ws, dst_ds_schema) with TaskGroup() as tg: reader = CompositeReader(names, [src_ds.reader() for src_ds in src_dses]) pipe(reader, dst_ds.writer(), num_runtime_threads=3) session.run(tg) for i in range(num_srcs): written_data = sorted( ws.fetch_blob(str(dst_ds.content()[names[i]].label()))) npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
def test_composite_reader_builder(self): ws = workspace.C.Workspace() session = LocalSession(ws) num_srcs = 3 names = ["src_{}".format(i) for i in range(num_srcs)] size = 100 offsets = [i * size for i in range(num_srcs)] src_ds_builders = [ TestReaderBuilder(offset=offset, size=size, name=name) for (name, offset) in zip(names, offsets) ] # Make an identically-sized empty destnation dataset dst_ds_schema = schema.Struct( *[(name, src_ds_builder.schema()) for name, src_ds_builder in zip(names, src_ds_builders)]) dst_ds = make_destination_dataset(ws, dst_ds_schema) with TaskGroup() as tg: reader_builder = CompositeReaderBuilder(names, src_ds_builders) reader_builder.setup(ws=ws) pipe(reader_builder.new_reader(), dst_ds.writer(), num_runtime_threads=3) session.run(tg) for name, offset in zip(names, offsets): written_data = sorted( ws.fetch_blob(str(dst_ds.content()[name].label()))) npt.assert_array_equal(range(offset, offset + size), written_data, "name: {}".format(name))
def _test_limit_reader_shared(self, reader_class, size, expected_read_len, expected_finish, num_threads, read_delay, **limiter_args): ws, session, src_ds, dst_ds = \ self._test_limit_reader_init_shared(size) # Read without limiter # WorkspaceType.GLOBAL is required because we are fetching # reader.data_finished() after the TaskGroup finishes. with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: if read_delay > 0: reader = reader_class( ReaderWithDelay(src_ds.reader(), read_delay), **limiter_args) else: reader = reader_class(src_ds.reader(), **limiter_args) pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads) session.run(tg) read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch())) self.assertEqual(read_len, expected_read_len) self.assertEqual( sorted(ws.blobs[str(dst_ds.content().label())].fetch()), list(range(expected_read_len))) self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(), expected_finish)
def init( self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None ): if self._node_managers is not None: assert [node for node, _ in self._node_managers] == nodes return TaskGroup(WorkspaceType.GLOBAL) self._node_managers = [] self._path_prefix = path_prefix self._path_type = path_type self._node_names = [str(node) for node in nodes] if self._metadata_handler: self._metadata_handler.init( db_prefix=self._db_prefix, db_type=self._db_type, node_names=self._node_names, path_prefix=self._path_prefix, path_type=self._path_type) for node in nodes: with Node(node): manager = CheckpointManager( db_prefix=self._db_prefix, node_name=str(node), db_type=self._db_type) self._node_managers.append((node, manager)) return self._task_group( CheckpointManager.init, nodes=[node], retrieve_from_epoch=retrieve_from_epoch, path_prefix=path_prefix, path_type=path_type)
def _read_all_data(ws, reader, session): dst_ds = make_destination_dataset(ws, reader.schema().clone_schema()) with TaskGroup() as tg: pipe(reader, dst_ds.writer(), num_runtime_threads=8) session.run(tg) return ws.blobs[str(dst_ds.content().label())].fetch()
def compile(cls, runnable): if isinstance(runnable, CompiledRunnable): assert cls == runnable.session_class, ( 'Runnable was compiled for different session type. ' + 'Need: %s, got: %s' % ( cls.__name__, runnable.session_class.__name__)) return runnable if runnable in cls._compiled_cache: return cls._compiled_cache[runnable] if isinstance(runnable, TaskGroup): tg = runnable else: tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL) if isinstance(runnable, Task): tg.add(runnable) elif isinstance(runnable, core.ExecutionStep): tg.add(Task(step=runnable)) else: step = core.execution_step('runnable', runnable) tg.add(Task(step=step)) compiled = CompiledRunnable( cls._compile_task_group(tg), session_class=cls) cls._compiled_cache[runnable] = compiled return compiled
def test_dequeue_many(self): init_net = core.Net('init') N = 17 NUM_DEQUEUE_RECORDS = 3 src_values = Struct( ('uid', np.array(range(N))), ('value', 0.1 * np.array(range(N)))) expected_dst = Struct( ('uid', 2 * np.array(range(N))), ('value', np.array(N * [0.0]))) with core.NameScope('init'): src_blobs = NewRecord(init_net, src_values) dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema()) counter = init_net.Const(0) ONE = init_net.Const(1) def proc1(rec): with core.NameScope('proc1'): out = NewRecord(ops, rec) ops.Add([rec.uid(), rec.uid()], [out.uid()]) out.value.set(blob=rec.value(), unsafe=True) return out def proc2(rec): with core.NameScope('proc2'): out = NewRecord(ops, rec) out.uid.set(blob=rec.uid(), unsafe=True) ops.Sub([rec.value(), rec.value()], [out.value()]) ops.Add([counter, ONE], [counter]) return out src_ds = Dataset(src_blobs) dst_ds = Dataset(dst_blobs) with TaskGroup() as tg: out1 = pipe( src_ds.reader(), output=Queue( capacity=11, num_dequeue_records=NUM_DEQUEUE_RECORDS), processor=proc1) out2 = pipe(out1, processor=proc2) pipe(out2, dst_ds.writer()) ws = workspace.C.Workspace() FeedRecord(src_blobs, src_values, ws) session = LocalSession(ws) session.run(init_net) session.run(tg) output = FetchRecord(dst_blobs, ws=ws) num_dequeues = ws.blobs[str(counter)].fetch() self.assertEquals( num_dequeues, int(math.ceil(float(N) / NUM_DEQUEUE_RECORDS))) for a, b in zip(output.field_blobs(), expected_dst.field_blobs()): np.testing.assert_array_equal(a, b)
def build(self, epoch, checkpoint_manager): with TaskGroup(WorkspaceType.GLOBAL) as upload_task_group: for node, manager in checkpoint_manager._node_managers: with Node(str(node)), Task(): src_path = db_name(epoch, manager._node_name, manager._db_prefix) dest_path = os.path.join(self.dest_dir, str(node)) ops.Python((local_copy_op, [src_path, dest_path], {}))([], []) return upload_task_group
def read_all_data(ws, reader, session): dst_init = core.Net('dst_init') with core.NameScope('dst'): dst_ds = Dataset(reader.schema().clone_schema()) dst_ds.init_empty(dst_init) session.run(dst_init) with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: pipe(reader, dst_ds.writer(), num_runtime_threads=8) session.run(tg) return ws.blobs[str(dst_ds.content().label())].fetch()
def _task_group(self, func, *args, **kw): assert self._node_managers is not None, 'init must be called first.' with TaskGroup(WorkspaceType.GLOBAL) as task_group: for node, manager in self._node_managers: # TODO(aartibasant, T21070353): Enable the checkpoints for # readers. # The checkpointing for readers is broken because of D5582328. # Disabling the reader checkpoints until it is fixed. if "reader" in str(node): continue with Node(node): func(manager, *args, **kw) return task_group
def run(self, runnable): assert self.is_open(), 'Session is closed.' if runnable not in self._runnable_cache: if isinstance(runnable, TaskGroup): tg = runnable else: tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL) if isinstance(runnable, Task): tg.add(runnable) elif isinstance(runnable, core.ExecutionStep): tg.add(Task(step=runnable)) else: step = core.execution_step('runnable', runnable) tg.add(Task(step=step)) self._runnable_cache[runnable] = tg self._run_task_group(self._runnable_cache[runnable])
def load_blobs_locally(self, blob_names, epoch, session): """Loads the necessary blobs from the checkpoints to the current node. Args: blob_names: A list of strings. Each string is the name of a blob. epoch: An integer. The checkpoint epoch to load from. session: A Session object to execute the Load ops. """ assert self._node_managers is not None, 'init must be called first.' for _, manager in self._node_managers: with TaskGroup(WorkspaceType.GLOBAL) as task_group: manager.load_blobs_from_checkpoint(blob_names, epoch) session.run(task_group)
def build_cache(self, cache_path, overwrite=False): if not self.has_cache() or overwrite: self.cache_path = cache_path if self.has_cache() and not overwrite: # cache already exists, no need to rebuild it return core.execution_step('build_step', []) init_net = core.Net('init') self._init_dataset(init_net) with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg: pipe(self.original_reader, self.ds.writer(), num_threads=16) copy_step = copy_tg.to_task().get_step() save_net = core.Net('save') self._save_to_file(save_net) return core.execution_step('build_cache', [init_net, copy_step, save_net])
def test_local_session(self): init_net = core.Net('init') src_values = Struct( ('uid', np.array([1, 2, 6])), ('value', np.array([1.4, 1.6, 1.7]))) expected_dst = Struct( ('uid', np.array([2, 4, 12])), ('value', np.array([0.0, 0.0, 0.0]))) with core.NameScope('init'): src_blobs = NewRecord(init_net, src_values) dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema()) def proc1(rec): net = core.Net('proc1') with core.NameScope('proc1'): out = NewRecord(net, rec) net.Add([rec.uid(), rec.uid()], [out.uid()]) out.value.set(blob=rec.value(), unsafe=True) return [net], out def proc2(rec): net = core.Net('proc2') with core.NameScope('proc2'): out = NewRecord(net, rec) out.uid.set(blob=rec.uid(), unsafe=True) net.Sub([rec.value(), rec.value()], [out.value()]) return [net], out src_ds = Dataset(src_blobs) dst_ds = Dataset(dst_blobs) with TaskGroup() as tg: out1 = pipe(src_ds.reader(), processor=proc1) out2 = pipe(out1, processor=proc2) pipe(out2, dst_ds.writer()) ws = workspace.C.Workspace() FeedRecord(src_blobs, src_values, ws) session = LocalSession(ws) session.run(init_net) session.run(tg) output = FetchRecord(dst_blobs, ws=ws) for a, b in zip(output.field_blobs(), expected_dst.field_blobs()): np.testing.assert_array_equal(a, b)
def test_multi_instance(self): NUM_INSTANCES = 10 NUM_ITERS = 15 with TaskGroup() as tg: with Task(num_instances=NUM_INSTANCES): with ops.task_init(): counter1 = ops.CreateCounter([], ['global_counter']) counter2 = ops.CreateCounter([], ['global_counter2']) counter3 = ops.CreateCounter([], ['global_counter3']) # both task_counter and local_counter should be thread local with ops.task_instance_init(): task_counter = ops.CreateCounter([], ['task_counter']) local_counter = ops.CreateCounter([], ['local_counter']) with ops.loop(NUM_ITERS): ops.CountUp(counter1) ops.CountUp(task_counter) ops.CountUp(local_counter) # gather sum of squares of local counters to make sure that # each local counter counted exactly up to NUM_ITERS, and # that there was no false sharing of counter instances. with ops.task_instance_exit(): count2 = ops.RetrieveCount(task_counter) with ops.loop(ops.Mul([count2, count2])): ops.CountUp(counter2) # This should have the same effect as the above count3 = ops.RetrieveCount(local_counter) with ops.loop(ops.Mul([count3, count3])): ops.CountUp(counter3) # The code below will only run once with ops.task_exit(): total1 = final_output(ops.RetrieveCount(counter1)) total2 = final_output(ops.RetrieveCount(counter2)) total3 = final_output(ops.RetrieveCount(counter3)) with LocalSession() as session: session.run(tg) self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS) self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS**2)) self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS**2))
def load_blobs_locally(self, nodes, blob_names, epoch, session): """Loads the necessary blobs from the checkpoints to the current node. Args: blob_names: A list of strings. Each string is the name of a blob. epoch: An integer. The checkpoint epoch to load from. session: A Session object to execute the Load ops. """ if self._node_managers is not None: assert [node for node, _ in self._node_managers] == nodes else: self._node_managers = [] for node in nodes: with Node(node): manager = self._node_manager_class(db=os.path.join( self._db_prefix, node), db_type=self._db_type) self._node_managers.append((node, manager)) assert self._node_managers is not None, 'must initialize node managers' for _, manager in self._node_managers: with TaskGroup(WorkspaceType.GLOBAL) as task_group: manager.load_blobs_from_checkpoint(blob_names, epoch) session.run(task_group)
def compile(cls, runnable, workspace_type=None, setup_net_list=None): if isinstance(runnable, CompiledRunnable): assert cls == runnable.session_class, ( 'Runnable was compiled for different session type. ' + 'Need: %s, got: %s' % ( cls.__name__, runnable.session_class.__name__)) return runnable if runnable in cls._compiled_cache: return cls._compiled_cache[runnable] if isinstance(runnable, TaskGroup): if workspace_type: if runnable.workspace_type(): assert runnable.workspace_type() == workspace_type, \ "Require {} but already have {}".format( workspace_type, runnable.workspace_type()) else: runnable._workspace_type = workspace_type tg = runnable else: if workspace_type is None: workspace_type = WorkspaceType.GLOBAL tg = TaskGroup(workspace_type=workspace_type) if isinstance(runnable, Task): tg.add(runnable) elif isinstance(runnable, core.ExecutionStep): tg.add(Task(step=runnable)) elif isinstance(runnable, core.Plan): # ExecutionSteps in Plan() object is supposed to run sequentially, while # tasks in TaskGroup run in parallel. So if we have multiple # ExecutionSteps in Plan() object, we choose to have a root # ExecutionStep to wrap all ExecutionSteps. assert len(runnable.Steps()) > 0 if len(runnable.Steps()) == 1: tg.add(Task(step=runnable.Steps()[0])) else: # Task takes a list of ExecutionSteps and automatically wrap into # a root ExecutionStep tg.add(Task(step=runnable.Steps())) else: step = core.execution_step('runnable', runnable) tg.add(Task(step=step)) compiled = CompiledRunnable( cls._compile_task_group(tg, setup_net_list), session_class=cls) cls._compiled_cache[runnable] = compiled return compiled
class Job(object): """ A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the `exit_group` which will be run by a JobRunner. The `init_group` will be run only once at startup. Its role is to initialize globally persistent blobs such as model weights, accumulators and data file lists. The `epoch_group` will be run in a loop after init_group. The loop will exit when any of the stop signals added with `add_stop_signal` is True at the end of an epoch. The `exit_group` will be run only once at the very end of the job, when one of the stopping criterias for `epoch_group` was met. The role of this group is save the results of training in the end of the job. Jobs are context-driven, so that Tasks can be added to the active Job without having to explicitly pass the job object around. Example of usage: def build_reader(partitions): with Job.current().init_group: reader = HiveReader(init_reader, ..., partitions) Task(step=init_reader) with Job.current().epoch_group: limited_reader = ReaderWithLimit(reader, num_iter=10000) data_queue = pipe(limited_reader, num_threads=8) Job.current().add_stop_signal(limited_reader.data_finished()) return data_queue def build_hogwild_trainer(reader, model): with Job.current().init_group: Task(step=model.param_init_net) with Job.current().epoch_group: pipe(reader, processor=model, num_threads=8) with Job.current().exit_group: Task(step=model.save_model_net) with Job() as job: reader = build_reader(partitions) model = build_model(params) build_hogwild_trainer(reader, model) """ def __init__(self): self.init_group = TaskGroup(workspace_type=WorkspaceType.GLOBAL) self.epoch_group = TaskGroup() self.exit_group = TaskGroup() self.stop_signals = [] def __enter__(self): self.epoch_group.__enter__() return self def __exit__(self, *args): self.epoch_group.__exit__() def add_stop_signal(self, output): if isinstance(output, core.BlobReference): t = Task(outputs=[output], group=self.epoch_group) output = t.outputs()[0] assert isinstance(output, TaskOutput) self.stop_signals.append(output)
def __init__(self): self.init_group = TaskGroup(workspace_type=WorkspaceType.GLOBAL) self.epoch_group = TaskGroup() self.exit_group = TaskGroup() self.stop_signals = []
def _pipe_step(input, output=None, num_threads=1, processor=None, name=None, capacity=None, group=None, final_outputs=None): """ """ group = TaskGroup.current(group) if name is None: name = 'processor:%d' % group.num_registered_tasks() if isinstance(input, Reader): reader = input elif hasattr(input, 'reader'): reader = input.reader() else: raise ValueError('in must be a reader, queue or streaam.') if processor is not None: reader = ProcessingReader(reader, processor) if num_threads == 0: assert output is None return reader, None global_exit_net = core.Net(name + '_producer_global_exit') global_init_net = core.Net(name + '_producer_global_init') out_queue = None writer = None reader.setup_ex(global_init_net, global_exit_net) steps = [] for thread_id in range(num_threads): init_net = core.Net(name + "_init_net_%d" % thread_id) exit_net = core.Net(name + "_exit_net_%d" % thread_id) read_nets, status, rec = reader.read_record_ex(init_net, exit_net) if rec is not None: if writer is None: out_queue, writer = _init_output(output, capacity, global_init_net, global_exit_net) write_nets, _ = writer.write_record_ex(rec, init_net, exit_net, status) else: write_nets = [] step = core.execution_step(name + "_thread_%d" % thread_id, [ core.execution_step(name + "_init_step", init_net), core.execution_step(name + "_worker_step", list(read_nets) + list(write_nets), should_stop_blob=status), core.execution_step(name + "_exit_step", exit_net) ]) steps.append(step) step = core.execution_step("sender_step", [ core.execution_step('init_step', global_init_net), core.execution_step("sender_steps", steps, concurrent_substeps=True), core.execution_step('finish_step', global_exit_net), ]) return out_queue, step
def compile(cls, runnable, workspace_type=None, setup_net_list=None): if isinstance(runnable, CompiledRunnable): assert cls == runnable.session_class, ( 'Runnable was compiled for different session type. ' + 'Need: %s, got: %s' % (cls.__name__, runnable.session_class.__name__)) return runnable if runnable in cls._compiled_cache: return cls._compiled_cache[runnable] if isinstance(runnable, TaskGroup): if workspace_type: if runnable.workspace_type(): assert runnable.workspace_type() == workspace_type, \ "Require {} but already have {}".format( workspace_type, runnable.workspace_type()) else: runnable._workspace_type = workspace_type tg = runnable else: if workspace_type is None: workspace_type = WorkspaceType.GLOBAL tg = TaskGroup(workspace_type=workspace_type) if isinstance(runnable, Task): tg.add(runnable) elif isinstance(runnable, core.ExecutionStep): tg.add(Task(step=runnable)) elif isinstance(runnable, core.Plan): # ExecutionSteps in Plan() object is supposed to run sequentially, while # tasks in TaskGroup run in parallel. So if we have multiple # ExecutionSteps in Plan() object, we choose to have a root # ExecutionStep to wrap all ExecutionSteps. assert len(runnable.Steps()) > 0 if len(runnable.Steps()) == 1: tg.add(Task(step=runnable.Steps()[0])) else: # Task takes a list of ExecutionSteps and automatically wrap into # a root ExecutionStep tg.add(Task(step=runnable.Steps())) else: step = core.execution_step('runnable', runnable) tg.add(Task(step=step)) compiled = CompiledRunnable(cls._compile_task_group( tg, setup_net_list), session_class=cls) cls._compiled_cache[runnable] = compiled return compiled
def _pipe_step( input, output=None, num_threads=1, processor=None, name=None, capacity=None, group=None, final_outputs=None): """ """ group = TaskGroup.current(group) if name is None: name = 'processor:%d' % group.num_registered_tasks() if isinstance(input, Reader): reader = input elif hasattr(input, 'reader'): reader = input.reader() else: raise ValueError('in must be a reader, queue or streaam.') if processor is not None: reader = ProcessingReader(reader, processor) if num_threads == 0: assert output is None return reader, None global_exit_net = core.Net(name + '_producer_global_exit') global_init_net = core.Net(name + '_producer_global_init') out_queue = None writer = None reader.setup_ex(global_init_net, global_exit_net) steps = [] for thread_id in range(num_threads): init_net = core.Net(name + "_init_net_%d" % thread_id) exit_net = core.Net(name + "_exit_net_%d" % thread_id) read_nets, status, rec = reader.read_record_ex(init_net, exit_net) if rec is not None: if writer is None: out_queue, writer = _init_output( output, capacity, global_init_net, global_exit_net) write_nets, _ = writer.write_record_ex( rec, init_net, exit_net, status) else: write_nets = [] step = core.execution_step( name + "_thread_%d" % thread_id, [ core.execution_step(name + "_init_step", init_net), core.execution_step( name + "_worker_step", list(read_nets) + list(write_nets), should_stop_blob=status ), core.execution_step(name + "_exit_step", exit_net) ] ) steps.append(step) step = core.execution_step( "sender_step", [ core.execution_step('init_step', global_init_net), core.execution_step( "sender_steps", steps, concurrent_substeps=True), core.execution_step('finish_step', global_exit_net), ]) return out_queue, step