示例#1
0
def save(config: PyTextConfig, model: Model, meta: CommonMetadata) -> None:
    """
    Save a task, will save the original config, model state and metadata
    """
    save_path = config.save_snapshot_path
    print(f"Saving pytorch model to: {save_path}")
    model.save_modules(base_path=config.modules_save_dir)
    state = OrderedDict([
        (DATA_STATE, meta),
        (CONFIG_JSON, config_to_json(PyTextConfig, config)),
        (MODEL_STATE, model.state_dict()),
    ])  # type: OrderedDict
    torch.save(state, save_path)
示例#2
0
def save_checkpoint(
    f: io.IOBase,
    config: PyTextConfig,
    model: Model,
    meta: Optional[CommonMetadata],
    tensorizers: Dict[str, Tensorizer],
    training_state: Optional[TrainingState] = None,
) -> str:
    # Currently torch.save() has error pickling certain models when not saving
    # by model.state_dict(), thus currently overriding the model in
    # training_state with None, and put back saving
    # https://github.com/pytorch/pytorch/issues/15116
    model_in_training_state = None
    if training_state:
        model_in_training_state, training_state.model = training_state.model, None
    try:
        state = {
            DATA_STATE: meta,
            CONFIG_JSON: config_to_json(PyTextConfig, config),
            MODEL_STATE: model.state_dict(),
            SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
            TENSORIZERS: tensorizers,
            TRAINING_STATE: training_state,
        }
        torch.save(state, f)
    finally:
        if training_state:
            training_state.model = model_in_training_state
示例#3
0
    def from_config(cls, model_config, feature_config, metadata: CommonMetadata):
        if model_config.compositional_type == RNNGParser.Config.CompositionalType.SUM:
            p_compositional = CompositionalSummationNN(
                lstm_dim=model_config.lstm.lstm_dim
            )
        elif (
            model_config.compositional_type == RNNGParser.Config.CompositionalType.BLSTM
        ):
            p_compositional = CompositionalNN(lstm_dim=model_config.lstm.lstm_dim)
        else:
            raise ValueError(
                "Cannot understand compositional flag {}".format(
                    model_config.compositional_type
                )
            )

        return cls(
            ablation=model_config.ablation,
            constraints=model_config.constraints,
            lstm_num_layers=model_config.lstm.num_layers,
            lstm_dim=model_config.lstm.lstm_dim,
            max_open_NT=model_config.max_open_NT,
            dropout=model_config.dropout,
            actions_vocab=metadata.actions_vocab,
            shift_idx=metadata.shift_idx,
            reduce_idx=metadata.reduce_idx,
            ignore_subNTs_roots=metadata.ignore_subNTs_roots,
            valid_NT_idxs=metadata.valid_NT_idxs,
            valid_IN_idxs=metadata.valid_IN_idxs,
            valid_SL_idxs=metadata.valid_SL_idxs,
            embedding=Model.create_embedding(feature_config, metadata=metadata),
            p_compositional=p_compositional,
        )
示例#4
0
def save(
    config: PyTextConfig,
    model: Model,
    meta: CommonMetadata,
    tensorizers: Dict[str, Tensorizer],
) -> None:
    """
    Save a task, will save the original config, model state and metadata
    """
    save_path = config.save_snapshot_path
    print(f"Saving pytorch model to: {save_path}")
    model.save_modules(base_path=config.modules_save_dir)
    state = {
        DATA_STATE: meta,
        CONFIG_JSON: config_to_json(PyTextConfig, config),
        MODEL_STATE: model.state_dict(),
        SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
        TENSORIZERS: tensorizers,
    }
    torch.save(state, save_path)
def save(
    config: PyTextConfig,
    model: Model,
    meta: Optional[CommonMetadata],
    tensorizers: Dict[str, Tensorizer],
    training_state: Optional[TrainingState] = None,
    f: Optional[io.IOBase] = None,
) -> None:
    """
    Save all stateful information of a training task to a specified file-like
    object, will save the original config, model state, metadata,
    training state if training is not completed
    """
    if config.modules_save_dir:
        model.save_modules(base_path=config.modules_save_dir)

    # Currently torch.save() has error pickling certain models when not saving
    # by model.state_dict(), thus currently overriding the model in
    # training_state with None, and put back saving
    # https://github.com/pytorch/pytorch/issues/15116
    model_in_training_state = None
    if training_state:
        model_in_training_state, training_state.model = training_state.model, None
    try:
        state = {
            DATA_STATE: meta,
            CONFIG_JSON: config_to_json(PyTextConfig, config),
            MODEL_STATE: model.state_dict(),
            SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
            TENSORIZERS: tensorizers,
            TRAINING_STATE: training_state,
        }
        if f is None:
            save_path = config.save_snapshot_path
            print(f"Saving pytorch model to: {save_path}")
            torch.save(state, save_path)
        else:
            torch.save(state, f)
    finally:
        if training_state:
            training_state.model = model_in_training_state
