def testInferenceConstruction(self): # pylint: disable=W0612 data = constant_op.constant( [[random.uniform(-1, 1) for i in range(self.params.num_features)] for _ in range(100)]) with variable_scope.variable_scope( "DecisionsToDataThenNNTest_testInferenceConstruction"): graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( self.params) graph = graph_builder.inference_graph(data, None) self.assertTrue(isinstance(graph, Tensor))
def testTrainingConstruction(self): # pylint: disable=W0612 data = constant_op.constant( [[random.uniform(-1, 1) for i in range(self.params.num_features)] for _ in range(100)]) labels = [1 for _ in range(100)] with variable_scope.variable_scope( "DecisionsToDataThenNNTest_testTrainingConstruction"): graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( self.params) graph = graph_builder.training_graph(data, labels, None) self.assertTrue(isinstance(graph, Operation))
def testHParams(self): self.assertEquals(self.params.num_classes, 2) self.assertEquals(self.params.num_features, 31) self.assertEquals(self.params.layer_size, 11) self.assertEquals(self.params.num_layers, 13) self.assertEquals(self.params.num_trees, 17) self.assertEquals(self.params.hybrid_tree_depth, 4) self.assertEquals(self.params.connection_probability, 0.1) # Building the graphs modifies the params. with variable_scope.variable_scope("DecisionsToDataThenNNTest_testHParams"): # pylint: disable=W0612 graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( self.params) # Tree with depth 4 should have 2**0 + 2**1 + 2**2 + 2**3 = 15 nodes. self.assertEquals(self.params.num_nodes, 15)
def testConstructionPollution(self): """Ensure that graph building doesn't modify the params in a bad way.""" # pylint: disable=W0612 data = [[random.uniform(-1, 1) for i in range(self.params.num_features)] for _ in range(100)] self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams)) self.assertFalse( isinstance(self.params.num_trees, tensor_forest.ForestHParams)) with variable_scope.variable_scope( "DecisionsToDataThenNNTest_testContructionPollution"): graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( self.params) self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams)) self.assertFalse( isinstance(self.params.num_trees, tensor_forest.ForestHParams))