def test_stop_gradients_default_params(self): """Tests for the default params of `utils_tf.stop_gradient`.""" stopped_gradients_graph = utils_tf.stop_gradient(self._graph) gradients_exist = self._check_if_gradients_exist( stopped_gradients_graph) expected_gradients_exist = [False, False, False] self.assertAllEqual(expected_gradients_exist, gradients_exist)
def test_stop_gradients_outputs(self, stop_globals, stop_nodes, stop_edges): stopped_gradients_graph = utils_tf.stop_gradient( self._graph, stop_globals=stop_globals, stop_nodes=stop_nodes, stop_edges=stop_edges) gradients_exist = self._check_if_gradients_exist(stopped_gradients_graph) expected_gradients_exist = [ not stop_globals, not stop_nodes, not stop_edges ] self.assertAllEqual(expected_gradients_exist, gradients_exist)
def test_stop_gradients_with_missing_field_raises(self, none_field): self._graph = self._graph.map(lambda _: None, [none_field]) with self.assertRaisesRegexp(ValueError, none_field): utils_tf.stop_gradient(self._graph)