示例#1
0
def train_classifier(filename, randomize = False, overfit = False, verbose = False):
    data, target = load_dataset(randomize=randomize, overfit=overfit)
    hist_data = []

    hog = HOG(orientations=18, pixelsPerCell=(10, 10), cellsPerBlock=(1, 1), normalize=True)

    if os.path.exists(MODELS_DIR + filename):
        print 'loading model'
        model = joblib.load(MODELS_DIR + filename)
    else:
        print 'training model'
        i = 0
        for image in data:
            image = dataset.center_extent(image, (20, 20))
            hist = hog.describe(image)
            hist_data.append(hist)
            i += 1

        # model = LinearSVC(random_state=42)
        model = svm.SVC(probability=True, kernel='sigmoid')
        # model = svm.NuSVC(probability=True)
        start = time.time()
        model.fit(hist_data, target)

        seconds = time.time() - start
        m, s = divmod(seconds, 60)
        h, m = divmod(m, 60)
        print "Took %d:%02d:%02d to train model." % (h, m, s)

        joblib.dump(model, MODELS_DIR + filename)

    correct = 0
    total = 0

    for i in range(data.shape[0]):
        image = data[i]
        input_image = image
        input_image = dataset.center_extent(input_image, (20, 20))
        hist = hog.describe(input_image)
        pred = model.predict_proba(hist).reshape([14]).tolist()
        max_value = max(pred)
        max_index = pred.index(max_value)

        if verbose:
            print '-' * 50
            print 'Predicted: ' + str(max_index) + ' with ' + str(max_value) + ' confidence'
            print 'Actual: ' + str(target[i])

            cv2.imshow('input', image)
            cv2.waitKey(0)

        if max_index == target[i]:
            correct += 1
        total += 1

    print str(correct) + ' out of ' + str(total)
    print 'accuracy: ' + str(correct / total)

    return model
