def initialize_gradients(self): Future.gen_tuple_list([ self.call_async(rank, '_async_initialize_gradients') for rank in range(self.num_replicas) ]) self._grads_initialized = True
def update_parameters(self, grad_denom=1): """ When we update parameters, all replicas update at the same time""" self.check_global_overflow() Future.gen_tuple_list([ self.call_async(rank, '_async_update', grad_denom=grad_denom, is_global_overflow=False) for rank in range(self.num_replicas) ])
def check_global_overflow(self): local_over_flows = Future.gen_tuple_list([ self.call_async(rank, '_async_local_overflow') for rank in range(self.num_replicas) ]) # global_flows = sum(local_over_flows) return False
def step(self, samples, eval=False): self._scatter_samples(samples, replace_empty_samples=False) # call the async forward function losses, logging_outputs, ooms = Future.gen_tuple_list([ self.call_async(rank, '_async_forward', eval=eval) for rank in range(self.num_replicas) ]) logging_output = aggregate_logging_outputs(logging_outputs) loss = aggregate_loss(losses) logging_output['oom'] = sum(ooms) logging_output['loss'] = loss return logging_output
def zero_grad(self): Future.gen_tuple_list([ self.call_async(rank, '_async_zero_grad') for rank in range(self.num_replicas) ])
def update_parameters(self, grad_denom=1): Future.gen_tuple_list([ self.call_async(rank, '_async_update', grad_denom=grad_denom) for rank in range(self.num_replicas) ])