def train_step(self, samples): """Do forward, backward and gradient step in parallel.""" # PyTorch initializes gradient buffers lazily, so the first # train step needs to send non-empty samples to all replicas replace_empty_samples = False if not self._grads_initialized: replace_empty_samples = True self._grads_initialized = True # scatter sample across GPUs self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) # forward pass sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([ self.call_async(rank, '_async_forward') for rank in range(self.num_replicas) ]) # backward pass, all-reduce gradients and take an optimization step grad_denom = self.criterion.__class__.grad_denom(sample_sizes) grad_norms, ooms_bwd = Future.gen_tuple_list([ self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom) for rank in range(self.num_replicas) ]) # aggregate logging output logging_output = self.criterion.__class__.aggregate_logging_outputs( logging_outputs) logging_output['gnorm'] = grad_norms[0] # log the gradient norm logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd) return logging_output
def train_step(self, samples, criterion): """Do forward, backward and gradient step in parallel.""" assert isinstance(criterion, FairseqCriterion) # PyTorch initializes gradient buffers lazily, so the first # train step needs to send non-empty samples to all replicas replace_empty_samples = False if not self._grads_initialized: replace_empty_samples = True self._grads_initialized = True # scatter sample across GPUs self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) criterion.prepare(samples) # forward pass, backward pass and gradient step losses = [ self.call_async(rank, '_async_train_step', criterion=criterion) for rank in range(self.num_replicas) ] # aggregate losses and gradient norms losses, grad_norms = Future.gen_tuple_list(losses) loss = criterion.aggregate(losses) return loss, grad_norms[0]
def valid_step(self, samples): """Do forward pass in parallel.""" # scatter sample across GPUs self._scatter_samples(samples, volatile=True) # forward pass _sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([ self.call_async(rank, '_async_forward', eval=True) for rank in range(self.num_replicas) ]) assert sum(ooms_fwd) == 0 # aggregate logging output logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) return logging_output
def train_step(self, samples, criterion): """Do forward, backward and gradient step in parallel.""" assert isinstance(criterion, FairseqCriterion) # scatter sample across GPUs self._scatter_samples(samples) criterion.prepare(samples) # forward pass, backward pass and gradient step losses = [ self.call_async(rank, '_async_train_step', criterion=criterion) for rank in range(self.num_replicas) ] # aggregate losses and gradient norms losses, grad_norms = Future.gen_tuple_list(losses) loss = criterion.aggregate(losses) return loss, grad_norms[0]
def valid_step(self, samples, criterion): """Do forward pass in parallel.""" # scatter sample across GPUs self._scatter_samples(samples, volatile=True) criterion.prepare(samples) # forward pass res = [ self.call_async(rank, '_async_valid_step', criterion=criterion) for rank in range(self.num_replicas) ] # aggregate losses losses, mean_rouge_greedy, mean_rouge_sampled = Future.gen_tuple_list( res) loss = criterion.aggregate(losses) mean_rouge_greedy = utils.sum_if_not_none(mean_rouge_greedy) mean_rouge_sampled = utils.sum_if_not_none(mean_rouge_sampled) return loss, mean_rouge_greedy, mean_rouge_sampled
def train_step(self, samples, criterion): """Do forward, backward and gradient step in parallel.""" assert isinstance(criterion, FairseqCriterion) # scatter sample across GPUs self._scatter_samples(samples) criterion.prepare(samples) # forward pass, backward pass and gradient step # res is namedtuple res = [ self.call_async(rank, '_async_train_step', criterion=criterion) for rank in range(self.num_replicas) ] # aggregate losses and gradient norms losses, grad_norms, ml_losses, rl_losses, mean_rouge_greedy, mean_rouge_sampled, mean_sum_log_probs = Future.gen_tuple_list( res) loss = criterion.aggregate(losses) ml_loss = criterion.aggregate(ml_losses) rl_loss = utils.sum_if_not_none(rl_losses) mean_rouge_greedy = utils.sum_if_not_none(mean_rouge_greedy) mean_rouge_sampled = utils.sum_if_not_none(mean_rouge_sampled) mean_sum_log_prob = utils.sum_if_not_none(mean_sum_log_probs) aggregate_res = Results(loss, grad_norms[0], ml_loss, rl_loss, mean_rouge_greedy, mean_rouge_sampled, mean_sum_log_prob) return aggregate_res