def _get_config_with_export_list(
     self,
     task_class: Type[NewTask],
     model_class: Type[Model],
     test_file_metadata: TestFileMetadata,
 ) -> PyTextConfig:
     return PyTextConfig(
         task=task_class.Config(
             data=Data.Config(
                 source=TSVDataSource.Config(
                     train_filename=test_file_metadata.filename,
                     eval_filename=test_file_metadata.filename,
                     test_filename=test_file_metadata.filename,
                     field_names=test_file_metadata.field_names,
                 ),
                 batcher=PoolingBatcher.Config(train_batch_size=1,
                                               test_batch_size=1),
             ),
             trainer=TaskTrainer.Config(epochs=1),
             model=model_class.Config(
                 inputs=type(model_class.Config.inputs)(
                     dense=FloatListTensorizer.Config(
                         column=test_file_metadata.dense_col_name,
                         error_check=True,
                         dim=test_file_metadata.dense_feat_dim,
                     ))),
         ),
         use_tensorboard=False,
         use_cuda_if_available=False,
         export=ExportConfig(
             export_torchscript_path="/tmp/model_torchscript.pt"),
         version=LATEST_VERSION,
     )
示例#2
0
    def test_load_saved_model(self):
        with tempfile.NamedTemporaryFile() as snapshot_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(
                    data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        )
                    )
                ),
                version=LATEST_VERSION,
                save_snapshot_path=snapshot_file.name,
            )
            task = create_task(config.task)
            model = task.model

            save(config, model, meta=None, tensorizers=task.data.tensorizers)
            task2, config2 = load(snapshot_file.name)

            self.assertEqual(config, config2)
            self.assertModulesEqual(model, task2.model)

            model.eval()
            task2.model.eval()

            inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
            self.assertEqual(model(*inputs).tolist(), task2.model(*inputs).tolist())
示例#3
0
 def _get_pytext_config(
     self,
     test_file_name: TestFileName,
     task_class: Type[NewTask],
     model_class: Type[Model],
 ) -> PyTextConfig:
     test_file_metadata = get_test_file_metadata(test_file_name)
     return PyTextConfig(
         task=task_class.Config(
             data=Data.Config(
                 source=TSVDataSource.Config(
                     train_filename=test_file_metadata.filename,
                     eval_filename=test_file_metadata.filename,
                     test_filename=test_file_metadata.filename,
                     field_names=test_file_metadata.field_names,
                 ),
                 batcher=Batcher.Config(
                 ),  # Use Batcher to avoid shuffling.
             ),
             trainer=TaskTrainer.Config(epochs=1),
             model=model_class.Config(
                 inputs=type(model_class.Config.inputs)(
                     dense=FloatListTensorizer.Config(
                         column=test_file_metadata.dense_col_name,
                         dim=test_file_metadata.dense_feat_dim,
                     ))),
         ),
         use_tensorboard=False,
         use_cuda_if_available=False,
         version=LATEST_VERSION,
     )
示例#4
0
    def test_batch_predict_caffe2_model(self):
        with tempfile.NamedTemporaryFile(
        ) as snapshot_file, tempfile.NamedTemporaryFile() as caffe2_model_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(data=Data.Config(
                    source=TSVDataSource.Config(
                        train_filename=train_data,
                        eval_filename=eval_data,
                        test_filename=eval_data,
                        field_names=["label", "slots", "text"],
                    ))),
                version=LATEST_VERSION,
                save_snapshot_path=snapshot_file.name,
                export_caffe2_path=caffe2_model_file.name,
            )
            task = create_task(config.task)
            task.export(task.model, caffe2_model_file.name)
            model = task.model
            save(config, model, meta=None, tensorizers=task.data.tensorizers)

            results = batch_predict_caffe2_model(snapshot_file.name,
                                                 caffe2_model_file.name)
            self.assertEqual(4, len(results))
