Esempio n. 1
0
def Test(num_datapoints):
    tokens_to_index = get_token_dict_from_file()

    err_count = 0
    count = 0

    for datapoint in test_dataset():
        conjecture = datapoint.conjecture
        statement = datapoint.statement
        label = datapoint.label

        conjecture_statement = [conjecture, statement]

        prediction_val = F([conjecture_statement])
        _, prediction_label = torch.max(prediction_val, dim=1)

        if cuda_available:
            prediction_label = prediction_label.cpu()
        prediction_label = prediction_label.numpy()

        if datapoint.label != prediction_label[0]:
            err_count += 1

        count += 1

        if count % 100 == 0:
            print("Count: ", count)

    print("Fraction of Incorrect Test Points: ", err_count / count)

    return err_count / count
Esempio n. 2
0
def Validate(num_datapoints):
    tokens_to_index = get_token_dict_from_file()

    err_count = 0
    count = 0
    for datapoint in validation_dataset():
        conjecture = datapoint.conjecture
        statement = datapoint.statement
        label = datapoint.label

        for node_id, node_obj in conjecture.nodes.items(
        ):  # Find and replace unknowns
            if node_obj.token not in tokens_to_index.keys():  # UNKOWN token
                node_obj.token = "UNKNOWN"

        for node_id, node_obj in statement.nodes.items():
            if node_obj.token not in tokens_to_index.keys():
                node_obj.token = "UNKNOWN"

        prediction_val = F([conjecture], [statement])
        _, prediction_label = torch.max(prediction_val, dim=1)

        if cuda_available:
            prediction_label = prediction_label.cpu()
        prediction_label = prediction_label.numpy()

        # print(label)
        # print(prediction_label)

        if datapoint.label != prediction_label[0]:
            err_count += 1

        count += 1

        if count % 100 == 0:
            print("Count: ", count)

        if count == num_datapoints:
            break

    print("Fraction of Incorrect Validations: ", err_count / count)

    return err_count / count
Esempio n. 3
0
    def __init__(self, num_steps, cuda_available=False):
        super(FormulaNet, self).__init__()
        # Initialize models
        self.dense_map = LinearMap()  # maps one_hot -> 256 dimension vector
        self.FP = FPClass()
        self.FI = FIClass()
        self.FO = FOClass()
        self.FL = FLClass()
        self.FR = FRClass()
        self.FH = FHClass()
        self.Classifier = CondClassifier()
        self.Softmax = nn.Softmax(dim=1)

        self.max_pool_dense_graph = max_pool_dense_graph()

        self.num_steps = num_steps
        self.token_to_index = get_token_dict_from_file()

        self.cuda_available = cuda_available