def __init__(self, layers, activation_function, activation_args): """ Base class providing the skeleton of a network """ self.layers = len(layers) if (self.layers < 2): raise ValueError('Network must contain at least 2 layers, {count} given'.format(count = self.layers)) self.layer_counts = layers for index, layer_size in enumerate(self.layer_counts): if (not (layer_size > 0)): raise ValueError('Network layers must contain at least 1 node, layer {bad_layer} specifies {bad_count}'.format( bad_layer = index, bad_count = layer_size)) self.inputs = int(self.layer_counts[0]) self.outputs = int(self.layer_counts[-1]) self.edge_weights = GenerateDefaultEdgeWeights(self) self.edge_update_deltas = [] self.update_edge_weights(self.edge_weights) self.bias_weights = GenerateDefaultBiasWeights(self) self.bias_update_deltas = [] self.update_bias_weights(self.bias_weights) try: self.activator = ActivationFunctions.construct(activation_function, **activation_args) except Exception as e: raise ValueError('Failed to construct activation function, \'{af}\' with args \'{args}\''.format(af = activation_function, args = str(activation_args)))
def test_returns_same_size(self): np.random.seed(0) elements = 30 for r in range(1,elements): for c in range(1,elements): Q = np.random.rand(r,c) F = ActivationFunctions.construct('sigmoidal', beta = 1) A = F.call(Q) self.assertTrue(A.shape == Q.shape)
#!/usr/bin/env python import sys, os whereami = os.path.abspath(os.path.dirname(__file__)) sys.path.append('{wai}/..'.format(wai = whereami)) from sanity.networks.perceptron import * from sanity.common.util.training_tools import GenerateTrainingData from sanity.common.util.loggers import * from sanity.common.lib.activation_functions import ActivationFunctions print "Available activation functions:" print "-------------------------------" for (key, val) in ActivationFunctions.get_function_map().items(): print '{key:<15} : {val}'.format(key = key, val = val) print '' print 'Function reference table' print "-------------------------------\n" afs = ActivationFunctions.get_function_reference() for (key, val) in afs.items(): print '{wat}'.format(wat = val)