Example #1
0
    def run_epoch(self, state: TrainingState, data, metric_reporter: MetricReporter):
        """Our run_epoch is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc."""
        report_metric = state.stage != Stage.TRAIN or self.config.report_train_metrics
        model = state.model

        for batch_id, batch in enumerate(data):
            self.zero_grads(state)
            with timing.time("model.train_batch"):
                loss, metric_data = model.train_batch(model, batch)
            self.backprop(state, loss)
            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id, *metric_data, **metric_reporter.batch_context(batch)
                    )

        metrics = None
        if report_metric:
            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model, state.stage, state.epoch, print_to_channels=(state.rank == 0)
                )
        else:
            metric_reporter._reset()

        return metrics
Example #2
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for i, (batch_id, (inputs, targets, context)) in enumerate(samples):
            if cuda.DISTRIBUTED_WORLD_SIZE > 1:
                # Whenever *samples* contains more than one mini-batch, we
                # want to accumulate gradients locally and only call
                # all-reduce in the last backwards pass.
                if i < sample_size - 1:
                    # sync gradients in the last sample backward
                    model.accumulate_gradients(True)
                else:
                    model.accumulate_gradients(False)

            # pass context to model to use in forward call if needed
            model.contextualize(context)
            with timing.time("model.forward"):
                logits = model(*inputs)

            with timing.time("compute loss"):
                loss = precision.maybe_float(
                    model.get_loss(logits, targets, context))
                if BatchContext.IGNORE_LOSS in context:
                    loss *= 0
                elif sample_size > 1:
                    # gradients averaged per each batch and accumulated across samples.
                    # divide sample_size to let gradients averaged per example
                    loss = loss / sample_size

            self.backprop(state, loss)

            if report_metric:
                with timing.time("get pred"):
                    preds, scores = model.get_pred(logits, targets, context,
                                                   state.stage, *inputs)

                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(batch_id,
                                                    preds, targets, scores,
                                                    loss.item(), inputs,
                                                    **context)

            if batch_id % self.config.num_samples_to_log_progress == 0:
                print(
                    f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage",
                    flush=True,
                )
        # update gradients after len(samples) forward & backward
        self.optimizer_step(state)
Example #3
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        if self('begin_batch'): return

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (inputs, targets, context)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                # pass context to model to use in forward call if needed
                model.contextualize(context)
                with timing.time("model.forward"):
                    logits = model(*inputs)

                with timing.time("compute loss"):
                    loss = precision.maybe_float(
                        model.get_loss(logits, targets, context))
                    if BatchContext.IGNORE_LOSS in context:
                        loss *= 0
                    elif sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size

                self.backprop(state, loss)
                self.samples, self.state, self.loss = samples, state, loss
                if self('after_loss'): break

            if report_metric:
                with timing.time("get pred"):
                    preds, scores = model.get_pred(logits, targets, context,
                                                   state.stage, *inputs)

                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(batch_id,
                                                    preds, targets, scores,
                                                    loss.item(), inputs,
                                                    **context)

            if batch_id % self.config.num_samples_to_log_progress == 0:
                print(
                    f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage",
                    flush=True,
                )
        # update gradients after len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
        self('after_batch')