示例#5
0
def gen_config_impl(task_name, options):
    task_class_set = find_config_class(task_name)
    if not task_class_set:
        raise Exception(f"Unknown task class: {task_name}")
    elif len(task_class_set) > 1:
        raise Exception(f"Multiple tasks named {task_name}: {task_class_set}")

    task_class = next(iter(task_class_set))
    root = PyTextConfig(task=task_class.Config())

    # Use components listed in options instead of defaults
    for opt in options:
        replace_class_set = find_config_class(opt)
        if not replace_class_set:
            raise Exception(f"Not a component class: {opt}")
        elif len(replace_class_set) > 1:
            raise Exception(
                f"Multiple component named {opt}: {replace_class_set}")
        replace_class = next(iter(replace_class_set))
        found = replace_components(root, opt, set(replace_class.__bases__))
        if found:
            eprint("INFO - Applying option:", "->".join(reversed(found)), "=",
                   opt)
            obj = root
            for k in reversed(found[1:]):
                obj = getattr(obj, k)
            if hasattr(replace_class, "Config"):
                setattr(obj, found[0], replace_class.Config())
            else:
                setattr(obj, found[0], replace_class())
        else:
            raise Exception(f"Unknown option: {opt}")
    return config_to_json(PyTextConfig, root)
示例#6
0
def prepare_task(
    config: PyTextConfig,
    dist_init_url: str = None,
    device_id: int = 0,
    rank: int = 0,
    world_size: int = 1,
    metric_channels: Optional[List[Channel]] = None,
    metadata: CommonMetadata = None,
) -> Tuple[Task_Deprecated, TrainingState]:
    if world_size > 1 and config.random_seed is None:
        msg = (
            "Must set random seed when using world_size > 1, so that parameters have "
            "same initialization across workers."
        )
        raise ValueError(msg)

    if rank == 0:
        print("\nParameters: {}\n".format(config), flush=True)
    _set_cuda(config.use_cuda_if_available, device_id, world_size)
    _set_fp16(config.use_fp16, rank)
    _set_distributed(
        rank,
        world_size,
        dist_init_url,
        device_id,
        config.gpu_streams_for_distributed_training,
    )

    if config.random_seed is not None:
        set_random_seeds(config.random_seed, config.use_deterministic_cudnn)

    training_state = None

    if config.auto_resume_from_snapshot:
        # if there are existing checkpoints, resume from the latest one
        latest_snapshot_path = get_latest_checkpoint_path(
            os.path.dirname(config.save_snapshot_path)
        )
        if latest_snapshot_path:
            config.load_snapshot_path = latest_snapshot_path

    if config.load_snapshot_path:
        assert PathManager.isfile(config.load_snapshot_path)
        if config.use_config_from_snapshot:
            task, _, training_state = load(config.load_snapshot_path)
        else:
            task, _, training_state = load(
                config.load_snapshot_path, overwrite_config=config
            )
        if training_state:
            training_state.rank = rank
    else:
        task = create_task(
            config.task, metadata=metadata, rank=rank, world_size=world_size
        )

    for mc in metric_channels or []:
        task.metric_reporter.add_channel(mc)

    return task, training_state
示例#7
0
    def test_batch_predict_caffe2_model(self):
        with tempfile.NamedTemporaryFile() as snapshot_file, tempfile.NamedTemporaryFile() as caffe2_model_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(
                    model=DocModel.Config(
                        inputs=DocModel.Config.ModelInput(
                            tokens=TokenTensorizer.Config(),
                            dense=FloatListTensorizer.Config(
                                column="dense", dim=1, error_check=True
                            ),
                            labels=LabelTensorizer.Config(),
                        )
                    ),
                    data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            test_filename=eval_data,
                            field_names=["label", "slots", "text", "dense"],
                        )
                    ),
                ),
                version=21,
                save_snapshot_path=snapshot_file.name,
                export_caffe2_path=caffe2_model_file.name,
            )
            task = create_task(config.task)
            task.export(task.model, caffe2_model_file.name)
            model = task.model
            save(config, model, meta=None, tensorizers=task.data.tensorizers)

            pt_results = task.predict(task.data.data_source.test)

            def assert_caffe2_results_correct(caffe2_results):
                for pt_res, res in zip(pt_results, caffe2_results):
                    np.testing.assert_array_almost_equal(
                        pt_res["score"].tolist()[0],
                        [score[0] for score in res.values()],
                    )

            results = batch_predict_caffe2_model(
                snapshot_file.name, caffe2_model_file.name
            )
            self.assertEqual(4, len(results))
            assert_caffe2_results_correct(results)

            results = batch_predict_caffe2_model(
                snapshot_file.name, caffe2_model_file.name, cache_size=2
            )
            self.assertEqual(4, len(results))
            assert_caffe2_results_correct(results)

            results = batch_predict_caffe2_model(
                snapshot_file.name, caffe2_model_file.name, cache_size=-1
            )
            self.assertEqual(4, len(results))
            assert_caffe2_results_correct(results)
