예제 #1
0
    def _check_grad_norms(self, grad_norm):
        """Check that grad norms are consistent across workers."""
        if self._grad_norm_buf is not None:
            self._grad_norm_buf.zero_()
            self._grad_norm_buf[self.data_parallel_rank] = grad_norm
            distributed_utils.all_reduce(
                self._grad_norm_buf, group=self.data_parallel_process_group)

            def is_consistent(tensor):
                max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
                return (torch.isfinite(tensor).all()
                        or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all())

            if not is_consistent(self._grad_norm_buf):
                pretty_detail = "\n".join(
                    "rank {:3d} = {:.8f}".format(r, n)
                    for r, n in enumerate(self._grad_norm_buf.tolist()))
                error_detail = "grad_norm across the workers:\n{}\n".format(
                    pretty_detail)
                # use FloatingPointError to trigger NanDetector
                raise FloatingPointError(
                    "Fatal error: gradients are inconsistent between workers. "
                    "Try --ddp-backend=no_c10d. "
                    "Or are you mixing up different generation of GPUs in training?"
                    + "\n" + "-" * 80 + "\n{}\n".format(error_detail) +
                    "-" * 80)
예제 #2
0
 def _check_grad_norms(self, grad_norm):
     """Check that grad norms are consistent across workers."""
     if self._grad_norm_buf is not None:
         self._grad_norm_buf.zero_()
         self._grad_norm_buf[self.args.distributed_rank] = grad_norm
         distributed_utils.all_reduce(self._grad_norm_buf)
         if not (self._grad_norm_buf == self._grad_norm_buf[0]).all():
             raise RuntimeError(
                 "Fatal error: gradients are inconsistent between workers. "
                 "Try --ddp-backend=no_c10d.")
예제 #3
0
def _reduce(input_):
    """All-reduce the the input tensor across model parallel group."""
    group = get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if get_world_size(group=group) == 1:
        return input_

    # All-reduce.
    all_reduce(input_, group=group)

    return input_
 def _sync_sample_ratios(self, ratios):
     # in case the ratios are not precisely the same across processes
     # also to ensure every procresses update the ratios in the same pace
     ratios = torch.DoubleTensor(ratios)
     if torch.distributed.is_initialized():
         if torch.cuda.is_available():
             distributed_utils.all_reduce(ratios.cuda())
         else:
             distributed_utils.all_reduce(ratios)
         ret = ratios.cpu()
         ret = ret.numpy()
     return ret
예제 #5
0
    def _fast_stat_sync_sum(
        self,
        logging_outputs: List[Dict[str, Any]],
        *extra_stats_to_sum,
        min_buffer_size: int = 50,
    ):
        """
        Sync logging outputs across workers. fast_stat_sync_sum is
        faster than all_gather_list_sync, but is only suitable when
        logging outputs are scalars and can be summed.
        """
        num_extra = len(extra_stats_to_sum)
        if len(logging_outputs) > 0:
            sorted_keys = sorted(logging_outputs[0].keys())
            stats = [0.] + list(extra_stats_to_sum) + [
                sum(log.get(k, 0) for log in logging_outputs)
                for k in sorted_keys
            ]
            stats = stats + [0.] * (min_buffer_size - len(stats))
            buf = torch.cuda.DoubleTensor(stats)
        else:
            buf = torch.zeros(min_buffer_size,
                              dtype=torch.double,
                              device='cuda')
            buf[0] = 1.  # flag to indicate we should fallback to _all_gather_list_sync

        # stats buffer is organized like:
        # 0: flag to indicate whether fast-stat-sync should be disabled
        # 1-i: extra_stats_to_sum
        # i-j: values from logging_outputs (sorted by key)
        # j-min_buffer_size: padded with 0s
        distributed_utils.all_reduce(buf)

        buf = buf.tolist()
        fallback = buf[0]
        if fallback > 0.:
            # fallback to _all_gather_list_sync
            return self._all_gather_list_sync(logging_outputs,
                                              *extra_stats_to_sum)
        else:
            extra_stats_to_sum, stats = buf[1:num_extra + 1], buf[num_extra +
                                                                  1:]
            stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}]
            return [stats] + extra_stats_to_sum
