def broadcast_state(self, rank, src_rank):
     data = None
     if rank == src_rank:
         save_stream = io.BytesIO()
         self.save(save_stream)
         # Note: save_stream.getbuffer() will return a memoryview, which
         # cannot be convert to a tensor, need convert it to np array first
         data = numpy.asarray(save_stream.getbuffer())
     data = dist.broadcast_binary(data, src_rank)
     load_stream = io.BytesIO(data)
     self.load(load_stream)
Example #2
0
 def broadcast_run(rank, world_size, input):
     data = edist.broadcast_binary(
         np.asarray(input[rank]) if input[rank] else None, 1)
     return compute_checksum(data)