def from_config(cls, task_config, metadata=None, model_state=None): """ Create the task from config, and optionally load metadata/model_state This function will create components including :class:`~DataHandler`, :class:`~Trainer`, :class:`~Optimizer`, :class:`~Scheduler`, :class:`~MetricReporter`, :class:`~Exporter`, and wire them up. Args: task_config (Task.Config): the config of the current task metadata: saved global context of this task, e.g: vocabulary, will be generated by :class:`~DataHandler` if it's None model_state: saved model parameters, will be loaded into model when given """ print("Task parameters:\n") pprint(config_to_json(type(task_config), task_config)) featurizer = create_featurizer(task_config.featurizer, task_config.features) # load data data_handler = create_data_handler( task_config.data_handler, task_config.features, task_config.labels, featurizer=featurizer, ) print("\nLoading data...") if metadata: data_handler.load_metadata(metadata) else: data_handler.init_metadata() metadata = data_handler.metadata model = create_model(task_config.model, task_config.features, metadata) if model_state: model.load_state_dict(model_state) if cuda_utils.CUDA_ENABLED: model = model.cuda() metric_reporter = create_metric_reporter(task_config.metric_reporter, metadata) optimizer = create_optimizer(task_config.optimizer, model) exporter = ( create_exporter( task_config.exporter, task_config.features, task_config.labels, data_handler.metadata, task_config.model, ) if task_config.exporter else None ) return cls( trainer=create_trainer(task_config.trainer), data_handler=data_handler, model=model, metric_reporter=metric_reporter, optimizer=optimizer, lr_scheduler=Scheduler( optimizer, task_config.scheduler, metric_reporter.lower_is_better ), exporter=exporter, )
def save_checkpoint( f: io.IOBase, config: PyTextConfig, model: Model, meta: Optional[CommonMetadata], tensorizers: Dict[str, Tensorizer], training_state: Optional[TrainingState] = None, ) -> str: # Currently torch.save() has error pickling certain models when not saving # by model.state_dict(), thus currently overriding the model in # training_state with None, and put back saving # https://github.com/pytorch/pytorch/issues/15116 model_in_training_state = None if training_state: model_in_training_state, training_state.model = training_state.model, None try: state = { DATA_STATE: meta, CONFIG_JSON: config_to_json(PyTextConfig, config), MODEL_STATE: model.state_dict(), SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION, TENSORIZERS: tensorizers, TRAINING_STATE: training_state, } torch.save(state, f) finally: if training_state: training_state.model = model_in_training_state
def from_config(cls, task_config, metadata=None, model_state=None): print("Task parameters:\n") pprint(config_to_json(type(task_config), task_config)) data_handlers = OrderedDict() exporters = OrderedDict() for name, task in task_config.tasks.items(): featurizer = create_featurizer(task.featurizer, task.features) data_handlers[name] = create_data_handler(task.data_handler, task.features, task.labels, featurizer=featurizer) data_handler = DisjointMultitaskDataHandler(task_config.data_handler, data_handlers) print("\nLoading data...") if metadata: data_handler.load_metadata(metadata) else: data_handler.init_metadata() metadata = data_handler.metadata exporters = { name: (create_exporter( task.exporter, task.features, task.labels, data_handler.data_handlers[name].metadata, task.model, ) if task.exporter else None) for name, task in task_config.tasks.items() } metric_reporter = DisjointMultitaskMetricReporter( OrderedDict( (name, create_metric_reporter(task.metric_reporter, metadata[name])) for name, task in task_config.tasks.items()), target_task_name=task_config.metric_reporter.target_task_name, ) model = DisjointMultitaskModel( OrderedDict( (name, create_model(task.model, task.features, metadata[name])) for name, task in task_config.tasks.items())) if model_state: model.load_state_dict(model_state) if cuda_utils.CUDA_ENABLED: model = model.cuda() optimizers = create_optimizer(model, task_config.optimizer) return cls( exporters=exporters, trainer=create_trainer(task_config.trainer), data_handler=data_handler, model=model, metric_reporter=metric_reporter, optimizers=optimizers, lr_scheduler=Scheduler(optimizers, task_config.scheduler, metric_reporter.lower_is_better), )
def make_config_from_dict(self, config, disable_tensorboard): # config is the path module name of the actual PyText config if isinstance(config, str): config = config_to_json(PyTextConfig, import_module(config).config) config = upgrade_to_latest(config) # Disable TensorBoard for integration tests if disable_tensorboard: config["use_tensorboard"] = False return self.disable_cuda(self.fix_paths(config))
def save(config: PyTextConfig, model: Model, meta: CommonMetadata) -> None: """ Save a task, will save the original config, model state and metadata """ save_path = config.save_snapshot_path print(f"Saving pytorch model to: {save_path}") model.save_modules(base_path=config.modules_save_dir) state = OrderedDict([ (DATA_STATE, meta), (CONFIG_JSON, config_to_json(PyTextConfig, config)), (MODEL_STATE, model.state_dict()), ]) # type: OrderedDict torch.save(state, save_path)
def from_config(cls, task_config, metadata=None, model_state=None): print("(mldc/task/gpt_task.py def from_config) Task parameters:\n") pprint(config_to_json(type(task_config), task_config)) featurizer = create_featurizer( task_config.featurizer, task_config.features, text_embedder_config=task_config.text_embedder ) # featurizer :: text embedder GPT2Embed # load data data_handler = create_data_handler( task_config.data_handler, task_config.features, task_config.labels, text_embedder_config=task_config.text_embedder, featurizer=featurizer, ) print( "\n(mldc/task/retrieval.py GptTask def from_config) Loading data..." ) if metadata: data_handler.load_metadata(metadata) else: data_handler.init_metadata() metadata = data_handler.metadata task_config.features.seq_word_feat.embed_dim = data_handler.text_embedder.embed_dim print("create model!") model = create_model(task_config.model, task_config.features, metadata) if model_state: model.load_state_dict(model_state) if cuda_utils.CUDA_ENABLED: model = model.cuda() metric_reporter = create_metric_reporter( task_config.metric_reporter, metadata, text_embedder=task_config.text_embedder) return cls( trainer=create_trainer(task_config.trainer), data_handler=data_handler, model=model, metric_reporter=metric_reporter, model_needs_meta_training=task_config.model_needs_meta_training, )
def save( config: PyTextConfig, model: Model, meta: CommonMetadata, tensorizers: Dict[str, Tensorizer], ) -> None: """ Save a task, will save the original config, model state and metadata """ save_path = config.save_snapshot_path print(f"Saving pytorch model to: {save_path}") model.save_modules(base_path=config.modules_save_dir) state = { DATA_STATE: meta, CONFIG_JSON: config_to_json(PyTextConfig, config), MODEL_STATE: model.state_dict(), SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION, TENSORIZERS: tensorizers, } torch.save(state, save_path)
def save( config: PyTextConfig, model: Model, meta: Optional[CommonMetadata], tensorizers: Dict[str, Tensorizer], training_state: Optional[TrainingState] = None, f: Optional[io.IOBase] = None, ) -> None: """ Save all stateful information of a training task to a specified file-like object, will save the original config, model state, metadata, training state if training is not completed """ if config.modules_save_dir: model.save_modules(base_path=config.modules_save_dir) # Currently torch.save() has error pickling certain models when not saving # by model.state_dict(), thus currently overriding the model in # training_state with None, and put back saving # https://github.com/pytorch/pytorch/issues/15116 model_in_training_state = None if training_state: model_in_training_state, training_state.model = training_state.model, None try: state = { DATA_STATE: meta, CONFIG_JSON: config_to_json(PyTextConfig, config), MODEL_STATE: model.state_dict(), SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION, TENSORIZERS: tensorizers, TRAINING_STATE: training_state, } if f is None: save_path = config.save_snapshot_path print(f"Saving pytorch model to: {save_path}") torch.save(state, save_path) else: torch.save(state, f) finally: if training_state: training_state.model = model_in_training_state
def from_config( cls, task_config: Config, metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1, ): print("Task parameters:\n") pprint(config_to_json(type(task_config), task_config)) data_handlers = OrderedDict() exporters = OrderedDict() for name, task in task_config.tasks.items(): featurizer = create_featurizer(task.featurizer, task.features) data_handlers[name] = create_data_handler( task.data_handler, task.features, task.labels, featurizer=featurizer ) data_handler = DisjointMultitaskDataHandler( task_config.data_handler, data_handlers, target_task_name=task_config.target_task_name, ) print("\nLoading data...") if metadata: data_handler.load_metadata(metadata) else: data_handler.init_metadata() metadata = data_handler.metadata exporters = { name: ( create_exporter( task.exporter, task.features, task.labels, data_handler.data_handlers[name].metadata, task.model, ) if task.exporter else None ) for name, task in task_config.tasks.items() } task_weights = { task_name: task_config.task_weights.get(task_name, 1) for task_name in task_config.tasks.keys() } metric_reporter = DisjointMultitaskMetricReporter( OrderedDict( (name, create_metric_reporter(task.metric_reporter, metadata[name])) for name, task in task_config.tasks.items() ), loss_weights=task_weights, target_task_name=task_config.target_task_name, use_subtask_select_metric=( task_config.metric_reporter.use_subtask_select_metric ), ) model = DisjointMultitaskModel( OrderedDict( (name, create_model(task.model, task.features, metadata[name])) for name, task in task_config.tasks.items() ), loss_weights=task_weights, ) if model_state: model.load_state_dict(model_state) if cuda.CUDA_ENABLED: model = model.cuda() return cls( target_task_name=task_config.target_task_name, exporters=exporters, trainer=create_trainer(task_config.trainer, model), data_handler=data_handler, model=model, metric_reporter=metric_reporter, )
def save( config: PyTextConfig, model: Model, meta: Optional[CommonMetadata], tensorizers: Dict[str, Tensorizer], training_state: Optional[TrainingState] = None, identifier: Optional[str] = None, ) -> str: """ Save all stateful information of a training task to a specified file-like object, will save the original config, model state, metadata, training state if training is not completed Args: identifier (str): used to identify a checkpoint within a training job, used as a suffix for save path config (PytextConfig): contains all raw parameter/hyper-parameters for training task model (Model): actual model in training training_state (TrainingState): stateful infomation during training Returns: identifier (str): if identifier is not specified, will save to config.save_snapshot_path to be consistent to post-training snapshot; if specified, will be used to save checkpoint during training, identifier is used to identify checkpoints in the same training """ saved_path = "" if identifier: # saving during-training checkpoints saved_path = generate_checkpoint_path(config, identifier) else: # saving post-training snapshot if no identifer given saved_path = config.save_snapshot_path print(f"Saving pytorch model to: {saved_path}") saved_folder = os.path.dirname(saved_path) if not PathManager.exists(saved_folder): PathManager.mkdirs(saved_folder) print(f"created {saved_folder}") # Currently torch.save() has error pickling certain models when not saving # by model.state_dict(), thus currently overriding the model in # training_state with None, and put back saving # https://github.com/pytorch/pytorch/issues/15116 model_in_training_state = None if training_state: model_in_training_state, training_state.model = training_state.model, None try: state = { DATA_STATE: meta, CONFIG_JSON: config_to_json(PyTextConfig, config), MODEL_STATE: model.state_dict(), SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION, TENSORIZERS: tensorizers, TRAINING_STATE: training_state, } if identifier is not None: _CHECKPOINT_MANAGER.save_checkpoint(state, saved_path) else: _CHECKPOINT_MANAGER.save_snapshot(state, saved_path) finally: if training_state: training_state.model = model_in_training_state return saved_path