Esempio n. 1
0
def configure_ermlp_training_pipeline(model_name: str):
    """Configure ERMLP from pipeline.

    :param model_name: name of the model
    :rtype: OrderedDict
    :return: configuration dictionary
    """
    config = get_config_dict(model_name)

    # Step 1: Query embedding dimension
    print_training_embedding_dimension_message()
    print_embedding_dimension_info_message()
    embedding_dimension = select_integer_value(
        print_msg=EMBEDDING_DIMENSION_PRINT_MSG,
        prompt_msg=EMBEDDING_DIMENSION_PROMPT_MSG,
        error_msg=EMBEDDING_DIMENSION_ERROR_MSG)
    config[EMBEDDING_DIM] = embedding_dimension
    print_section_divider()

    # Step 2: Query margin loss
    print_training_margin_loss_message()
    magin_loss = select_float_value(print_msg=MARGIN_LOSS_PRINT_MSG,
                                    prompt_msg=MARGIN_LOSS_PROMPT_MSG,
                                    error_msg=MARGIN_LOSS_ERROR_MSG)
    config[MARGIN_LOSS] = magin_loss
    print_section_divider()

    # Step 3: Query learning rate
    print_learning_rate_message()
    learning_rate = select_float_value(print_msg=LEARNING_RATE_PRINT_MSG,
                                       prompt_msg=LEARNING_RATE_PROMPT_MSG,
                                       error_msg=LEARNING_RATE_ERROR_MSG)
    config[LEARNING_RATE] = learning_rate
    print_section_divider()

    # Step 4: Query batch size
    print_batch_size_message()
    batch_size = select_integer_value(print_msg=BATCH_SIZE_PRINT_MSG,
                                      prompt_msg=BATCH_SIZE_PROMPT_MSG,
                                      error_msg=BATCH_SIZE_ERROR_MSG)
    config[BATCH_SIZE] = batch_size
    print_section_divider()

    # Step 5: Query number of epochs
    print_number_epochs_message()
    number_epochs = select_integer_value(print_msg=EPOCH_PRINT_MSG,
                                         prompt_msg=EPOCH_PROMPT_MSG,
                                         error_msg=EPOCH_ERROR_MSG)
    config[NUM_EPOCHS] = number_epochs
    print_section_divider()

    return config
Esempio n. 2
0
def configure_um_training_pipeline(model_name: str):
    config = OrderedDict()
    config[KG_EMBEDDING_MODEL_NAME] = model_name

    # Step 1: Query embedding dimension
    print_training_embedding_dimension_message()
    print_embedding_dimension_info_message()
    embedding_dimension = select_integer_value(print_msg=EMBEDDING_DIMENSION_PRINT_MSG,
                                               prompt_msg=EMBEDDING_DIMENSION_PROMPT_MSG,
                                               error_msg=EMBEDDING_DIMENSION_ERROR_MSG)
    config[EMBEDDING_DIM] = embedding_dimension
    print_section_divider()

    # Step 2: Query margin loss
    print_training_margin_loss_message()
    magin_loss = select_float_value(print_msg=MARGIN_LOSS_PRINT_MSG,
                                    prompt_msg=MARGIN_LOSS_PROMPT_MSG,
                                    error_msg=MARGIN_LOSS_ERROR_MSG)
    config[MARGIN_LOSS] = magin_loss
    print_section_divider()

    # Step 3: Query L_p norm as scoring function
    print_scoring_fct_message()
    scoring_fct_norm = select_integer_value(print_msg=NORM_SCORING_FUNCTION_PRINT_MSG,
                                            prompt_msg=NORM_SCORING_FUNCTION_PROMPT_MSG,
                                            error_msg=NORM_SCORING_FUNCTION_ERROR_MSG)
    config[SCORING_FUNCTION_NORM] = scoring_fct_norm
    print_section_divider()

    # Step 4: Query L_p norm for normalizing the entities
    print_entity_normalization_message()
    entity_normalization_norm = select_integer_value(print_msg=ENTITIES_NORMALIZATION_PRINT_MSG,
                                                     prompt_msg=ENTITIES_NORMALIZATION_PROMPT_MSG,
                                                     error_msg=ENTITIES_NORMALIZATION_ERROR_MSG)
    config[NORM_FOR_NORMALIZATION_OF_ENTITIES] = entity_normalization_norm
    print_section_divider()

    # Step 5: Query learning rate
    print_learning_rate_message()
    learning_rate = select_float_value(print_msg=LEARNING_RATE_PRINT_MSG,
                                       prompt_msg=LEARNING_RATE_PROMPT_MSG,
                                       error_msg=LEARNING_RATE_ERROR_MSG)
    config[LEARNING_RATE] = learning_rate
    print_section_divider()

    # Step 6: Query batch size
    print_batch_size_message()
    batch_size = select_integer_value(print_msg=BATCH_SIZE_PRINT_MSG,
                                      prompt_msg=BATCH_SIZE_PROMPT_MSG,
                                      error_msg=BATCH_SIZE_ERROR_MSG)
    config[BATCH_SIZE] = batch_size
    print_section_divider()

    # Step 7: Query number of epochs
    print_number_epochs_message()
    number_epochs = select_integer_value(print_msg=EPOCH_PRINT_MSG,
                                         prompt_msg=EPOCH_PROMPT_MSG,
                                         error_msg=EPOCH_ERROR_MSG)
    config[NUM_EPOCHS] = number_epochs
    print_section_divider()

    return config
