コード例 #1
0
 class Config(_NewTask.Config):
     tasks: Dict[str, NewTask.Config] = {}
     task_weights: Dict[str, float] = {}
     target_task_name: Optional[str] = None  # for selecting best epoch
     data: DisjointMultitaskData.Config = DisjointMultitaskData.Config()
     metric_reporter: DisjointMultitaskMetricReporter.Config = (
         DisjointMultitaskMetricReporter.Config())
コード例 #2
0
    def from_config(cls, task_config, metadata=None, model_state=None):
        datas = OrderedDict()
        models = OrderedDict()
        metric_reporters = OrderedDict()
        for name, task in task_config.tasks.items():
            tensorizers, data = NewTask._init_tensorizers(task)
            datas[name] = data
            models[name] = NewTask._init_model(task, tensorizers)
            metric_reporters[name] = create_component(
                ComponentType.METRIC_REPORTER,
                task.metric_reporter,
                tensorizers=tensorizers,
            )

        task_weights = {
            task_name: task_config.task_weights.get(task_name, 1.0)
            for task_name in task_config.tasks.keys()
        }
        data = DisjointMultitaskData(task_config.data, datas)
        model = NewDisjointMultitaskModel(models, loss_weights=task_weights)
        if model_state:
            model.load_state_dict(model_state)
        metric_reporter = DisjointMultitaskMetricReporter(
            metric_reporters,
            loss_weights=task_weights,
            target_task_name=task_config.target_task_name,
            use_subtask_select_metric=(
                task_config.metric_reporter.use_subtask_select_metric
            ),
        )
        trainer = create_trainer(task_config.trainer, model)

        return cls(data, model, metric_reporter, trainer)
コード例 #3
0
 class Config(TaskBase.Config):
     tasks: Dict[str, Task.Config]
     task_weights: Dict[str, float] = {}
     data_handler: DisjointMultitaskDataHandler.Config = DisjointMultitaskDataHandler.Config(
     )
     metric_reporter: DisjointMultitaskMetricReporter.Config = DisjointMultitaskMetricReporter.Config(
     )
コード例 #4
0
 class Config(TaskBase.Config):
     tasks: Dict[str, Task_Deprecated.Config]
     task_weights: Dict[str, float] = {}
     target_task_name: Optional[str] = None  # for selecting best epoch
     data_handler: DisjointMultitaskDataHandler.Config = DisjointMultitaskDataHandler.Config(
     )
     metric_reporter: DisjointMultitaskMetricReporter.Config = DisjointMultitaskMetricReporter.Config(
     )
