Exemplo n.º 1
0
    def train(
        self,
        training_data: BatchIterator,
        eval_data: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model, the model states will be modified.
        Args:
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
            output and report results to console, file.. etc
            train_config (PyTextConfig): training config
            training_result (Optional): only meaningful for Hogwild training. default
            is None
            rank (int): only used in distributed training, the rank of the current
            training thread, evaluation will only be done in rank 0

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        state = TrainingState(
            model=model,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            sparsifier=self.sparsifier,
            rank=rank,
        )
        return self.train_from_state(state, training_data, eval_data,
                                     metric_reporter, train_config)
Exemplo n.º 2
0
 def test(self, test_iter, model, metric_reporter: MetricReporter):
     state = TrainingState(stage=Stage.TEST, model=model, epoch=1)
     if cuda.CUDA_ENABLED:
         state.model.cuda()
     state.model.eval()
     with torch.no_grad():
         return self.run_epoch(state, test_iter, metric_reporter)
Exemplo n.º 3
0
        def test_load_checkpoint(self):
            with tempfile.NamedTemporaryFile() as checkpoint_file:
                train_data = tests_module.test_file("train_data_tiny.tsv")
                eval_data = tests_module.test_file("test_data_tiny.tsv")
                config = PyTextConfig(
                    task=DocumentClassificationTask.Config(data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        ))),
                    version=LATEST_VERSION,
                    save_snapshot_path=checkpoint_file.name,
                )
                task = create_task(config.task)
                model = task.model
                # test checkpoint saving and loading
                optimizer = create_optimizer(Adam.Config(), model)
                scheduler = create_scheduler(Scheduler.Config(), optimizer)
                training_state = TrainingState(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    start_time=0,
                    epoch=0,
                    rank=0,
                    stage=Stage.TRAIN,
                    epochs_since_last_improvement=0,
                    best_model_state=None,
                    best_model_metric=None,
                    tensorizers=None,
                )

                checkpoint_path = checkpoint_file.name
                save(
                    config,
                    model,
                    None,
                    task.data.tensorizers,
                    training_state,
                    checkpoint_file,
                )
                task_restored, config_restored, training_state_restored = load(
                    checkpoint_path)
                optimizer_restored = training_state_restored.optimizer
                scheduler_restored = training_state_restored.scheduler
                self.assertOptimizerEqual(optimizer, optimizer_restored)
                self.assertNotNone(scheduler_restored)
                self.assertEqual(config, config_restored)
                self.assertModulesEqual(model, task_restored.model)
                model.eval()
                task_restored.model.eval()

                inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
                self.assertEqual(
                    model(*inputs).tolist(),
                    task_restored.model(*inputs).tolist())
Exemplo n.º 4
0
    def test_load_checkpoint_in_dist_training(self):
        with tempfile.NamedTemporaryFile() as checkpoint_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(data=Data.Config(
                    source=BlockShardedTSVDataSource.Config(
                        train_filename=train_data,
                        eval_filename=eval_data,
                        field_names=["label", "slots", "text"],
                    ))),
                version=LATEST_VERSION,
                save_snapshot_path=checkpoint_file.name,
            )
            task = create_task(config.task)
            model = task.model
            # test checkpoint saving and loading
            optimizer = create_optimizer(Adam.Config(), model)
            scheduler = create_scheduler(Scheduler.Config(), optimizer)
            training_state = TrainingState(
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                start_time=0,
                epoch=0,
                rank=0,
                stage=Stage.TRAIN,
                epochs_since_last_improvement=0,
                best_model_state=None,
                best_model_metric=None,
                tensorizers=task.data.tensorizers,
            )

            id = "epoch-1"
            saved_path = save(config, model, None, task.data.tensorizers,
                              training_state, id)
            new_rank = 2
            new_world_size = 4
            task_restored, config_restored, training_state_restored = load(
                saved_path, rank=new_rank, world_size=new_world_size)
            self.assertCheckpointEqual(
                model,
                config,
                training_state,
                task_restored.model,
                config_restored,
                training_state_restored,
            )
            self.assertEqual(task_restored.data.data_source.rank, new_rank)
            self.assertEqual(task_restored.data.data_source.world_size,
                             new_world_size)
