Example #1
0
    def test_multi_instance(self):
        NUM_INSTANCES = 10
        NUM_ITERS = 15
        with TaskGroup() as tg:
            with Task(num_instances=NUM_INSTANCES):
                with ops.task_init():
                    counter1 = ops.CreateCounter([], ['global_counter'])
                    counter2 = ops.CreateCounter([], ['global_counter2'])
                    counter3 = ops.CreateCounter([], ['global_counter3'])
                # both task_counter and local_counter should be thread local
                with ops.task_instance_init():
                    task_counter = ops.CreateCounter([], ['task_counter'])
                local_counter = ops.CreateCounter([], ['local_counter'])
                with ops.loop(NUM_ITERS):
                    ops.CountUp(counter1)
                    ops.CountUp(task_counter)
                    ops.CountUp(local_counter)
                # gather sum of squares of local counters to make sure that
                # each local counter counted exactly up to NUM_ITERS, and
                # that there was no false sharing of counter instances.
                with ops.task_instance_exit():
                    count2 = ops.RetrieveCount(task_counter)
                    with ops.loop(ops.Mul([count2, count2])):
                        ops.CountUp(counter2)
                # This should have the same effect as the above
                count3 = ops.RetrieveCount(local_counter)
                with ops.loop(ops.Mul([count3, count3])):
                    ops.CountUp(counter3)
                # The code below will only run once
                with ops.task_exit():
                    total1 = final_output(ops.RetrieveCount(counter1))
                    total2 = final_output(ops.RetrieveCount(counter2))
                    total3 = final_output(ops.RetrieveCount(counter3))

        with LocalSession() as session:
            session.run(tg)
            self.assertEquals(total1.fetch(), NUM_INSTANCES * NUM_ITERS)
            self.assertEquals(total2.fetch(), NUM_INSTANCES * (NUM_ITERS**2))
            self.assertEquals(total3.fetch(), NUM_INSTANCES * (NUM_ITERS**2))
Example #2
0
 def proc(rec):
     # executed once
     with ops.task_init():
         counter1 = ops.CreateCounter([], ['global_counter'])
         counter2 = ops.CreateCounter([], ['global_counter2'])
         counter3 = ops.CreateCounter([], ['global_counter3'])
     # executed once per thread
     with ops.task_instance_init():
         task_counter = ops.CreateCounter([], ['task_counter'])
     # executed on each iteration
     ops.CountUp(counter1)
     ops.CountUp(task_counter)
     # executed once per thread
     with ops.task_instance_exit():
         with ops.loop(ops.RetrieveCount(task_counter)):
             ops.CountUp(counter2)
         ops.CountUp(counter3)
     # executed once
     with ops.task_exit():
         totals[0] = final_output(ops.RetrieveCount(counter1))
         totals[1] = final_output(ops.RetrieveCount(counter2))
         totals[2] = final_output(ops.RetrieveCount(counter3))
     return rec