Esempio n. 3
0
def configure_trans_d_training_pipeline(model_name: str):
    """Configure Trans D from pipeline.

    :param model_name: name of the model
    :return: configuration dictionary
    """
    config = get_config_dict(model_name)

    # Step 1: Query embedding dimension for entities
    print_entities_embedding_dimension_message()
    embedding_dimension = select_integer_value(
        print_msg=ENTITIES_EMBEDDING_DIMENSION_PRINT_MSG,
        prompt_msg=ENTITIES_EMBEDDING_DIMENSION_PROMPT_MSG,
        error_msg=ENTITIES_EMBEDDING_DIMENSION_ERROR_MSG,
    )
    config[EMBEDDING_DIM] = embedding_dimension
    print_section_divider()

    # Step 2: Query embedding dimension for relations
    print_relations_embedding_dimension_message()
    relation_embedding_dimension = select_integer_value(
        print_msg=RELATION_EMBEDDING_DIMENSION_PRINT_MSG,
        prompt_msg=RELATION_EMBEDDING_DIMENSION_PROMPT_MSG,
        error_msg=RELATION_EMBEDDING_DIMENSION_ERROR_MSG,
    )
    config[RELATION_EMBEDDING_DIM] = relation_embedding_dimension
    print_section_divider()

    # Step 2: Query margin loss
    print_training_margin_loss_message()
    magin_loss = select_float_value(print_msg=MARGIN_LOSS_PRINT_MSG,
                                    prompt_msg=MARGIN_LOSS_PROMPT_MSG,
                                    error_msg=MARGIN_LOSS_ERROR_MSG)
    config[MARGIN_LOSS] = magin_loss
    print_section_divider()

    # Step 3: Query L_p norm as scoring function
    print_scoring_fct_message()
    scoring_fct_norm = select_integer_value(
        print_msg=NORM_SCORING_FUNCTION_PRINT_MSG,
        prompt_msg=NORM_SCORING_FUNCTION_PROMPT_MSG,
        error_msg=NORM_SCORING_FUNCTION_ERROR_MSG)
    config[SCORING_FUNCTION_NORM] = scoring_fct_norm
    print_section_divider()

    # Step 5: Query learning rate
    print_learning_rate_message()
    learning_rate = select_float_value(print_msg=LEARNING_RATE_PRINT_MSG,
                                       prompt_msg=LEARNING_RATE_PROMPT_MSG,
                                       error_msg=LEARNING_RATE_ERROR_MSG)
    config[LEARNING_RATE] = learning_rate
    print_section_divider()

    # Step 6: Query batch size
    print_batch_size_message()
    batch_size = select_integer_value(print_msg=BATCH_SIZE_PRINT_MSG,
                                      prompt_msg=BATCH_SIZE_PROMPT_MSG,
                                      error_msg=BATCH_SIZE_ERROR_MSG)
    config[BATCH_SIZE] = batch_size
    print_section_divider()

    # Step 7: Query number of epochs
    print_number_epochs_message()
    number_epochs = select_integer_value(print_msg=EPOCH_PRINT_MSG,
                                         prompt_msg=EPOCH_PROMPT_MSG,
                                         error_msg=EPOCH_ERROR_MSG)
    config[NUM_EPOCHS] = number_epochs
    print_section_divider()

    return config
