def __init__(self):
        super(G_M5, self).__init__()
        self.conv1 = nn.Conv2d(1, 5, 5)
        self.conv2 = nn.Conv2d(5, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.flatten = Flatten()
        self.relu = nn.ReLU()

        # Calculate the output size of the Flatten Layer
        # conv_dimensions(c_in, h_in, w_in, c_out, stride, pad, k_height, k_width)
        c_out, h_out, w_out = conv_dimensions(1, 32, 32, 5, 1, 0, 5, 5)
        c_out, h_out, w_out = pool_dimensions(5, h_out, w_out, 2)
        c_out, h_out, w_out = conv_dimensions(5, h_out, w_out, 6, 1, 0, 5, 5)
        c_out, h_out, w_out = pool_dimensions(6, h_out, w_out, 2)
        flatten_size = c_out * h_out * w_out

        self.fc1 = nn.Linear(flatten_size, 50)
        self.fc2 = nn.Linear(50, 10)
    def __init__(self):
        super(FC_M3, self).__init__()
        self.conv = nn.Conv2d(3, 3, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.flatten = Flatten()
        self.relu = nn.ReLU()

        # Calculate the output size of the Flatten Layer
        # conv_dimensions(c_in, h_in, w_in, c_out, stride, pad, k_height, k_width)
        c_out, h_out, w_out = conv_dimensions(3, 32, 32, 3, 1, 0, 5, 5)
        c_out, h_out, w_out = pool_dimensions(3, h_out, w_out, 2)
        flatten_size = c_out * h_out * w_out

        self.fc = nn.Linear(flatten_size, 10)
    def __init__(self):
        super(G_M15, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5)
        self.conv2 = nn.Conv2d(30, 60, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.flatten = Flatten()
        self.relu = nn.ReLU()

        # Calculating total memory usage:
        print(
            "Model weights (units are weights, where one weight is 4 bytes):")
        print(
            "Note: These numbers are already doubled to account for gradient space"
        )
        print(1 * 30 * 5 * 5 * 2)
        print(30 * 60 * 5 * 5 * 2)

        print("\nIntermediate Layers:")
        # Calculate the output size of the Flatten Layer
        # conv_dimensions(c_in, h_in, w_in, c_out, stride, pad, k_height, k_width)
        print(1 * 32 * 32)
        c_out, h_out, w_out = conv_dimensions(1, 32, 32, 30, 1, 0, 5, 5)
        print(c_out * h_out * w_out)
        c_out, h_out, w_out = pool_dimensions(30, h_out, w_out, 2)
        print(c_out * h_out * w_out)
        c_out, h_out, w_out = conv_dimensions(30, h_out, w_out, 60, 1, 0, 5, 5)
        print(c_out * h_out * w_out)
        c_out, h_out, w_out = pool_dimensions(60, h_out, w_out, 2)
        print(c_out * h_out * w_out)
        flatten_size = c_out * h_out * w_out
        print(500)

        print("\nMore model weights:")
        self.fc1 = nn.Linear(flatten_size, 500)
        print(flatten_size * 500 * 2)
        self.fc2 = nn.Linear(500, 10)
        print(500 * 10 * 2)