def generate_training_data(problem_id, exclusions, threshold, leeway, transformations, label):
    try:
        problem = Problem.objects.get(id=problem_id)
    except Problem.DoesNotExist:
        return None

    global nn, f_output

    im = cv2.imread(problem.image.path)

    im_height, im_width = im.shape[:2]

    im_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)

    im_blurred = cv2.GaussianBlur(im_gray, (9, 9), 0)

    im_edged = cv2.Canny(im_blurred, 30, 150)

    (im_contoured, contours, _) = cv2.findContours(im_edged.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    contours = sorted([(contour, cv2.boundingRect(contour)[0]) for contour in contours], key =lambda x: x[1])

    rectangles = [cv2.boundingRect(contour[0]) for contour in contours]

    leeway = leeway
    if leeway == None:
        leeway = 0

    rectangles, bad_rectangles = valid_rectangles(rectangles, leeway)

    while bad_rectangles > 0:
        rectangles, bad_rectangles = valid_rectangles(rectangles, leeway)

    i = 0
    for rectangle in rectangles:
        x, y, w, h = rectangle

        if 7 <= w <= im_width / 4 and 20 <= h <= im_height / 4:
            if i not in exclusions:
                roi = im_gray[y: y + h, x: x + w]
                im_thresh = roi.copy()
                if threshold == None:
                    threshold = mahotas.thresholding.otsu(roi)
                im_thresh[im_thresh > threshold] = 255
                im_thresh = cv2.bitwise_not(im_thresh)

                im_thresh = dataset.center_extent(im_thresh, (28, 28))

                training_sample = im_thresh.copy()
                training_samples.append((training_sample.reshape((784, 1)), label))
                if 'rot90' in transformations:
                    training_samples.append((np.rot90(training_sample, 3).reshape((784, 1)), label))
                if 'rot180' in transformations:
                    training_samples.append((np.rot90(training_sample, 2).reshape((784, 1)), label))
                if 'rot270' in transformations:
                    training_samples.append((np.rot90(training_sample, 1).reshape((784, 1)), label))
                if 'fliplr' in transformations:
                    training_samples.append((np.fliplr(training_sample).reshape((784, 1)), label))
                if 'flipud' in transformations:
                    training_samples.append((np.flipud(training_sample).reshape((784, 1)), label))

                # print i
                # cv2.imshow('roi', roi)
                # cv2.imshow('thresh', im_thresh)
                # cv2.waitKey(0)

                color = (0, 0, 255)

                cv2.rectangle(im, (x, y), (x + w, y + h), color, 2)
            i += 1

    print problem_id, i
def classify_problem(problem_id):
    try:
        problem = Problem.objects.get(id=problem_id)
    except Problem.DoesNotExist:
        problem = None
        return []

    global nn, f_output

    im = cv2.imread(problem.image.path)

    im_height, im_width = im.shape[:2]

    im_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)

    im_blurred = cv2.GaussianBlur(im_gray, (9, 9), 0)

    im_edged = cv2.Canny(im_blurred, 30, 150)

    (im_contoured, contours, _) = cv2.findContours(im_edged.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    contours = sorted([(contour, cv2.boundingRect(contour)[0]) for contour in contours], key =lambda x: x[1])

    rectangles = [cv2.boundingRect(contour[0]) for contour in contours]

    leeway = 20
    rectangles, bad_rectangles = valid_rectangles(rectangles, leeway)

    while bad_rectangles > 0:
        rectangles, bad_rectangles = valid_rectangles(rectangles, leeway)

    classification = ''
    for rectangle in rectangles:
        x, y, w, h = rectangle

        if w <= im_width / 2 and h <= im_height / 2:
            roi = im_gray[y: y + h, x: x + w]
            im_thresh = roi.copy()
            threshold = mahotas.thresholding.otsu(roi)
            im_thresh[im_thresh > threshold] = 255
            im_thresh = cv2.bitwise_not(im_thresh)

            im_thresh = dataset.center_extent(im_thresh, (20, 20))

            # CNN
            # im_thresh = cv2.copyMakeBorder(im_thresh , 4, 4, 4, 4, cv2.BORDER_CONSTANT, value=(0, 0, 0)) #28 x 28 image
            # network_input = im_thresh.reshape([1, 1, 28, 28])
            # network_output = f_output(network_input).reshape([14]).tolist()
            #
            # confidence = max(network_output)
            # prediction = network_output.index(confidence)
            # character = possible_values[prediction]
            #
            # print 'Predicted: {} with {} confidence.'.format(character, confidence)
            #
            # print '-' * 100
            # color = (128, 128, 128)
            # if confidence >= .98:
            #     color = (255, 0, 0)
            # elif .5 < confidence < .98:
            #     color = (0, 255, 0)
            # else:
            #     color = (0, 0, 255)

            # SVM
            hist = hog.describe(im_thresh)
            output_vector = svm.predict_proba(hist).reshape([14]).tolist()
            confidence = max(output_vector)
            prediction = output_vector.index(confidence)
            character = possible_values[prediction]

            color = (128, 128, 128)
            if confidence >= .98:
                color = (255, 0, 0)
            elif .5 < confidence < .98:
                color = (0, 255, 0)
            else:
                color = (0, 0, 255)

            cv2.rectangle(im, (x, y), (x + w, y + h), color, 2)

            cv2.putText(im, str(character), (x - 10 + w / 2, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1)

            classification += str(character) + ' '

    classification = classification[:-1]
    print '-> Classification: ', classification

    if __name__ == '__main__':
        cv2.imshow('image', im)
        cv2.waitKey(0)

    # Saves processed image
    image_path, extension = os.path.splitext(problem.image.path)
    processed_image_path = image_path + '_processed' + extension
    cv2.imwrite(processed_image_path, im)

    # Updated the problem object with the processed photo
    opened = open(processed_image_path, 'rb')
    processed_image = File(opened)
    problem.processed_image.save(processed_image_path, processed_image)
    problem.save()
    os.remove(processed_image_path)

    classification = classification.strip('\r\n')

    prediction_solution = ''

    try:
        prediction_solution = eval(classification)
    except SyntaxError:
        prediction_solution = 'Error'
    finally:
        prediction_solution = str(prediction_solution)

    return classification, prediction_solution