class Config(_NewTask.Config): tasks: Dict[str, NewTask.Config] = {} task_weights: Dict[str, float] = {} target_task_name: Optional[str] = None # for selecting best epoch data: DisjointMultitaskData.Config = DisjointMultitaskData.Config() metric_reporter: DisjointMultitaskMetricReporter.Config = ( DisjointMultitaskMetricReporter.Config())
def from_config(cls, task_config, metadata=None, model_state=None): datas = OrderedDict() models = OrderedDict() metric_reporters = OrderedDict() for name, task in task_config.tasks.items(): tensorizers, data = NewTask._init_tensorizers(task) datas[name] = data models[name] = NewTask._init_model(task, tensorizers) metric_reporters[name] = create_component( ComponentType.METRIC_REPORTER, task.metric_reporter, tensorizers=tensorizers, ) task_weights = { task_name: task_config.task_weights.get(task_name, 1.0) for task_name in task_config.tasks.keys() } data = DisjointMultitaskData(task_config.data, datas) model = NewDisjointMultitaskModel(models, loss_weights=task_weights) if model_state: model.load_state_dict(model_state) metric_reporter = DisjointMultitaskMetricReporter( metric_reporters, loss_weights=task_weights, target_task_name=task_config.target_task_name, use_subtask_select_metric=( task_config.metric_reporter.use_subtask_select_metric ), ) trainer = create_trainer(task_config.trainer, model) return cls(data, model, metric_reporter, trainer)
def from_config( cls, task_config: Config, unused_metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1, ): data_dict = OrderedDict() models = OrderedDict() metric_reporters = OrderedDict() tensorizers_dict = tensorizers or {} # We can't really re-use the tensorizers, and the tensorizers saved for disjoint # multitask are an empty dict right now anyway. Really we should serialize # all of the subtasks individually. for name, task in task_config.tasks.items(): tensorizers, data = cls._init_tensorizers( task, tensorizers_dict.get(name), rank, world_size ) data_dict[name] = data models[name] = NewTask._init_model(task.model, tensorizers) metric_reporters[name] = create_component( ComponentType.METRIC_REPORTER, task.metric_reporter, tensorizers=tensorizers, ) task_weights = { task_name: task_config.task_weights.get(task_name, 1.0) for task_name in task_config.tasks.keys() } data = DisjointMultitaskData.from_config( task_config.data, data_dict=data_dict, rank=rank, world_size=world_size ) # for serialization data.tensorizers = { name: data.data_dict[name].tensorizers for name in data.data_dict } model = NewDisjointMultitaskModel(models, loss_weights=task_weights) if model_state: model.load_state_dict(model_state) metric_reporter = DisjointMultitaskMetricReporter( metric_reporters, loss_weights=task_weights, target_task_name=task_config.target_task_name, use_subtask_select_metric=( task_config.metric_reporter.use_subtask_select_metric ), ) trainer = create_trainer(task_config.trainer, model) return cls(data, model, metric_reporter, trainer)