def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
        self.model.eval()
        self._num_val_iterations += 1
        # forward pass
        sample, sample_size = self._prepare_sample(sample)
        with torch.no_grad():
            loss, oom_fwd = self._forward(sample)
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences': sample['target'].size(0) if sample is not None else 0,
            'sample_size': sample_size
        }
        loss = loss.item() if loss is not None else 0
        assert not oom_fwd, 'Ran out of memory during validation'

        # gather logging outputs from all GPUs
        if self.args.distributed_world_size > 1:
            losses, sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
                (loss, sample_size, logging_output)
            ))
        else:
            losses = [loss]
            sample_sizes = [sample_size]
            logging_outputs = [logging_output]

        # TODO: check when ntokens != sample_size
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        weight = sum(log.get('sample_size', 0) for log in logging_outputs)
        scaled_loss = sum(losses) / weight / math.log(2)

        return scaled_loss 
예제 #2
0
def all_gather_from_master(args, data: List) -> List:
    if args.distributed_world_size == 1:
        return data

    gathered_data = distributed_utils.all_gather_list(data)
    # Converts [[x0, y0, z0, ...], [x1, y1, z1, ...], [x2, y2, z2, ...], ...]
    # to [[x0, x1, x2, ...], [y0, y1, y2, ...], [z0, z1, z2, ...], ...]
    gathered_data_list = list(zip(*gathered_data))

    output_data = []
    for data_index, all_data in enumerate(gathered_data_list):
        # The master's (process 0) data is guaranteed to be in position 0.
        master_data = all_data[0]
        # Sanity check that only the master returned any result.
        if master_data is None:
            raise RuntimeError(
                f"Input data element {data_index} of all_gather_from_master "
                f"returned None from master. Results from all processes: {all_data}"
            )
        for i in range(1, len(all_data)):
            if all_data[i] is not None:
                raise RuntimeError(
                    f"Input data element {data_index} of all_gather_from_master "
                    f"should have returned None from non-master process {i}. "
                    f"Results from all processes: {all_data}")
        output_data.append(master_data)
    return output_data
예제 #3
0
    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                ignore_results = True
            else:
                ignore_results = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion)
            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
                            p.grad = None  # free some memory
                    if self.cuda:
                        torch.cuda.empty_cache()
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e

            if ignore_results:
                logging_output, sample_size = {}, 0

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_output, sample_size = zip(
                *distributed_utils.all_gather_list(
                    [logging_output, sample_size], ))
            logging_output = list(logging_output)
            sample_size = list(sample_size)
        else:
            logging_output = [logging_output]
            sample_size = [sample_size]

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.get_criterion())
        sample_size = self.task.grad_denom(sample_size, self.get_criterion())

        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0),
                                         sample_size)
        if 'valid_acc' in self.meters:
            self.meters['valid_acc'].update(logging_output.get('acc', 0),
                                            sample_size)

        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(
                logging_output.get('nll_loss', 0), ntokens)

        return logging_output
예제 #4
0
    def validate(self, data_loader, args):
        was_training = self.training
        self.eval()

        keys = ['loss', 'sample_size']

        logging_info_list = []
        with torch.no_grad():
            with tqdm(total=len(data_loader),
                      desc=f"Evaluation",
                      file=sys.stdout) as pbar:
                for step, batch in enumerate(data_loader):
                    loss_sum, logging_info = self(**batch)
                    logging_info = {k: logging_info[k] for k in keys}
                    logging_info_list.append(logging_info)

                    pbar.update(1)

        if was_training:
            self.train()

        stats = {k: sum(x[k] for x in logging_info_list) for k in keys}

        # handel distributed evaluation
        if args.multi_gpu:
            stats = distributed_utils.all_gather_list(stats)
            stats = {k: sum(x[k] for x in stats) for k in keys}

        valid_result = {'ppl': math.exp(stats['loss'] / stats['sample_size'])}

        return valid_result
예제 #5
0
    def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
        # forward pass
        sample = self._prepare_sample(sample)
        _loss, sample_size, logging_output, oom_fwd = self._forward(sample,
                                                                    eval=True)
        assert not oom_fwd, 'Ran out of memory during validation'

        # gather logging outputs from all GPUs
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs = zip(
                *distributed_utils.all_gather_list((sample_size,
                                                    logging_output)))
        else:
            sample_sizes = [sample_size]
            logging_outputs = [logging_output]

        # aggregate stats and logging outputs
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
        agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(
            logging_outputs)

        # update loss meters for validation
        if 'loss' in agg_logging_output:
            self.meters['valid_loss'].update(agg_logging_output['loss'],
                                             grad_denom)
        # criterions can optionally log the NLL loss too
        if 'nll_loss' in agg_logging_output:
            self.meters['valid_nll_loss'].update(
                agg_logging_output['nll_loss'], ntokens)

        return agg_logging_output
예제 #6
0
 def _all_gather_list_sync(
     self,
     logging_outputs: List[Dict[str, Any]],
     *extra_stats_to_sum,
     ignore=False,
 ):
     """
     Sync logging outputs across workers. all_gather_list_sync is
     suitable when logging outputs are complex types.
     """
     if self.tpu:
         raise NotImplementedError
     if ignore:
         logging_outputs = []
     results = list(
         zip(*distributed_utils.all_gather_list(
             [logging_outputs] + list(extra_stats_to_sum),
             max_size=getattr(self.cfg.common, "all_gather_list_size",
                              16384),
             group=self.data_parallel_process_group,
         )))
     logging_outputs, extra_stats_to_sum = results[0], results[1:]
     logging_outputs = list(chain.from_iterable(logging_outputs))
     extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
     return logging_outputs, extra_stats_to_sum
예제 #7
0
파일: trainer.py 프로젝트: fyabc/fairseq
    def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
        # forward pass
        sample = self._prepare_sample(sample)
        _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
        assert not oom_fwd, 'Ran out of memory during validation'

        # gather logging outputs from all GPUs
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
                (sample_size, logging_output)
            ))
        else:
            sample_sizes = [sample_size]
            logging_outputs = [logging_output]

        # aggregate stats and logging outputs
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
        agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)

        # update loss meters for validation
        if 'loss' in agg_logging_output:
            self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom)
        # criterions can optionally log the NLL loss too
        if 'nll_loss' in agg_logging_output:
            self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)

        return agg_logging_output
