Exemplo n.º 1
0
 class Config(Task.Config):
     model: SeqNNModel.Config = SeqNNModel.Config()
     trainer: Trainer.Config = Trainer.Config()
     labels: DocLabelConfig = DocLabelConfig()
     data_handler: SeqModelDataHandler.Config = SeqModelDataHandler.Config()
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
Exemplo n.º 2
0
 class Config(DocumentClassificationTask.Config):
     model: NewBertModel.Config = NewBertModel.Config(
         inputs=NewBertModel.Config.BertModelInput(
             tokens=BERTTensorizer.Config(columns=["text1", "text2"],
                                          max_seq_len=128)))
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config(
             text_column_names=["text1", "text2"]))
Exemplo n.º 3
0
 class Config(Task_Deprecated.Config):
     model: SeqNNModel_Deprecated.Config = SeqNNModel_Deprecated.Config()
     trainer: Trainer.Config = Trainer.Config()
     labels: DocLabelConfig = DocLabelConfig()
     data_handler: SeqModelDataHandler.Config = SeqModelDataHandler.Config()
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
     exporter: Optional[DenseFeatureExporter.Config] = None
Exemplo n.º 4
0
 class Config(Task.Config):
     model: Union[BaggingDocEnsemble.Config, BaggingIntentSlotEnsemble.Config]
     trainer: EnsembleTrainer.Config = EnsembleTrainer.Config()
     labels: List[TargetConfigBase]
     metric_reporter: Union[
         ClassificationMetricReporter.Config, IntentSlotMetricReporter.Config
     ] = ClassificationMetricReporter.Config()
     data_handler: JointModelDataHandler.Config = JointModelDataHandler.Config()
Exemplo n.º 5
0
 class Config(Task.Config):
     model: DocModel.Config = DocModel.Config()
     trainer: Trainer.Config = Trainer.Config()
     features: DocClassification.ModelInputConfig = (
         DocClassification.ModelInputConfig())
     labels: DocClassification.TargetConfig = DocClassification.TargetConfig(
     )
     data_handler: DocClassificationDataHandler.Config = (
         DocClassificationDataHandler.Config())
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
Exemplo n.º 6
0
 class Config(Task.Config):
     features: PairClassification.ModelInputConfig = (
         PairClassification.ModelInputConfig())
     model: PairClassificationModel.Config = PairClassificationModel.Config(
     )
     data_handler: PairClassificationDataHandler.Config = (
         PairClassificationDataHandler.Config())
     trainer: Trainer.Config = Trainer.Config()
     labels: PairClassification.TargetConfig = PairClassification.TargetConfig(
     )
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
Exemplo n.º 7
0
 class Config(Task_Deprecated.Config):
     model: DocModel_Deprecated.Config = DocModel_Deprecated.Config()
     trainer: Trainer.Config = Trainer.Config()
     features: DocClassification.ModelInputConfig = (
         DocClassification.ModelInputConfig())
     labels: DocClassification.TargetConfig = DocClassification.TargetConfig(
     )
     data_handler: DocClassificationDataHandler.Config = (
         DocClassificationDataHandler.Config())
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
     exporter: Optional[DenseFeatureExporter.Config] = None
Exemplo n.º 8
0
    def train(
        self,
        training_data: DataLoader,
        eval_data: DataLoader,
        model: Model,
        optimizer: Optimizer,
        label_names: List[str],
        scheduler: Scheduler = None,
        sparsifier: Sparsifier = None,
        metric_reporter: MetricReporter = None,
        train_config: PyTextConfig = None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        # temp workaround to minimize changes to TaskTrainer
        if not train_config:
            train_config = PyTextConfig(
                task=NewTask.Config(model=RoBERTa.Config), version=20)
        if scheduler:
            self.scheduler = scheduler
        if sparsifier:
            self.sparsifier = sparsifier

        state = TrainingState(
            model=model,
            optimizer=optimizer,
            scheduler=self.scheduler,
            sparsifier=self.sparsifier,
            rank=rank,
        )
        metric_reporter_config = ClassificationMetricReporter.Config(
            output_path="/tmp/test_out.txt",
            pep_format=False,
            model_select_metric=ComparableClassificationMetric.
            ACCURACY,  # in json: "accuracy"
            target_label=None,
            text_column_names=["text"],
            additional_column_names=[],
            recall_at_precision_thresholds=[0.2, 0.4, 0.6, 0.8, 0.9],
        )
        metric_reporter = ClassificationMetricReporter.from_config_and_label_names(
            config=metric_reporter_config, label_names=label_names)
        return self.train_from_state(state, training_data, eval_data,
                                     metric_reporter, train_config)
Exemplo n.º 9
0
 class Config(NewTask.Config):
     model: BaseModel.Config = DocModel.Config()
     metric_reporter: Union[ClassificationMetricReporter.Config,
                            PureLossMetricReporter.
                            Config] = ClassificationMetricReporter.Config()
Exemplo n.º 10
0
 class Config(NewTask.Config):
     model: EnsembleModel.Config
     trainer: EnsembleTrainer.Config = EnsembleTrainer.Config()
     metric_reporter: Union[ClassificationMetricReporter.Config,
                            IntentSlotMetricReporter.
                            Config] = ClassificationMetricReporter.Config()
Exemplo n.º 11
0
 class Config(NewTask.Config):
     model: SeqNNModel.Config = SeqNNModel.Config()
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config(
             text_column_names=["text_seq"]))
Exemplo n.º 12
0
 class Config(NewTask.Config):
     model: BasePairwiseModel.Config = PairwiseModel.Config()
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config(
             text_column_names=["text1", "text2"]))
     trace_both_encoders: bool = True
Exemplo n.º 13
0
 class Config(NewTask.Config):
     model: Model.Config = DocModel.Config()
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
Exemplo n.º 14
0
 class Config(NewTask.Config):
     model: BasePairwiseModel.Config = PairwiseModel.Config()
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
Exemplo n.º 15
0
 class Config(NewTask.Config):
     model: BaggingDocEnsemble.Config
     trainer: EnsembleTrainer.Config = EnsembleTrainer.Config()
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
Exemplo n.º 16
0
def classification_metric_reporter_config_expand(label_names: List[str],
                                                 **kwargs):
    classification_metric_reporter_config = ClassificationMetricReporter.Config(
        **kwargs)
    return ClassificationMetricReporter.from_config_and_label_names(
        classification_metric_reporter_config, label_names)