def _apply_config(self, neural_net): """ This function actually does the work. """ if not isinstance(neural_net, NeuralNet): raise ValueError("neural_net must be of the NeuralNet class.") for snode in self.get_source_nodes(neural_net): prev_copy_node = None for level in xrange(self.copy_levels): copy_node = CopyNode() if level == 0: copy_node.set_source_node(snode) else: copy_node.set_source_node(prev_copy_node) copy_node.source_update_config(self.source_type, self.incoming_weight, self.existing_weight) copy_node.set_activation_type(self.activation_type) if self.connection_type == 'm': self._fully_connect(copy_node, self.get_upper_nodes(neural_net)) elif self.connection_type == 's': copy_node.add_input_connection(Connection( copy_node, snode)) else: raise ValueError("Invalid connection_type") neural_net.layers[self.copy_nodes_layer].add_node(copy_node) prev_copy_node = copy_node
def _apply_config(self, neural_net): """ This function actually does the work. """ if not isinstance(neural_net, NeuralNet): raise ValueError("neural_net must be of the NeuralNet class.") for snode in self.get_source_nodes(neural_net): prev_copy_node = None for level in xrange(self.copy_levels): copy_node = CopyNode() if level == 0: copy_node.set_source_node(snode) else: copy_node.set_source_node(prev_copy_node) copy_node.source_update_config( self.source_type, self.incoming_weight, self.existing_weight) copy_node.set_activation_type(self.activation_type) if self.connection_type == 'm': self._fully_connect( copy_node, self.get_upper_nodes(neural_net)) elif self.connection_type == 's': copy_node.add_input_connection( Connection(copy_node, snode)) else: raise ValueError("Invalid connection_type") neural_net.layers[self.copy_nodes_layer].add_node(copy_node) prev_copy_node = copy_node
class CopyNodeTest(unittest.TestCase): """ Tests CopyNode """ def setUp(self): self.node = CopyNode() def test__init__(self): self.assertEqual(NODE_COPY, self.node.node_type) def test_set_source_node(self): source_node = Node() self.node.set_source_node(source_node) self.assertEqual(source_node, self.node._source_node) def test_get_source_node(self): self.node._source_node = Node() self.assertEqual(self.node._source_node, self.node.get_source_node()) def test_load_source_value(self): self.node._value = .25 self.node._existing_weight = .25 self.node._incoming_weight = .5 source_node = Node() source_node.set_value(.3) source_node.set_activation_type(ACTIVATION_SIGMOID) self.node.set_source_node(source_node) # activate self.node._source_type = 'a' self.node.load_source_value() self.assertAlmostEqual(sigmoid(.3) * .5 + .25 * .25, self.node._value) # value self.node._value = .25 self.node._source_type = 'v' self.node.load_source_value() self.assertAlmostEqual(.3 * .5 + .25 * .25, self.node._value) # invalid source type self.node._source_type = 'f' self.failUnlessRaises(ValueError, self.node.load_source_value) def test_get_source_type(self): self.node._source_type = 'a' self.assertEqual('a', self.node.get_source_type()) def test_get_incoming_weight(self): self.node._incoming_weight = .3 self.assertAlmostEqual(.3, self.node.get_incoming_weight()) def test_get_existing_weight(self): self.node._existing_weight = .3 self.assertAlmostEqual(.3, self.node.get_existing_weight()) def test_source_update_config(self): self.node.source_update_config('a', .3, .2) self.assertEqual('a', self.node._source_type) self.assertAlmostEqual(.3, self.node._incoming_weight) self.assertAlmostEqual(.2, self.node._existing_weight) self.failUnlessRaises(ValueError, self.node.source_update_config, 'e', .3, .2) self.failUnlessRaises(ValueError, self.node.source_update_config, 'a', 1.3, .2) self.failUnlessRaises(ValueError, self.node.source_update_config, 'a', .3, 1.2)