Exemplo n.º 5
0
    def train(
        self,
        training_data: DataLoader,
        eval_data: DataLoader,
        model: Model,
        optimizer: Optimizer,
        label_names: List[str],
        scheduler: Scheduler = None,
        sparsifier: Sparsifier = None,
        metric_reporter: MetricReporter = None,
        train_config: PyTextConfig = None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        # temp workaround to minimize changes to TaskTrainer
        if not train_config:
            train_config = PyTextConfig(
                task=NewTask.Config(model=RoBERTa.Config), version=20)
        if scheduler:
            self.scheduler = scheduler
        if sparsifier:
            self.sparsifier = sparsifier

        state = TrainingState(
            model=model,
            optimizer=optimizer,
            scheduler=self.scheduler,
            sparsifier=self.sparsifier,
            rank=rank,
        )
        metric_reporter_config = ClassificationMetricReporter.Config(
            output_path="/tmp/test_out.txt",
            pep_format=False,
            model_select_metric=ComparableClassificationMetric.
            ACCURACY,  # in json: "accuracy"
            target_label=None,
            text_column_names=["text"],
            additional_column_names=[],
            recall_at_precision_thresholds=[0.2, 0.4, 0.6, 0.8, 0.9],
        )
        metric_reporter = ClassificationMetricReporter.from_config_and_label_names(
            config=metric_reporter_config, label_names=label_names)
        return self.train_from_state(state, training_data, eval_data,
                                     metric_reporter, train_config)
Exemplo n.º 6
0
    def train(
        self,
        training_data: BatchIterator,
        eval_data: BatchIterator,
        model: Model,
        metric_reporter: MetricReporter,
        train_config: PyTextConfig,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        """
        Train and eval a model, the model states will be modified. This function
        iterates epochs specified in config, and for each epoch do:

            1. Train model using training data, aggregate and report training results
            2. Adjust learning rate if scheduler is specified
            3. Evaluate model using evaluation data
            4. Calculate metrics based on evaluation results and select best model

        Args:
            train_iter (BatchIterator): batch iterator of training data
            eval_iter (BatchIterator): batch iterator of evaluation data
            model (Model): model to be trained
            metric_reporter (MetricReporter): compute metric based on training
                output and report results to console, file.. etc
            train_config (PyTextConfig): training config
            training_result (Optional): only meaningful for Hogwild training. default
                is None
            rank (int): only used in distributed training, the rank of the current
                training thread, evaluation will only be done in rank 0

        Returns:
            model, best_metric: the trained model together with the best metric
        """
        state = TrainingState(model=model,
                              optimizer=self.optimizer,
                              scheduler=self.scheduler,
                              rank=rank)
        training_data = self.set_up_training(state, training_data)
        trainable_params = sum(p.numel() for p in state.model.parameters()
                               if p.requires_grad)
        print(f"Num trainable parameters: {trainable_params}")

        while self.continue_training(state):
            state.epoch += 1
            state.epochs_since_last_improvement += 1
            lrs = learning_rates(state.optimizer)
            print(f"\nWorker {state.rank} starting epoch {state.epoch}",
                  flush=True)
            print(f"Learning rate(s): {', '.join(map(str, lrs))}")

            with timing.time("train epoch"):
                state.stage = Stage.TRAIN
                state.model.train()
                print(f"start training epoch {state.epoch}", flush=True)
                epoch_data = training_data
                if self.config.num_batches_per_epoch:
                    # We want to limit the number of batches in the epoch;
                    # equivalent to epoch_data[:num_batches_per_epoch] for iterators.
                    # In this case we set the training data iterator to cycle earlier
                    # in the training process, so when it reaches the end it will
                    # loop back to the beginning.
                    epoch_data = itertools.islice(
                        epoch_data, self.config.num_batches_per_epoch)
                self.run_epoch(state, epoch_data, metric_reporter)

            if not self.config.do_eval:
                continue

            with timing.time("eval epoch"):
                state.stage = Stage.EVAL
                model.eval(Stage.EVAL)
                print(f"start evaluating epoch {state.epoch}", flush=True)
                with torch.no_grad():
                    eval_metric = self.run_epoch(state, eval_data,
                                                 metric_reporter)

            # Step the learning rate scheduler(s)
            assert eval_metric is not None
            state.scheduler.step_epoch(
                metrics=metric_reporter.get_model_select_metric(eval_metric),
                epoch=state.epoch,
            )

            # Did we train a better model?
            better_model = metric_reporter.compare_metric(
                eval_metric, state.best_model_metric)
            if better_model:
                self.update_best_model(state, train_config, eval_metric)
            if better_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)

        if self.optimizer.finalize():
            state.stage = Stage.EVAL
            model.eval(Stage.EVAL)
            print(f"start evaluating finalized state", flush=True)
            with torch.no_grad():
                eval_metric = self.run_epoch(state, eval_data, metric_reporter)
            better_model = metric_reporter.compare_metric(
                eval_metric, state.best_model_metric)
            if better_model:
                self.update_best_model(state, train_config, eval_metric)
            if better_model or train_config.save_all_checkpoints:
                self.save_checkpoint(state, train_config)
        # Only bother loading the best model for master worker
        if rank == 0 and state.best_model_state is not None:
            self.load_best_model(state)

        return state.model, state.best_model_metric
Exemplo n.º 7
0
 def test(self, test_iter, model, metric_reporter: MetricReporter):
     state = TrainingState(stage=Stage.TEST, model=model, epoch=1)
     if cuda.CUDA_ENABLED:
         state.model.cuda()
     state.model.eval()
     return self.eval_from_state(state, test_iter, metric_reporter)