예제 #8
0
def _all_gather_predictions(predictions):
    ready = False
    all_ready = False
    reduced_predictions = []
    max_size = 65000
    while not all_ready:
        lst_len = len(predictions)
        size = 2000  #some extra space for python stuff
        n = 0
        while n < lst_len:
            str_len = len(predictions[n].encode(
                'utf8')) + 8  # per string pickle overhead
            if size + str_len >= max_size:
                break
            size += str_len
            n += 1
        chunk = predictions[:n]
        predictions = predictions[n:]
        if not predictions:
            ready = True
        chunk = (ready, chunk)
        torch.cuda.synchronize()
        gathered = distributed_utils.all_gather_list(chunk, max_size=65000)
        torch.cuda.synchronize()
        reduced_predictions += [t[1] for t in gathered]
        all_ready = all([t[0] for t in gathered])

    reduced_predictions = [
        item for sublist in reduced_predictions for item in sublist
    ]

    return reduced_predictions
예제 #9
0
파일: trainer.py 프로젝트: yxlin1/fairseq
    def _update_params(self):
        # gather logging outputs from all replicas
        sample_sizes = self._buffered_stats['sample_sizes']
        logging_outputs = self._buffered_stats['logging_outputs']
        ooms_fwd = self._buffered_stats['ooms_fwd']
        ooms_bwd = self._buffered_stats['ooms_bwd']
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                lambda l: list(chain.from_iterable(l)),
                zip(*distributed_utils.all_gather_list((sample_sizes,
                                                        logging_outputs,
                                                        ooms_fwd, ooms_bwd))))
        ooms_fwd = sum(ooms_fwd)
        ooms_bwd = sum(ooms_bwd)

        if ooms_fwd == self.args.distributed_world_size:
            print('| WARNING: OOM in all workers, skipping batch')
            self.zero_grad()
            return None

        # aggregate stats and logging outputs
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(
            logging_outputs)
        grad_denom = self.criterion.__class__.grad_denom(sample_sizes)

        try:
            # all-reduce and rescale gradients, then take an optimization step
            grad_norm = self._all_reduce_and_rescale(grad_denom)
            self._opt()

            # update meters
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            if grad_norm is not None:
                self.meters['gnorm'].update(grad_norm)
                self.meters['clip'].update(
                    1. if grad_norm > self.args.clip_norm else 0.)
            self.meters['oom'].update(ooms_fwd + ooms_bwd)

            # update loss meters for training
            if 'loss' in agg_logging_output:
                self.meters['train_loss'].update(agg_logging_output['loss'],
                                                 grad_denom)
            # criterions can optionally log the NLL loss too
            if 'nll_loss' in agg_logging_output:
                self.meters['train_nll_loss'].update(
                    agg_logging_output['nll_loss'], ntokens)
        except OverflowError as e:
            self.zero_grad()
            print('| WARNING: overflow detected, ' + str(e))

        self.clear_buffered_stats()

        return agg_logging_output
예제 #10
0
파일: trainer.py 프로젝트: fyabc/fairseq
    def _update_params(self):
        # gather logging outputs from all replicas
        sample_sizes = self._buffered_stats['sample_sizes']
        logging_outputs = self._buffered_stats['logging_outputs']
        ooms_fwd = self._buffered_stats['ooms_fwd']
        ooms_bwd = self._buffered_stats['ooms_bwd']
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                lambda l: list(chain.from_iterable(l)),
                zip(*distributed_utils.all_gather_list(
                    (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
                ))
            )
        ooms_fwd = sum(ooms_fwd)
        ooms_bwd = sum(ooms_bwd)

        if ooms_fwd == self.args.distributed_world_size:
            print('| WARNING: OOM in all workers, skipping batch')
            self.zero_grad()
            return None

        # aggregate stats and logging outputs
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
        grad_denom = self.criterion.__class__.grad_denom(sample_sizes)

        try:
            # all-reduce and rescale gradients, then take an optimization step
            grad_norm = self._all_reduce_and_rescale(grad_denom)
            self._opt()

            # update meters
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            if grad_norm is not None:
                self.meters['gnorm'].update(grad_norm)
                self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
            self.meters['oom'].update(ooms_fwd + ooms_bwd)

            # update loss meters for training
            if 'loss' in agg_logging_output:
                self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
            # criterions can optionally log the NLL loss too
            if 'nll_loss' in agg_logging_output:
                self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
        except OverflowError as e:
            self.zero_grad()
            print('| WARNING: overflow detected, ' + str(e))

        self.clear_buffered_stats()

        return agg_logging_output
예제 #11
0
    def validate(self, data_loader, args):
        gc.collect()

        keys = [
            'masked_context_token_loss',
            'masked_context_token_num',
            'masked_column_token_loss',
            'masked_column_token_num'
        ]

        if self.config.predict_cell_tokens:
            keys += [
                'masked_cell_token_loss',
                'masked_cell_token_num'
            ]

        was_training = self.training
        self.eval()

        logging_info_list = []
        with torch.no_grad():
            with tqdm(total=len(data_loader), desc=f"Evaluation", file=sys.stdout) as pbar:
                for step, batch in enumerate(data_loader):
                    loss_sum, logging_info = self(**batch)
                    logging_info = {k: logging_info[k] for k in keys}
                    logging_info_list.append(logging_info)

                    pbar.update(1)

        if was_training:
            self.train()

        stats = {
            k: sum(x[k] for x in logging_info_list)
            for k in keys
        }

        # handel distributed evaluation
        if args.multi_gpu:
            stats = distributed_utils.all_gather_list(stats)
            stats = {
                k: sum(x[k] for x in stats)
                for k in keys
            }

        valid_result = {
            'masked_context_token_ppl': math.exp(stats['masked_context_token_loss'] / stats['masked_context_token_num']),
            'masked_column_token_ppl': math.exp(stats['masked_column_token_loss'] / stats['masked_column_token_num'])
        }

        if self.config.predict_cell_tokens:
            valid_result['masked_cell_token_ppl'] = math.exp(stats['masked_cell_token_loss'] / stats['masked_cell_token_num'])

        return valid_result
예제 #12
0
    def mlm_eval_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                ignore_results = True
            else:
                ignore_results = False

            try:
                logging_output = self.task.mlm_eval_step(
                    sample, self.model, self.criterion
                )
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        print(
                            "| WARNING: ran out of memory in validation step, retrying batch"
                        )
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e

            if ignore_results:
                logging_output, sample_size = {}, 0

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_output, sample_size = zip(
                *distributed_utils.all_gather_list([logging_output, sample_size])
            )
            logging_output = list(logging_output)
            sample_size = list(sample_size)
        else:
            logging_output = [logging_output]
            sample_size = [sample_size]

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.get_criterion()
        )

        return logging_output
