Esempio n. 1
0
    def _receiver(self):
        """Continuously receives futures from workers until the shutdown
        sentinel is set.
        """
        log.debug('Manager._receiver()')
        assert MPI.Is_thread_main() is False
        assert threading.currentThread().getName() == "receiver"

        while not self.shutItDown:

            # are we waiting on any results?
            while self.nPending:

                stat = MPI.Status()
                (tid, msg, val) = self.comm.recv(source=MPI.ANY_SOURCE,
                                                 tag=self.result_tag,
                                                 status=stat)
                log.debug('Manager._receiver received task: %s' % tid)

                # update future
                ft = self.pending_futures.pop(tid)
                if msg == 'exception':
                    ft._set_exception(*val)
                else:
                    ft._set_result(val)

                with self.lock:
                    self.dests.append(stat.Get_source())
                    self.nPending -= 1

            # force context switching ( 1ms )
            time.sleep(0.001)
Esempio n. 2
0
    def _dispatcher(self):
        """Continuously dispatches tasks to idle destinations until the
        shutdown sentinel is set.
        """
        log.debug('Manager._dispatcher()')
        assert MPI.Is_thread_main() is False
        assert threading.currentThread().getName() == "dispatcher"

        while not self.shutItDown:

            req = []
            # do we have work and somewhere to send it?
            while self.tasks and self.dests:

                with self.lock:
                    task = self.tasks.popleft()
                    sendTo = self.dests.popleft()
                    self.nPending += 1

                req.append(
                    self.comm.isend(task, dest=sendTo, tag=self.task_tag))

            # make sure all sends completed
            MPI.Request.Waitall(requests=req)

            # force context switching ( 1ms )
            time.sleep(0.001)
Esempio n. 3
0
 def testIsThreadMain(self):
     try:
         flag = MPI.Is_thread_main()
     except NotImplementedError:
         self.skipTest('mpi-is_thread_main')
     name = threading.current_thread().name
     main = (name == 'MainThread') or not HAVE_THREADING
     self.assertEqual(flag, main)
     if VERBOSE:
         log = lambda m: sys.stderr.write(m + '\n')
         log("%s: MPI.Is_thread_main() -> %s" % (name, flag))
Esempio n. 4
0
 def testIsThreadMain(self, main=True):
     try:
         flag = MPI.Is_thread_main()
     except NotImplementedError:
         return
     self.assertEqual(flag, main)
     if _VERBOSE:
         thread = current_thread()
         name = thread.getName()
         log = lambda m: sys.stderr.write(m + '\n')
         log("%s: MPI.Is_thread_main() -> %s" % (name, flag))
Esempio n. 5
0
 def _test_is(self, main=False):
     try:
         flag = MPI.Is_thread_main()
     except NotImplementedError:
         return
     self.assertEqual(flag, main)
     if _VERBOSE:
         from sys import stderr
         thread = current_thread()
         name = thread.getName()
         log = lambda m: stderr.write(m + '\n')
         log("%s: MPI.Is_thread_main() -> %s" % (name, flag))
Esempio n. 6
0
    def __new__(cls):
        """Creates a Serial WorkManager if size is 1.  Otherwise creates a
        single Manager and size-1 Worker.
        """
        log.debug('MPIWorkManager.__new__()')
        assert MPI.Is_initialized()
        assert MPI.Is_thread_main()

        rank = MPI.COMM_WORLD.Get_rank()
        size = MPI.COMM_WORLD.Get_size()

        if size == 1:
            return super().__new__(Serial)
        elif rank == 0:
            return super().__new__(Manager)
        else:
            return super().__new__(Worker)
            if (rank == 0):
                print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss))
        if (i % 1000 == 0 and i != 0):
            sess.check_overflow()
            sess.copy_weights_for_evaluation()
            data_reader_eval = sess.get_data_reader_eval()
            for _ in range(solver_config.max_eval_batches):
                sess.eval()
            metrics = sess.get_eval_metrics()
            print("[HUGECTR][INFO] rank: {}, iter: {}, {}".format(
                rank, i, metrics))
    return


if __name__ == "__main__":
    json_file = sys.argv[1]
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    thread = threading.Thread(target=session_impl_test,
                              args=(json_file, ),
                              name='[rank-%d train]' % rank)
    current_thread = threading.currentThread()
    print('[HUGECTR][INFO] %s is main thread: %s' %
          (current_thread.name, MPI.Is_thread_main()))
    print('[HUGECTR][INFO] before: rank %d ' % (rank))
    # start the thread
    thread.start()
    # wait for terminate
    thread.join()
    print('[HUGECTR][INFO] after: rank %d ' % (rank))
Esempio n. 8
0
    sess.set_learning_rate(lr)
    sess.train()
    if (i%100 == 0):
      loss = sess.get_current_loss()
      if (rank == 0):
        print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss))
    if (i%1000 == 0 and i != 0):
      sess.check_overflow()
      sess.copy_weights_for_evaluation()
      data_reader_eval = sess.get_data_reader_eval()
      for _ in range(solver_config.max_eval_batches):
        sess.eval()
      metrics = sess.get_eval_metrics()
      print("[HUGECTR][INFO] rank: {}, iter: {}, {}".format(rank, i, metrics))
  return

if __name__ == "__main__":
  json_file = sys.argv[1]
  comm = MPI.COMM_WORLD
  rank = comm.Get_rank()
  thread = threading.Thread(target=session_impl_test, args = (json_file,), name='[rank-%d train]' % rank)
  current_thread = threading.currentThread()
  print('[HUGECTR][INFO] %s is main thread: %s' % (current_thread.name, MPI.Is_thread_main()))
  print('[HUGECTR][INFO] before: rank %d '% (rank))
  # start the thread
  thread.start()
  # wait for terminate
  thread.join()
  print('[HUGECTR][INFO] after: rank %d ' % (rank))

Esempio n. 9
0
def recv():
    current_thread = threading.currentThread()
    print '%s is main thread: %s' % (current_thread.name, MPI.Is_thread_main())
    comm.Recv(recv_buf, source=other, tag=11)
    print '%s receives %d from rank %d' % (current_thread.name, recv_buf,
                                           other)
Esempio n. 10
0
def send():
    current_thread = threading.currentThread()
    print '%s is main thread: %s' % (current_thread.name, MPI.Is_thread_main())
    print '%s sends %d to rank %d...' % (current_thread.name, rank, other)
    comm.Send(np.array(rank), dest=other, tag=11)
Esempio n. 11
0
    current_thread = threading.currentThread()
    print '%s is main thread: %s' % (current_thread.name, MPI.Is_thread_main())
    print '%s sends %d to rank %d...' % (current_thread.name, rank, other)
    comm.Send(np.array(rank), dest=other, tag=11)


def recv():
    current_thread = threading.currentThread()
    print '%s is main thread: %s' % (current_thread.name, MPI.Is_thread_main())
    comm.Recv(recv_buf, source=other, tag=11)
    print '%s receives %d from rank %d' % (current_thread.name, recv_buf,
                                           other)


# create thread by using a function
send_thread = threading.Thread(target=send,
                               name='[rank-%d send_thread]' % rank)
recv_thread = threading.Thread(target=recv,
                               name='[rank-%d recv_thread]' % rank)

current_thread = threading.currentThread()
print '%s is main thread: %s' % (current_thread.name, MPI.Is_thread_main())
print 'before: rank %d has %d' % (rank, recv_buf)
# start the threads
send_thread.start()
recv_thread.start()

# wait for terminate
send_thread.join()
recv_thread.join()
print 'after: rank %d has %d' % (rank, recv_buf)