Example #1
0
    def train(
        self,
        training_data: BatchIterator,
        eval_data: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model, the model states will be modified.
        Args:
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
            output and report results to console, file.. etc
            train_config (PyTextConfig): training config
            training_result (Optional): only meaningful for Hogwild training. default
            is None
            rank (int): only used in distributed training, the rank of the current
            training thread, evaluation will only be done in rank 0

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        state = TrainingState(
            model=model,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            sparsifier=self.sparsifier,
            rank=rank,
        )
        return self.train_from_state(state, training_data, eval_data,
                                     metric_reporter, train_config)
Example #2
0
 def test(self, test_iter, model, metric_reporter: MetricReporter):
     state = TrainingState(stage=Stage.TEST, model=model, epoch=1)
     if cuda.CUDA_ENABLED:
         state.model.cuda()
     state.model.eval()
     with torch.no_grad():
         return self.run_epoch(state, test_iter, metric_reporter)
Example #3
0
    def update_best_model(self, state: TrainingState,
                          train_config: PyTextConfig, eval_metric):
        # This should be updated by all workers so they agree on when to stop training
        # when `early_stop_after` is specified.
        state.epochs_since_last_improvement = 0
        state.best_model_metric = eval_metric
        print(f"Found a better model!")

        # Only one worker should save checkpoints
        if state.rank != 0:
            return

        model_state = state.model.state_dict()
        # save to cpu to avoid multiple model copies in gpu memory
        if cuda.CUDA_ENABLED:
            self.move_state_dict_to_cpu(model_state)
        state.best_model_state = model_state
        def test_load_checkpoint(self):
            with tempfile.NamedTemporaryFile() as checkpoint_file:
                train_data = tests_module.test_file("train_data_tiny.tsv")
                eval_data = tests_module.test_file("test_data_tiny.tsv")
                config = PyTextConfig(
                    task=DocumentClassificationTask.Config(data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        ))),
                    version=LATEST_VERSION,
                    save_snapshot_path=checkpoint_file.name,
                )
                task = create_task(config.task)
                model = task.model
                # test checkpoint saving and loading
                optimizer = create_optimizer(Adam.Config(), model)
                scheduler = create_scheduler(Scheduler.Config(), optimizer)
                training_state = TrainingState(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    start_time=0,
                    epoch=0,
                    rank=0,
                    stage=Stage.TRAIN,
                    epochs_since_last_improvement=0,
                    best_model_state=None,
                    best_model_metric=None,
                    tensorizers=None,
                )

                checkpoint_path = checkpoint_file.name
                save(
                    config,
                    model,
                    None,
                    task.data.tensorizers,
                    training_state,
                    checkpoint_file,
                )
                task_restored, config_restored, training_state_restored = load(
                    checkpoint_path)
                optimizer_restored = training_state_restored.optimizer
                scheduler_restored = training_state_restored.scheduler
                self.assertOptimizerEqual(optimizer, optimizer_restored)
                self.assertNotNone(scheduler_restored)
                self.assertEqual(config, config_restored)
                self.assertModulesEqual(model, task_restored.model)
                model.eval()
                task_restored.model.eval()

                inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
                self.assertEqual(
                    model(*inputs).tolist(),
                    task_restored.model(*inputs).tolist())
Example #5
0
    def set_up_training(self, state: TrainingState, training_data: BatchIterator):
        if cuda.CUDA_ENABLED:
            state.model.cuda()
        state.scheduler.prepare(training_data, self.config.epochs)

        if cuda.DISTRIBUTED_WORLD_SIZE > 1:
            device_id = torch.cuda.current_device()
            state.model = DistributedModel(
                module=state.model,
                device_ids=[device_id],
                output_device=device_id,
                broadcast_buffers=False,
                find_unused_parameters=state.model.find_unused_parameters,
            )
        state.start_time = time.time()

        if self.config.num_batches_per_epoch:
            # Set the training_data iterator to cycle, so it will never run out,
            # but rather after reaching the end will loop back to the beginning.
            training_data = cycle(training_data)
        return training_data
Example #6
0
    def test_load_checkpoint_in_dist_training(self):
        with tempfile.NamedTemporaryFile() as checkpoint_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(data=Data.Config(
                    source=BlockShardedTSVDataSource.Config(
                        train_filename=train_data,
                        eval_filename=eval_data,
                        field_names=["label", "slots", "text"],
                    ))),
                version=LATEST_VERSION,
                save_snapshot_path=checkpoint_file.name,
            )
            task = create_task(config.task)
            model = task.model
            # test checkpoint saving and loading
            optimizer = create_optimizer(Adam.Config(), model)
            scheduler = create_scheduler(Scheduler.Config(), optimizer)
            training_state = TrainingState(
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                start_time=0,
                epoch=0,
                rank=0,
                stage=Stage.TRAIN,
                epochs_since_last_improvement=0,
                best_model_state=None,
                best_model_metric=None,
                tensorizers=task.data.tensorizers,
            )

            id = "epoch-1"
            saved_path = save(config, model, None, task.data.tensorizers,
                              training_state, id)
            new_rank = 2
            new_world_size = 4
            task_restored, config_restored, training_state_restored = load(
                saved_path, rank=new_rank, world_size=new_world_size)
            self.assertCheckpointEqual(
                model,
                config,
                training_state,
                task_restored.model,
                config_restored,
                training_state_restored,
            )
            self.assertEqual(task_restored.data.data_source.rank, new_rank)
            self.assertEqual(task_restored.data.data_source.world_size,
                             new_world_size)
