Exemplo n.º 1
0
def read_config(config_file):
    """
    Reads the config file.
    :param config_file: Path to config file
    :return: train_file, val_file, test_file, skip_header, split_symbol,
             entity2wiki_file
    """
    config = configparser.ConfigParser()
    config.read(str(config_file))
    logger.info("Reading config from: %s" % config_file)

    # GPUs
    Config.set("DisableCuda",
               config.getboolean("GPU", "DisableCuda", fallback=False))
    if not Config.get("DisableCuda") and torch.cuda.is_available():
        Config.set("device", 'cuda')
    else:
        Config.set("device", 'cpu')
    Config.set("GPUs", [
        int(gpu) for gpu in config.get("GPU", "GPUs", fallback='').split(',')
    ])

    # Training
    # Config.set("LinkPredictionModelType", config.get("Training", "LinkPredictionModelType", fallback=None))
    Config.set("Epochs", config.getint("Training", "Epochs", fallback=1000))
    Config.set("BatchSize",
               config.getint("Training", "BatchSize", fallback=None))
    Config.set(
        "EmbeddingDimensionality",
        config.getint("Training", "EmbeddingDimensionality", fallback=300))
    Config.set("LearningRate",
               config.getfloat("Training", "LearningRate", fallback=0.1))
    Config.set(
        "LearningRateSchedule",
        config.get("Training", "LearningRateSchedule",
                   fallback="50,100,200").split(','))
    Config.set(
        "LearningRateGammas",
        config.get("Training", "LearningRateGammas",
                   fallback="0.1,0.1,0.1").split(','))
    Config.set(
        "InitializeEmbeddingWithAllEntities",
        config.getboolean("Training",
                          "InitializeEmbeddingWithAllEntities",
                          fallback=False))
    Config.set(
        "InitializeWithPretrainedKGCEmbedding",
        config.getboolean("Training",
                          "InitializeWithPretrainedKGCEmbedding",
                          fallback=False))

    Config.set("TransformationType",
               config.get("Training", "TransformationType", fallback="Linear"))
    Config.set("EncoderType",
               config.get("Training", "EncoderType", fallback="Average"))
    Config.set(
        "UseTailsToOptimize",
        config.getboolean("Training", "UseTailsToOptimize", fallback=False))
    Config.set("Loss", config.get("Training", "Loss", fallback="Pairwise"))
    Config.set("UNKType", config.get("Training", "UNKType",
                                     fallback="Average"))
    Config.set("AverageWordDropout",
               config.getfloat("Training", "AverageWordDropout", fallback=0.))
    Config.set("IterTriplets",
               config.getboolean("Training", "IterTriplets", fallback=True))

    # FCN
    Config.set("FCNUseSigmoid",
               config.getboolean("FCN", "FCNUseSigmoid", fallback=False))
    Config.set("FCNLayers", config.getint("FCN", "FCNLayers", fallback=0))
    Config.set("FCNDropout", config.getfloat("FCN", "FCNDropout", fallback=0))
    Config.set("FCNHiddenDim",
               config.getint("FCN", "FCNHiddenDim", fallback=None))

    # LSTM
    Config.set("LSTMOutputDim",
               config.getint("LSTM", "LSTMOutputDim", fallback=None))
    Config.set("LSTMBidirectional",
               config.getboolean("LSTM", "LSTMBidirectional", fallback=False))

    # Evaluation
    Config.set("ValidateEvery",
               config.getint("Evaluation", "ValidateEvery", fallback=1000))
    Config.set(
        "UseTargetFilteringShi",
        config.getboolean("Evaluation",
                          "UseTargetFilteringShi",
                          fallback=False))
    Config.set("PrintTrainNN",
               config.getboolean("Evaluation", "PrintTrainNN", fallback=False))
    Config.set("PrintTestNN",
               config.getboolean("Evaluation", "PrintTestNN", fallback=False))
    Config.set(
        "EvalRandomHeads",
        config.getboolean("Evaluation", "EvalRandomHeads", fallback=False))
    Config.set(
        "CalculateNNMeanRank",
        config.getboolean("Evaluation", "CalculateNNMeanRank", fallback=False))
    Config.set(
        "ShiTargetFilteringBaseline",
        config.getboolean("Evaluation",
                          "ShiTargetFilteringBaseline",
                          fallback=False))
    Config.set(
        "GetTensorboardEmbeddings",
        config.getboolean("Evaluation",
                          "GetTensorboardEmbeddings",
                          fallback=True))

    if not len(Config.get("LearningRateSchedule")) == len(
            Config.get("LearningRateGammas")):
        raise ValueError(
            "Length of LearningRateSchedule must be equal to LearningRateGammas"
        )

    # early stopping
    Config.set(
        "EarlyStopping",
        config.getboolean("EarlyStopping", "EarlyStopping", fallback=False))
    Config.set(
        "EarlyStoppingThreshold",
        config.getfloat("EarlyStopping",
                        "EarlyStoppingThreshold",
                        fallback=0.1))
    Config.set(
        "EarlyStoppingLastX",
        config.getint("EarlyStopping", "EarlyStoppingLastX", fallback=10))
    Config.set(
        "EarlyStoppingMinEpochs",
        config.getint("EarlyStopping", "EarlyStoppingMinEpochs", fallback=10))

    # Entity2text
    Config.set(
        "PretrainedEmbeddingFile",
        config.get("Entity2Text", "PretrainedEmbeddingFile", fallback=None))
    Config.set(
        "ConvertEntities",
        config.getboolean("Entity2Text", "ConvertEntities", fallback=False))
    Config.set(
        "ConvertEntitiesWithMultiprocessing",
        config.getboolean("Entity2Text",
                          "ConvertEntitiesWithMultiprocessing",
                          fallback=True))
    Config.set(
        "MatchTokenInEmbedding",
        config.getboolean("Entity2Text",
                          "MatchTokenInEmbedding",
                          fallback=False))
    Config.set(
        "MatchLabelInEmbedding",
        config.getboolean("Entity2Text",
                          "MatchLabelInEmbedding",
                          fallback=False))
    Config.set(
        "LimitDescription",
        config.getint("Entity2Text", "LimitDescription", fallback=100000))

    # logger.info("LinkPredictionModelType: %s " % Config.get("LinkPredictionModelType"))
    # if Config.get("LinkPredictionModelType") not in ["ComplEx", "TransE", "TransR", "DistMult"]:
    #     raise ValueError("LinkPredictionModelType not recognized")

    # Dataset
    train_file = config.get("Dataset", "TrainFile")
    valid_file = config.get("Dataset", "ValidationFile")
    test_file = config.get("Dataset", "TestFile")
    entity2wiki_file = config.get("Dataset",
                                  "Entity2wikidata",
                                  fallback="entity2wikidata.json")
    logger.info("Using {} as wikidata file".format(entity2wiki_file))
    skip_header = config.getboolean("Dataset", "SkipHeader", fallback=False)
    split_symbol = config.get("Dataset", "SplitSymbol", fallback='TAB')
    if split_symbol not in ["TAB", "SPACE"]:
        raise ValueError("SplitSymbol must be either TAB or SPACE.")
    split_symbol = '\t' if split_symbol == 'TAB' else ' '
    return train_file, valid_file, test_file, skip_header, split_symbol, entity2wiki_file