Example #4
0
    def _run_epoch(
        self,
        stage,
        epoch,
        data_iter,
        model,
        metric_reporter,
        pre_batch=lambda: None,
        backprop=lambda loss: None,
        rank=0,
        num_samples_to_log_progress=1000,
    ):
        print(f"Rank {rank} worker: Running epoch #{epoch} for {stage}")
        report_metric = stage != Stage.TRAIN or self.config.report_train_metrics

        for batch_id, (inputs, targets, context) in enumerate(data_iter):
            pre_batch()
            # pass context to model to use in forward call if needed
            model.contextualize(context)
            with timing.time("model.forward"):
                logits = model(*inputs)

            with timing.time("compute loss"):
                loss = model.get_loss(logits, targets, context)
                if BatchContext.IGNORE_LOSS in context:
                    loss *= 0

            with timing.time("backprop"):
                backprop(loss)

            if report_metric:
                with timing.time("add metrics"):
                    preds, scores = model.get_pred(logits, targets, context,
                                                   stage, *inputs)
                    metric_reporter.add_batch_stats(batch_id,
                                                    preds, targets, scores,
                                                    loss.item(), inputs,
                                                    **context)

            if rank == 0 and (batch_id + 1) % num_samples_to_log_progress == 0:
                print(
                    f"Epoch {epoch}: finished training {batch_id + 1} samples.",
                    flush=True,
                )

        metrics = None
        if report_metric:
            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model, stage, epoch, print_to_channels=(rank == 0))
        else:
            metric_reporter._reset()

        return metrics
Example #5
0
    def run_epoch(self, state: TrainingState, data: BatchIterator,
                  metric_reporter: MetricReporter):
        # This method is due for some refactoring, pushing it off because it interacts
        # with the metric reporter too much. Much of the logic here either changes in
        # the NewTaskTrainer or should change with a better metric reporter design.
        report_metric = state.stage != Stage.TRAIN or self.config.report_train_metrics
        model = state.model

        for batch_id, (inputs, targets, context) in enumerate(data):
            self.zero_grads(state)
            # pass context to model to use in forward call if needed
            model.contextualize(context)
            with timing.time("model.forward"):
                logits = model(*inputs)

            with timing.time("compute loss"):
                loss = model.get_loss(logits, targets, context)
                if BatchContext.IGNORE_LOSS in context:
                    loss *= 0

            self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    preds, scores = model.get_pred(logits, targets, context,
                                                   state.stage, *inputs)
                    metric_reporter.add_batch_stats(batch_id,
                                                    preds, targets, scores,
                                                    loss.item(), inputs,
                                                    **context)

            if (state.rank == 0 and
                    batch_id % self.config.num_samples_to_log_progress == 0):
                print(
                    f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage",
                    flush=True,
                )

        metrics = None
        if report_metric:
            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model,
                    state.stage,
                    state.epoch,
                    print_to_channels=(state.rank == 0))
        else:
            metric_reporter._reset()

        return metrics
Example #6
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches
        if self('begin_batch'): return

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (raw_batch, batch)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                # enter ddp no_sync context and fp16 delay_scale context if needed
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                with timing.time("model.train_batch"):
                    loss, metric_data = model.train_batch(model, batch, state)
                    if sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size
                self.backprop(state, loss)
                self.samples, self.state, self.loss = samples, state, loss
                self('after_loss')

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        # TODO merge this step into add_batch_stats once all data
                        # migration is done
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
                if batch_id % self.config.num_samples_to_log_progress == 0:
                    metric_reporter.report_realtime_metric(state.stage)
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
        self('after_batch')
Example #7
0
    def run_epoch(
        self, state: TrainingState, data: BatchIterator, metric_reporter: MetricReporter
    ):
        # This method is due for some refactoring, pushing it off because it interacts
        # with the metric reporter too much. Much of the logic here either changes in
        # the NewTaskTrainer or should change with a better metric reporter design.
        report_metric = state.stage != Stage.TRAIN or self.config.report_train_metrics
        model = state.model
        samples = []
        is_data_empty = True

        """
        Sometimes, a batch of inputs is too large to fit into GPU, which has to
        be split into several micro-batches. However, to improve efficiency,
        it would be helpful to only apply params/gradients sync at original batch
        boundaries instead of micro-batch boundaries.
        num_accumulated_batches specified the number of accumulating gradients
        locally before sync gradients, total training_batch_size =
        train_batch_size x num_accumulated_batches and it will improve the system
        performance by reduce the total network transfer bytes.
        """
        for sample in enumerate(data):
            is_data_empty = False
            samples.append(sample)
            if (
                state.stage != Stage.TRAIN
                or len(samples) == self.config.num_accumulated_batches
            ):
                self.run_step(samples, state, metric_reporter, report_metric)
                samples = []
        if samples:
            self.run_step(samples, state, metric_reporter, report_metric)
            samples = []

        metrics = None
        if report_metric:
            if is_data_empty:
                error_msg = (
                    f"Trying to report metric for stage {state.stage}, but no data was "
                    "found. Either disable metric reporting for this stage, pass in "
                    "non-empty data, or see if data fields are misnamed (warnings "
                    "would appear in preceding stdout logs)."
                )
                raise ValueError(error_msg)

            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model,
                    state.stage,
                    state.epoch,
                    print_to_channels=(state.rank == 0),
                    optimizer=getattr(
                        state, "optimizer", None
                    ),  # optimizer is not present during test
                    privacy_engine=getattr(state, "privacy_engine", None),
                )
        else:
            metric_reporter._reset()

        return metrics
