def test_multi_instance(self): NUM_INSTANCES = 10 NUM_ITERS = 15 with TaskGroup() as tg: with Task(num_instances=NUM_INSTANCES): with ops.task_init(): counter1 = ops.CreateCounter([], ['global_counter']) counter2 = ops.CreateCounter([], ['global_counter2']) counter3 = ops.CreateCounter([], ['global_counter3']) # both task_counter and local_counter should be thread local with ops.task_instance_init(): task_counter = ops.CreateCounter([], ['task_counter']) local_counter = ops.CreateCounter([], ['local_counter']) with ops.loop(NUM_ITERS): ops.CountUp(counter1) ops.CountUp(task_counter) ops.CountUp(local_counter) # gather sum of squares of local counters to make sure that # each local counter counted exactly up to NUM_ITERS, and # that there was no false sharing of counter instances. with ops.task_instance_exit(): count2 = ops.RetrieveCount(task_counter) with ops.loop(ops.Mul([count2, count2])): ops.CountUp(counter2) # This should have the same effect as the above count3 = ops.RetrieveCount(local_counter) with ops.loop(ops.Mul([count3, count3])): ops.CountUp(counter3) # The code below will only run once with ops.task_exit(): total1 = final_output(ops.RetrieveCount(counter1)) total2 = final_output(ops.RetrieveCount(counter2)) total3 = final_output(ops.RetrieveCount(counter3)) with LocalSession() as session: session.run(tg) self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS) self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS**2)) self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS**2))
def proc(rec): # executed once with ops.task_init(): counter1 = ops.CreateCounter([], ['global_counter']) counter2 = ops.CreateCounter([], ['global_counter2']) counter3 = ops.CreateCounter([], ['global_counter3']) # executed once per thread with ops.task_instance_init(): task_counter = ops.CreateCounter([], ['task_counter']) # executed on each iteration ops.CountUp(counter1) ops.CountUp(task_counter) # executed once per thread with ops.task_instance_exit(): with ops.loop(ops.RetrieveCount(task_counter)): ops.CountUp(counter2) ops.CountUp(counter3) # executed once with ops.task_exit(): totals[0] = final_output(ops.RetrieveCount(counter1)) totals[1] = final_output(ops.RetrieveCount(counter2)) totals[2] = final_output(ops.RetrieveCount(counter3)) return rec