Esempio n. 4
0
def configure_trans_h_training_pipeline(model_name: str) -> Dict:
    """Prompt the user to configure Trans H from pipeline."""
    config = get_config_dict(model_name)

    # Step 1: Query embedding dimension
    print_training_embedding_dimension_message()
    print_embedding_dimension_info_message()
    embedding_dimension = select_integer_value(
        print_msg=EMBEDDING_DIMENSION_PRINT_MSG,
        prompt_msg=EMBEDDING_DIMENSION_PROMPT_MSG,
        error_msg=EMBEDDING_DIMENSION_ERROR_MSG,
    )
    config[EMBEDDING_DIM] = embedding_dimension
    print_section_divider()

    # Step 2: Query margin loss
    print_training_margin_loss_message()
    magin_loss = select_float_value(
        print_msg=MARGIN_LOSS_PRINT_MSG,
        prompt_msg=MARGIN_LOSS_PROMPT_MSG,
        error_msg=MARGIN_LOSS_ERROR_MSG,
    )
    config[MARGIN_LOSS] = magin_loss
    print_section_divider()

    # Step 3: Query L_p norm as scoring function
    print_scoring_fct_message()
    scoring_fct_norm = select_integer_value(
        print_msg=NORM_SCORING_FUNCTION_PRINT_MSG,
        prompt_msg=NORM_SCORING_FUNCTION_PROMPT_MSG,
        error_msg=NORM_SCORING_FUNCTION_ERROR_MSG
    )
    config[SCORING_FUNCTION_NORM] = scoring_fct_norm
    print_section_divider()

    # Step 4: Query weight for the soft constraints
    print_trans_h_soft_constraints_weight_message()
    soft_constraints_weight = select_float_value(
        print_msg=WEIGHTS_SOFT_CONSTRAINT_TRANS_H_PRINT_MSG,
        prompt_msg=WEIGHTS_SOFT_CONSTRAINT_TRANS_H_PROMPT_MSG,
        error_msg=WEIGHTS_SOFT_CONSTRAINT_TRANS_H_ERROR_MSG,
    )
    config[WEIGHT_SOFT_CONSTRAINT_TRANS_H] = soft_constraints_weight
    print_section_divider()

    # Step 5: Query learning rate
    print_learning_rate_message()
    learning_rate = select_float_value(
        print_msg=LEARNING_RATE_PRINT_MSG,
        prompt_msg=LEARNING_RATE_PROMPT_MSG,
        error_msg=LEARNING_RATE_ERROR_MSG,
    )
    config[LEARNING_RATE] = learning_rate
    print_section_divider()

    # Step 6: Query batch size
    print_batch_size_message()
    batch_size = select_integer_value(
        print_msg=BATCH_SIZE_PRINT_MSG,
        prompt_msg=BATCH_SIZE_PROMPT_MSG,
        error_msg=BATCH_SIZE_ERROR_MSG,
    )
    config[BATCH_SIZE] = batch_size
    print_section_divider()

    # Step 7: Query number of epochs
    print_number_epochs_message()
    number_epochs = select_integer_value(
        print_msg=EPOCH_PRINT_MSG,
        prompt_msg=EPOCH_PROMPT_MSG,
        error_msg=EPOCH_ERROR_MSG,
    )
    config[NUM_EPOCHS] = number_epochs
    print_section_divider()

    return config