class Mul(Node):
    def __init__(self, A, B):

        self.node_name = "Mul{}".format(rand_str())

        if type(A) is not Parameter:
            A = A.get_output()
        self.A = A

        if type(B) is not Parameter:
            B = B.get_output()
        self.B = B

        NeuralNetwork.NeuralNetwork.add_node(self)
        NeuralNetwork.NeuralNetwork.add_param(self.A)
        NeuralNetwork.NeuralNetwork.add_param(self.B)

        self.out = Parameter(is_placeholder=True, name=self.node_name + "_out")

    def forward(self):
        res = self.A.get_data().dot(self.B.get_data())
        self.out.set_data_(res)
        return self.out

    def get_output(self):
        return self.out

    def backward(self):
        inc_grad = self.out.get_grad()
        self.A.set_grad_(inc_grad.dot(self.B.get_data().T))
        self.B.set_grad_(self.A.get_data().T.dot(inc_grad))
        return

    def __str__(self):
        return self.node_name
class ReLU(Node):
    def __init__(self, A):

        self.node_name = "ReLU{}".format(rand_str())

        if type(A) is not Parameter:
            A = A.get_output()
        self.A = A

        NeuralNetwork.NeuralNetwork.add_node(self)
        NeuralNetwork.NeuralNetwork.add_param(self.A)

        self.out = Parameter(is_placeholder=True, name=self.node_name + "_out")

    def forward(self):
        res = relu(self.A.get_data())
        self.out.set_data_(res)
        return self.out

    def get_output(self):
        return self.out

    def backward(self):
        inc_grad = self.out.get_grad()
        local_grad = relu_grad(self.A.get_data())
        self.A.set_grad_(inc_grad * local_grad)
        return

    def __str__(self):
        return self.node_name
class SoftmaxCrossEnt(Node):
    def __init__(self, logits=None, labels=None):

        assert ( logits is not None and labels is not None ),\
                "Usage: cross_entropy( logits=.., labels=.. )"

        self.node_name = "Loss{}".format(rand_str())

        if type(logits) is not Parameter:
            logits = logits.get_output()
        self.logits = logits

        if type(labels) is not Parameter:
            labels = labels.get_output()
        self.labels = labels

        NeuralNetwork.NeuralNetwork.add_node(self)
        NeuralNetwork.NeuralNetwork.add_param(self.logits)
        NeuralNetwork.NeuralNetwork.add_param(self.labels)

        self.out = Parameter(is_placeholder=True, name=self.node_name + "_out")
        self.pred = Parameter(is_placeholder=True,
                              name=self.node_name + "_pred")

    def forward(self):
        logits = self.logits.get_data()
        labels = self.labels.get_data()
        pred = softmax(logits)
        res = cross_entropy(logits=pred, labels=labels)
        self.out.set_data_(res)
        self.pred.set_data_(pred)
        return self.out

    def get_output(self):
        return self.out

    def backward(self):
        logits = self.logits.get_data()
        labels = self.labels.get_data()
        self.logits.set_grad_(labels - logits)
        return

    def __str__(self):
        return self.node_name

    def get_pred(self):
        """Only for Softmax cross entropy"""
        return self.pred
class NeuralNetwork(object):
    """Do not modify these. Use the static functions associated"""
    node_list = []
    param_list = []

    def __init__(self, inp_shape, num_classes, hidden_layers, lr=0.01):
        self.inp_shape = inp_shape
        self.num_classes = num_classes
        self.hidden_layers = hidden_layers
        self.hidden_layers.append(num_classes)
        self.lr = lr

        NeuralNetwork.clear_param_list()
        self._build_network()
        self.node_list = NeuralNetwork.get_node_list()
        self.param_list = NeuralNetwork.get_param_list()
        assert self.node_list, "Node list is empty"
        assert self.param_list, "Parameter list is empty"

    def _build_network(self):
        self.X = Parameter(is_placeholder=True, name="X")
        self.Y = Parameter(is_placeholder=True, name="Y")

        prev_inp = self.X
        prev_shape = self.inp_shape
        for idx, num_units in enumerate(self.hidden_layers):
            W = Parameter(np.random.random((prev_shape, num_units)),
                          requires_grad=True,
                          name="W" + str(idx + 1))
            B = Parameter(np.zeros((num_units, )),
                          requires_grad=True,
                          name="B" + str(idx + 1))
            if idx == len(self.hidden_layers) - 1:
                out = Node.Add(Node.Mul(prev_inp, W), B)
                out = Node.SoftmaxCrossEnt(out, self.Y)
            else:
                out = Node.ReLU(Node.Add(Node.Mul(prev_inp, W), B))
            prev_inp = out
            prev_shape = num_units
        self.loss = out

    def _forward(self):
        for node in self.node_list:
            node.forward()
        for param in self.param_list:
            print "Param: {}\n{}\n\n".format(param, param.get_data())

    def _backward(self):
        for node in self.node_list[::-1]:
            node.backward()

    def _update_params(self, lr):
        for param in self.param_list:
            param.update(lr)

    def _optimize(self):
        self._backward()
        self._update_params(self.lr)

    def fit(self, x, y):
        self.X.set_data_(x)
        self.Y.set_data_(y)
        self._forward()
        self._optimize()
        loss, pred = self.loss.get_output(), self.loss.get_pred()
        return loss.get_data(), pred.get_data()

    def predict(self, x):
        self.X.set_data_(x)
        pred = self._forward()
        return pred.get_data()

    @staticmethod
    def get_node_list():
        return NeuralNetwork.node_list[:]

    @staticmethod
    def add_node(node):
        NeuralNetwork.node_list.append(node)

    @staticmethod
    def clear_node_list():
        NeuralNetwork.node_list = []

    @staticmethod
    def get_param_list():
        return NeuralNetwork.param_list[:]

    @staticmethod
    def add_param(node):
        NeuralNetwork.param_list.append(node)

    @staticmethod
    def clear_param_list():
        NeuralNetwork.param_list = []