Exemplo n.º 1
0
    def train(
        self,
        training_data: DataLoader,
        eval_data: DataLoader,
        model: Model,
        optimizer: Optimizer,
        label_names: List[str],
        scheduler: Scheduler = None,
        sparsifier: Sparsifier = None,
        metric_reporter: MetricReporter = None,
        train_config: PyTextConfig = None,
        rank: int = 0,
    ) -> Tuple[torch.nn.Module, Any]:
        # temp workaround to minimize changes to TaskTrainer
        if not train_config:
            train_config = PyTextConfig(
                task=NewTask.Config(model=RoBERTa.Config), version=20)
        if scheduler:
            self.scheduler = scheduler
        if sparsifier:
            self.sparsifier = sparsifier

        state = TrainingState(
            model=model,
            optimizer=optimizer,
            scheduler=self.scheduler,
            sparsifier=self.sparsifier,
            rank=rank,
        )
        metric_reporter_config = ClassificationMetricReporter.Config(
            output_path="/tmp/test_out.txt",
            pep_format=False,
            model_select_metric=ComparableClassificationMetric.
            ACCURACY,  # in json: "accuracy"
            target_label=None,
            text_column_names=["text"],
            additional_column_names=[],
            recall_at_precision_thresholds=[0.2, 0.4, 0.6, 0.8, 0.9],
        )
        metric_reporter = ClassificationMetricReporter.from_config_and_label_names(
            config=metric_reporter_config, label_names=label_names)
        return self.train_from_state(state, training_data, eval_data,
                                     metric_reporter, train_config)
Exemplo n.º 2
0
 def create_metric_reporter(cls, config: Config,
                            tensorizers: Dict[str, Tensorizer]):
     return ClassificationMetricReporter.from_config_and_label_names(
         config.metric_reporter, list(tensorizers["labels"].labels))
Exemplo n.º 3
0
def classification_metric_reporter_config_expand(label_names: List[str],
                                                 **kwargs):
    classification_metric_reporter_config = ClassificationMetricReporter.Config(
        **kwargs)
    return ClassificationMetricReporter.from_config_and_label_names(
        classification_metric_reporter_config, label_names)