Exemplo n.º 1
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        if self('begin_batch'): return

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (inputs, targets, context)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                # pass context to model to use in forward call if needed
                model.contextualize(context)
                with timing.time("model.forward"):
                    logits = model(*inputs)

                with timing.time("compute loss"):
                    loss = precision.maybe_float(
                        model.get_loss(logits, targets, context))
                    if BatchContext.IGNORE_LOSS in context:
                        loss *= 0
                    elif sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size

                self.backprop(state, loss)
                self.samples, self.state, self.loss = samples, state, loss
                if self('after_loss'): break

            if report_metric:
                with timing.time("get pred"):
                    preds, scores = model.get_pred(logits, targets, context,
                                                   state.stage, *inputs)

                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(batch_id,
                                                    preds, targets, scores,
                                                    loss.item(), inputs,
                                                    **context)

            if batch_id % self.config.num_samples_to_log_progress == 0:
                print(
                    f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage",
                    flush=True,
                )
        # update gradients after len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
        self('after_batch')
Exemplo n.º 2
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, batch) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                # enter ddp no_sync context and fp16 delay_scale context if needed
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                logits = model(batch)
                targets = batch["label_ids"]
                loss = self.loss(logits, targets)
                if sample_size > 1:
                    # gradients averaged per batch and accumulated across samples.
                    # divide sample_size to let gradients averaged per example
                    loss = loss / sample_size
                self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    predictions = torch.max(logits, -1)[1]
                    scores = F.log_softmax(logits)
                    # [len(targets)] means the batch_size, it's required by add_batch_stats
                    # Will rewrite metric_reporter rather than fixing it
                    metric_data = (predictions, targets, scores, loss,
                                   [targets])
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        # TODO merge this step into add_batch_stats once all data
                        # migration is done
                        # in new data API, we don't have raw_batch
                        **metric_reporter.batch_context(raw_batch=[],
                                                        batch=batch),
                    )
                if batch_id % self.config.num_samples_to_log_progress == 0:
                    metric_reporter.report_realtime_metric(state.stage)
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
Exemplo n.º 3
0
def maybe_no_sync(model, index, sample_size):
    """
    Whenever *samples* contains more than one mini-batch (e.g sample_size > 1),
    we want to accumulate gradients locally and only call all-reduce in the last
    backwards pass.
    """
    if (cuda.DISTRIBUTED_WORLD_SIZE > 1 and hasattr(model, "no_sync")
            and index < sample_size - 1):
        return model.no_sync()
    else:
        return contextlib_ExitStack()
Exemplo n.º 4
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches
        if self('begin_batch'): return

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (raw_batch, batch)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                # enter ddp no_sync context and fp16 delay_scale context if needed
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                with timing.time("model.train_batch"):
                    loss, metric_data = model.train_batch(model, batch, state)
                    if sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size
                self.backprop(state, loss)
                self.samples, self.state, self.loss = samples, state, loss
                self('after_loss')

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        # TODO merge this step into add_batch_stats once all data
                        # migration is done
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
                if batch_id % self.config.num_samples_to_log_progress == 0:
                    metric_reporter.report_realtime_metric(state.stage)
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
        self('after_batch')