class Config(_NewTask.Config): tasks: Dict[str, NewTask.Config] = {} task_weights: Dict[str, float] = {} target_task_name: Optional[str] = None # for selecting best epoch data: DisjointMultitaskData.Config = DisjointMultitaskData.Config() metric_reporter: DisjointMultitaskMetricReporter.Config = ( DisjointMultitaskMetricReporter.Config())
def from_config(cls, task_config, metadata=None, model_state=None): datas = OrderedDict() models = OrderedDict() metric_reporters = OrderedDict() for name, task in task_config.tasks.items(): tensorizers, data = NewTask._init_tensorizers(task) datas[name] = data models[name] = NewTask._init_model(task, tensorizers) metric_reporters[name] = create_component( ComponentType.METRIC_REPORTER, task.metric_reporter, tensorizers=tensorizers, ) task_weights = { task_name: task_config.task_weights.get(task_name, 1.0) for task_name in task_config.tasks.keys() } data = DisjointMultitaskData(task_config.data, datas) model = NewDisjointMultitaskModel(models, loss_weights=task_weights) if model_state: model.load_state_dict(model_state) metric_reporter = DisjointMultitaskMetricReporter( metric_reporters, loss_weights=task_weights, target_task_name=task_config.target_task_name, use_subtask_select_metric=( task_config.metric_reporter.use_subtask_select_metric ), ) trainer = create_trainer(task_config.trainer, model) return cls(data, model, metric_reporter, trainer)
class Config(TaskBase.Config): tasks: Dict[str, Task.Config] task_weights: Dict[str, float] = {} data_handler: DisjointMultitaskDataHandler.Config = DisjointMultitaskDataHandler.Config( ) metric_reporter: DisjointMultitaskMetricReporter.Config = DisjointMultitaskMetricReporter.Config( )
class Config(TaskBase.Config): tasks: Dict[str, Task_Deprecated.Config] task_weights: Dict[str, float] = {} target_task_name: Optional[str] = None # for selecting best epoch data_handler: DisjointMultitaskDataHandler.Config = DisjointMultitaskDataHandler.Config( ) metric_reporter: DisjointMultitaskMetricReporter.Config = DisjointMultitaskMetricReporter.Config( )
def from_config(cls, task_config, metadata=None, model_state=None): print("Task parameters:\n") pprint(config_to_json(type(task_config), task_config)) data_handlers = OrderedDict() exporters = OrderedDict() for name, task in task_config.tasks.items(): featurizer = create_featurizer(task.featurizer, task.features) data_handlers[name] = create_data_handler(task.data_handler, task.features, task.labels, featurizer=featurizer) data_handler = DisjointMultitaskDataHandler(task_config.data_handler, data_handlers) print("\nLoading data...") if metadata: data_handler.load_metadata(metadata) else: data_handler.init_metadata() metadata = data_handler.metadata exporters = { name: (create_exporter( task.exporter, task.features, task.labels, data_handler.data_handlers[name].metadata, task.model, ) if task.exporter else None) for name, task in task_config.tasks.items() } metric_reporter = DisjointMultitaskMetricReporter( OrderedDict( (name, create_metric_reporter(task.metric_reporter, metadata[name])) for name, task in task_config.tasks.items()), target_task_name=task_config.metric_reporter.target_task_name, ) model = DisjointMultitaskModel( OrderedDict( (name, create_model(task.model, task.features, metadata[name])) for name, task in task_config.tasks.items())) if model_state: model.load_state_dict(model_state) if cuda_utils.CUDA_ENABLED: model = model.cuda() optimizers = create_optimizer(model, task_config.optimizer) return cls( exporters=exporters, trainer=create_trainer(task_config.trainer), data_handler=data_handler, model=model, metric_reporter=metric_reporter, optimizers=optimizers, lr_scheduler=Scheduler(optimizers, task_config.scheduler, metric_reporter.lower_is_better), )
def from_config( cls, task_config: Config, unused_metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1, ): data_dict = OrderedDict() models = OrderedDict() metric_reporters = OrderedDict() tensorizers_dict = tensorizers or {} # We can't really re-use the tensorizers, and the tensorizers saved for disjoint # multitask are an empty dict right now anyway. Really we should serialize # all of the subtasks individually. for name, task in task_config.tasks.items(): tensorizers, data = cls._init_tensorizers( task, tensorizers_dict.get(name), rank, world_size ) data_dict[name] = data models[name] = NewTask._init_model(task.model, tensorizers) metric_reporters[name] = create_component( ComponentType.METRIC_REPORTER, task.metric_reporter, tensorizers=tensorizers, ) task_weights = { task_name: task_config.task_weights.get(task_name, 1.0) for task_name in task_config.tasks.keys() } data = DisjointMultitaskData.from_config( task_config.data, data_dict=data_dict, rank=rank, world_size=world_size ) # for serialization data.tensorizers = { name: data.data_dict[name].tensorizers for name in data.data_dict } model = NewDisjointMultitaskModel(models, loss_weights=task_weights) if model_state: model.load_state_dict(model_state) metric_reporter = DisjointMultitaskMetricReporter( metric_reporters, loss_weights=task_weights, target_task_name=task_config.target_task_name, use_subtask_select_metric=( task_config.metric_reporter.use_subtask_select_metric ), ) trainer = create_trainer(task_config.trainer, model) return cls(data, model, metric_reporter, trainer)
def from_config( cls, task_config: Config, metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1, ): print("Task parameters:\n") pprint(config_to_json(type(task_config), task_config)) data_handlers = OrderedDict() exporters = OrderedDict() for name, task in task_config.tasks.items(): featurizer = create_featurizer(task.featurizer, task.features) data_handlers[name] = create_data_handler( task.data_handler, task.features, task.labels, featurizer=featurizer ) data_handler = DisjointMultitaskDataHandler( task_config.data_handler, data_handlers, target_task_name=task_config.target_task_name, ) print("\nLoading data...") if metadata: data_handler.load_metadata(metadata) else: data_handler.init_metadata() metadata = data_handler.metadata exporters = { name: ( create_exporter( task.exporter, task.features, task.labels, data_handler.data_handlers[name].metadata, task.model, ) if task.exporter else None ) for name, task in task_config.tasks.items() } task_weights = { task_name: task_config.task_weights.get(task_name, 1) for task_name in task_config.tasks.keys() } metric_reporter = DisjointMultitaskMetricReporter( OrderedDict( (name, create_metric_reporter(task.metric_reporter, metadata[name])) for name, task in task_config.tasks.items() ), loss_weights=task_weights, target_task_name=task_config.target_task_name, use_subtask_select_metric=( task_config.metric_reporter.use_subtask_select_metric ), ) model = DisjointMultitaskModel( OrderedDict( (name, create_model(task.model, task.features, metadata[name])) for name, task in task_config.tasks.items() ), loss_weights=task_weights, ) if model_state: model.load_state_dict(model_state) if cuda.CUDA_ENABLED: model = model.cuda() return cls( target_task_name=task_config.target_task_name, exporters=exporters, trainer=create_trainer(task_config.trainer, model), data_handler=data_handler, model=model, metric_reporter=metric_reporter, )