예제 #1
0
    def testTrainingConstructionClassificationSparse(self):
        input_data = sparse_tensor.SparseTensor(
            indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],
            values=[-1.0, 0.0, -1., 2., 1., -2.0],
            dense_shape=[4, 10])
        input_labels = [0, 1, 2, 3]

        params = tensor_forest.ForestHParams(num_classes=4,
                                             num_features=10,
                                             num_trees=10,
                                             max_nodes=1000,
                                             split_after_samples=25).fill()

        graph_builder = tensor_forest.RandomForestGraphs(params)
        graph = graph_builder.training_graph(input_data, input_labels)
        self.assertTrue(isinstance(graph, ops.Operation))
예제 #2
0
    def testTrainingConstructionClassification(self):
        input_data = [
            [-1., 0.],
            [-1., 2.],  # node 1
            [1., 0.],
            [1., -2.]
        ]  # node 2
        input_labels = [0, 1, 2, 3]

        params = tensor_forest.ForestHParams(num_classes=4,
                                             num_features=2,
                                             num_trees=10,
                                             max_nodes=1000,
                                             split_after_samples=25).fill()

        graph_builder = tensor_forest.RandomForestGraphs(params)
        graph = graph_builder.training_graph(input_data, input_labels)
        self.assertTrue(isinstance(graph, ops.Operation))
예제 #3
0
    def testInferenceConstructionSparse(self):
        input_data = sparse_tensor.SparseTensor(
            indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],
            values=[-1.0, 0.0, -1., 2., 1., -2.0],
            dense_shape=[4, 10])

        params = tensor_forest.ForestHParams(num_classes=4,
                                             num_features=10,
                                             num_trees=10,
                                             max_nodes=1000,
                                             regression=True,
                                             split_after_samples=25).fill()

        graph_builder = tensor_forest.RandomForestGraphs(params)
        probs, paths, var = graph_builder.inference_graph(input_data)
        self.assertTrue(isinstance(probs, ops.Tensor))
        self.assertTrue(isinstance(paths, ops.Tensor))
        self.assertTrue(isinstance(var, ops.Tensor))
예제 #4
0
    def testInferenceConstruction(self):
        input_data = [
            [-1., 0.],
            [-1., 2.],  # node 1
            [1., 0.],
            [1., -2.]
        ]  # node 2

        params = tensor_forest.ForestHParams(num_classes=4,
                                             num_features=2,
                                             num_trees=10,
                                             max_nodes=1000,
                                             split_after_samples=25).fill()

        graph_builder = tensor_forest.RandomForestGraphs(params)
        probs, paths, var = graph_builder.inference_graph(input_data)
        self.assertTrue(isinstance(probs, ops.Tensor))
        self.assertTrue(isinstance(paths, ops.Tensor))
        self.assertTrue(isinstance(var, ops.Tensor))
예제 #5
0
 def testInfrenceFromRestoredModel(self):
     input_data = [
         [-1., 0.],
         [-1., 2.],  # node 1
         [1., 0.],
         [1., -2.]
     ]  # node 2
     expected_prediction = [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]
     hparams = tensor_forest.ForestHParams(num_classes=2,
                                           num_features=2,
                                           num_trees=1,
                                           max_nodes=1000,
                                           split_after_samples=25).fill()
     tree_weight = {
         'decisionTree': {
             'nodes': [{
                 'binaryNode': {
                     'rightChildId': 2,
                     'leftChildId': 1,
                     'inequalityLeftChildTest': {
                         'featureId': {
                             'id': '0'
                         },
                         'threshold': {
                             'floatValue': 0
                         }
                     }
                 }
             }, {
                 'leaf': {
                     'vector': {
                         'value': [{
                             'floatValue': 0.0
                         }, {
                             'floatValue': 1.0
                         }]
                     }
                 },
                 'nodeId': 1
             }, {
                 'leaf': {
                     'vector': {
                         'value': [{
                             'floatValue': 0.0
                         }, {
                             'floatValue': 1.0
                         }]
                     }
                 },
                 'nodeId': 2
             }]
         }
     }
     restored_tree_param = ParseDict(
         tree_weight, _tree_proto.Model()).SerializeToString()
     graph_builder = tensor_forest.RandomForestGraphs(
         hparams, [restored_tree_param])
     probs, paths, var = graph_builder.inference_graph(input_data)
     self.assertTrue(isinstance(probs, ops.Tensor))
     self.assertTrue(isinstance(paths, ops.Tensor))
     self.assertTrue(isinstance(var, ops.Tensor))
     with self.cached_session():
         variables.global_variables_initializer().run()
         resources.initialize_resources(resources.shared_resources()).run()
         self.assertEquals(probs.eval().shape, (4, 2))
         self.assertEquals(probs.eval().tolist(), expected_prediction)