示例#6
0
    def from_config(
        cls,
        model_config,
        feature_config=None,
        metadata: CommonMetadata = None,
        tensorizers: Dict[str, Tensorizer] = None,
    ):
        if model_config.compositional_type == RNNGParser.Config.CompositionalType.SUM:
            p_compositional = CompositionalSummationNN(
                lstm_dim=model_config.lstm.lstm_dim)
        elif (model_config.compositional_type ==
              RNNGParser.Config.CompositionalType.BLSTM):
            p_compositional = CompositionalNN(
                lstm_dim=model_config.lstm.lstm_dim)
        else:
            raise ValueError("Cannot understand compositional flag {}".format(
                model_config.compositional_type))

        if tensorizers is not None:
            embedding = EmbeddingList(
                [
                    create_module(model_config.embedding,
                                  tensorizer=tensorizers["tokens"])
                ],
                concat=True,
            )
            actions_params = tensorizers["actions"]
            actions_vocab = actions_params.vocab
        else:
            embedding = Model.create_embedding(feature_config,
                                               metadata=metadata)
            actions_params = metadata
            actions_vocab = metadata.actions_vocab

        return cls(
            ablation=model_config.ablation,
            constraints=model_config.constraints,
            lstm_num_layers=model_config.lstm.num_layers,
            lstm_dim=model_config.lstm.lstm_dim,
            max_open_NT=model_config.max_open_NT,
            dropout=model_config.dropout,
            actions_vocab=actions_vocab,
            shift_idx=actions_params.shift_idx,
            reduce_idx=actions_params.reduce_idx,
            ignore_subNTs_roots=actions_params.ignore_subNTs_roots,
            valid_NT_idxs=actions_params.valid_NT_idxs,
            valid_IN_idxs=actions_params.valid_IN_idxs,
            valid_SL_idxs=actions_params.valid_SL_idxs,
            embedding=embedding,
            p_compositional=p_compositional,
        )
示例#7
0
    def from_config(cls, model_config, feature_config,
                    metadata: CommonMetadata):
        device = ("cuda:{}".format(torch.cuda.current_device())
                  if cuda.CUDA_ENABLED else "cpu")
        if model_config.compositional_type == RNNGParser.Config.CompositionalType.SUM:
            p_compositional = CompositionalSummationNN(
                lstm_dim=model_config.lstm.lstm_dim)
        elif (model_config.compositional_type ==
              RNNGParser.Config.CompositionalType.BLSTM):
            p_compositional = CompositionalNN(
                lstm_dim=model_config.lstm.lstm_dim, device=device)
        else:
            raise ValueError("Cannot understand compositional flag {}".format(
                model_config.compositional_type))
        emb_module = Model.create_embedding(feature_config, metadata=metadata)
        contextual_emb_dim = feature_config.contextual_token_embedding.embed_dim

        return cls(
            cls.get_input_for_trace(contextual_emb_dim),
            embedding=emb_module,
            ablation=model_config.ablation,
            constraints=model_config.constraints,
            lstm_num_layers=model_config.lstm.num_layers,
            lstm_dim=model_config.lstm.lstm_dim,
            max_open_NT=model_config.max_open_NT,
            dropout=model_config.dropout,
            num_actions=len(metadata.actions_vocab),
            shift_idx=metadata.shift_idx,
            reduce_idx=metadata.reduce_idx,
            ignore_subNTs_roots=metadata.ignore_subNTs_roots,
            valid_NT_idxs=metadata.valid_NT_idxs,
            valid_IN_idxs=metadata.valid_IN_idxs,
            valid_SL_idxs=metadata.valid_SL_idxs,
            embedding_dim=emb_module.embedding_dim,
            p_compositional=p_compositional,
            device=device,
        )
示例#8
0
def save(
    config: PyTextConfig,
    model: Model,
    meta: Optional[CommonMetadata],
    tensorizers: Dict[str, Tensorizer],
    training_state: Optional[TrainingState] = None,
    identifier: Optional[str] = None,
) -> str:
    """
    Save all stateful information of a training task to a specified file-like
    object, will save the original config, model state, metadata,
    training state if training is not completed
    Args:
    identifier (str): used to identify a checkpoint within a training job,
    used as a suffix for save path
    config (PytextConfig): contains all raw parameter/hyper-parameters
    for training task
    model (Model): actual model in training
    training_state (TrainingState): stateful infomation during training
    Returns:
    identifier (str): if identifier is not specified, will save to
    config.save_snapshot_path to be consistent to post-training snapshot;
    if specified, will be used to save checkpoint during training,
    identifier is used to identify checkpoints in the same training
    """
    saved_path = ""
    if identifier:
        # saving during-training checkpoints
        saved_path = generate_checkpoint_path(config, identifier)
    else:
        # saving post-training snapshot if no identifer given
        saved_path = config.save_snapshot_path
        print(f"Saving pytorch model to: {saved_path}")

    saved_folder = os.path.dirname(saved_path)
    if not PathManager.exists(saved_folder):
        PathManager.mkdirs(saved_folder)
        print(f"created {saved_folder}")

    # Currently torch.save() has error pickling certain models when not saving
    # by model.state_dict(), thus currently overriding the model in
    # training_state with None, and put back saving
    # https://github.com/pytorch/pytorch/issues/15116
    model_in_training_state = None
    if training_state:
        model_in_training_state, training_state.model = training_state.model, None
    try:
        state = {
            DATA_STATE: meta,
            CONFIG_JSON: config_to_json(PyTextConfig, config),
            MODEL_STATE: model.state_dict(),
            SERIALIZE_VERSION_KEY: LATEST_SERIALIZE_VERSION,
            TENSORIZERS: tensorizers,
            TRAINING_STATE: training_state,
        }
        if identifier is not None:
            _CHECKPOINT_MANAGER.save_checkpoint(state, saved_path)
        else:
            _CHECKPOINT_MANAGER.save_snapshot(state, saved_path)

    finally:
        if training_state:
            training_state.model = model_in_training_state
    return saved_path