예제 #1
0
 def broadcast(self):
     broadcast_parameters(self._model.state_dict(), root_rank=0)
     broadcast_optimizer_state(self._optimizer, root_rank=0)
     self.global_completed_batch_num = broadcast_object(
         self.global_completed_batch_num, name="GlobalCompletedBatchNum")
예제 #2
0
 def sync(self):
     broadcast_parameters(self.model.state_dict(), root_rank=0)
     broadcast_optimizer_state(self.optimizer, root_rank=0)
     super(TorchState, self).sync()
예제 #3
0
파일: state.py 프로젝트: rongou/horovod
 def sync(self):
     broadcast_parameters(self.value.state_dict(), root_rank=0)