예제 #13
0
def _all_gather_bleu_scorer(scorer):
    stats = distributed_utils.all_gather_list(scorer.stat)
    bleu_stat = bleu.BleuStat()
    bleu_stat.reflen = reduce(lambda x, y: x + y, [s.reflen for s in stats])
    bleu_stat.predlen = reduce(lambda x, y: x + y, [s.predlen for s in stats])
    bleu_stat.match1 = reduce(lambda x, y: x + y, [s.match1 for s in stats])
    bleu_stat.count1 = reduce(lambda x, y: x + y, [s.count1 for s in stats])
    bleu_stat.match2 = reduce(lambda x, y: x + y, [s.match2 for s in stats])
    bleu_stat.count2 = reduce(lambda x, y: x + y, [s.count2 for s in stats])
    bleu_stat.match3 = reduce(lambda x, y: x + y, [s.match3 for s in stats])
    bleu_stat.count3 = reduce(lambda x, y: x + y, [s.count3 for s in stats])
    bleu_stat.match4 = reduce(lambda x, y: x + y, [s.match4 for s in stats])
    bleu_stat.count4 = reduce(lambda x, y: x + y, [s.count4 for s in stats])
    scorer.stat = bleu_stat
예제 #14
0
    def prune_step(self, sample, raise_oom=False):
        """Do forward and backward pass in evaluation mode."""
        self.model.eval()
        self.zero_grad()

        sample = self._prepare_sample(sample)
        if sample is None:
            sample = self._prepare_sample(self._dummy_batch)
            ignore_results = True
        else:
            ignore_results = False

        try:
            sample_size, logging_output = self.task.prune_step(
                sample, self.model, self.criterion,
            )
        except RuntimeError as e:
            if 'out of memory' in str(e) and not raise_oom:
                print('| WARNING: ran out of memory, retrying batch')
                for p in self.model.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                return self.prune_step(sample, raise_oom=True)
            else:
                raise e

        if ignore_results:
            logging_output, sample_size = {}, 0

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_output, sample_size = zip(*distributed_utils.all_gather_list(
                [logging_output, sample_size],
            ))
            logging_output = list(logging_output)
            sample_size = list(sample_size)
        else:
            logging_output = [logging_output]
            sample_size = [sample_size]

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.criterion
        )
        sample_size = self.task.grad_denom(
            sample_size, self.criterion
        )

        return logging_output
예제 #15
0
 def _all_gather_list_sync(self, logging_outputs: List[Dict[str, Any]],
                           *extra_stats_to_sum):
     """
     Sync logging outputs across workers. all_gather_list_sync is
     suitable when logging outputs are complex types.
     """
     results = list(
         zip(*distributed_utils.all_gather_list(
             [logging_outputs] + list(extra_stats_to_sum),
             max_size=getattr(self.args, 'all_gather_list_size', 16384),
         )))
     logging_outputs, extra_stats_to_sum = results[0], results[1:]
     logging_outputs = list(chain.from_iterable(logging_outputs))
     extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
     return [logging_outputs] + extra_stats_to_sum
예제 #16
0
    def _forward(self, sample, eval=False):
        # prepare model and optimizer
        if eval:
            self.model.eval()
        else:
            self.model.train()
            self.optimizer.zero_grad()

        loss = None
        sample_size = 0
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences':
            sample['target'].size(0) if sample is not None else 0,
        }
        oom = 0
        if sample is not None:
            try:
                with utils.maybe_no_grad(eval):
                    # calculate loss and sample size
                    loss, sample_size, logging_output_ = self.criterion(
                        self.model, sample)
                    logging_output.update(logging_output_)
            except RuntimeError as e:
                if not eval and 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom = 1
                    loss = None
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise e

        # synchronize logging outputs for multi-GPU training
        if self.args.distributed_world_size > 1:
            sample_sizes, logging_outputs, ooms = zip(*list(
                distributed_utils.all_gather_list((sample_size, logging_output,
                                                   oom))))
            ooms = sum(ooms)
        else:
            sample_sizes = [sample_size]
            logging_outputs = [logging_output]
            ooms = oom

        return loss, sample_sizes, logging_outputs, ooms
예제 #17
0
    def __init__(self,
                 cfg: FairseqConfig,
                 task,
                 model,
                 criterion,
                 quantizer=None):

        if isinstance(cfg, Namespace):
            logger.warning(
                "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
            )
            cfg = convert_namespace_to_omegaconf(cfg)

        self.cfg = cfg
        self.task = task

        # catalog shared parameters
        shared_params = _catalog_shared_params(model)
        self.tpu = cfg.common.tpu
        self.cuda = torch.cuda.is_available(
        ) and not cfg.common.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device("cuda")
        elif self.tpu:
            self.device = utils.get_tpu_device()
        else:
            self.device = torch.device("cpu")

        # copy model and criterion to current device/dtype
        self._criterion = criterion
        self._model = model
        if cfg.common.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        elif cfg.common.bf16:
            self._criterion = self._criterion.to(dtype=torch.bfloat16)
            self._model = self._model.to(dtype=torch.bfloat16)
        if not cfg.distributed_training.pipeline_model_parallel:
            self._criterion = self._criterion.to(device=self.device)
            self._model = self._model.to(device=self.device)
        self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
        self.last_device = None
        if self.cuda and self.pipeline_model_parallel:
            self.last_device = torch.device(
                cfg.distributed_training.pipeline_devices[-1])

        # check that shared parameters are preserved after device transfer
        for shared_param in shared_params:
            ref = _get_module_by_path(self._model, shared_param[0])
            for path in shared_param[1:]:
                logger.info("detected shared parameter: {} <- {}".format(
                    shared_param[0], path))
                _set_module_by_path(self._model, path, ref)

        self._dummy_batch = None  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._num_xla_compiles = 0  # for TPUs
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        # TODO(myleott): support tpu
        if self.cuda and self.data_parallel_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(
                self.data_parallel_world_size)
        else:
            self._grad_norm_buf = None

        self.quantizer = quantizer
        if self.quantizer is not None:
            self.quantizer.set_trainer(self)

        # get detailed cuda environment
        if self.cuda:
            self.cuda_env = utils.CudaEnvironment()
            if self.data_parallel_world_size > 1:
                self.cuda_env_arr = distributed_utils.all_gather_list(
                    self.cuda_env, group=distributed_utils.get_global_group())
            else:
                self.cuda_env_arr = [self.cuda_env]
            if self.data_parallel_rank == 0:
                utils.CudaEnvironment.pretty_print_cuda_env_list(
                    self.cuda_env_arr)
        else:
            self.cuda_env = None
            self.cuda_env_arr = None

        metrics.log_start_time("wall", priority=790, round=0)

        self._start_time = time.time()
        self._previous_training_time = 0
        self._cumulative_training_time = None
