예제 #1
0
    def from_config(cls, task_config, metadata=None, model_state=None):
        print("Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))

        data_handlers = OrderedDict()
        exporters = OrderedDict()
        for name, task in task_config.tasks.items():
            featurizer = create_featurizer(task.featurizer, task.features)
            data_handlers[name] = create_data_handler(task.data_handler,
                                                      task.features,
                                                      task.labels,
                                                      featurizer=featurizer)
        data_handler = DisjointMultitaskDataHandler(task_config.data_handler,
                                                    data_handlers)
        print("\nLoading data...")
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()
        metadata = data_handler.metadata
        exporters = {
            name: (create_exporter(
                task.exporter,
                task.features,
                task.labels,
                data_handler.data_handlers[name].metadata,
                task.model,
            ) if task.exporter else None)
            for name, task in task_config.tasks.items()
        }
        metric_reporter = DisjointMultitaskMetricReporter(
            OrderedDict(
                (name,
                 create_metric_reporter(task.metric_reporter, metadata[name]))
                for name, task in task_config.tasks.items()),
            target_task_name=task_config.metric_reporter.target_task_name,
        )

        model = DisjointMultitaskModel(
            OrderedDict(
                (name, create_model(task.model, task.features, metadata[name]))
                for name, task in task_config.tasks.items()))
        if model_state:
            model.load_state_dict(model_state)
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()

        optimizers = create_optimizer(model, task_config.optimizer)
        return cls(
            exporters=exporters,
            trainer=create_trainer(task_config.trainer),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
            optimizers=optimizers,
            lr_scheduler=Scheduler(optimizers, task_config.scheduler,
                                   metric_reporter.lower_is_better),
        )
예제 #2
0
    def from_config(cls, task_config, metadata=None, model_state=None):
        """
        Create the task from config, and optionally load metadata/model_state
        This function will create components including :class:`~DataHandler`,
        :class:`~Trainer`, :class:`~Optimizer`, :class:`~Scheduler`,
        :class:`~MetricReporter`, :class:`~Exporter`, and wire them up.

        Args:
            task_config (Task.Config): the config of the current task
            metadata: saved global context of this task, e.g: vocabulary, will be
                generated by :class:`~DataHandler` if it's None
            model_state: saved model parameters, will be loaded into model when given
        """
        print("Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))
        featurizer = create_featurizer(task_config.featurizer,
                                       task_config.features)
        # load data
        data_handler = create_data_handler(
            task_config.data_handler,
            task_config.features,
            task_config.labels,
            featurizer=featurizer,
        )
        print("\nLoading data...")
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()

        metadata = data_handler.metadata

        model = create_model(task_config.model, task_config.features, metadata)
        if model_state:
            model.load_state_dict(model_state)
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()
        metric_reporter = create_metric_reporter(task_config.metric_reporter,
                                                 metadata)
        optimizers = create_optimizer(model, task_config.optimizer)
        exporter = (create_exporter(
            task_config.exporter,
            task_config.features,
            task_config.labels,
            data_handler.metadata,
            task_config.model,
        ) if task_config.exporter else None)
        return cls(
            trainer=create_trainer(task_config.trainer),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
            optimizers=optimizers,
            lr_scheduler=Scheduler(optimizers, task_config.scheduler,
                                   metric_reporter.lower_is_better),
            exporter=exporter,
        )