Exemplo n.º 1
0
 def setUp(self):
     super(DNNRankingNetworkTest, self).setUp()
     self.context_feature_columns = _context_feature_columns()
     self.example_feature_columns = _example_feature_columns()
     self.features = _features()
     self.network = dnn.DNNRankingNetwork(
         context_feature_columns=self.context_feature_columns,
         example_feature_columns=self.example_feature_columns,
         hidden_layer_dims=["10", "10", "10"],
         activation=tf.nn.relu,
         dropout=0.5)
Exemplo n.º 2
0
 def test_call_empty_example_hidden_layer_dims(self):
   with self.assertRaisesRegexp(
       ValueError, r"example_feature_columns or "
       "hidden_layer_dims must not be empty."):
     dnn.DNNRankingNetwork(
         context_feature_columns=self.context_feature_columns,
         example_feature_columns=self.example_feature_columns,
         hidden_layer_dims=[],
         activation=tf.nn.relu,
         dropout=0.5,
         name="dnn_ranking_network")
Exemplo n.º 3
0
 def test_call_none_context_feature_columns(self):
     network = dnn.DNNRankingNetwork(
         context_feature_columns=None,
         example_feature_columns=self.example_feature_columns,
         hidden_layer_dims=["10", "10"],
         activation=tf.nn.relu,
         dropout=0.5,
         name="dnn_ranking_network")
     logits = network(inputs=self.features,
                      mask=[[True, False], [True, True]])
     self.assertAllEqual([2, 2], logits.get_shape().as_list())