示例#8
0
def gen_config_impl(task_name, options):
    # import the classes required by parameters
    requested_classes = [locate(opt) for opt in options] + [locate(task_name)]
    register_tasks(requested_classes)

    task_class_set = find_config_class(task_name)
    if not task_class_set:
        raise Exception(f"Unknown task class: {task_name} "
                        "(try fully qualified class name?)")
    elif len(task_class_set) > 1:
        raise Exception(f"Multiple tasks named {task_name}: {task_class_set}")

    task_class = next(iter(task_class_set))
    task_config = getattr(task_class, "example_config", task_class.Config)
    root = PyTextConfig(task=task_config(), version=LATEST_VERSION)
    eprint("INFO - Applying task option:", task_class.__name__)

    # Use components listed in options instead of defaults
    for opt in options:
        if "=" in opt:
            param_path, value = opt.split("=", 1)
            found = find_param(root, "." + param_path)
            if len(found) == 1:
                eprint("INFO - Applying parameter option to", found[0], ":",
                       opt)
                replace_param(root, found[0].split("."), value)
            elif not found:
                raise Exception(f"Unknown parameter option: {opt}")
            else:
                raise Exception(
                    f"Multiple possibilities for {opt}: {', '.join(found)}")
        else:
            replace_class_set = find_config_class(opt)
            if not replace_class_set:
                raise Exception(f"Not a component class: {opt}")
            elif len(replace_class_set) > 1:
                raise Exception(
                    f"Multiple component named {opt}: {replace_class_set}")
            replace_class = next(iter(replace_class_set))
            found = replace_components(root, opt,
                                       get_subclasses(replace_class))
            if found:
                eprint(
                    "INFO - Applying class option:",
                    "->".join(reversed(found)),
                    "=",
                    opt,
                )
                obj = root
                for k in reversed(found[1:]):
                    obj = getattr(obj, k)
                if hasattr(replace_class, "Config"):
                    setattr(obj, found[0], replace_class.Config())
                else:
                    setattr(obj, found[0], replace_class())
            else:
                raise Exception(f"Unknown class option: {opt}")
    return root
        def test_load_checkpoint(self):
            with tempfile.NamedTemporaryFile() as checkpoint_file:
                train_data = tests_module.test_file("train_data_tiny.tsv")
                eval_data = tests_module.test_file("test_data_tiny.tsv")
                config = PyTextConfig(
                    task=DocumentClassificationTask.Config(data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        ))),
                    version=LATEST_VERSION,
                    save_snapshot_path=checkpoint_file.name,
                )
                task = create_task(config.task)
                model = task.model
                # test checkpoint saving and loading
                optimizer = create_optimizer(Adam.Config(), model)
                scheduler = create_scheduler(Scheduler.Config(), optimizer)
                training_state = TrainingState(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    start_time=0,
                    epoch=0,
                    rank=0,
                    stage=Stage.TRAIN,
                    epochs_since_last_improvement=0,
                    best_model_state=None,
                    best_model_metric=None,
                    tensorizers=None,
                )

                checkpoint_path = checkpoint_file.name
                save(
                    config,
                    model,
                    None,
                    task.data.tensorizers,
                    training_state,
                    checkpoint_file,
                )
                task_restored, config_restored, training_state_restored = load(
                    checkpoint_path)
                optimizer_restored = training_state_restored.optimizer
                scheduler_restored = training_state_restored.scheduler
                self.assertOptimizerEqual(optimizer, optimizer_restored)
                self.assertNotNone(scheduler_restored)
                self.assertEqual(config, config_restored)
                self.assertModulesEqual(model, task_restored.model)
                model.eval()
                task_restored.model.eval()

                inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
                self.assertEqual(
                    model(*inputs).tolist(),
                    task_restored.model(*inputs).tolist())