Exemplo n.º 2
0
Arquivo: data.py Projeto: ren-1247/OWE
    def get_tokens_in_emb(self, embedding):
        def tokenize_old(content,
                         lower=True,
                         remove_punctuation=True,
                         add_underscores=False,
                         limit_len=100000):
            """
            Splits on spaces between tokens.

            :param content: The string that shall be tokenized.
            :param lower: Lowers content string
            :param remove_punctuation: Removes single punctuation tokens
            :param add_underscores: Replaces spaces with underscores
            :return:
            """
            if not content or not limit_len:
                return [""] if add_underscores else []

            if not isinstance(content, (str)):
                raise ValueError("Content must be a string.")

            if remove_punctuation:
                content = re.sub('[^A-Za-z0-9 ]+', '', content)

            if lower:
                content = content.lower()

            if add_underscores:
                res = [re.sub(' ', '_', content)]
                return res

            res = word_tokenize(content)

            res = res[:limit_len]
            return res

        label = tokenize_old(self.name)  # list
        label_uscored_uncased_punct = tokenize_old(
            self.name, remove_punctuation=False,
            add_underscores=True)[0]  # str
        label_uscored_cased_punct = tokenize_old(
            self.name,
            remove_punctuation=False,
            lower=False,
            add_underscores=True)[0]  # str
        label_uscored_uncased = tokenize_old(self.name,
                                             remove_punctuation=False,
                                             add_underscores=True)[0]  # str
        label_uscored_cased = tokenize_old(self.name,
                                           remove_punctuation=False,
                                           lower=False,
                                           add_underscores=True)[0]  # str

        description = tokenize_old(
            self.description, limit_len=Config.get("LimitDescription"))  # list
        # DEPRECATED should be removed
        if "ENTITY/" + self.entity_id in embedding:
            name_token = ["ENTITY/" + self.entity_id]
        elif self.entity_id in embedding:
            name_token = [self.entity_id]
        elif "ENTITY/" + label_uscored_cased_punct in embedding:
            name_token = ["ENTITY/" + label_uscored_cased_punct]
        elif "ENTITY/" + label_uscored_uncased_punct in embedding:
            name_token = ["ENTITY/" + label_uscored_uncased_punct]
        elif "ENTITY/" + label_uscored_cased in embedding:
            name_token = ["ENTITY/" + label_uscored_cased]
        elif "ENTITY/" + label_uscored_uncased in embedding:
            name_token = ["ENTITY/" + label_uscored_uncased]
        elif label_uscored_cased_punct in embedding:
            name_token = [label_uscored_cased_punct]
        elif label_uscored_uncased_punct in embedding:
            name_token = [label_uscored_uncased_punct]
        elif label_uscored_cased in embedding:
            name_token = [label_uscored_cased]
        elif label_uscored_uncased in embedding:
            name_token = [label_uscored_uncased]
        else:
            name_token = [n for n in label if n in embedding]

        desc_tokens = [n for n in description if n in embedding]
        return name_token or ["_UNK_"], desc_tokens or ["_UNK_"]
