def testInferenceConstruction(self):
        # pylint: disable=W0612
        data = constant_op.constant(
            [[random.uniform(-1, 1) for i in range(self.params.num_features)]
             for _ in range(100)])

        with variable_scope.variable_scope(
                "DecisionsToDataThenNNTest_testInferenceConstruction"):
            graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
                self.params)
            graph = graph_builder.inference_graph(data, None)

            self.assertTrue(isinstance(graph, Tensor))
    def testTrainingConstruction(self):
        # pylint: disable=W0612
        data = constant_op.constant(
            [[random.uniform(-1, 1) for i in range(self.params.num_features)]
             for _ in range(100)])

        labels = [1 for _ in range(100)]

        with variable_scope.variable_scope(
                "DecisionsToDataThenNNTest_testTrainingConstruction"):
            graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
                self.params)
            graph = graph_builder.training_graph(data, labels, None)

            self.assertTrue(isinstance(graph, Operation))
  def testHParams(self):
    self.assertEquals(self.params.num_classes, 2)
    self.assertEquals(self.params.num_features, 31)
    self.assertEquals(self.params.layer_size, 11)
    self.assertEquals(self.params.num_layers, 13)
    self.assertEquals(self.params.num_trees, 17)
    self.assertEquals(self.params.hybrid_tree_depth, 4)
    self.assertEquals(self.params.connection_probability, 0.1)

    # Building the graphs modifies the params.
    with variable_scope.variable_scope("DecisionsToDataThenNNTest_testHParams"):
      # pylint: disable=W0612
      graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
          self.params)

      # Tree with depth 4 should have 2**0 + 2**1 + 2**2 + 2**3 = 15 nodes.
      self.assertEquals(self.params.num_nodes, 15)
  def testConstructionPollution(self):
    """Ensure that graph building doesn't modify the params in a bad way."""
    # pylint: disable=W0612
    data = [[random.uniform(-1, 1) for i in range(self.params.num_features)]
            for _ in range(100)]

    self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams))
    self.assertFalse(
        isinstance(self.params.num_trees, tensor_forest.ForestHParams))

    with variable_scope.variable_scope(
        "DecisionsToDataThenNNTest_testContructionPollution"):
      graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
          self.params)

      self.assertTrue(isinstance(self.params, tensor_forest.ForestHParams))
      self.assertFalse(
          isinstance(self.params.num_trees, tensor_forest.ForestHParams))