Example #7
0
    def train(
        self,
        training_data: DataLoader,
        eval_data: DataLoader,
        model: Model,
        optimizer: Optimizer,
        label_names: List[str],
        scheduler: Scheduler = None,
        sparsifier: Sparsifier = None,
        metric_reporter: MetricReporter = None,
        train_config: PyTextConfig = None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        # temp workaround to minimize changes to TaskTrainer
        if not train_config:
            train_config = PyTextConfig(
                task=NewTask.Config(model=RoBERTa.Config), version=20)
        if scheduler:
            self.scheduler = scheduler
        if sparsifier:
            self.sparsifier = sparsifier

        state = TrainingState(
            model=model,
            optimizer=optimizer,
            scheduler=self.scheduler,
            sparsifier=self.sparsifier,
            rank=rank,
        )
        metric_reporter_config = ClassificationMetricReporter.Config(
            output_path="/tmp/test_out.txt",
            pep_format=False,
            model_select_metric=ComparableClassificationMetric.
            ACCURACY,  # in json: "accuracy"
            target_label=None,
            text_column_names=["text"],
            additional_column_names=[],
            recall_at_precision_thresholds=[0.2, 0.4, 0.6, 0.8, 0.9],
        )
        metric_reporter = ClassificationMetricReporter.from_config_and_label_names(
            config=metric_reporter_config, label_names=label_names)
        return self.train_from_state(state, training_data, eval_data,
                                     metric_reporter, train_config)
Example #8
0
    def train_from_state(
        self,
        state: TrainingState,
        training_data: BatchIterator,
        eval_data: BatchIterator,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model from a given training state will be modified.
        This function iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            training_state (TrainingState): contrains stateful information to be
            able to restore a training job
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        training_data = self.set_up_training(state, training_data)
        model = state.model
        rank = state.rank
        trainable_params = sum(p.numel() for p in state.model.parameters()
                               if p.requires_grad)
        print(f"Model :{model}")
        print(f"Num trainable parameters: {trainable_params}")

        self.sparsifier.initialize(self, state, eval_data, metric_reporter,
                                   train_config)

        while self.continue_training(state):
            self.sparsifier.op_pre_epoch(self, state)
            state.epoch += 1
            state.epochs_since_last_improvement += 1
            lrs = learning_rates(state.optimizer)
            print(f"\nWorker {state.rank} starting epoch {state.epoch}")
            print(f"Learning rate(s): {', '.join(map(str, lrs))}")

            with timing.time("train epoch"):
                state.stage = Stage.TRAIN
                state.model.train()
                print(f"start training epoch {state.epoch}")
                epoch_data = training_data
                if self.config.num_batches_per_epoch:
                    # We want to limit the number of batches in the epoch;
                    # equivalent to epoch_data[:num_batches_per_epoch] for iterators.
                    # In this case we set the training data iterator to cycle earlier
                    # in the training process, so when it reaches the end it will
                    # loop back to the beginning.
                    epoch_data = itertools.islice(
                        epoch_data, self.config.num_batches_per_epoch)
                self.run_epoch(state, epoch_data, metric_reporter)

            if not self.config.do_eval:
                continue

            with timing.time("eval epoch"):
                state.stage = Stage.EVAL
                model.eval(Stage.EVAL)
                print(f"start evaluating epoch {state.epoch}")
                with torch.no_grad():
                    eval_metric = self.run_epoch(state, eval_data,
                                                 metric_reporter)

            # Step the learning rate scheduler(s)
            assert eval_metric is not None
            state.scheduler.step_epoch(
                metrics=metric_reporter.get_model_select_metric(eval_metric),
                epoch=state.epoch,
            )

            # Did we train a better model?
            better_model = metric_reporter.compare_metric(
                eval_metric, state.best_model_metric)
            if better_model:
                self.update_best_model(state, train_config, eval_metric)
            if better_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)

        if self.optimizer.finalize():
            should_update_model = True
            eval_metric = None
            if self.config.do_eval:
                state.stage = Stage.EVAL
                model.eval(Stage.EVAL)
                print(f"start evaluating finalized state")
                with torch.no_grad():
                    eval_metric = self.run_epoch(state, eval_data,
                                                 metric_reporter)
                should_update_model = metric_reporter.compare_metric(
                    eval_metric, state.best_model_metric)
            if should_update_model:
                self.update_best_model(state, train_config, eval_metric)
            if should_update_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)
        # Only bother loading the best model for master worker
        if (rank == 0 and state.best_model_state is not None
                and self.config.load_best_model_after_train):
            self.load_best_model(state)

        return state.model, state.best_model_metric
Example #9
0
 def test(self, test_iter, model, metric_reporter: MetricReporter):
     state = TrainingState(stage=Stage.TEST, model=model, epoch=1)
     if cuda.CUDA_ENABLED:
         state.model.cuda()
     state.model.eval()
     return self.eval_from_state(state, test_iter, metric_reporter)