Beispiel #1
0
    def from_config(cls, task_config, metadata=None, model_state=None):
        """
        Create the task from config, and optionally load metadata/model_state
        This function will create components including :class:`~DataHandler`,
        :class:`~Trainer`, :class:`~Optimizer`, :class:`~Scheduler`,
        :class:`~MetricReporter`, :class:`~Exporter`, and wire them up.

        Args:
            task_config (Task.Config): the config of the current task
            metadata: saved global context of this task, e.g: vocabulary, will be
                generated by :class:`~DataHandler` if it's None
            model_state: saved model parameters, will be loaded into model when given
        """
        print("Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))
        featurizer = create_featurizer(task_config.featurizer, task_config.features)
        # load data
        data_handler = create_data_handler(
            task_config.data_handler,
            task_config.features,
            task_config.labels,
            featurizer=featurizer,
        )
        print("\nLoading data...")
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()

        metadata = data_handler.metadata

        model = create_model(task_config.model, task_config.features, metadata)
        if model_state:
            model.load_state_dict(model_state)
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()
        metric_reporter = create_metric_reporter(task_config.metric_reporter, metadata)
        optimizer = create_optimizer(task_config.optimizer, model)
        exporter = (
            create_exporter(
                task_config.exporter,
                task_config.features,
                task_config.labels,
                data_handler.metadata,
                task_config.model,
            )
            if task_config.exporter
            else None
        )
        return cls(
            trainer=create_trainer(task_config.trainer),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
            optimizer=optimizer,
            lr_scheduler=Scheduler(
                optimizer, task_config.scheduler, metric_reporter.lower_is_better
            ),
            exporter=exporter,
        )
Beispiel #2
0
def save_checkpoint(
    f: io.IOBase,
    config: PyTextConfig,
    model: Model,
    meta: Optional[CommonMetadata],
    tensorizers: Dict[str, Tensorizer],
    training_state: Optional[TrainingState] = None,
) -> str:
    # Currently torch.save() has error pickling certain models when not saving
    # by model.state_dict(), thus currently overriding the model in
    # training_state with None, and put back saving
    # https://github.com/pytorch/pytorch/issues/15116
    model_in_training_state = None
    if training_state:
        model_in_training_state, training_state.model = training_state.model, None
    try:
        state = {
            DATA_STATE: meta,
            CONFIG_JSON: config_to_json(PyTextConfig, config),
            MODEL_STATE: model.state_dict(),
            SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
            TENSORIZERS: tensorizers,
            TRAINING_STATE: training_state,
        }
        torch.save(state, f)
    finally:
        if training_state:
            training_state.model = model_in_training_state
Beispiel #3
0
    def from_config(cls, task_config, metadata=None, model_state=None):
        print("Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))

        data_handlers = OrderedDict()
        exporters = OrderedDict()
        for name, task in task_config.tasks.items():
            featurizer = create_featurizer(task.featurizer, task.features)
            data_handlers[name] = create_data_handler(task.data_handler,
                                                      task.features,
                                                      task.labels,
                                                      featurizer=featurizer)
        data_handler = DisjointMultitaskDataHandler(task_config.data_handler,
                                                    data_handlers)
        print("\nLoading data...")
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()
        metadata = data_handler.metadata
        exporters = {
            name: (create_exporter(
                task.exporter,
                task.features,
                task.labels,
                data_handler.data_handlers[name].metadata,
                task.model,
            ) if task.exporter else None)
            for name, task in task_config.tasks.items()
        }
        metric_reporter = DisjointMultitaskMetricReporter(
            OrderedDict(
                (name,
                 create_metric_reporter(task.metric_reporter, metadata[name]))
                for name, task in task_config.tasks.items()),
            target_task_name=task_config.metric_reporter.target_task_name,
        )

        model = DisjointMultitaskModel(
            OrderedDict(
                (name, create_model(task.model, task.features, metadata[name]))
                for name, task in task_config.tasks.items()))
        if model_state:
            model.load_state_dict(model_state)
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()

        optimizers = create_optimizer(model, task_config.optimizer)
        return cls(
            exporters=exporters,
            trainer=create_trainer(task_config.trainer),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
            optimizers=optimizers,
            lr_scheduler=Scheduler(optimizers, task_config.scheduler,
                                   metric_reporter.lower_is_better),
        )
Beispiel #4
0
 def make_config_from_dict(self, config, disable_tensorboard):
     # config is the path module name of the actual PyText config
     if isinstance(config, str):
         config = config_to_json(PyTextConfig, import_module(config).config)
     config = upgrade_to_latest(config)
     # Disable TensorBoard for integration tests
     if disable_tensorboard:
         config["use_tensorboard"] = False
     return self.disable_cuda(self.fix_paths(config))
Beispiel #5
0
def save(config: PyTextConfig, model: Model, meta: CommonMetadata) -> None:
    """
    Save a task, will save the original config, model state and metadata
    """
    save_path = config.save_snapshot_path
    print(f"Saving pytorch model to: {save_path}")
    model.save_modules(base_path=config.modules_save_dir)
    state = OrderedDict([
        (DATA_STATE, meta),
        (CONFIG_JSON, config_to_json(PyTextConfig, config)),
        (MODEL_STATE, model.state_dict()),
    ])  # type: OrderedDict
    torch.save(state, save_path)
Beispiel #6
0
    def from_config(cls, task_config, metadata=None, model_state=None):
        print("(mldc/task/gpt_task.py def from_config) Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))

        featurizer = create_featurizer(
            task_config.featurizer,
            task_config.features,
            text_embedder_config=task_config.text_embedder
        )  # featurizer :: text embedder GPT2Embed

        # load data
        data_handler = create_data_handler(
            task_config.data_handler,
            task_config.features,
            task_config.labels,
            text_embedder_config=task_config.text_embedder,
            featurizer=featurizer,
        )
        print(
            "\n(mldc/task/retrieval.py GptTask def from_config) Loading data..."
        )
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()

        metadata = data_handler.metadata
        task_config.features.seq_word_feat.embed_dim = data_handler.text_embedder.embed_dim

        print("create model!")
        model = create_model(task_config.model, task_config.features, metadata)
        if model_state:
            model.load_state_dict(model_state)
        if cuda_utils.CUDA_ENABLED:
            model = model.cuda()
        metric_reporter = create_metric_reporter(
            task_config.metric_reporter,
            metadata,
            text_embedder=task_config.text_embedder)

        return cls(
            trainer=create_trainer(task_config.trainer),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
            model_needs_meta_training=task_config.model_needs_meta_training,
        )
Beispiel #7
0
def save(
    config: PyTextConfig,
    model: Model,
    meta: CommonMetadata,
    tensorizers: Dict[str, Tensorizer],
) -> None:
    """
    Save a task, will save the original config, model state and metadata
    """
    save_path = config.save_snapshot_path
    print(f"Saving pytorch model to: {save_path}")
    model.save_modules(base_path=config.modules_save_dir)
    state = {
        DATA_STATE: meta,
        CONFIG_JSON: config_to_json(PyTextConfig, config),
        MODEL_STATE: model.state_dict(),
        SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
        TENSORIZERS: tensorizers,
    }
    torch.save(state, save_path)
def save(
    config: PyTextConfig,
    model: Model,
    meta: Optional[CommonMetadata],
    tensorizers: Dict[str, Tensorizer],
    training_state: Optional[TrainingState] = None,
    f: Optional[io.IOBase] = None,
) -> None:
    """
    Save all stateful information of a training task to a specified file-like
    object, will save the original config, model state, metadata,
    training state if training is not completed
    """
    if config.modules_save_dir:
        model.save_modules(base_path=config.modules_save_dir)

    # Currently torch.save() has error pickling certain models when not saving
    # by model.state_dict(), thus currently overriding the model in
    # training_state with None, and put back saving
    # https://github.com/pytorch/pytorch/issues/15116
    model_in_training_state = None
    if training_state:
        model_in_training_state, training_state.model = training_state.model, None
    try:
        state = {
            DATA_STATE: meta,
            CONFIG_JSON: config_to_json(PyTextConfig, config),
            MODEL_STATE: model.state_dict(),
            SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
            TENSORIZERS: tensorizers,
            TRAINING_STATE: training_state,
        }
        if f is None:
            save_path = config.save_snapshot_path
            print(f"Saving pytorch model to: {save_path}")
            torch.save(state, save_path)
        else:
            torch.save(state, f)
    finally:
        if training_state:
            training_state.model = model_in_training_state
Beispiel #9
0
    def from_config(
        cls,
        task_config: Config,
        metadata=None,
        model_state=None,
        tensorizers=None,
        rank=0,
        world_size=1,
    ):
        print("Task parameters:\n")
        pprint(config_to_json(type(task_config), task_config))

        data_handlers = OrderedDict()
        exporters = OrderedDict()
        for name, task in task_config.tasks.items():
            featurizer = create_featurizer(task.featurizer, task.features)
            data_handlers[name] = create_data_handler(
                task.data_handler, task.features, task.labels, featurizer=featurizer
            )
        data_handler = DisjointMultitaskDataHandler(
            task_config.data_handler,
            data_handlers,
            target_task_name=task_config.target_task_name,
        )
        print("\nLoading data...")
        if metadata:
            data_handler.load_metadata(metadata)
        else:
            data_handler.init_metadata()

        metadata = data_handler.metadata
        exporters = {
            name: (
                create_exporter(
                    task.exporter,
                    task.features,
                    task.labels,
                    data_handler.data_handlers[name].metadata,
                    task.model,
                )
                if task.exporter
                else None
            )
            for name, task in task_config.tasks.items()
        }
        task_weights = {
            task_name: task_config.task_weights.get(task_name, 1)
            for task_name in task_config.tasks.keys()
        }
        metric_reporter = DisjointMultitaskMetricReporter(
            OrderedDict(
                (name, create_metric_reporter(task.metric_reporter, metadata[name]))
                for name, task in task_config.tasks.items()
            ),
            loss_weights=task_weights,
            target_task_name=task_config.target_task_name,
            use_subtask_select_metric=(
                task_config.metric_reporter.use_subtask_select_metric
            ),
        )
        model = DisjointMultitaskModel(
            OrderedDict(
                (name, create_model(task.model, task.features, metadata[name]))
                for name, task in task_config.tasks.items()
            ),
            loss_weights=task_weights,
        )
        if model_state:
            model.load_state_dict(model_state)
        if cuda.CUDA_ENABLED:
            model = model.cuda()

        return cls(
            target_task_name=task_config.target_task_name,
            exporters=exporters,
            trainer=create_trainer(task_config.trainer, model),
            data_handler=data_handler,
            model=model,
            metric_reporter=metric_reporter,
        )
Beispiel #10
0
def save(
    config: PyTextConfig,
    model: Model,
    meta: Optional[CommonMetadata],
    tensorizers: Dict[str, Tensorizer],
    training_state: Optional[TrainingState] = None,
    identifier: Optional[str] = None,
) -> str:
    """
    Save all stateful information of a training task to a specified file-like
    object, will save the original config, model state, metadata,
    training state if training is not completed
    Args:
    identifier (str): used to identify a checkpoint within a training job,
    used as a suffix for save path
    config (PytextConfig): contains all raw parameter/hyper-parameters
    for training task
    model (Model): actual model in training
    training_state (TrainingState): stateful infomation during training
    Returns:
    identifier (str): if identifier is not specified, will save to
    config.save_snapshot_path to be consistent to post-training snapshot;
    if specified, will be used to save checkpoint during training,
    identifier is used to identify checkpoints in the same training
    """
    saved_path = ""
    if identifier:
        # saving during-training checkpoints
        saved_path = generate_checkpoint_path(config, identifier)
    else:
        # saving post-training snapshot if no identifer given
        saved_path = config.save_snapshot_path
        print(f"Saving pytorch model to: {saved_path}")

    saved_folder = os.path.dirname(saved_path)
    if not PathManager.exists(saved_folder):
        PathManager.mkdirs(saved_folder)
        print(f"created {saved_folder}")

    # Currently torch.save() has error pickling certain models when not saving
    # by model.state_dict(), thus currently overriding the model in
    # training_state with None, and put back saving
    # https://github.com/pytorch/pytorch/issues/15116
    model_in_training_state = None
    if training_state:
        model_in_training_state, training_state.model = training_state.model, None
    try:
        state = {
            DATA_STATE: meta,
            CONFIG_JSON: config_to_json(PyTextConfig, config),
            MODEL_STATE: model.state_dict(),
            SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
            TENSORIZERS: tensorizers,
            TRAINING_STATE: training_state,
        }
        if identifier is not None:
            _CHECKPOINT_MANAGER.save_checkpoint(state, saved_path)
        else:
            _CHECKPOINT_MANAGER.save_snapshot(state, saved_path)

    finally:
        if training_state:
            training_state.model = model_in_training_state
    return saved_path