Ejemplo n.º 1
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (inputs, targets, context)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                # pass context to model to use in forward call if needed
                model.contextualize(context)
                with timing.time("model.forward"):
                    logits = model(*inputs)

                with timing.time("compute loss"):
                    loss = precision.maybe_float(
                        model.get_loss(logits, targets, context)
                    )
                    if BatchContext.IGNORE_LOSS in context:
                        loss *= 0
                    elif sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size

                self.backprop(state, loss)

            if report_metric:
                with timing.time("get pred"):
                    preds, scores = model.get_pred(
                        logits, targets, context, state.stage, *inputs
                    )

                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id, preds, targets, scores, loss.item(), inputs, **context
                    )

            if batch_id % self.config.num_samples_to_log_progress == 0:
                print(
                    f"Running batch {batch_id} for epoch {state.epoch} \
                        in {state.stage} stage",
                    flush=True,
                )
        # update gradients after len(samples) forward & backward
        self.optimizer_step(state)
        with timing.time("add gradients"):
            if report_metric and state.stage == Stage.TRAIN:
                metric_reporter.add_gradients(state.model)
        self.sparsification_step(state)
Ejemplo n.º 2
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (raw_batch, batch)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                # enter ddp no_sync context and fp16 delay_scale context if needed
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                with timing.time("model.train_batch"):
                    loss, metric_data = model.train_batch(model, batch, state)
                    if sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size
                self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        # TODO merge this step into add_batch_stats once all data
                        # migration is done
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
                if batch_id % self.config.num_samples_to_log_progress == 0:
                    metric_reporter.report_realtime_metric(state.stage)
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
        with timing.time("add gradients"):
            if report_metric and state.stage == Stage.TRAIN:
                metric_reporter.add_gradients(state.model)
        self.sparsification_step(state)