コード例 #1
0
 def test_call_none_context_hidden_layer_dims(self):
   network = gam.GAMRankingNetwork(
       context_feature_columns=self.context_feature_columns,
       example_feature_columns=self.example_feature_columns,
       example_hidden_layer_dims=["10", "10"],
       context_hidden_layer_dims=None,
       activation=tf.nn.relu,
       dropout=0.5,
       name="gam_ranking_model")
   logits = network(inputs=self.features, mask=[[True, False], [True, True]])
   self.assertAllEqual([2, 2], logits.get_shape().as_list())
コード例 #2
0
 def test_call_empty_example_hidden_layer_dims(self):
     with self.assertRaisesRegexp(
             ValueError, r"example_feature_columns or "
             "example_hidden_layer_dims must not be empty."):
         gam.GAMRankingNetwork(
             context_feature_columns=self.context_feature_columns,
             example_feature_columns=self.example_feature_columns,
             example_hidden_layer_dims=[],
             context_hidden_layer_dims=["10", "10"],
             activation=tf.nn.relu,
             dropout=0.5,
             name="gam_ranking_model")
コード例 #3
0
 def setUp(self):
     super(GAMRankingNetworkTest, self).setUp()
     self.context_feature_columns = _context_feature_columns()
     self.example_feature_columns = _example_feature_columns()
     self.features = _features()
     self.network = gam.GAMRankingNetwork(
         context_feature_columns=self.context_feature_columns,
         example_feature_columns=self.example_feature_columns,
         example_hidden_layer_dims=["10", "10"],
         context_hidden_layer_dims=["10", "10"],
         activation=tf.nn.relu,
         dropout=0.5,
         name="gam_ranking_model")