예제 #6
0
    def _check_grad_norms(self, grad_norm):
        """Check that grad norms are consistent across workers."""
        if self._grad_norm_buf is not None:
            self._grad_norm_buf.zero_()
            self._grad_norm_buf[self.data_parallel_rank] = grad_norm
            distributed_utils.all_reduce(
                self._grad_norm_buf, group=self.data_parallel_process_group)

            if not self._is_grad_norms_consistent(self._grad_norm_buf):
                pretty_detail = "\n".join(
                    "rank {:3d} = {:.8f}".format(r, n)
                    for r, n in enumerate(self._grad_norm_buf.tolist()))
                error_detail = "grad_norm across the workers:\n{}\n".format(
                    pretty_detail)
                raise RuntimeError(
                    "Fatal error: gradients are inconsistent between workers. "
                    "Try --ddp-backend=no_c10d. "
                    "Or are you mixing up different generation of GPUs in training?"
                    + "\n" + "-" * 80 + "\n{}\n".format(error_detail) +
                    "-" * 80)
예제 #7
0
 def _sync_sample_ratios(self, ratios):
     # in case the ratios are not precisely the same across processes
     # also to ensure every procresses update the ratios in the same pace
     # fixme: this has bug only on tir?
     # ratios = torch.DoubleTensor(ratios)
     if torch.distributed.is_initialized():
         if torch.cuda.is_available():
             distributed_utils.all_reduce(ratios.cuda())
         else:
             distributed_utils.all_reduce(ratios)
         ret = ratios.cpu()
         ret = ret.numpy()
         # Ad-hoc FIX!
         if self.remapped_lang_ids is not None:
             ret = ret[self.remapped_lang_ids]
     else:
         ret = ratios.cpu()
         ret = ret.numpy()
         # Ad-hoc FIX!
         if self.remapped_lang_ids is not None:
             ret = ret[self.remapped_lang_ids]
     return ret
예제 #8
0
파일: trainer.py 프로젝트: zawecha1/fairseq
    def _fast_stat_sync_sum(self, logging_outputs: List[Dict[str, Any]],
                            *extra_stats_to_sum):
        """
        Sync logging outputs across workers. fast_stat_sync_sum is
        faster than all_gather_list_sync, but is only suitable when
        logging outputs are scalars and can be summed.
        """
        num_extra = len(extra_stats_to_sum)
        if len(logging_outputs) > 0:
            sorted_keys = sorted(logging_outputs[0].keys())
            stats = list(extra_stats_to_sum) + [
                sum(log.get(k, 0) for log in logging_outputs)
                for k in sorted_keys
            ]
            buf = torch.cuda.DoubleTensor(stats)

            # When the number of batches is not evenly divisible by the
            # number of GPUs, logging_outputs will be empty for some
            # workers in the last iteration. But we still need to know
            # the keys and buffer size, so we cache the state in case it
            # needs to be reused by this worker later.
            self._fss_buf = buf
            self._fss_sorted_keys = sorted_keys
        elif self._fss_buf is not None:
            buf = self._fss_buf
            buf.zero_()
            buf[:num_extra] = torch.cuda.DoubleTensor(extra_stats_to_sum)
            sorted_keys = self._fss_sorted_keys
        else:
            raise RuntimeError(
                'fast_stat_sync failed, perhaps (# GPUs) > (# batches)?')

        distributed_utils.all_reduce(buf)

        buf = buf.tolist()
        extra_stats_to_sum, stats = buf[:num_extra], buf[num_extra:]
        stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}]
        return [stats] + extra_stats_to_sum
예제 #9
0
    def forward(ctx, vocab_parallel_logits, target):

        # Copy so the input remains unchanged.
        logits = vocab_parallel_logits.clone()
        # Maximum value along vocab dimension across all GPUs.
        logits_max = torch.max(logits, dim=-1)[0]
        all_reduce(logits_max, op='max', group=get_model_parallel_group())
        # Subtract the maximum value.
        logits.sub_(logits_max.unsqueeze(dim=-1))
        # Sum of exponential of logits along vocab dimension across all GPUs.
        exp_logits = logits.exp()
        sum_exp_logits = exp_logits.sum(dim=-1)
        all_reduce(sum_exp_logits, op='sum', group=get_model_parallel_group())

        # Get the partition's vocab indecies
        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
        partition_vocab_size = vocab_parallel_logits.size()[-1]
        rank = get_model_parallel_rank()
        world_size = get_model_parallel_world_size()
        vocab_start_index, vocab_end_index = get_vocab_range(
            partition_vocab_size, rank, world_size)

        # Create a mask of valid vocab ids (1 means it needs to be masked).
        target_mask = (target < vocab_start_index) | (target >=
                                                      vocab_end_index)
        masked_target = target.clone() - vocab_start_index
        masked_target[target_mask] = 0

        # Get predicted-logits = logits[target].
        # For Simplicity, we convert logits to a 2-D tensor with size
        # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
        logits_2d = logits.view(-1, partition_vocab_size)
        masked_target_1d = masked_target.view(-1)
        arange_1d = torch.arange(start=0,
                                 end=logits_2d.size()[0],
                                 device=logits_2d.device)
        predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
        predicted_logits = predicted_logits_1d.view_as(target)
        predicted_logits[target_mask] = 0.0
        # All reduce is needed to get the chunks from other GPUs.
        all_reduce(predicted_logits,
                   op='sum',
                   group=get_model_parallel_group())

        # Loss = log(sum(exp(logits))) - predicted-logit.
        loss = torch.log(sum_exp_logits) - predicted_logits

        # Store softmax, target-mask and masked-target for backward pass.
        exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
        ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)

        return loss