예제 #18
0
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters["train_wall"].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)

                    if self.fast_stat_sync:
                        self._all_reduce_list[0] += sample_size
                        self._all_reduce_list[1] += logging_output.get(
                            "nsentences", 0.0)
                        self._all_reduce_list[2] += logging_output.get(
                            "loss", 0.0)
                        self._all_reduce_list[3] += logging_output.get(
                            "nll_loss", 0.0)
                        self._all_reduce_list[4] += logging_output.get(
                            "ntokens", 0.0)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    print(
                        "| WARNING: attempting to recover from OOM in forward/backward pass",
                        file=sys.stderr)
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

            if self.fast_stat_sync:
                self._all_reduce_list[5] += ooms

        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.fast_stat_sync:
            # rework all_gather_list
            all_reduce_list_tensor = torch.cuda.DoubleTensor(
                self._all_reduce_list)
            if self._sync_stats():
                torch.distributed.all_reduce(all_reduce_list_tensor)
            # Normalize loss and nll_loss by "sample_size"
            # and convert to log base 2
            all_reduce_list_tensor[2:4].div_(
                (all_reduce_list_tensor[0:1] *
                 torch.log(torch.cuda.DoubleTensor([2]))))
            self._all_reduce_list = all_reduce_list_tensor.tolist()
            logging_output = {}
            [
                sample_size,
                logging_output["nsentences"],
                logging_output["loss"],
                logging_output["nll_loss"],
                logging_output["ntokens"],
                ooms,
            ] = self._all_reduce_list
        elif self._sync_stats():
            logging_outputs, sample_sizes, ooms, prev_norms = zip(
                *distributed_utils.all_gather_list([
                    logging_outputs, sample_sizes, ooms, self._prev_grad_norm
                ]))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

            if not self.args.use_bmuf:
                assert all(
                    norm == prev_norms[0] for norm in prev_norms
                ) or all(
                    math.isnan(norm) or math.isinf(norm) for norm in prev_norms
                ), "Fatal error: gradients are inconsistent between workers"

        self.meters["oom"].update(ooms, len(samples))
        if ooms == self.args.distributed_world_size * len(samples):
            print("| WARNING: OOM in all workers, skipping update")
            self.zero_grad()
            return None

        if not self.fast_stat_sync:
            # aggregate logging outputs and sample sizes
            logging_output = self.task.aggregate_logging_outputs(
                logging_outputs, self.get_criterion())
            sample_size = self.task.grad_denom(sample_sizes,
                                               self.get_criterion())

        if not all(k in logging_output for k in ["ntokens", "nsentences"]):
            raise Exception(
                ("Please update the {}.aggregate_logging_outputs() method to "
                 "return ntokens and nsentences").format(
                     self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            if sample_size > 0:
                self.optimizer.multiply_grads(
                    self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
            self._prev_grad_norm = grad_norm

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self._num_updates)

            # update meters
            ntokens = logging_output.get("ntokens", 0)
            nsentences = logging_output.get("nsentences", 0)
            self.meters["wps"].update(ntokens)
            self.meters["ups"].update(1.0)
            self.meters["wpb"].update(ntokens)
            self.meters["bsz"].update(nsentences)
            self.meters["gnorm"].update(grad_norm)
            self.meters["clip"].update(1.0 if grad_norm > self.args.clip_norm
                                       and self.args.clip_norm > 0 else 0.0)
            self.meters["train_loss"].update(logging_output.get("loss", 0),
                                             sample_size)
            if "train_acc" in self.meters:
                self.meters["train_acc"].update(logging_output.get("acc", 0),
                                                sample_size)

            if "nll_loss" in logging_output:
                self.meters["train_nll_loss"].update(
                    logging_output.get("nll_loss", 0), ntokens)

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except OverflowError as e:
            print("| WARNING: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                print("| ERROR: OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            self.meters["loss_scale"].reset()
            self.meters["loss_scale"].update(self.optimizer.scaler.loss_scale)

        self.clear_buffered_stats()
        self.meters["train_wall"].stop()

        return logging_output
예제 #19
0
    def train_step(self,
                   samples,
                   dummy_batch=False,
                   assistant=None,
                   assistant_queue=None,
                   weights=None):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.model.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            try:
                if self.args.distributed_world_size > 1:
                    # Whenever *samples* contains more than one mini-batch, we
                    # want to accumulate gradients locally and only call
                    # all-reduce in the last backwards pass. Currently the
                    # *need_reduction* flag is only supported by
                    # LegacyDistributedDataParallel.
                    if i < len(samples) - 1:
                        self.model.accumulate_grads = True
                    else:
                        self.model.accumulate_grads = False

                # forward and backward
                if self.args.assistant:
                    losses, sample_size, logging_output, precisions = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)
                elif self.args.spl:
                    losses, sample_size, logging_output, precisions = self.task.train_step(
                        sample,
                        self.model,
                        self.criterion,
                        self.optimizer,
                        ignore_grad,
                        lambda_t=self.lambda_t)
                else:
                    losses, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)
                # record new losses
                if self.args.spl and not dummy_batch:
                    y_lengths = utils.get_len(
                        sample['target'].cpu().numpy(),
                        self.task.target_dictionary.pad())
                    norm_losses = np.divide(losses.detach().cpu().numpy(),
                                            y_lengths)
                    self.loss_chart[
                        sample['id'].cpu().numpy()] = torch.from_numpy(
                            norm_losses, ).type(torch.FloatTensor).cuda()
                    if self.args.distributed_world_size > 1:
                        all_reduce(self.loss_chart, op=MIN_OP)

                # prepare data for assistant trainning
                if assistant is not None and np.random.rand(
                ) < SEC_TRAIN_RATIO:
                    sec_batch_size = sample['id'].size(0)
                    indices_sec = np.random.choice(sample['id'].size(0),
                                                   sec_batch_size)
                    x = sample['net_input']['src_tokens'][indices_sec]
                    y = sample['target'][indices_sec]
                    l = losses[indices_sec]
                    x = x.cpu().numpy()
                    y = y.cpu().numpy()
                    l = l.detach().cpu().numpy()
                    keep_probs = assistant.train_step(x, y, l)
                elif assistant_queue is not None and np.random.rand(
                ) < SEC_TRAIN_RATIO:
                    sec_batch_size = sample['id'].size(0)
                    local_indices_sec = np.random.choice(
                        sample['id'].size(0), sec_batch_size)
                    global_indices_sec = sample['id'][local_indices_sec].cpu(
                    ).numpy()

                    l = losses[local_indices_sec]
                    l = l.detach().cpu().numpy()
                    if not assistant_queue.full():
                        assistant_queue.put((global_indices_sec, l),
                                            block=False)
                    else:
                        _ = assistant_queue.get()
                        assistant_queue.put((global_indices_sec, l),
                                            block=False)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    ooms += 1
                    self.zero_grad()
                else:
                    print(sample, flush=True, force=True)
                    raise e

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_outputs, sample_sizes, ooms = zip(
                *distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms], ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

        if ooms == self.args.distributed_world_size * len(samples):
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion)
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception(
                ('Please update the {}.aggregate_logging_outputs() method to '
                 'return ntokens and nsentences').format(
                     self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            self.optimizer.multiply_grads(self.args.distributed_world_size /
                                          float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(1. if grad_norm > self.args.clip_norm
                                       and self.args.clip_norm > 0 else 0.)
            self.meters['oom'].update(ooms)
            self.meters['train_loss'].update(logging_output.get('loss', 0),
                                             sample_size)
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(
                    logging_output.get('nll_loss', 0), ntokens)
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.meters['train_wall'].stop()

        return logging_output
예제 #20
0
def main(cfg: DictConfig, override_args=None):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"

    use_fp16 = cfg.common.fp16
    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    if use_cuda:
        torch.cuda.set_device(cfg.distributed_training.device_id)

    if cfg.distributed_training.distributed_world_size > 1:
        data_parallel_world_size = distributed_utils.get_data_parallel_world_size(
        )
        data_parallel_rank = distributed_utils.get_data_parallel_rank()
    else:
        data_parallel_world_size = 1
        data_parallel_rank = 0

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
    else:
        overrides = None

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
        [cfg.common_eval.path],
        arg_overrides=overrides,
        suffix=cfg.checkpoint.checkpoint_suffix,
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(saved_cfg)

    # Build criterion
    criterion = task.build_criterion(saved_cfg.criterion)
    criterion.eval()

    for subset in cfg.dataset.valid_subset.split(","):
        try:
            task.load_dataset(subset,
                              combine=False,
                              epoch=1,
                              task_cfg=saved_cfg.task)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception("Cannot find dataset: " + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=cfg.dataset.max_tokens,
            max_sentences=cfg.dataset.batch_size,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=cfg.dataset.
            skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=cfg.dataset.
            required_batch_size_multiple,
            seed=cfg.common.seed,
            num_shards=data_parallel_world_size,
            shard_id=data_parallel_rank,
            num_workers=cfg.dataset.num_workers,
            data_buffer_size=cfg.dataset.data_buffer_size,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        if data_parallel_world_size > 1:
            log_outputs = distributed_utils.all_gather_list(
                log_outputs,
                max_size=cfg.common.all_gather_list_size,
                group=distributed_utils.get_data_parallel_group(),
            )
            log_outputs = list(chain.from_iterable(log_outputs))

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        progress.print(log_output, tag=subset, step=i)
    def train_step(self, sample, update_params=True, last_step=False):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.model.train()
        if isinstance(self.model, DDP):
            if last_step:
                self.model.disable_allreduce()
            else:
                self.model.enable_allreduce()

        # forward and backward pass
        sample, sample_size = self._prepare_sample(sample)
        loss, oom_fwd = self._forward(sample)

        # If this is a last batch forward pass is skipped on some workers
        # Batch with sample_size 0 is not accounted for in weighted loss
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences': sample['target'].size(0) if sample is not None else 0,
            'loss': utils.item(loss.data) if loss is not None else 0,
            'sample_size': sample_size

        }
        oom_bwd = self._backward(loss)

        # buffer stats and logging outputs
        self._buffered_stats['sample_sizes'].append(sample_size)
        self._buffered_stats['logging_outputs'].append(logging_output)
        self._buffered_stats['ooms_fwd'].append(oom_fwd)
        self._buffered_stats['ooms_bwd'].append(oom_bwd)

        # update parameters 
        if update_params and not last_step:
            # gather logging outputs from all replicas
            sample_sizes = self._buffered_stats['sample_sizes']
            logging_outputs = self._buffered_stats['logging_outputs']
            ooms_fwd = self._buffered_stats['ooms_fwd']
            ooms_bwd = self._buffered_stats['ooms_bwd']
            if self.args.distributed_world_size > 1:
                sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                    lambda l: list(chain.from_iterable(l)),
                    zip(*distributed_utils.all_gather_list(
                        (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
                    ))
                )
            ooms_fwd = sum(ooms_fwd)
            ooms_bwd = sum(ooms_bwd)
            ooms = ooms_fwd + ooms_bwd #this is always <= distributed_world_size

            if ooms == self.args.distributed_world_size:
                print('| WARNING: OOM in all workers, skipping batch')
                self.zero_grad()
                return

            # aggregate stats and logging outputs
            grad_denom = sum(sample_sizes)
            for p in self.model.parameters():
                if p.requires_grad and not p.grad is None:
                    p.grad /= grad_denom

            self._opt()

            # Handle logging
            sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
            ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
            self.throughput_meter.update(ntokens)
            info_log_data = {
                        'tokens/s':self.throughput_meter.avg,
                        'tokens':ntokens,
                        'loss':sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2)
                        }
            debug_log_data = {
                        'batch_size':sum(log.get('nsentences', 0) for log in logging_outputs),
                        'lr':self.get_lr(),
                        'grad_denom':grad_denom,
                        'updates':1
                        }

            DLLogger.log(step=self._num_updates, data=info_log_data, verbosity=0)
            DLLogger.log(step=self._num_updates, data=debug_log_data, verbosity=1)

            self.clear_buffered_stats()
예제 #22
0
    def train_step(self, sample, update_params=True):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        # forward and backward pass
        sample = self._prepare_sample(sample)
        loss, sample_size, logging_output, oom_fwd = self._forward(sample)
        oom_bwd = self._backward(loss)

        # buffer stats and logging outputs
        self._buffered_stats['sample_sizes'].append(sample_size)
        self._buffered_stats['logging_outputs'].append(logging_output)
        self._buffered_stats['ooms_fwd'].append(oom_fwd)
        self._buffered_stats['ooms_bwd'].append(oom_bwd)

        # update parameters
        if update_params:
            # gather logging outputs from all replicas
            sample_sizes = self._buffered_stats['sample_sizes']
            logging_outputs = self._buffered_stats['logging_outputs']
            ooms_fwd = self._buffered_stats['ooms_fwd']
            ooms_bwd = self._buffered_stats['ooms_bwd']
            if self.args.distributed_world_size > 1:
                sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                    lambda l: list(chain.from_iterable(l)),
                    zip(*distributed_utils.all_gather_list((sample_sizes,
                                                            logging_outputs,
                                                            ooms_fwd,
                                                            ooms_bwd))))
            ooms_fwd = sum(ooms_fwd)
            ooms_bwd = sum(ooms_bwd)

            if ooms_fwd == self.args.distributed_world_size:
                print('| WARNING: OOM in all workers, skipping batch')
                self.zero_grad()
                return None

            # aggregate stats and logging outputs
            ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
            nsentences = sum(
                log.get('nsentences', 0) for log in logging_outputs)
            agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(
                logging_outputs)
            grad_denom = self.criterion.__class__.grad_denom(sample_sizes)

            try:
                # all-reduce and rescale gradients, then take an optimization step
                grad_norm = self._all_reduce_and_rescale(grad_denom)
                self._opt()

                # update meters
                self.meters['wps'].update(ntokens)
                self.meters['ups'].update(1.)
                self.meters['wpb'].update(ntokens)
                self.meters['bsz'].update(nsentences)
                if grad_norm is not None:
                    self.meters['gnorm'].update(grad_norm)
                    self.meters['clip'].update(
                        1. if grad_norm > self.args.clip_norm else 0.)
                self.meters['oom'].update(ooms_fwd + ooms_bwd)

                # update loss meters for training
                if 'loss' in agg_logging_output:
                    self.meters['train_loss'].update(
                        agg_logging_output['loss'], grad_denom)
                # criterions can optionally log the NLL loss too
                if 'nll_loss' in agg_logging_output:
                    self.meters['train_nll_loss'].update(
                        agg_logging_output['nll_loss'], ntokens)
            except OverflowError as e:
                self.zero_grad()
                print('| WARNING: overflow detected, ' + str(e))

            self.clear_buffered_stats()

            return agg_logging_output
        else:
            return None  # buffering updates
예제 #23
0
    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()
            if self.args.float_valid:
                self.model.float()
                self.criterion.float()


            sample = self._prepare_sample(sample, force_float=self.args.float_valid)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                ignore_results = True
            else:
                ignore_results = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion
                )

                src_target_hypo_str = self._get_decoding(sample)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        print(
                            "| WARNING: ran out of memory in validation step, retrying batch"
                        )
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e

            if ignore_results:
                logging_output, sample_size = {}, 0
                src_target_hypo_str = []

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_output, sample_size, src_target_hypo_str = zip(
                *distributed_utils.all_gather_list(
                    [logging_output, sample_size, src_target_hypo_str],
                    max_size=getattr(self.args, 'all_gather_list_size', 16384),
                )
            )
            logging_output = list(logging_output)
            sample_size = list(sample_size)

        else:
            logging_output = [logging_output]
            sample_size = [sample_size]
            src_target_hypo_str = [src_target_hypo_str]

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_output, self.get_criterion()
        )
        sample_size = self.task.grad_denom(sample_size, self.get_criterion())

        # update meters for validation
        ntokens = logging_output.get("ntokens", 0)
        self.meters["valid_loss"].update(logging_output.get("loss", 0), sample_size)
        if "valid_acc" in self.meters:
            self.meters["valid_acc"].update(logging_output.get("acc", 0), sample_size)

        if "nll_loss" in logging_output:
            self.meters["valid_nll_loss"].update(
                logging_output.get("nll_loss", 0), ntokens
            )

        logging_output['src_target_hypo_str'] = src_target_hypo_str
        if self.args.float_valid:
            self.model.half()
            self.criterion.half()

        return logging_output