示例#10
0
    def test_load_checkpoint_in_dist_training(self):
        with tempfile.NamedTemporaryFile() as checkpoint_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(data=Data.Config(
                    source=BlockShardedTSVDataSource.Config(
                        train_filename=train_data,
                        eval_filename=eval_data,
                        field_names=["label", "slots", "text"],
                    ))),
                version=LATEST_VERSION,
                save_snapshot_path=checkpoint_file.name,
            )
            task = create_task(config.task)
            model = task.model
            # test checkpoint saving and loading
            optimizer = create_optimizer(Adam.Config(), model)
            scheduler = create_scheduler(Scheduler.Config(), optimizer)
            training_state = TrainingState(
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                start_time=0,
                epoch=0,
                rank=0,
                stage=Stage.TRAIN,
                epochs_since_last_improvement=0,
                best_model_state=None,
                best_model_metric=None,
                tensorizers=task.data.tensorizers,
            )

            id = "epoch-1"
            saved_path = save(config, model, None, task.data.tensorizers,
                              training_state, id)
            new_rank = 2
            new_world_size = 4
            task_restored, config_restored, training_state_restored = load(
                saved_path, rank=new_rank, world_size=new_world_size)
            self.assertCheckpointEqual(
                model,
                config,
                training_state,
                task_restored.model,
                config_restored,
                training_state_restored,
            )
            self.assertEqual(task_restored.data.data_source.rank, new_rank)
            self.assertEqual(task_restored.data.data_source.world_size,
                             new_world_size)