예제 #10
0
def validate(args,
             trainer,
             task,
             epoch_itr,
             subsets,
             test_bleu=False,
             summary_writer=None):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []

    distributed_utils.barrier(args, "validate1_%d" % trainer.get_num_updates())
    for subset in subsets:
        # Initialize data iterator
        def get_itr():
            itr = task.get_batch_iterator(
                dataset=task.dataset(subset),
                max_tokens=args.max_tokens,
                max_sentences=args.max_sentences_valid,
                max_positions=utils.resolve_max_positions(
                    task.max_positions(),
                    trainer.get_model().max_positions(),
                ),
                ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
                required_batch_size_multiple=8,
                seed=args.seed,
                num_shards=args.distributed_world_size,
                shard_id=args.distributed_rank,
            ).next_epoch_itr(shuffle=False)
            progress = progress_bar.build_progress_bar(
                args,
                itr,
                epoch_itr.epoch,
                prefix='valid on \'{}\' subset'.format(subset),
                no_progress_bar='simple')
            return progress

        progress = get_itr()

        num_dataset = task.dataset(subset).num_dataset

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)

        bleu_scorers = [
            bleu.Scorer(task.target_dictionary.pad(),
                        task.target_dictionary.eos(),
                        task.target_dictionary.unk())
            for _ in range(num_dataset)
        ] if test_bleu else None

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg

        if bleu_scorers is not None:
            # test bleu
            print("| test bleu.")
            sample_size = [0 for _ in range(num_dataset)]
            bleu_scores = [0 for _ in range(num_dataset)]
            progress = get_itr()

            tgt_str_files = []
            hypo_str_files = []
            for ds_id in range(num_dataset):
                tgt_str_path = task.dataset(
                    subset).dataset_names[ds_id] + '.tgt.txt'
                hypo_str_path = task.dataset(
                    subset).dataset_names[ds_id] + '.hypo.txt'
                tgt_str_files.append(
                    open(os.path.join(args.save_dir, tgt_str_path),
                         'w',
                         encoding='utf-8'))
                hypo_str_files.append(
                    open(os.path.join(args.save_dir, hypo_str_path),
                         'w',
                         encoding='utf-8'))

            def print_to_file(dataset_id, tgt_str, hypo_str):
                tgt_str_files[dataset_id].write(tgt_str + '\n')
                hypo_str_files[dataset_id].write(hypo_str + '\n')

            for sample in progress:
                trainer.test_bleu_step(sample, bleu_scorers, print_to_file)
                if 'dataset_id' in sample:
                    for ds_id in range(num_dataset):
                        sample_size[ds_id] += (
                            sample['dataset_id'] == ds_id).int().sum().item()
                elif 'id' in sample:
                    sample_size[0] += len(sample['id'])

            for f in tgt_str_files + hypo_str_files:
                f.close()

            distributed_utils.barrier(
                args, "validate2_%d" % trainer.get_num_updates())
            for ds_id in range(num_dataset):
                try:
                    bleu_scores[ds_id] = bleu_scorers[ds_id].score(
                    ) * sample_size[ds_id]
                except Exception as e:
                    bleu_scores[ds_id] = 0

            sample_size = torch.Tensor(sample_size).cuda()
            bleu_scores = torch.Tensor(bleu_scores).cuda()
            if args.distributed_world_size > 1:
                all_reduce(sample_size)
                all_reduce(bleu_scores)

            bleu_dict = {}
            for ds_id in range(num_dataset):
                if sample_size[ds_id].item() > 0:
                    name = "bleu_" + task.dataset(subset).dataset_names[ds_id]
                    bleu_dict[name] = stats[name] = bleu_scores[ds_id].item(
                    ) / sample_size[ds_id].item()
                    try:
                        train_ds_id = task.dataset(
                            'train').dataset_names.index(
                                task.dataset(subset).dataset_names[ds_id])
                        task.dataset('train').student_scores[
                            train_ds_id] = bleu_dict[name]
                    except ValueError:
                        pass
            output_path = os.path.join(args.save_dir, 'val_bleu.json')
            json.dump(bleu_dict, open(output_path, 'w'))

        progress.print(stats)
        if summary_writer is not None:
            summary_writer.log_stats('val/' + subset, stats,
                                     trainer.get_num_updates())

        valid_losses.append(stats['valid_loss'])
    return valid_losses