예제 #24
0
def main(args, override_args=None):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if use_cuda:
        torch.cuda.set_device(args.device_id)

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
    else:
        overrides = None

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path],
        arg_overrides=overrides,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    criterion.eval()

    for subset in args.valid_subset.split(','):
        try:
            task.load_dataset(subset, combine=False, epoch=1)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception('Cannot find dataset: ' + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
            data_buffer_size=args.data_buffer_size,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        if args.distributed_world_size > 1:
            log_outputs = distributed_utils.all_gather_list(
                log_outputs,
                max_size=getattr(args, 'all_gather_list_size', 16384),
            )
            log_outputs = list(chain.from_iterable(log_outputs))

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        progress.print(log_output, tag=subset, step=i)
예제 #25
0
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            try:
                if self.args.distributed_world_size > 1:
                    # Whenever *samples* contains more than one mini-batch, we
                    # want to accumulate gradients locally and only call
                    # all-reduce in the last backwards pass. Currently the
                    # *accumulate_grads* flag is only supported by
                    # LegacyDistributedDataParallel.
                    if i < len(samples) - 1:
                        self.model.accumulate_grads = True
                    else:
                        self.model.accumulate_grads = False

                # forward and backward
                loss, sample_size, logging_output = self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer,
                    ignore_grad
                )

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    msg = (
                        '| WARNING: ran out of memory with exception: '
                        + '{};'.format(e)
                        + '\n Skipping batch'
                    )
                    # TODO: print should really go to logger, this print goes
                    # to stdout, which is buffered, which in many case is not
                    # printed out if another exception happens
                    # print(msg)
                    print(msg, file=sys.stderr)
                    if raise_oom:
                        raise ValueError(msg)
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1 and (
            (not self.args.use_bmuf)
            or (
                self.args.use_bmuf
                and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
            )
        ):
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

            if not self.args.use_bmuf:
                assert (
                    all(norm == prev_norms[0] for norm in prev_norms)
                    or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
                ), 'Fatal error: gradients are inconsistent between workers'

        self.meters['oom'].update(ooms, len(samples))
        if ooms == self.args.distributed_world_size * len(samples):
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion
        )
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
            ).format(self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            if sample_size > 0:
                self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
            self._prev_grad_norm = grad_norm

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.meters['train_wall'].stop()

        return logging_output
예제 #26
0
    def __init__(self, args, task, model, criterion, quantizer=None):
        self.args = args
        self.task = task

        # catalog shared parameters
        shared_params = _catalog_shared_params(model)

        self.tpu = getattr(args, 'tpu', False)
        self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device('cuda')
        elif self.tpu:
            self.device = utils.get_tpu_device(args)
        else:
            self.device = torch.device('cpu')

        # copy model and criterion to current device/dtype
        self._criterion = criterion
        self._model = model
        if self.tpu:
            import torch_xla.core.xla_model as xm
            self._model = xm.send_cpu_data_to_device(self._model, self.device)
        if args.fp16:
            self._criterion = self._criterion.half()
            self._model = self._model.half()
        elif args.bf16:
            self._criterion = self._criterion.to(dtype=torch.bfloat16)
            self._model = self._model.to(dtype=torch.bfloat16)
        self._criterion = self._criterion.to(device=self.device)
        self._model = self._model.to(device=self.device)

        # check that shared parameters are preserved after device transfer
        for shared_param in shared_params:
            ref = _get_module_by_path(self._model, shared_param[0])
            for path in shared_param[1:]:
                logger.info(
                    'detected shared parameter: {} <- {}'.format(shared_param[0], path)
                )
                _set_module_by_path(self._model, path, ref)

        self._dummy_batch = "DUMMY"  # indicates we don't have a dummy batch at first
        self._lr_scheduler = None
        self._num_updates = 0
        self._num_xla_compiles = 0  # for TPUs
        self._optim_history = None
        self._optimizer = None
        self._warn_once = set()
        self._wrapped_criterion = None
        self._wrapped_model = None

        # TODO(myleott): support tpu
        if self.cuda and self.data_parallel_world_size > 1:
            self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
        else:
            self._grad_norm_buf = None

        self.quantizer = quantizer
        if self.quantizer is not None:
            self.quantizer.set_trainer(self)

        # get detailed cuda environment
        if self.cuda:
            self.cuda_env = utils.CudaEnvironment()
            if self.data_parallel_world_size > 1:
                self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env)
            else:
                self.cuda_env_arr = [self.cuda_env]
            if self.data_parallel_rank == 0:
                utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
        else:
            self.cuda_env = None
            self.cuda_env_arr = None

        metrics.log_start_time("wall", priority=790, round=0)

        self._start_time = time.time()
        self._previous_training_time = 0
        self._cumulative_training_time = None
