class ReLU(Node):
    def __init__(self, A):

        self.node_name = "ReLU{}".format(rand_str())

        if type(A) is not Parameter:
            A = A.get_output()
        self.A = A

        NeuralNetwork.NeuralNetwork.add_node(self)
        NeuralNetwork.NeuralNetwork.add_param(self.A)

        self.out = Parameter(is_placeholder=True, name=self.node_name + "_out")

    def forward(self):
        res = relu(self.A.get_data())
        self.out.set_data_(res)
        return self.out

    def get_output(self):
        return self.out

    def backward(self):
        inc_grad = self.out.get_grad()
        local_grad = relu_grad(self.A.get_data())
        self.A.set_grad_(inc_grad * local_grad)
        return

    def __str__(self):
        return self.node_name
class Mul(Node):
    def __init__(self, A, B):

        self.node_name = "Mul{}".format(rand_str())

        if type(A) is not Parameter:
            A = A.get_output()
        self.A = A

        if type(B) is not Parameter:
            B = B.get_output()
        self.B = B

        NeuralNetwork.NeuralNetwork.add_node(self)
        NeuralNetwork.NeuralNetwork.add_param(self.A)
        NeuralNetwork.NeuralNetwork.add_param(self.B)

        self.out = Parameter(is_placeholder=True, name=self.node_name + "_out")

    def forward(self):
        res = self.A.get_data().dot(self.B.get_data())
        self.out.set_data_(res)
        return self.out

    def get_output(self):
        return self.out

    def backward(self):
        inc_grad = self.out.get_grad()
        self.A.set_grad_(inc_grad.dot(self.B.get_data().T))
        self.B.set_grad_(self.A.get_data().T.dot(inc_grad))
        return

    def __str__(self):
        return self.node_name