예제 #11
0
 def _aggregate_model_parallel_grad_norm(total_norm):
     total_norm = total_norm ** 2
     distributed_utils.all_reduce(total_norm, group=get_model_parallel_group())
     total_norm = total_norm ** 0.5
     return total_norm
예제 #12
0
    def train_step(self,
                   samples,
                   dummy_batch=False,
                   assistant=None,
                   assistant_queue=None,
                   weights=None):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.model.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            try:
                if self.args.distributed_world_size > 1:
                    # Whenever *samples* contains more than one mini-batch, we
                    # want to accumulate gradients locally and only call
                    # all-reduce in the last backwards pass. Currently the
                    # *need_reduction* flag is only supported by
                    # LegacyDistributedDataParallel.
                    if i < len(samples) - 1:
                        self.model.accumulate_grads = True
                    else:
                        self.model.accumulate_grads = False

                # forward and backward
                if self.args.assistant:
                    losses, sample_size, logging_output, precisions = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)
                elif self.args.spl:
                    losses, sample_size, logging_output, precisions = self.task.train_step(
                        sample,
                        self.model,
                        self.criterion,
                        self.optimizer,
                        ignore_grad,
                        lambda_t=self.lambda_t)
                else:
                    losses, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)
                # record new losses
                if self.args.spl and not dummy_batch:
                    y_lengths = utils.get_len(
                        sample['target'].cpu().numpy(),
                        self.task.target_dictionary.pad())
                    norm_losses = np.divide(losses.detach().cpu().numpy(),
                                            y_lengths)
                    self.loss_chart[
                        sample['id'].cpu().numpy()] = torch.from_numpy(
                            norm_losses, ).type(torch.FloatTensor).cuda()
                    if self.args.distributed_world_size > 1:
                        all_reduce(self.loss_chart, op=MIN_OP)

                # prepare data for assistant trainning
                if assistant is not None and np.random.rand(
                ) < SEC_TRAIN_RATIO:
                    sec_batch_size = sample['id'].size(0)
                    indices_sec = np.random.choice(sample['id'].size(0),
                                                   sec_batch_size)
                    x = sample['net_input']['src_tokens'][indices_sec]
                    y = sample['target'][indices_sec]
                    l = losses[indices_sec]
                    x = x.cpu().numpy()
                    y = y.cpu().numpy()
                    l = l.detach().cpu().numpy()
                    keep_probs = assistant.train_step(x, y, l)
                elif assistant_queue is not None and np.random.rand(
                ) < SEC_TRAIN_RATIO:
                    sec_batch_size = sample['id'].size(0)
                    local_indices_sec = np.random.choice(
                        sample['id'].size(0), sec_batch_size)
                    global_indices_sec = sample['id'][local_indices_sec].cpu(
                    ).numpy()

                    l = losses[local_indices_sec]
                    l = l.detach().cpu().numpy()
                    if not assistant_queue.full():
                        assistant_queue.put((global_indices_sec, l),
                                            block=False)
                    else:
                        _ = assistant_queue.get()
                        assistant_queue.put((global_indices_sec, l),
                                            block=False)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    ooms += 1
                    self.zero_grad()
                else:
                    print(sample, flush=True, force=True)
                    raise e

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_outputs, sample_sizes, ooms = zip(
                *distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms], ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

        if ooms == self.args.distributed_world_size * len(samples):
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion)
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception(
                ('Please update the {}.aggregate_logging_outputs() method to '
                 'return ntokens and nsentences').format(
                     self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            self.optimizer.multiply_grads(self.args.distributed_world_size /
                                          float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(1. if grad_norm > self.args.clip_norm
                                       and self.args.clip_norm > 0 else 0.)
            self.meters['oom'].update(ooms)
            self.meters['train_loss'].update(logging_output.get('loss', 0),
                                             sample_size)
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(
                    logging_output.get('nll_loss', 0), ntokens)
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.meters['train_wall'].stop()

        return logging_output