Example #8
0
    def backprop(self, state, loss):
        if state.stage != Stage.TRAIN:
            return

        with timing.time("loss.backward"):
            precision.backward(state.optimizer, loss)

        state.scheduler.step_batch()

        if self.config.max_clip_norm is not None:
            grad_norm = precision.clip_grad_norm(state.model, self.optimizer,
                                                 self.config.max_clip_norm)
        else:
            grad_norm = None

        with timing.time("optimizer.step"):
            state.optimizer.step()
        # grad_norm could be used to check grads sync in distributed training
        return grad_norm
Example #9
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc."""
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for i, (batch_id, (raw_batch, batch)) in enumerate(samples):
            if cuda.DISTRIBUTED_WORLD_SIZE > 1:
                # Whenever *samples* contains more than one mini-batch, we
                # want to accumulate gradients locally and only call
                # all-reduce in the last backwards pass.
                if i < sample_size - 1:
                    # sync gradients in the last sample backward
                    model.accumulate_gradients(True)
                else:
                    model.accumulate_gradients(False)

            with timing.time("model.train_batch"):
                loss, metric_data = model.train_batch(model, batch)
                if sample_size > 1:
                    # gradients averaged per each batch and accumulated across samples.
                    # divide sample_size to let gradients averaged per example
                    loss = loss / sample_size
            self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
Example #10
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, batch) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                # enter ddp no_sync context and fp16 delay_scale context if needed
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                logits = model(batch)
                targets = batch["label_ids"]
                loss = self.loss(logits, targets)
                if sample_size > 1:
                    # gradients averaged per batch and accumulated across samples.
                    # divide sample_size to let gradients averaged per example
                    loss = loss / sample_size
                self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    predictions = torch.max(logits, -1)[1]
                    scores = F.log_softmax(logits)
                    # [len(targets)] means the batch_size, it's required by add_batch_stats
                    # Will rewrite metric_reporter rather than fixing it
                    metric_data = (predictions, targets, scores, loss,
                                   [targets])
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        # TODO merge this step into add_batch_stats once all data
                        # migration is done
                        # in new data API, we don't have raw_batch
                        **metric_reporter.batch_context(raw_batch=[],
                                                        batch=batch),
                    )
                if batch_id % self.config.num_samples_to_log_progress == 0:
                    metric_reporter.report_realtime_metric(state.stage)
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
Example #11
0
        def training_backprop(loss):
            with timing.time("loss.backward"):
                precision.backward(self.optimizer, loss)
                if world_size > 1:
                    # DDP fix when some parameters don't receive grads
                    for p in model.parameters():
                        if p.requires_grad and p.grad is None:
                            p.backward(torch.zeros_like(p.data))

            if self.lr_scheduler:
                self.lr_scheduler.step_batch()

            if self.config.max_clip_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), self.config.max_clip_norm)
            else:
                grad_norm = None

            with timing.time("optimizer.step"):
                self.optimizer.step()
            # grad_norm could be used to check grads sync in distributed training
            return grad_norm
Example #12
0
    def _run_epoch(
        self,
        stage: Stage,
        epoch: int,
        batches,
        model: Model,
        metric_reporter: MetricReporter,
        pre_batch=lambda: None,
        backprop=lambda loss: None,
        rank=0,
        num_samples_to_log_progress: int = None,
    ):
        """Our run_epoch is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc."""
        print(f"Rank {rank} worker: Running epoch #{epoch} for {stage}")
        report_metric = stage != Stage.TRAIN or self.config.report_train_metrics

        for batch_id, batch in enumerate(batches):
            pre_batch()
            with timing.time("model.train_batch"):
                loss, metric_data = model.train_batch(batch)
            with timing.time("backprop"):
                backprop(loss)
            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id, *metric_data, **metric_reporter.batch_context(batch)
                    )

        metrics = None
        if report_metric:
            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model, stage, epoch, print_to_channels=(rank == 0)
                )
        else:
            metric_reporter._reset()

        return metrics
Example #13
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        """Our run_step is a bit different, because we're wrapping the model forward
        call with model.train_batch, which arranges tensors and gets loss, etc.

        Whenever "samples" contains more than one mini-batch (sample_size > 1),
        we want to accumulate gradients locally and only call all-reduce in the
        last backwards pass.
        """
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (raw_batch, batch)) in enumerate(samples):
            with maybe_no_sync(model, idx, sample_size):
                with timing.time("model.train_batch"):
                    loss, metric_data = model.train_batch(model, batch, state)
                    if sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size
                self.backprop(state, loss)

            if report_metric:
                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(
                        batch_id,
                        *metric_data,
                        **metric_reporter.batch_context(raw_batch, batch),
                    )
        # update gradients after #len(samples) forward & backward
        self.optimizer_step(state)
Example #14
0
    def backprop(self, state, loss):
        if state.stage != Stage.TRAIN:
            return

        with timing.time("loss.backward"):
            precision.backward(state.optimizer, loss)
            if cuda.DISTRIBUTED_WORLD_SIZE > 1:
                # DDP fix when some parameters don't receive grads
                for p in state.model.parameters():
                    if p.requires_grad and p.grad is None:
                        p.backward(torch.zeros_like(p.data))

        state.scheduler.step_batch()

        if self.config.max_clip_norm is not None:
            grad_norm = precision.clip_grad_norm(state.model, self.optimizer,
                                                 self.config.max_clip_norm)
        else:
            grad_norm = None

        with timing.time("optimizer.step"):
            state.optimizer.step()
        # grad_norm could be used to check grads sync in distributed training
        return grad_norm
Example #15
0
    def run_epoch(
        self, state: TrainingState, data: BatchIterator, metric_reporter: MetricReporter
    ):
        # This method is due for some refactoring, pushing it off because it interacts
        # with the metric reporter too much. Much of the logic here either changes in
        # the NewTaskTrainer or should change with a better metric reporter design.
        report_metric = state.stage != Stage.TRAIN or self.config.report_train_metrics
        model = state.model
        samples = []

        """
        Sometimes, a batch of inputs is too large to fit into GPU, which has to
        be split into several micro-batches. However, to improve efficiency,
        it would be helpful to only apply params/gradients sync at original batch
        boundaries instead of micro-batch boundaries.
        num_accumulated_batches specified the number of accumulating gradients
        locally before sync gradients, total training_batch_size =
        train_batch_size x num_accumulated_batches and it will improve the system
        performance by reduce the total network transfer bytes.
        """
        for sample in enumerate(data):
            samples.append(sample)
            if (
                state.stage != Stage.TRAIN
                or len(samples) == self.config.num_accumulated_batches
            ):
                self.run_step(samples, state, metric_reporter, report_metric)
                samples = []
        if samples:
            self.run_step(samples, state, metric_reporter, report_metric)
            samples = []

        metrics = None
        if report_metric:
            with timing.time("report metrics"):
                metrics = metric_reporter.report_metric(
                    model,
                    state.stage,
                    state.epoch,
                    print_to_channels=(state.rank == 0),
                    optimizer=getattr(
                        state, "optimizer", None
                    ),  # optimizer is not present during test
                )
        else:
            metric_reporter._reset()

        return metrics
Example #16
0
    def optimizer_step(self, state):
        if state.stage != Stage.TRAIN:
            return

        state.scheduler.step_batch()

        if self.config.max_clip_norm is not None:
            grad_norm = precision.clip_grad_norm(state.model, state.optimizer,
                                                 self.config.max_clip_norm)
        else:
            grad_norm = None

        with timing.time("optimizer.step"):
            state.optimizer.step()

        state.step_counter += 1
        # grad_norm could be used to check grads sync in distributed training
        return grad_norm
Example #17
0
    def optimizer_step(self, state):
        if state.stage != Stage.TRAIN:
            return

        try:
            grad_norm = state.optimizer.clip_grad_norm(
                self.config.max_clip_norm, state.model)
        except OverflowError as e:
            print(f"Gradient overflow. Skipping step, {e}")
            return None

        state.scheduler.step_batch()
        with timing.time("optimizer.step"):
            state.optimizer.step()

        state.step_counter += 1
        # grad_norm could be used to check grads sync in distributed training
        return grad_norm
Example #18
0
    def backprop(self, state, loss):
        if state.stage != Stage.TRAIN:
            return

        with timing.time("loss.backward"):
            state.optimizer.backward(loss)
Example #19
0
    def train(
        self,
        training_data: BatchIterator,
        eval_data: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model, the model states will be modified. This function
        iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config
            training_result (Optional): only meaningful for Hogwild training. default
                is None
            rank (int): only used in distributed training, the rank of the current
                training thread, evaluation will only be done in rank 0

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        state = TrainingState(model=model,
                              optimizer=self.optimizer,
                              scheduler=self.scheduler,
                              rank=rank)
        self.set_up_training(state, training_data)

        while self.continue_training(state):
            state.epoch += 1
            state.epochs_since_last_improvement += 1
            print(f"Worker {state.rank} starting epoch {state.epoch}",
                  flush=True)
            lrs = learning_rates(state.optimizer)
            print(f"Learning rate(s): {', '.join(map(str, lrs))}")

            with timing.time("train epoch"):
                state.stage = Stage.TRAIN
                state.model.train()
                self.run_epoch(state, training_data, metric_reporter)

            if not self.config.do_eval:
                continue

            with timing.time("eval epoch"):
                state.stage = Stage.EVAL
                model.eval(Stage.EVAL)
                with torch.no_grad():
                    eval_metric = self.run_epoch(state, eval_data,
                                                 metric_reporter)

            # Step the learning rate scheduler(s)
            assert eval_metric is not None
            state.scheduler.step_epoch(
                metrics=metric_reporter.get_model_select_metric(eval_metric),
                epoch=state.epoch,
            )

            # Did we train a better model?
            if metric_reporter.compare_metric(eval_metric,
                                              state.best_model_metric):
                state.epochs_since_last_improvement = 0
                state.best_model_metric = eval_metric
                self.save_checkpoint(state, train_config)

        # Only bother loading the best model for master worker
        if rank == 0 and state.best_model_state is not None:
            self.load_best_model(state)

        return state.model, state.best_model_metric
Example #20
0
    def train_from_state(
        self,
        state: TrainingState,
        training_data: BatchIterator,
        eval_data: BatchIterator,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model from a given training state will be modified.
        This function iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            training_state (TrainingState): contrains stateful information to be
            able to restore a training job
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        training_data = self.set_up_training(state, training_data)
        model = state.model
        rank = state.rank
        trainable_params = sum(p.numel() for p in state.model.parameters()
                               if p.requires_grad)
        print(f"Model :{model}")
        print(f"Num trainable parameters: {trainable_params}")

        self.sparsifier.initialize(self, state, eval_data, metric_reporter,
                                   train_config)

        while self.continue_training(state):
            self.sparsifier.op_pre_epoch(self, state)
            state.epoch += 1
            state.epochs_since_last_improvement += 1
            lrs = learning_rates(state.optimizer)
            print(f"\nWorker {state.rank} starting epoch {state.epoch}")
            print(f"Learning rate(s): {', '.join(map(str, lrs))}")

            with timing.time("train epoch"):
                state.stage = Stage.TRAIN
                state.model.train()
                print(f"start training epoch {state.epoch}")
                epoch_data = training_data
                if self.config.num_batches_per_epoch:
                    # We want to limit the number of batches in the epoch;
                    # equivalent to epoch_data[:num_batches_per_epoch] for iterators.
                    # In this case we set the training data iterator to cycle earlier
                    # in the training process, so when it reaches the end it will
                    # loop back to the beginning.
                    epoch_data = itertools.islice(
                        epoch_data, self.config.num_batches_per_epoch)
                self.run_epoch(state, epoch_data, metric_reporter)

            if not self.config.do_eval:
                continue

            with timing.time("eval epoch"):
                state.stage = Stage.EVAL
                model.eval(Stage.EVAL)
                print(f"start evaluating epoch {state.epoch}")
                with torch.no_grad():
                    eval_metric = self.run_epoch(state, eval_data,
                                                 metric_reporter)

            # Step the learning rate scheduler(s)
            assert eval_metric is not None
            state.scheduler.step_epoch(
                metrics=metric_reporter.get_model_select_metric(eval_metric),
                epoch=state.epoch,
            )

            # Did we train a better model?
            better_model = metric_reporter.compare_metric(
                eval_metric, state.best_model_metric)
            if better_model:
                self.update_best_model(state, train_config, eval_metric)
            if better_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)

        if self.optimizer.finalize():
            should_update_model = True
            eval_metric = None
            if self.config.do_eval:
                state.stage = Stage.EVAL
                model.eval(Stage.EVAL)
                print(f"start evaluating finalized state")
                with torch.no_grad():
                    eval_metric = self.run_epoch(state, eval_data,
                                                 metric_reporter)
                should_update_model = metric_reporter.compare_metric(
                    eval_metric, state.best_model_metric)
            if should_update_model:
                self.update_best_model(state, train_config, eval_metric)
            if should_update_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)
        # Only bother loading the best model for master worker
        if (rank == 0 and state.best_model_state is not None
                and self.config.load_best_model_after_train):
            self.load_best_model(state)

        return state.model, state.best_model_metric
Example #21
0
    def train(
        self,
        train_iter: BatchIterator,
        eval_iter: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model, the model states will be modified. This function
        iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config
            training_result (Optional): only meaningful for Hogwild training. default
                is None
            rank (int): only used in distributed training, the rank of the current
                training thread, evaluation will only be done in rank 0

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        with timing.time("pre-training"):
            world_size = 1
            if cuda.CUDA_ENABLED:
                model = model.cuda()
                world_size = cuda.DISTRIBUTED_WORLD_SIZE
                if world_size > 1:
                    device_id = torch.cuda.current_device()
                    model = DistributedModel(
                        module=model,
                        device_ids=[device_id],
                        output_device=device_id,
                        broadcast_buffers=False,
                    )

            best_metric = None
            last_best_epoch = 0
            if self.lr_scheduler:
                self.lr_scheduler.prepare(train_iter, self.config.epochs)
            self.optimizer = precision.wrap_optimizer(self.optimizer)

        def training_pre_batch_callback():
            if world_size > 1:
                # replace optimizer.zero_grad() here to work with DDP
                # in cases where some parameters don't receive grads at each step
                # loss.backward will set grad for params in the computation graph
                # we can thus follow which params are left out and call .backward
                # on them manually
                for p in model.parameters():
                    if p.grad is not None:
                        p.grad.detach_()
                        p.grad = None
            else:
                self.optimizer.zero_grad()

        def training_backprop(loss):
            with timing.time("loss.backward"):
                precision.backward(self.optimizer, loss)
                if world_size > 1:
                    # DDP fix when some parameters don't receive grads
                    for p in model.parameters():
                        if p.requires_grad and p.grad is None:
                            p.backward(torch.zeros_like(p.data))

            if self.lr_scheduler:
                self.lr_scheduler.step_batch()

            if self.config.max_clip_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), self.config.max_clip_norm)
            else:
                grad_norm = None

            with timing.time("optimizer.step"):
                self.optimizer.step()
            # grad_norm could be used to check grads sync in distributed training
            return grad_norm

        time_start = time.time()
        best_model_state = None
        for epoch in range(1, self.config.epochs + 1):
            sys.stdout.flush()
            if self.config.target_time_limit_seconds > 0 and epoch > 1:
                time_elapsed = time.time() - time_start
                mean_epoch_time = time_elapsed / float(epoch - 1)
                expected_next_epoch_time = time_elapsed + mean_epoch_time
                if expected_next_epoch_time > self.config.target_time_limit_seconds:
                    print(
                        f"Training stopped after {epoch - 1} epochs and "
                        f"{int(time_elapsed)} seconds, due to the target max training "
                        f"time of {self.config.target_time_limit_seconds} seconds."
                    )
                    break

            print(f"Rank {rank} worker: Starting epoch #{epoch}")
            model.train()
            lrs = (str(lr) for lr in learning_rates(self.optimizer))
            print(f"Learning rate(s): {', '.join(lrs)}")

            with timing.time("epoch train"):
                self._run_epoch(
                    Stage.TRAIN,
                    epoch,
                    train_iter,
                    model,
                    metric_reporter,
                    pre_batch=training_pre_batch_callback,
                    backprop=training_backprop,
                    rank=rank,
                    num_samples_to_log_progress=self.config.
                    num_samples_to_log_progress,
                )

            if not self.config.do_eval:
                continue

            with timing.time("epoch eval"):
                model.eval(Stage.EVAL)
                with torch.no_grad():
                    eval_metric = self._run_epoch(
                        Stage.EVAL,
                        epoch,
                        eval_iter,
                        model,
                        metric_reporter,
                        rank=rank,
                        num_samples_to_log_progress=(
                            self.config.num_samples_to_log_progress),
                    )

            # Step the learning rate scheduler(s)
            if self.lr_scheduler:
                assert eval_metric is not None
                self.lr_scheduler.step_epoch(
                    metrics=metric_reporter.get_model_select_metric(
                        eval_metric),
                    epoch=epoch,
                )

            # choose best model.
            if metric_reporter.compare_metric(eval_metric, best_metric):
                with timing.time("save checkpoint model"):
                    last_best_epoch = epoch
                    best_metric = eval_metric
                    # Only rank = 0 trainer saves modules.
                    if train_config.save_module_checkpoints and rank == 0:
                        model.save_modules(
                            base_path=train_config.modules_save_dir,
                            suffix=f"-ep{epoch}",
                        )

                    if rank == 0:
                        print(f"Rank {rank} worker: Found a better model!")
                        model_state = model.state_dict()
                        # save to cpu to avoid multiple model copies in gpu memory
                        if cuda.CUDA_ENABLED:
                            for key, state in model_state.items():
                                model_state[key] = state.cpu()
                        best_model_state = model_state

            if self.config.early_stop_after > 0 and (
                    epoch - last_best_epoch == self.config.early_stop_after):
                print(f"Rank {rank} worker: Eval metric hasn't changed for " +
                      f"{self.config.early_stop_after} epochs. Stopping now.")
                break
            sys.stdout.flush()

        if rank == 0 and best_model_state is not None:
            if cuda.CUDA_ENABLED:
                for key, state in best_model_state.items():
                    best_model_state[key] = state.cuda()
            model.load_state_dict(best_model_state)

        return model, best_metric