Exemplo n.º 1
0
    def testObservesWrappedFunction(self):
        activation_module = base.Module(tf.nn.relu)
        with base.observe_connections(self._connection_observer):
            outputs = activation_module(self._inputs)

        self.assertEqual(1, len(self._connected_subgraphs))

        self.assertIs(activation_module, self._connected_subgraphs[0].module)
        self.assertIs(self._inputs,
                      self._connected_subgraphs[0].inputs["args"][0])
        self.assertIs(self._connected_subgraphs[0].outputs, outputs)
Exemplo n.º 2
0
    def testObservesSimpleModule(self):
        simple_module = SimpleModule()
        with base.observe_connections(self._connection_observer):
            outputs = simple_module(self._inputs)

        self.assertEqual(1, len(self._connected_subgraphs))

        self.assertIs(simple_module, self._connected_subgraphs[0].module)
        self.assertIs(self._inputs,
                      self._connected_subgraphs[0].inputs["inputs"])
        self.assertIs(self._connected_subgraphs[0].outputs, outputs)
Exemplo n.º 3
0
  def testObservesComplexModule(self):
    complex_module = ComplexModule()
    with base.observe_connections(self._connection_observer):
      outputs = complex_module(self._inputs)

    self.assertEqual(3, len(self._connected_subgraphs))

    self.assertIsInstance(self._connected_subgraphs[0].module, SimpleModule)
    self.assertIs(self._inputs, self._connected_subgraphs[0].inputs["inputs"])

    self.assertIsInstance(self._connected_subgraphs[1].module, SimpleModule)
    self.assertIs(self._connected_subgraphs[0].outputs,
                  self._connected_subgraphs[1].inputs["inputs"])
    self.assertIs(self._connected_subgraphs[1].outputs, outputs)

    self.assertIs(complex_module, self._connected_subgraphs[2].module)
    self.assertIs(self._connected_subgraphs[2].outputs, outputs)