class EvolinoNetwork(Module):
    def __init__(self, indim, outdim, hiddim=6):
        Module.__init__(self, indim, outdim)

        self._network = Network()
        self._in_layer = LinearLayer(indim + outdim)
        self._hid_layer = LSTMLayer(hiddim)
        self._out_layer = LinearLayer(outdim)
        self._bias = BiasUnit()


        self._hid_to_out_connection = FullConnection(self._hid_layer,
        self._in_to_hid_connection = FullConnection(self._in_layer,


        self.offset = self._network.offset
        self.backprojectionFactor = 0.01

    def reset(self):

    def _washout(self, input, target, first_idx=None, last_idx=None):
        assert self.indim == len(input[0])
        assert self.outdim == len(target[0])
        assert len(input) == len(target)

        if first_idx is None: first_idx = 0
        if last_idx is None: last_idx = len(target) - 1
        raw_outputs = []
        for i in xrange(first_idx, last_idx + 1):
            backprojection = self._getLastOutput()
            backprojection *= self.backprojectionFactor
            full_inp = self._createFullInput(input[i], backprojection)
            raw_out = self._getRawOutput()
            #            print "RAWOUT: ", full_inp, " --> ", raw_out, self._getLastOutput()

        return array(raw_outputs)

    def _activateNetwork(self, input):
        assert len(input) == self._network.indim
        output = self._network.activate(input)
        self.offset = self._network.offset
        #        print "INNNNNNN=", input, "   OUTPP=", output
        return output

    def activate(self, input):
        assert len(input) == self.indim

        backprojection = self._getLastOutput()
        backprojection *= self.backprojectionFactor
        full_inp = self._createFullInput(input, backprojection)
        out = self._activateNetwork(full_inp)
        #        print "AAAAAACT: ", full_inp, "-->", out

        #        self._setLastOutput(last_out*5)

        return out

    def calculateOutput(self, dataset, washout_calculation_ratio=(1, 2)):
        washout_calculation_ratio = array(washout_calculation_ratio, float)
        ratio = washout_calculation_ratio / sum(washout_calculation_ratio)

        # iterate through all sequences
        collected_input = None
        collected_output = None
        collected_target = None
        for i in range(dataset.getNumSequences()):

            seq = dataset.getSequence(i)
            input = seq[0]
            target = seq[1]

            washout_steps = int(len(input) * ratio[0])

            washout_input = input[:washout_steps]
            washout_target = target[:washout_steps]
            calculation_target = target[washout_steps:]

            # reset

            # washout
            self._washout(washout_input, washout_target)

            # collect calculation data
            outputs = []
            inputs = []
            #            for i in xrange(washout_steps, len(input)):
            for inp in input[washout_steps:]:
                out = self.activate(inp)
                #                    print out
                #                print inp

            # collect output and targets
            if collected_input is not None:
                collected_input = append(collected_input, inputs, axis=0)
                collected_input = array(inputs)

#            print collected_input; exit()

            if collected_output is not None:
                collected_output = append(collected_output, outputs, axis=0)
                collected_output = array(outputs)

            if collected_target is not None:
                collected_target = append(collected_target,
                collected_target = calculation_target

        return collected_input, collected_output, collected_target

    def _createFullInput(self, input, output):
        if self.indim > 0:
            return append(input, output)
            return array(output)

    def _getLastOutput(self):
        if self.offset == 0:
            return zeros(self.outdim)
            return self._out_layer.outputbuffer[self.offset - 1]

    def _setLastOutput(self, output):
        self._out_layer.outputbuffer[self.offset - 1][:] = output

    # ======================================================== Genome related ===

    def _validateGenomeLayer(self, layer):
        """ Validates the type and state of a layer
        assert isinstance(layer, LSTMLayer)
        assert not layer.peepholes

    def getGenome(self):
        """ Returns the Genome of the network.
            See class description for more details.
        return self._getGenomeOfLayer(self._hid_layer)

    def setGenome(self, weights):
        """ Sets the Genome of the network.
            See class description for more details.
        weights = deepcopy(weights)
        self._setGenomeOfLayer(self._hid_layer, weights)

    def _getGenomeOfLayer(self, layer):
        """ Returns the genome of a single layer.

        dim = layer.outdim
        layer_weights = []

        connections = self._getInputConnectionsOfLayer(layer)

        for cell_idx in range(dim):
            # todo: the evolino paper uses a different order of weights for the genotype of a lstm cell
            cell_weights = []
            for c in connections:
                cell_weights += [
                    c.params[cell_idx + 0 * dim], c.params[cell_idx + 1 * dim],
                    c.params[cell_idx + 2 * dim], c.params[cell_idx + 3 * dim]

        return layer_weights

    def _setGenomeOfLayer(self, layer, weights):
        """ Sets the genome of a single layer.

        dim = layer.outdim

        connections = self._getInputConnectionsOfLayer(layer)

        for cell_idx in range(dim):
            cell_weights = weights.pop(0)
            for c in connections:
                params = c.params
                params[cell_idx + 0 * dim] = cell_weights.pop(0)
                params[cell_idx + 1 * dim] = cell_weights.pop(0)
                params[cell_idx + 2 * dim] = cell_weights.pop(0)
                params[cell_idx + 3 * dim] = cell_weights.pop(0)
            assert not len(cell_weights)

    # ============================================ Linear Regression related ===

    def setOutputWeightMatrix(self, W):
        """ Sets the weight matrix of the output layer's input connection.
        c = self._hid_to_out_connection
        c.params[:] = W.flatten()

    def getOutputWeightMatrix(self):
        """ Sets the weight matrix of the output layer's input connection.
        c = self._hid_to_out_connection
        p = c.getParameters()
        return reshape(p, (c.outdim, c.indim))

    def _getRawOutput(self):
        """ Returns the current output of the last hidden layer.
            This is needed for linear regression, which calculates
            the weight matrix W of the full connection between this layer
            and the output layer.
        return copy(self._hid_layer.outputbuffer[self.offset - 1])

    # ====================================================== Topology Helper ===

    def _getInputConnectionsOfLayer(self, layer):
        """ Returns a list of all input connections for the layer. """
        connections = []
        for c in sum(self._network.connections.values(), []):
            if c.outmod is layer:
                if not isinstance(c, FullConnection):
                    raise NotImplementedError(
                        "At the time there is only support for FullConnection")
        return connections
class EvolinoNetwork(Module):
    def __init__(self, indim, outdim, hiddim=6):
        Module.__init__(self, indim, outdim)

        self._network = Network()
        self._in_layer = LinearLayer(indim + outdim)
        self._hid_layer = LSTMLayer(hiddim)
        self._out_layer = LinearLayer(outdim)
        self._bias = BiasUnit()


        self._hid_to_out_connection = FullConnection(self._hid_layer , self._out_layer)
        self._in_to_hid_connection = FullConnection(self._in_layer  , self._hid_layer)
        self._network.addConnection(FullConnection(self._bias, self._hid_layer))


        self.time = self._network.time
        self.backprojectionFactor = 0.01

    def reset(self):

    def _washout(self, input, target, first_idx=None, last_idx=None):
        assert self.indim == len(input[0])
        assert self.outdim == len(target[0])
        assert len(input) == len(target)

        if first_idx is None: first_idx = 0
        if last_idx  is None: last_idx = len(target) - 1
        raw_outputs = []
        for i in xrange(first_idx, last_idx + 1):
            backprojection = self._getLastOutput()
            backprojection *= self.backprojectionFactor
            full_inp = self._createFullInput(input[i], backprojection)
            raw_out = self._getRawOutput()
#            print "RAWOUT: ", full_inp, " --> ", raw_out, self._getLastOutput()

        return array(raw_outputs)

    def _activateNetwork(self, input):
        assert len(input) == self._network.indim
        output = self._network.activate(input)
        self.time = self._network.time
#        print "INNNNNNN=", input, "   OUTPP=", output
        return output

    def activate(self, input):
        assert len(input) == self.indim

        backprojection = self._getLastOutput()
        backprojection *= self.backprojectionFactor
        full_inp = self._createFullInput(input, backprojection)
        out = self._activateNetwork(full_inp)
#        print "AAAAAACT: ", full_inp, "-->", out

#        self._setLastOutput(last_out*5)

        return out

    def calculateOutput(self, dataset, washout_calculation_ratio=(1, 2)):
        washout_calculation_ratio = array(washout_calculation_ratio, float)
        ratio = washout_calculation_ratio / sum(washout_calculation_ratio)

        # iterate through all sequences
        collected_input = None
        collected_output = None
        collected_target = None
        for i in range(dataset.getNumSequences()):

            seq = dataset.getSequence(i)
            input = seq[0]
            target = seq[1]

            washout_steps = int(len(input) * ratio[0])

            washout_input = input  [               : washout_steps ]
            washout_target = target [               : washout_steps ]
            calculation_target = target [ washout_steps :               ]

            # reset

            # washout
            self._washout(washout_input, washout_target)

            # collect calculation data
            outputs = []
            inputs = []
#            for i in xrange(washout_steps, len(input)):
            for inp in input[washout_steps:]:
                out = self.activate(inp)
#                    print out
#                print inp

            # collect output and targets
            if collected_input is not None:
                collected_input = append(collected_input, inputs, axis=0)
                collected_input = array(inputs)
#            print collected_input; exit()

            if collected_output is not None:
                collected_output = append(collected_output, outputs, axis=0)
                collected_output = array(outputs)

            if collected_target is not None:
                collected_target = append(collected_target, calculation_target, axis=0)
                collected_target = calculation_target

        return collected_input, collected_output, collected_target

    def _createFullInput(self, input, output):
        if self.indim > 0:
            return append(input, output)
            return array(output)

    def _getLastOutput(self):
        if self.time == 0:
            return zeros(self.outdim)
            return self._out_layer.outputbuffer[self.time - 1]

    def _setLastOutput(self, output):
        self._out_layer.outputbuffer[self.time - 1][:] = output

    # ======================================================== Genome related ===

    def _validateGenomeLayer(self, layer):
        """ Validates the type and state of a layer
        assert isinstance(layer, LSTMLayer)
        assert not layer.peepholes

    def getGenome(self):
        """ Returns the Genome of the network.
            See class description for more details.
        return self._getGenomeOfLayer(self._hid_layer)

    def setGenome(self, weights):
        """ Sets the Genome of the network.
            See class description for more details.
        weights = deepcopy(weights)
        self._setGenomeOfLayer(self._hid_layer, weights)

    def _getGenomeOfLayer(self, layer):
        """ Returns the genome of a single layer.

        dim = layer.outdim
        layer_weights = []

        connections = self._getInputConnectionsOfLayer(layer)

        for cell_idx in range(dim):
            # todo: the evolino paper uses a different order of weights for the genotype of a lstm cell
            cell_weights = []
            for c in connections:
                cell_weights += [
                    c.params[ cell_idx + 0 * dim ],
                    c.params[ cell_idx + 1 * dim ],
                    c.params[ cell_idx + 2 * dim ],
                    c.params[ cell_idx + 3 * dim ] ]

        return layer_weights

    def _setGenomeOfLayer(self, layer, weights):
        """ Sets the genome of a single layer.

        dim = layer.outdim

        connections = self._getInputConnectionsOfLayer(layer)

        for cell_idx in range(dim):
            cell_weights = weights.pop(0)
            for c in connections:
                params = c.params
                params[cell_idx + 0 * dim] = cell_weights.pop(0)
                params[cell_idx + 1 * dim] = cell_weights.pop(0)
                params[cell_idx + 2 * dim] = cell_weights.pop(0)
                params[cell_idx + 3 * dim] = cell_weights.pop(0)
            assert not len(cell_weights)

    # ============================================ Linear Regression related ===

    def setOutputWeightMatrix(self, W):
        """ Sets the weight matrix of the output layer's input connection.
        c = self._hid_to_out_connection
        c.params[:] = W.flatten()

    def getOutputWeightMatrix(self):
        """ Sets the weight matrix of the output layer's input connection.
        c = self._hid_to_out_connection
        p = c.getParameters()
        return reshape(p, (c.outdim, c.indim))

    def _getRawOutput(self):
        """ Returns the current output of the last hidden layer.
            This is needed for linear regression, which calculates
            the weight matrix W of the full connection between this layer
            and the output layer.
        return copy(self._hid_layer.outputbuffer[self.time - 1])

    # ====================================================== Topology Helper ===

    def _getInputConnectionsOfLayer(self, layer):
        """ Returns a list of all input connections for the layer. """
        connections = []
        for c in sum(self._network.connections.values(), []):
            if c.outmod is layer:
                if not isinstance(c, FullConnection):
                    raise NotImplementedError("At the time there is only support for FullConnection")
        return connections