예제 #1
0
    def run_epoch(self, state: TrainingState, data, metric_reporter: MetricReporter):
        """Our run_epoch is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc."""
        report_metric = state.stage != Stage.TRAIN or self.config.report_train_metrics
        model = state.model

        for batch_id, batch in enumerate(data):
            self.zero_grads(state)
            with timing.time("model.train_batch"):
                loss, metric_data = model.train_batch(model, batch)
            self.backprop(state, loss)
            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id, *metric_data, **metric_reporter.batch_context(batch)
                    )

        metrics = None
        if report_metric:
            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model, state.stage, state.epoch, print_to_channels=(state.rank == 0)
                )
        else:
            metric_reporter._reset()

        return metrics
예제 #2
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, batch) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                # enter ddp no_sync context and fp16 delay_scale context if needed
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                logits = model(batch)
                targets = batch["label_ids"]
                loss = self.loss(logits, targets)
                if sample_size > 1:
                    # gradients averaged per batch and accumulated across samples.
                    # divide sample_size to let gradients averaged per example
                    loss = loss / sample_size
                self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    predictions = torch.max(logits, -1)[1]
                    scores = F.log_softmax(logits)
                    # [len(targets)] means the batch_size, it's required by add_batch_stats
                    # Will rewrite metric_reporter rather than fixing it
                    metric_data = (predictions, targets, scores, loss,
                                   [targets])
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        # TODO merge this step into add_batch_stats once all data
                        # migration is done
                        # in new data API, we don't have raw_batch
                        **metric_reporter.batch_context(raw_batch=[],
                                                        batch=batch),
                    )
                if batch_id % self.config.num_samples_to_log_progress == 0:
                    metric_reporter.report_realtime_metric(state.stage)
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
예제 #3
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches
        if self('begin_batch'): return

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (raw_batch, batch)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                # enter ddp no_sync context and fp16 delay_scale context if needed
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                with timing.time("model.train_batch"):
                    loss, metric_data = model.train_batch(model, batch, state)
                    if sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size
                self.backprop(state, loss)
                self.samples, self.state, self.loss = samples, state, loss
                self('after_loss')

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        # TODO merge this step into add_batch_stats once all data
                        # migration is done
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
                if batch_id % self.config.num_samples_to_log_progress == 0:
                    metric_reporter.report_realtime_metric(state.stage)
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
        self('after_batch')
예제 #4
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc."""
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for i, (batch_id, (raw_batch, batch)) in enumerate(samples):
            if cuda.DISTRIBUTED_WORLD_SIZE > 1:
                # Whenever *samples* contains more than one mini-batch, we
                # want to accumulate gradients locally and only call
                # all-reduce in the last backwards pass.
                if i < sample_size - 1:
                    # sync gradients in the last sample backward
                    model.accumulate_gradients(True)
                else:
                    model.accumulate_gradients(False)

            with timing.time("model.train_batch"):
                loss, metric_data = model.train_batch(model, batch)
                if sample_size > 1:
                    # gradients averaged per each batch and accumulated across samples.
                    # divide sample_size to let gradients averaged per example
                    loss = loss / sample_size
            self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
예제 #5
0
    def _run_epoch(
        self,
        stage: Stage,
        epoch: int,
        batches,
        model: Model,
        metric_reporter: MetricReporter,
        pre_batch=lambda: None,
        backprop=lambda loss: None,
        rank=0,
        num_samples_to_log_progress: int = None,
    ):
        """Our run_epoch is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc."""
        print(f"Rank {rank} worker: Running epoch #{epoch} for {stage}")
        report_metric = stage != Stage.TRAIN or self.config.report_train_metrics

        for batch_id, batch in enumerate(batches):
            pre_batch()
            with timing.time("model.train_batch"):
                loss, metric_data = model.train_batch(batch)
            with timing.time("backprop"):
                backprop(loss)
            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id, *metric_data, **metric_reporter.batch_context(batch)
                    )

        metrics = None
        if report_metric:
            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model, stage, epoch, print_to_channels=(rank == 0)
                )
        else:
            metric_reporter._reset()

        return metrics
예제 #6
0
파일: trainer.py 프로젝트: alam52/pytext
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (raw_batch, batch)) in enumerate(samples):
            with maybe_no_sync(model, idx, sample_size):
                with timing.time("model.train_batch"):
                    loss, metric_data = model.train_batch(model, batch, state)
                    if sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size
                self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
예제 #7
0
    def _run_epoch(
        self,
        stage: Stage,
        epoch: int,
        batches,
        model: Model,
        metric_reporter: MetricReporter,
        pre_batch=lambda: None,
        backprop=lambda loss, timer=None: None,
        rank=0,
    ):
        """Our run_epoch is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc."""
        print(f"Rank {rank} worker: Running epoch #{epoch} for {stage}")
        report_metric = stage != Stage.TRAIN or self.config.report_train_metrics
        for batch_id, (batch, tensors) in enumerate(batches):
            print(f"Batch {batch_id} has {len(batch)} examples")
            pre_batch()
            context = metric_reporter.batch_context(batch)
            # pass context to model to use in forward call if needed
            model.contextualize(context)
            loss, metric_data = model.train_batch(tensors)
            if BatchContext.IGNORE_LOSS in context:
                loss *= 0
            backprop(loss)
            if report_metric:
                metric_reporter.add_batch_stats(batch_id, *metric_data,
                                                **context)

        metrics = None
        if report_metric:
            metrics = metric_reporter.report_metric(
                stage, epoch, print_to_channels=(rank == 0))
        else:
            metric_reporter._reset()
        return metrics