예제 #27
0
    def train_step(self, samples, dummy_batch=False):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.model.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            try:
                # forward and backward
                loss, sample_size, logging_output = self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer,
                    ignore_grad
                )
                if self.args.distributed_world_size > 1:
                    # only all-reduce gradients in the last backwards pass
                    if i < len(samples) - 1:
                        self.model.need_reduction = False
                    else:
                        self.model.need_reduction = True

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
                [logging_outputs, sample_sizes, ooms],
            ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

        if ooms == self.args.distributed_world_size:
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        # aggregate logging outputs and sample sizes
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion
        )

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
            ).format(self.task.__class__.__name__))

        try:
            self._num_iterations += 1
            # normalize grads by sample size
            self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['oom'].update(ooms)
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.meters['train_wall'].stop()

        return logging_output
예제 #28
0
    def train_step(self, samples, dummy_batch=False, domain=None):
        """Do forward, backward and parameter update."""
        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            try:
                if self.args.distributed_world_size > 1:
                    # Whenever *samples* contains more than one mini-batch, we
                    # want to accumulate gradients locally and only call
                    # all-reduce in the last backwards pass. Currently the
                    # *need_reduction* flag is only supported by
                    # LegacyDistributedDataParallel.
                    if i < len(samples) - 1:
                        self.model.accumulate_grads = True
                    else:
                        self.model.accumulate_grads = False

                # forward and backward
                loss, sample_size, logging_output = self.task.train_step(
                    sample,
                    self.model,
                    self.criterion,
                    self.optimizer,
                    ignore_grad,
                    domain=domain,
                    num_updates=self._num_updates)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)
            assert all(norm == prev_norms[0] for norm in prev_norms), \
                'Fatal error: gradients are inconsistent between workers'

        self.meters['oom'].update(ooms, len(samples))
        if ooms == self.args.distributed_world_size * len(samples):
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        # aggregate logging outputs and sample sizes
        logging_output = self.task.aggregate_logging_outputs(
            logging_outputs, self.criterion)
        sample_size = self.task.grad_denom(sample_sizes, self.criterion)

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception(
                ('Please update the {}.aggregate_logging_outputs() method to '
                 'return ntokens and nsentences').format(
                     self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            self.optimizer.multiply_grads(self.args.distributed_world_size /
                                          float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
            self._prev_grad_norm = grad_norm

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(1. if grad_norm > self.args.clip_norm
                                       and self.args.clip_norm > 0 else 0.)
            self.meters['train_loss'].update(logging_output.get('loss', 0),
                                             sample_size)
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(
                    logging_output.get('nll_loss', 0), ntokens)
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.meters['train_wall'].stop()

        return logging_output
예제 #29
0
def compute_bleu(reference_corpus,
                 translation_corpus,
                 args,
                 max_order=4,
                 use_bp=True):
    """Computes BLEU score of translated segments against one or more references.

  Args:
    reference_corpus: list of references for each translation. Each
        reference should be tokenized into a list of tokens.
    translation_corpus: list of translations to score. Each translation
        should be tokenized into a list of tokens.
    args: CLI arguments
    max_order: Maximum n-gram order to use when computing BLEU score.
    use_bp: boolean, whether to apply brevity penalty.

  Returns:
    BLEU score.
  """
    reference_length = 0
    translation_length = 0
    bp = 1.0
    geo_mean = 0

    matches_by_order = [0] * max_order
    possible_matches_by_order = [0] * max_order
    precisions = []

    for (references, translations) in zip(reference_corpus,
                                          translation_corpus):
        reference_length += len(references)
        translation_length += len(translations)
        ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
        translation_ngram_counts = _get_ngrams_with_counter(
            translations, max_order)

        overlap = dict((ngram, min(count, translation_ngram_counts[ngram]))
                       for ngram, count in ref_ngram_counts.items())

        for ngram in overlap:
            matches_by_order[len(ngram) - 1] += overlap[ngram]
        for ngram in translation_ngram_counts:
            possible_matches_by_order[len(ngram) -
                                      1] += translation_ngram_counts[ngram]

    precisions = [0] * max_order
    smooth = 1.0

    # do reductions of matches_by_order and possible_matches_by_order
    if args.distributed_world_size > 1:
        stats = RefBleuStats(matches_by_order, possible_matches_by_order,
                             reference_length, translation_length)
        all_stats = distributed_utils.all_gather_list(stats)
        stats = reduce(lambda a, b: a + b, all_stats)
        matches_by_order = stats.matches_by_order
        possible_matches_by_order = stats.possible_matches_by_order
        reference_length = stats.reference_length
        translation_length = stats.translation_length

    for i in xrange(0, max_order):
        if possible_matches_by_order[i] > 0:
            precisions[i] = float(
                matches_by_order[i]) / possible_matches_by_order[i]
            if matches_by_order[i] > 0:
                precisions[i] = float(
                    matches_by_order[i]) / possible_matches_by_order[i]
            else:
                smooth *= 2
                precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
        else:
            precisions[i] = 0.0

    if max(precisions) > 0:
        p_log_sum = sum(math.log(p) for p in precisions if p)
        geo_mean = math.exp(p_log_sum / max_order)

    if use_bp:
        if reference_length > 0:
            ratio = translation_length / reference_length
            bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
        else:
            bp = 1.0
    bleu = geo_mean * bp
    return np.float32(bleu) * 100.0
예제 #30
0
    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        with torch.no_grad():
            self.model.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                ignore_results = True
            else:
                ignore_results = False

            try:
                _loss, sample_size, logging_output = self.task.get_loss(
                    self.model, self.criterion, sample, is_valid=True)
            except RuntimeError as e:
                if 'out of memory' in str(e) and not raise_oom:
                    print('| WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
                            del p.grad  # free some memory
                    torch.cuda.empty_cache()
                    return self.valid_step(sample, raise_oom=True)
                else:
                    raise e

            if ignore_results:
                logging_output, sample_size = {}, 0

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_output, sample_size = zip(
                *distributed_utils.all_gather_list(
                    [logging_output, sample_size], ))
            logging_output = list(logging_output)
            sample_size = list(sample_size)
        else:
            logging_output = [logging_output]
            sample_size = [sample_size]

        # extra hacky!
        if hasattr(self.task, 'aggregate_extra_metrics'):
            extra_metrics = self.task.aggregate_extra_metrics(logging_output)
        else:
            extra_metrics = None

        # aggregate logging outputs and sample sizes
        logging_output = self.criterion._aggregate_logging_outputs(
            logging_output)
        sample_size = self.criterion.__class__.grad_denom(sample_size)

        if extra_metrics is not None:
            logging_output['extra_metrics'] = extra_metrics

        # update meters for validation
        ntokens = logging_output.get('ntokens', 0)
        self.meters['valid_loss'].update(logging_output.get('loss', 0),
                                         sample_size)
        if 'nll_loss' in logging_output:
            self.meters['valid_nll_loss'].update(
                logging_output.get('nll_loss', 0), ntokens)

        if 'extra_metrics' in logging_output:
            for n, m in self.meters['task'].items():
                m.update(*logging_output['extra_metrics'][n])

        return logging_output
예제 #31
0
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.args.distributed_world_size > 1
                    and hasattr(self.model, 'no_sync')
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad
                    )

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)

                    if self.fast_stat_sync:
                        self._all_reduce_list[0] += sample_size
                        self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
                        self._all_reduce_list[2] += logging_output.get('loss', 0.0)
                        self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
                        self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    msg = (
                        '| WARNING: ran out of memory with exception: '
                        + '{};'.format(e)
                        + '\n Skipping batch'
                    )
                    # TODO: print should really go to logger, this print goes
                    # to stdout, which is buffered, which in many case is not
                    # printed out if another exception happens
                    # print(msg)
                    print(msg, file=sys.stderr)
                    if raise_oom:
                        raise ValueError(msg)
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

            if self.fast_stat_sync:
                self._all_reduce_list[5] += ooms


        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.fast_stat_sync:
            # rework all_gather_list
            all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
            if self._sync_stats():
                torch.distributed.all_reduce(all_reduce_list_tensor)
            # Normalize loss and nll_loss by "sample_size"
            # and convert to log base 2
            all_reduce_list_tensor[2:4].div_(
                (
                    all_reduce_list_tensor[0:1] *
                    torch.log(torch.cuda.DoubleTensor([2]))
                )
            )
            self._all_reduce_list = all_reduce_list_tensor.tolist()
            logging_output = {}
            [
                sample_size,
                logging_output['nsentences'],
                logging_output['loss'],
                logging_output['nll_loss'],
                logging_output['ntokens'],
                ooms,
            ] = self._all_reduce_list
        elif self._sync_stats():
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

            if not self.args.use_bmuf:
                assert (
                    all(norm == prev_norms[0] for norm in prev_norms)
                    or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
                ), 'Fatal error: gradients are inconsistent between workers'

        self.meters['oom'].update(ooms, len(samples))
        if ooms == self.args.distributed_world_size * len(samples):
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        if not self.fast_stat_sync:
            # aggregate logging outputs and sample sizes
            logging_output = self.task.aggregate_logging_outputs(
                logging_outputs, self.get_criterion()
            )
            sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
            ).format(self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            if sample_size > 0:
                self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
            self._prev_grad_norm = grad_norm

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(
                    logging_output.get('acc', 0), sample_size)

            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.clear_buffered_stats()
        self.meters['train_wall'].stop()

        return logging_output
예제 #32
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairnr_cli.render')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    logging.info(model)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_workers=args.num_workers).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)
    shard_id, world_size = args.distributed_rank, args.distributed_world_size
    output_files = []
    if generator.test_poses is not None:
        total_frames = generator.test_poses.shape[0]
        _frames = int(np.floor(total_frames / world_size))
        step = shard_id * _frames
        frames = _frames if shard_id < (world_size -
                                        1) else total_frames - step
    else:
        step = shard_id * args.render_num_frames
        frames = args.render_num_frames

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for i, sample in enumerate(t):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            gen_timer.start()

            step, _output_files = task.inference_step(generator, models,
                                                      [sample, step, frames])
            output_files += _output_files

            gen_timer.stop(500)
            wps_meter.update(500)
            t.log({'wps': round(wps_meter.avg)})

    timestamp = generator.save_images(
        output_files,
        steps='shard{}'.format(shard_id),
        combine_output=args.render_combine_output)

    # join videos from all GPUs and delete temp files
    try:
        timestamps = distributed_utils.all_gather_list(timestamp)
    except:
        timestamps = [timestamp]

    if shard_id == 0:
        generator.merge_videos(timestamps)