示例#1
0
    def from_config(cls, model_config, feature_config, metadata: CommonMetadata):
        if model_config.compositional_type == RNNGParser.Config.CompositionalType.SUM:
            p_compositional = CompositionalSummationNN(
                lstm_dim=model_config.lstm.lstm_dim
            )
        elif (
            model_config.compositional_type == RNNGParser.Config.CompositionalType.BLSTM
        ):
            p_compositional = CompositionalNN(lstm_dim=model_config.lstm.lstm_dim)
        else:
            raise ValueError(
                "Cannot understand compositional flag {}".format(
                    model_config.compositional_type
                )
            )

        return cls(
            ablation=model_config.ablation,
            constraints=model_config.constraints,
            lstm_num_layers=model_config.lstm.num_layers,
            lstm_dim=model_config.lstm.lstm_dim,
            max_open_NT=model_config.max_open_NT,
            dropout=model_config.dropout,
            actions_vocab=metadata.actions_vocab,
            shift_idx=metadata.shift_idx,
            reduce_idx=metadata.reduce_idx,
            ignore_subNTs_roots=metadata.ignore_subNTs_roots,
            valid_NT_idxs=metadata.valid_NT_idxs,
            valid_IN_idxs=metadata.valid_IN_idxs,
            valid_SL_idxs=metadata.valid_SL_idxs,
            embedding=Model.create_embedding(feature_config, metadata=metadata),
            p_compositional=p_compositional,
        )
示例#2
0
    def from_config(
        cls,
        model_config,
        feature_config=None,
        metadata: CommonMetadata = None,
        tensorizers: Dict[str, Tensorizer] = None,
    ):
        if model_config.compositional_type == RNNGParser.Config.CompositionalType.SUM:
            p_compositional = CompositionalSummationNN(
                lstm_dim=model_config.lstm.lstm_dim)
        elif (model_config.compositional_type ==
              RNNGParser.Config.CompositionalType.BLSTM):
            p_compositional = CompositionalNN(
                lstm_dim=model_config.lstm.lstm_dim)
        else:
            raise ValueError("Cannot understand compositional flag {}".format(
                model_config.compositional_type))

        if tensorizers is not None:
            embedding = EmbeddingList(
                [
                    create_module(model_config.embedding,
                                  tensorizer=tensorizers["tokens"])
                ],
                concat=True,
            )
            actions_params = tensorizers["actions"]
            actions_vocab = actions_params.vocab
        else:
            embedding = Model.create_embedding(feature_config,
                                               metadata=metadata)
            actions_params = metadata
            actions_vocab = metadata.actions_vocab

        return cls(
            ablation=model_config.ablation,
            constraints=model_config.constraints,
            lstm_num_layers=model_config.lstm.num_layers,
            lstm_dim=model_config.lstm.lstm_dim,
            max_open_NT=model_config.max_open_NT,
            dropout=model_config.dropout,
            actions_vocab=actions_vocab,
            shift_idx=actions_params.shift_idx,
            reduce_idx=actions_params.reduce_idx,
            ignore_subNTs_roots=actions_params.ignore_subNTs_roots,
            valid_NT_idxs=actions_params.valid_NT_idxs,
            valid_IN_idxs=actions_params.valid_IN_idxs,
            valid_SL_idxs=actions_params.valid_SL_idxs,
            embedding=embedding,
            p_compositional=p_compositional,
        )
示例#3
0
    def from_config(cls, model_config, feature_config,
                    metadata: CommonMetadata):
        device = ("cuda:{}".format(torch.cuda.current_device())
                  if cuda.CUDA_ENABLED else "cpu")
        if model_config.compositional_type == RNNGParser.Config.CompositionalType.SUM:
            p_compositional = CompositionalSummationNN(
                lstm_dim=model_config.lstm.lstm_dim)
        elif (model_config.compositional_type ==
              RNNGParser.Config.CompositionalType.BLSTM):
            p_compositional = CompositionalNN(
                lstm_dim=model_config.lstm.lstm_dim, device=device)
        else:
            raise ValueError("Cannot understand compositional flag {}".format(
                model_config.compositional_type))
        emb_module = Model.create_embedding(feature_config, metadata=metadata)
        contextual_emb_dim = feature_config.contextual_token_embedding.embed_dim

        return cls(
            cls.get_input_for_trace(contextual_emb_dim),
            embedding=emb_module,
            ablation=model_config.ablation,
            constraints=model_config.constraints,
            lstm_num_layers=model_config.lstm.num_layers,
            lstm_dim=model_config.lstm.lstm_dim,
            max_open_NT=model_config.max_open_NT,
            dropout=model_config.dropout,
            num_actions=len(metadata.actions_vocab),
            shift_idx=metadata.shift_idx,
            reduce_idx=metadata.reduce_idx,
            ignore_subNTs_roots=metadata.ignore_subNTs_roots,
            valid_NT_idxs=metadata.valid_NT_idxs,
            valid_IN_idxs=metadata.valid_IN_idxs,
            valid_SL_idxs=metadata.valid_SL_idxs,
            embedding_dim=emb_module.embedding_dim,
            p_compositional=p_compositional,
            device=device,
        )