Beispiel #1
0
    def fit(self, rows):
        """
        Build the tree.

        Rules of recursion: 1) Believe that it works. 2) Start by checking
        for the base case (no further information gain). 3) Prepare for
        giant stack traces.
        """

        # Try partitioing the dataset on each of the unique attribute,
        # calculate the information gain,
        # and return the question that produces the highest gain.
        gain, question = find_best_split(rows)

        # Base case: no further info gain
        # Since we can ask no further questions,
        # we'll return a leaf.
        if gain == 0:
            return Leaf(rows)

        # If we reach here, we have found a useful feature / value
        # to partition on.
        true_rows, false_rows = partition(rows, question)

        # Recursively build the true branch.
        true_branch = self.fit(true_rows)

        # Recursively build the false branch.
        false_branch = self.fit(false_rows)

        # Return a Question node.
        # This records the best feature / value to ask at this point,
        self.root = Decision_Node(question, true_branch, false_branch)
    def print_tree(self, rows, head, spacing=""):
        """
        A tree printing function.

        PARAMETERS
        ==========

        rows: list
            A list of lists to store the dataset.

        head: list
            A list to store the headings of the
            columns of the dataset.

        spacing: String
            To store and update the spaces to
            print the tree in an organised manner.

        RETURNS
        =======

        None

        """

        # Try partitioning the dataset on each of the unique attribute,
        # calculate the gini impurity,
        # and return the question that produces the least gini impurity.
        gain, question = find_best_split(rows, head)

        # Base case: we've reached a leaf
        if gain == 0:
            print(spacing + "Predict", class_counts(rows, len(rows[0])-1))
            return

        # If we reach here, we have found a useful feature / value
        # to partition on.
        true_rows, false_rows = partition(rows, question)

        # Print the question at this node
        print(spacing + str(question))

        # Call this function recursively on the true branch
        print(spacing + '--> True:')
        self.print_tree(true_rows, head, spacing + "  ")

        # Call this function recursively on the false branch
        print(spacing + '--> False:')
        self.print_tree(false_rows, head, spacing + "  ")
    def classify(self, rows, head, prediction_val):
        """
        A function to make predictions of
        the subsets of the dataset.

        PARAMETERS
        ==========

        rows: list
            A list of lists to store the subsets
            of the dataset.

        head: list
            A list to store the headings of the
            columns of the subset of the dataset.

        prediction_val: dictionary
            A dictionary to update and return the
            predictions of the subsets of the
            dataset.

        RETURNS
        =======

        prediction_val
            Dictionary to return the predictions
            corresponding to the subsets of the
            dataset.

        """

        N = len(rows[0])

        # Finding random indexes for columns
        # to collect random samples of the dataset.
        indexcol = []
        for j in range(0, 5):
            r = np.random.randint(0, N-2)
            if r not in indexcol:
                indexcol.append(r)

        row = []
        for j in rows:
            L = []
            for k in indexcol:
                L.append(j[k])
            row.append(L)

        # add last column to the random sample so created.
        for j in range(0, len(row)):
            row[j].append(rows[j][N-1])

        rows = row

        # Try partitioning the dataset on each of the unique attribute,
        # calculate the gini impurity,
        # and return the question that produces the least gini impurity.
        gain, question = find_best_split(rows, head)

        # Base case: we've reached a leaf
        if gain == 0:
            # Get the predictions of the current set of rows.
            p = class_counts(rows, len(rows[0])-1)
            for d in prediction_val:
                for j in p:
                    if d == j:
                        # update the predictions to be returned.
                        prediction_val[d] = prediction_val[d] + p[j]
            return prediction_val

        # If we reach here, we have found a useful feature / value
        # to partition on.
        true_rows, false_rows = partition(rows, question)

        # Recursively build the true branch.
        self.classify(true_rows, head, prediction_val)

        # Recursively build the false branch.
        self.classify(false_rows, head, prediction_val)

        # Return the dictionary of the predictions
        # at the end of the recursion.
        return prediction_val