def example_task(): with Task(): with ops.task_init(): one = ops.Const(1) two = ops.Add([one, one]) with ops.task_init(): three = ops.Const(3) accum = ops.Add([two, three]) # here, accum should be 5 with ops.task_exit(): # here, accum should be 6, since this executes after lines below seven_1 = ops.Add([accum, one]) six = ops.Add([accum, one]) ops.Add([accum, one], [accum]) seven_2 = ops.Add([accum, one]) o6 = final_output(six) o7_1 = final_output(seven_1) o7_2 = final_output(seven_2) with Task(num_instances=2): with ops.task_init(): one = ops.Const(1) with ops.task_instance_init(): local = ops.Const(2) ops.Add([one, local], [one]) ops.LogInfo('ble') return o6, o7_1, o7_2
def test_net_multi_use(self): with Task() as task: total = ops.Const(0) net = Net('my_net') net.Add([total, net.Const(1)], [total]) ops.net(net) ops.net(net) result = final_output(total) with LocalSession() as session: session.run(task) self.assertEquals(2, result.fetch())
def test_multi_instance(self): NUM_INSTANCES = 10 NUM_ITERS = 15 with TaskGroup() as tg: with Task(num_instances=NUM_INSTANCES): with ops.task_init(): counter1 = ops.CreateCounter([], ['global_counter']) counter2 = ops.CreateCounter([], ['global_counter2']) counter3 = ops.CreateCounter([], ['global_counter3']) # both task_counter and local_counter should be thread local with ops.task_instance_init(): task_counter = ops.CreateCounter([], ['task_counter']) local_counter = ops.CreateCounter([], ['local_counter']) with ops.loop(NUM_ITERS): ops.CountUp(counter1) ops.CountUp(task_counter) ops.CountUp(local_counter) # gather sum of squares of local counters to make sure that # each local counter counted exactly up to NUM_ITERS, and # that there was no false sharing of counter instances. with ops.task_instance_exit(): count2 = ops.RetrieveCount(task_counter) with ops.loop(ops.Mul([count2, count2])): ops.CountUp(counter2) # This should have the same effect as the above count3 = ops.RetrieveCount(local_counter) with ops.loop(ops.Mul([count3, count3])): ops.CountUp(counter3) # The code below will only run once with ops.task_exit(): total1 = final_output(ops.RetrieveCount(counter1)) total2 = final_output(ops.RetrieveCount(counter2)) total3 = final_output(ops.RetrieveCount(counter3)) with LocalSession() as session: session.run(tg) self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS) self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS**2)) self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS**2))
def test_multi_instance(self): NUM_INSTANCES = 10 NUM_ITERS = 15 with TaskGroup() as tg: with Task(num_instances=NUM_INSTANCES): with ops.task_init(): counter1 = ops.CreateCounter([], ['global_counter']) counter2 = ops.CreateCounter([], ['global_counter2']) counter3 = ops.CreateCounter([], ['global_counter3']) # both task_counter and local_counter should be thread local with ops.task_instance_init(): task_counter = ops.CreateCounter([], ['task_counter']) local_counter = ops.CreateCounter([], ['local_counter']) with ops.loop(NUM_ITERS): ops.CountUp(counter1) ops.CountUp(task_counter) ops.CountUp(local_counter) # gather sum of squares of local counters to make sure that # each local counter counted exactly up to NUM_ITERS, and # that there was no false sharing of counter instances. with ops.task_instance_exit(): count2 = ops.RetrieveCount(task_counter) with ops.loop(ops.Mul([count2, count2])): ops.CountUp(counter2) # This should have the same effect as the above count3 = ops.RetrieveCount(local_counter) with ops.loop(ops.Mul([count3, count3])): ops.CountUp(counter3) # The code below will only run once with ops.task_exit(): total1 = final_output(ops.RetrieveCount(counter1)) total2 = final_output(ops.RetrieveCount(counter2)) total3 = final_output(ops.RetrieveCount(counter3)) with LocalSession() as session: session.run(tg) self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS) self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2)) self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
def test_setup(self): with Task() as task: with ops.task_init(): one = ops.Const(1) two = ops.Add([one, one]) with ops.task_init(): three = ops.Const(3) accum = ops.Add([two, three]) # here, accum should be 5 with ops.task_exit(): # here, accum should be 6, since this executes after lines below seven_1 = ops.Add([accum, one]) six = ops.Add([accum, one]) ops.Add([accum, one], [accum]) seven_2 = ops.Add([accum, one]) o6 = final_output(six) o7_1 = final_output(seven_1) o7_2 = final_output(seven_2) with LocalSession() as session: session.run(task) self.assertEquals(o6.fetch(), 6) self.assertEquals(o7_1.fetch(), 7) self.assertEquals(o7_2.fetch(), 7)
def proc(rec): # executed once with ops.task_init(): counter1 = ops.CreateCounter([], ['global_counter']) counter2 = ops.CreateCounter([], ['global_counter2']) counter3 = ops.CreateCounter([], ['global_counter3']) # executed once per thread with ops.task_instance_init(): task_counter = ops.CreateCounter([], ['task_counter']) # executed on each iteration ops.CountUp(counter1) ops.CountUp(task_counter) # executed once per thread with ops.task_instance_exit(): with ops.loop(ops.RetrieveCount(task_counter)): ops.CountUp(counter2) ops.CountUp(counter3) # executed once with ops.task_exit(): totals[0] = final_output(ops.RetrieveCount(counter1)) totals[1] = final_output(ops.RetrieveCount(counter2)) totals[2] = final_output(ops.RetrieveCount(counter3)) return rec
def _actual_loop(self): total = ops.Const(0) total_large = ops.Const(0) total_small = ops.Const(0) total_tiny = ops.Const(0) with ops.loop(10) as loop: outer = ops.Mul([loop.iter(), ops.Const(10)]) with ops.loop(loop.iter()) as inner: val = ops.Add([outer, inner.iter()]) with ops.If(ops.GE([val, ops.Const(80)])) as c: ops.Add([total_large, val], [total_large]) with c.Elif(ops.GE([val, ops.Const(50)])) as c: ops.Add([total_small, val], [total_small]) with c.Else(): ops.Add([total_tiny, val], [total_tiny]) ops.Add([total, val], total) return [ final_output(x) for x in [total, total_large, total_small, total_tiny] ]
def _timed_task(self, cp_op_name, add_op): """ Build a Task that will measure the time span of checkpoint operations, once operation is done, time can be read from _current_checkpoint_duration. Args: cp_op_name: A string name of the checkpoint operation. add_op: A functor to add the checkpoint operation. Returns: A task with timer. """ with Task(name=cp_op_name) as task: with ops.task_init(): timer = ops.TimerBegin([], counter_name=self._node_name) add_op() with ops.task_exit(): time_span_blob = ops.TimerGetAndEnd(timer) self._current_checkpoint_duration = final_output(time_span_blob) return task