示例#11
0
    def train(
        self,
        training_data: DataLoader,
        eval_data: DataLoader,
        model: Model,
        optimizer: Optimizer,
        label_names: List[str],
        scheduler: Scheduler = None,
        sparsifier: Sparsifier = None,
        metric_reporter: MetricReporter = None,
        train_config: PyTextConfig = None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        # temp workaround to minimize changes to TaskTrainer
        if not train_config:
            train_config = PyTextConfig(
                task=NewTask.Config(model=RoBERTa.Config), version=20)
        if scheduler:
            self.scheduler = scheduler
        if sparsifier:
            self.sparsifier = sparsifier

        state = TrainingState(
            model=model,
            optimizer=optimizer,
            scheduler=self.scheduler,
            sparsifier=self.sparsifier,
            rank=rank,
        )
        metric_reporter_config = ClassificationMetricReporter.Config(
            output_path="/tmp/test_out.txt",
            pep_format=False,
            model_select_metric=ComparableClassificationMetric.
            ACCURACY,  # in json: "accuracy"
            target_label=None,
            text_column_names=["text"],
            additional_column_names=[],
            recall_at_precision_thresholds=[0.2, 0.4, 0.6, 0.8, 0.9],
        )
        metric_reporter = ClassificationMetricReporter.from_config_and_label_names(
            config=metric_reporter_config, label_names=label_names)
        return self.train_from_state(state, training_data, eval_data,
                                     metric_reporter, train_config)
示例#12
0
def gen_config_impl(task_name, options):
    # import the classes required by parameters
    requested_classes = [locate(opt) for opt in options] + [locate(task_name)]
    register_tasks(requested_classes)

    task_class_set = find_config_class(task_name)
    if not task_class_set:
        raise Exception(f"Unknown task class: {task_name} "
                        "(try fully qualified class name?)")
    elif len(task_class_set) > 1:
        raise Exception(f"Multiple tasks named {task_name}: {task_class_set}")

    task_class = next(iter(task_class_set))
    task_config = getattr(task_class, "example_config", task_class.Config)
    root = PyTextConfig(task=task_config(), version=LATEST_VERSION)

    # Use components listed in options instead of defaults
    for opt in options:
        replace_class_set = find_config_class(opt)
        if not replace_class_set:
            raise Exception(f"Not a component class: {opt}")
        elif len(replace_class_set) > 1:
            raise Exception(
                f"Multiple component named {opt}: {replace_class_set}")
        replace_class = next(iter(replace_class_set))
        found = replace_components(root, opt, set(replace_class.__bases__))
        if found:
            eprint("INFO - Applying option:", "->".join(reversed(found)), "=",
                   opt)
            obj = root
            for k in reversed(found[1:]):
                obj = getattr(obj, k)
            if hasattr(replace_class, "Config"):
                setattr(obj, found[0], replace_class.Config())
            else:
                setattr(obj, found[0], replace_class())
        else:
            raise Exception(f"Unknown option: {opt}")
    return config_to_json(PyTextConfig, root)
示例#13
0
    def test_load_saved_model(self):
        with tempfile.NamedTemporaryFile() as snapshot_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(
                    data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        )
                    )
                ),
                version=LATEST_VERSION,
                save_snapshot_path=snapshot_file.name,
            )
            task = create_task(config.task)
            model = task.model

            save(config, model, meta=None, tensorizers=task.data.tensorizers)
            task2, config2, training_state_none = load(snapshot_file.name)

            self.assertEqual(config, config2)
            self.assertModulesEqual(model, task2.model)
            self.assertIsNone(training_state_none)
            model.eval()
            task2.model.eval()

            inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
            self.assertEqual(model(*inputs).tolist(), task2.model(*inputs).tolist())

        def assertOptimizerEqual(self, optim_1, optim_2, msg=None):
            self.assertTrue(optim_1 is Optimizer and optim_2 is Optimizer, msg)
            state_dict_1 = optim_1.state_dict()
            state_dict_2 = optim_2.state_dict()
            self.assertEqual(len(state_dict_1), len(state_dict_2))
            for key_1, val_1 in optim_1.state_dict().items():
                self.assertEqualt(val_1, state_dict_2[key_1], msg)

        def test_load_checkpoint(self):
            with tempfile.NamedTemporaryFile() as checkpoint_file:
                train_data = tests_module.test_file("train_data_tiny.tsv")
                eval_data = tests_module.test_file("test_data_tiny.tsv")
                config = PyTextConfig(
                    task=DocumentClassificationTask.Config(
                        data=Data.Config(
                            source=TSVDataSource.Config(
                                train_filename=train_data,
                                eval_filename=eval_data,
                                field_names=["label", "slots", "text"],
                            )
                        )
                    ),
                    version=LATEST_VERSION,
                    save_snapshot_path=checkpoint_file.name,
                )
                task = create_task(config.task)
                model = task.model
                # test checkpoint saving and loading
                optimizer = create_optimizer(Adam.Config(), model)
                scheduler = create_scheduler(Scheduler.Config(), optimizer)
                training_state = TrainingState(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    start_time=0,
                    epoch=0,
                    rank=0,
                    stage=Stage.TRAIN,
                    epochs_since_last_improvement=0,
                    best_model_state=None,
                    best_model_metric=None,
                    tensorizers=task.data.tensorizers,
                )

                checkpoint_path = checkpoint_file.name

                save(
                    config,
                    model,
                    None,
                    task.data.tensorizers,
                    training_state,
                    "epoch-1",
                )
                task_restored, config_restored, training_state_restored = load(
                    checkpoint_path
                )
                optimizer_restored = training_state_restored.optimizer
                scheduler_restored = training_state_restored.scheduler
                self.assertOptimizerEqual(optimizer, optimizer_restored)
                self.assertNotNone(scheduler_restored)
                self.assertEqual(config, config_restored)
                self.assertModulesEqual(model, task_restored.model)
                model.eval()
                task_restored.model.eval()

                inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
                self.assertEqual(
                    model(*inputs).tolist(), task_restored.model(*inputs).tolist()
                )