示例#1
0
 def test_weighted_linear_combo(self):
   """Test that WeightedLinearCombo can be invoked."""
   batch_size = 10
   n_features = 5
   in_tensor_1 = np.random.rand(batch_size, n_features)
   in_tensor_2 = np.random.rand(batch_size, n_features)
   with self.session() as sess:
     in_tensor_1 = tf.convert_to_tensor(in_tensor_1, dtype=tf.float32)
     in_tensor_2 = tf.convert_to_tensor(in_tensor_2, dtype=tf.float32)
     out_tensor = WeightedLinearCombo()(in_tensor_1, in_tensor_2)
     sess.run(tf.global_variables_initializer())
     out_tensor = out_tensor.eval()
     assert out_tensor.shape == (batch_size, n_features)
示例#2
0
 def test_weighted_linear_combo(self):
     """Test that WeightedLinearCombo can be invoked."""
     batch_size = 10
     n_features = 5
     in_tensor_1 = np.random.rand(batch_size, n_features)
     in_tensor_2 = np.random.rand(batch_size, n_features)
     with self.session() as sess:
         in_tensor_1 = tf.convert_to_tensor(in_tensor_1, dtype=tf.float32)
         in_tensor_2 = tf.convert_to_tensor(in_tensor_2, dtype=tf.float32)
         out_tensor = WeightedLinearCombo()(in_tensor_1, in_tensor_2)
         sess.run(tf.global_variables_initializer())
         out_tensor = out_tensor.eval()
         assert out_tensor.shape == (batch_size, n_features)
示例#3
0
  def test_weighted_combo(self):
    """Tests that weighted linear combinations can be built"""
    N = 10
    n_features = 5

    X1 = NumpyDataset(np.random.rand(N, n_features))
    X2 = NumpyDataset(np.random.rand(N, n_features))
    y = NumpyDataset(np.random.rand(N))

    features_1 = Feature(shape=(None, n_features))
    features_2 = Feature(shape=(None, n_features))
    labels = Label(shape=(None,))

    combo = WeightedLinearCombo(in_layers=[features_1, features_2])
    out = ReduceSum(in_layers=[combo], axis=1)
    loss = ReduceSquareDifference(in_layers=[out, labels])

    databag = Databag({features_1: X1, features_2: X2, labels: y})

    tg = dc.models.TensorGraph(learning_rate=0.1, use_queue=False)
    tg.set_loss(loss)
    tg.fit_generator(databag.iterbatches(epochs=1))