Exemplo n.º 3
0
def basic_config():
    Config.set("device", "cuda")
    Config.set("ConvertEntities", False)
    Config.set("InitializeEmbeddingWithAllEntities", False)
Exemplo n.º 4
0
    def __init__(self, embedding_m: torch.tensor, complex_d: int):
        super().__init__()
        encoder_out_d = embedding_m.size(1)

        if Config.get("EncoderType") == "Average":
            self.encoder = AvgEncoder(embedding_m)
            logger.info("Using averaging encoder")
        elif Config.get("EncoderType") == "CNN":
            self.encoder = CNNEncoder(embedding_m, output_d=encoder_out_d)
            logger.info("Using CNN encoder")
        elif Config.get("EncoderType") == "BiLSTM":
            if Config.get("LSTMOutputDim"):
                encoder_out_d = Config.get("LSTMOutputDim")
            self.encoder = LSTMEncoder(embedding_m, output_d=encoder_out_d)
            logger.info("Using BiLSTM encoder")
        else:
            raise ValueError("EncoderType invalid in config")

        if Config.get("LinkPredictionModelType") in ("ComplEx", "RotatE"):
            if Config.get("TransformationType") == "Linear":
                self.transformer_r = LinearTransform(encoder_out_d, complex_d)
                self.transformer_i = LinearTransform(encoder_out_d, complex_d)
                logger.info("Using Linear transformation")
            elif Config.get("TransformationType") == "Affine":
                self.transformer_r = LinearTransform(encoder_out_d,
                                                     complex_d,
                                                     bias=True)
                self.transformer_i = LinearTransform(encoder_out_d,
                                                     complex_d,
                                                     bias=True)
                logger.info("Using Affine transformation")
            elif Config.get("TransformationType") == "FCN":
                use_sigmoid = Config.get("FCNUseSigmoid")
                n_layers = Config.get("FCNLayers")

                self.transformer_r = FCNTransform(
                    encoder_out_d,
                    complex_d,
                    hidden_dim=Config.get("FCNHiddenDim"),
                    n_layers=n_layers,
                    use_sigmoid=use_sigmoid)
                self.transformer_i = FCNTransform(
                    encoder_out_d,
                    complex_d,
                    hidden_dim=Config.get("FCNHiddenDim"),
                    n_layers=n_layers,
                    use_sigmoid=use_sigmoid)
                logger.info("Using FCN transformation")
            else:
                raise ValueError("TransformationType invalid in config")
        elif Config.get("LinkPredictionModelType") in ("TransE", "TransR",
                                                       "DistMult"):
            if Config.get("TransformationType") == "Linear":
                self.transformer = LinearTransform(encoder_out_d, complex_d)
                logger.info("Using Linear transformation")
            elif Config.get("TransformationType") == "Affine":
                self.transformer = LinearTransform(encoder_out_d,
                                                   complex_d,
                                                   bias=True)
                logger.info("Using Affine transformation")
            elif Config.get("TransformationType") == "FCN":
                use_sigmoid = Config.get("FCNUseSigmoid")
                n_layers = Config.get("FCNLayers")

                self.transformer = FCNTransform(
                    encoder_out_d,
                    complex_d,
                    hidden_dim=Config.get("FCNHiddenDim"),
                    n_layers=n_layers,
                    use_sigmoid=use_sigmoid)
                logger.info("Using FCN transformation")
            else:
                raise ValueError("TransformationType invalid in config")
        else:
            raise ValueError("LinkPredictionModelType unknown")
Exemplo n.º 5
0
 def __init__(self, embedding_m: torch.tensor):
     super().__init__(embedding_m)
     self.dropout = None
     if Config.get("AverageWordDropout"):
         self.dropout = torch.nn.Dropout2d(
             p=Config.get("AverageWordDropout"))