コード例 #5
0
    def from_config(cls, task_config, metadata=None, model_state=None):
        print("Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))

        data_handlers = OrderedDict()
        exporters = OrderedDict()
        for name, task in task_config.tasks.items():
            featurizer = create_featurizer(task.featurizer, task.features)
            data_handlers[name] = create_data_handler(task.data_handler,
                                                      task.features,
                                                      task.labels,
                                                      featurizer=featurizer)
        data_handler = DisjointMultitaskDataHandler(task_config.data_handler,
                                                    data_handlers)
        print("\nLoading data...")
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()
        metadata = data_handler.metadata
        exporters = {
            name: (create_exporter(
                task.exporter,
                task.features,
                task.labels,
                data_handler.data_handlers[name].metadata,
                task.model,
            ) if task.exporter else None)
            for name, task in task_config.tasks.items()
        }
        metric_reporter = DisjointMultitaskMetricReporter(
            OrderedDict(
                (name,
                 create_metric_reporter(task.metric_reporter, metadata[name]))
                for name, task in task_config.tasks.items()),
            target_task_name=task_config.metric_reporter.target_task_name,
        )

        model = DisjointMultitaskModel(
            OrderedDict(
                (name, create_model(task.model, task.features, metadata[name]))
                for name, task in task_config.tasks.items()))
        if model_state:
            model.load_state_dict(model_state)
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()

        optimizers = create_optimizer(model, task_config.optimizer)
        return cls(
            exporters=exporters,
            trainer=create_trainer(task_config.trainer),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
            optimizers=optimizers,
            lr_scheduler=Scheduler(optimizers, task_config.scheduler,
                                   metric_reporter.lower_is_better),
        )
コード例 #6
0
    def from_config(
        cls,
        task_config: Config,
        unused_metadata=None,
        model_state=None,
        tensorizers=None,
        rank=0,
        world_size=1,
    ):
        data_dict = OrderedDict()
        models = OrderedDict()
        metric_reporters = OrderedDict()
        tensorizers_dict = tensorizers or {}
        # We can't really re-use the tensorizers, and the tensorizers saved for disjoint
        # multitask are an empty dict right now anyway. Really we should serialize
        # all of the subtasks individually.
        for name, task in task_config.tasks.items():
            tensorizers, data = cls._init_tensorizers(
                task, tensorizers_dict.get(name), rank, world_size
            )
            data_dict[name] = data
            models[name] = NewTask._init_model(task.model, tensorizers)
            metric_reporters[name] = create_component(
                ComponentType.METRIC_REPORTER,
                task.metric_reporter,
                tensorizers=tensorizers,
            )

        task_weights = {
            task_name: task_config.task_weights.get(task_name, 1.0)
            for task_name in task_config.tasks.keys()
        }
        data = DisjointMultitaskData.from_config(
            task_config.data, data_dict=data_dict, rank=rank, world_size=world_size
        )
        # for serialization
        data.tensorizers = {
            name: data.data_dict[name].tensorizers for name in data.data_dict
        }
        model = NewDisjointMultitaskModel(models, loss_weights=task_weights)
        if model_state:
            model.load_state_dict(model_state)
        metric_reporter = DisjointMultitaskMetricReporter(
            metric_reporters,
            loss_weights=task_weights,
            target_task_name=task_config.target_task_name,
            use_subtask_select_metric=(
                task_config.metric_reporter.use_subtask_select_metric
            ),
        )
        trainer = create_trainer(task_config.trainer, model)

        return cls(data, model, metric_reporter, trainer)
コード例 #7
0
    def from_config(
        cls,
        task_config: Config,
        metadata=None,
        model_state=None,
        tensorizers=None,
        rank=0,
        world_size=1,
    ):
        print("Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))

        data_handlers = OrderedDict()
        exporters = OrderedDict()
        for name, task in task_config.tasks.items():
            featurizer = create_featurizer(task.featurizer, task.features)
            data_handlers[name] = create_data_handler(
                task.data_handler, task.features, task.labels, featurizer=featurizer
            )
        data_handler = DisjointMultitaskDataHandler(
            task_config.data_handler,
            data_handlers,
            target_task_name=task_config.target_task_name,
        )
        print("\nLoading data...")
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()

        metadata = data_handler.metadata
        exporters = {
            name: (
                create_exporter(
                    task.exporter,
                    task.features,
                    task.labels,
                    data_handler.data_handlers[name].metadata,
                    task.model,
                )
                if task.exporter
                else None
            )
            for name, task in task_config.tasks.items()
        }
        task_weights = {
            task_name: task_config.task_weights.get(task_name, 1)
            for task_name in task_config.tasks.keys()
        }
        metric_reporter = DisjointMultitaskMetricReporter(
            OrderedDict(
                (name, create_metric_reporter(task.metric_reporter, metadata[name]))
                for name, task in task_config.tasks.items()
            ),
            loss_weights=task_weights,
            target_task_name=task_config.target_task_name,
            use_subtask_select_metric=(
                task_config.metric_reporter.use_subtask_select_metric
            ),
        )
        model = DisjointMultitaskModel(
            OrderedDict(
                (name, create_model(task.model, task.features, metadata[name]))
                for name, task in task_config.tasks.items()
            ),
            loss_weights=task_weights,
        )
        if model_state:
            model.load_state_dict(model_state)
        if cuda.CUDA_ENABLED:
            model = model.cuda()

        return cls(
            target_task_name=task_config.target_task_name,
            exporters=exporters,
            trainer=create_trainer(task_config.trainer, model),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
        )