Exemple #1
0
def main():
    net = "conv1"
    spec = "../test_cases/conv1/img0_0.00200.txt"
    net_name = net
    with open(spec, 'r') as f:
        lines = [line[:-1] for line in f.readlines()]
        true_label = int(lines[0])
        pixel_values = [float(line) for line in lines[1:]]
        eps = float(spec[:-4].split('/')[-1].split('_')[-1])

    if net == 'fc1':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
    elif net == 'fc2':
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
    elif net == 'fc3':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif net == 'fc4':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [100, 100, 100, 10]).to(DEVICE)
    elif net == 'fc5':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [400, 200, 100, 100, 10]).to(DEVICE)
    elif net == 'conv1':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10],
                   10).to(DEVICE)
    elif net == 'conv2':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 10], 10).to(DEVICE)
    elif net == 'conv3':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [150, 10],
                   10).to(DEVICE)
    elif net == 'conv4':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 100, 10], 10).to(DEVICE)
    elif net == 'conv5':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [100, 100, 10],
                   10).to(DEVICE)

    net.load_state_dict(
        torch.load('../mnist_nets/%s.pt' % net_name,
                   map_location=torch.device(DEVICE)))

    inputs = torch.FloatTensor(pixel_values).view(1, 1, INPUT_SIZE,
                                                  INPUT_SIZE).to(DEVICE)
    outs = net(inputs)
    pred_label = outs.max(dim=1)[1].item()
    assert pred_label == true_label

    if analyze(net, inputs, eps, true_label):
        print('verified')
    else:
        print('not verified')
def main(net, spec, verbose=True):

    with open(spec, 'r') as f:
        lines = [line[:-1] for line in f.readlines()]
        true_label = int(lines[0])
        pixel_values = [float(line) for line in lines[1:]]
        eps = float(spec[:-4].split('/')[-1].split('_')[-1])

    if net == 'fc1':
        nn = FullyConnected(DEVICE, INPUT_SIZE, [50, 10]).to(DEVICE)
    elif net == 'fc2':
        nn = FullyConnected(DEVICE, INPUT_SIZE, [100, 50, 10]).to(DEVICE)
    elif net == 'fc3':
        nn = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif net == 'fc4':
        nn = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 50, 10]).to(DEVICE)
    elif net == 'fc5':
        nn = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 100, 10]).to(DEVICE)
    elif net == 'fc6':
        nn = FullyConnected(DEVICE, INPUT_SIZE,
                            [100, 100, 100, 100, 10]).to(DEVICE)
    elif net == 'fc7':
        nn = FullyConnected(DEVICE, INPUT_SIZE,
                            [100, 100, 100, 100, 100, 10]).to(DEVICE)
    elif net == 'conv1':
        nn = Conv(DEVICE, INPUT_SIZE, [(16, 3, 2, 1)], [100, 10],
                  10).to(DEVICE)
    elif net == 'conv2':
        nn = Conv(DEVICE, INPUT_SIZE, [(16, 4, 2, 1), (32, 4, 2, 1)],
                  [100, 10], 10).to(DEVICE)
    elif net == 'conv3':
        nn = Conv(DEVICE, INPUT_SIZE, [(16, 4, 2, 1), (64, 4, 2, 1)],
                  [100, 100, 10], 10).to(DEVICE)
    else:
        assert False

    nn.load_state_dict(
        torch.load('../mnist_nets/%s.pt' % net,
                   map_location=torch.device(DEVICE)))

    inputs = torch.FloatTensor(pixel_values).view(1, 1, INPUT_SIZE,
                                                  INPUT_SIZE).to(DEVICE)
    outs = nn(inputs)
    pred_label = outs.max(dim=1)[1].item()
    assert pred_label == true_label

    res = analyze(net, nn, inputs, eps, true_label)
    out = "verified" if res else 'not verified'
    if verbose:
        print(out)
    return out
Exemple #3
0
def main():
    with open(args.spec, "r") as f:
        lines = [line[:-1] for line in f.readlines()]
        true_label = int(lines[0])
        pixel_values = [float(line) for line in lines[1:]]
        eps = float(args.spec[:-4].split("/")[-1].split("_")[-1])

    if args.net == "fc1":
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
    elif args.net == "fc2":
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
    elif args.net == "fc3":
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif args.net == "fc4":
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [100, 100, 100, 10]).to(DEVICE)
    elif args.net == "fc5":
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [400, 200, 100, 100, 10]).to(DEVICE)
    elif args.net == "conv1":
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10],
                   10).to(DEVICE)
    elif args.net == "conv2":
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 10], 10).to(DEVICE)
    elif args.net == "conv3":
        net = Conv(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [150, 10],
                   10).to(DEVICE)
    elif args.net == "conv4":
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 100, 10], 10).to(DEVICE)
    elif args.net == "conv5":
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [100, 100, 10],
                   10).to(DEVICE)

    net.load_state_dict(
        torch.load("../mnist_nets/%s.pt" % args.net,
                   map_location=torch.device(DEVICE)))

    inputs = torch.FloatTensor(pixel_values).view(1, 1, INPUT_SIZE,
                                                  INPUT_SIZE).to(DEVICE)
    outs = net(inputs)
    pred_label = outs.max(dim=1)[1].item()
    assert pred_label == true_label

    if analyze(net, inputs, eps, true_label):
        print("verified")
    else:
        print("not verified")
Exemple #4
0
def main():
    parser = argparse.ArgumentParser(description='Neural network verification using DeepZ relaxation')
    parser.add_argument('--net',
                        type=str,
                        choices=['fc1', 'fc2', 'fc3', 'fc4', 'fc5', 'fc6', 'fc7', 'conv1', 'conv2', 'conv3'],
                        required=True,
                        help='Neural network architecture which is supposed to be verified.')
    parser.add_argument('--spec', type=str, required=True, help='Test case to verify.')
    args = parser.parse_args()

    with open(args.spec, 'r') as f:
        lines = [line[:-1] for line in f.readlines()]
        true_label = int(lines[0])
        pixel_values = [float(line) for line in lines[1:]]
        eps = float(args.spec[:-4].split('/')[-1].split('_')[-1])

    if args.net == 'fc1':
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 10]).to(DEVICE)
    elif args.net == 'fc2':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 50, 10]).to(DEVICE)
    elif args.net == 'fc3':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif args.net == 'fc4':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 50, 10]).to(DEVICE)
    elif args.net == 'fc5':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 100, 10]).to(DEVICE)
    elif args.net == 'fc6':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 100, 100, 10]).to(DEVICE)
    elif args.net == 'fc7':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 100, 100, 100, 10]).to(DEVICE)
    elif args.net == 'conv1':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 2, 1)], [100, 10], 10).to(DEVICE)
    elif args.net == 'conv2':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 4, 2, 1), (32, 4, 2, 1)], [100, 10], 10).to(DEVICE)
    elif args.net == 'conv3':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 4, 2, 1), (64, 4, 2, 1)], [100, 100, 10], 10).to(DEVICE)
    else:
        assert False

    net.load_state_dict(torch.load('./mnist_nets/%s.pt' % args.net, map_location=torch.device(DEVICE)))

    inputs = torch.FloatTensor(pixel_values).view(1, 1, INPUT_SIZE, INPUT_SIZE).to(DEVICE)
    outs = net(inputs)
    pred_label = outs.max(dim=1)[1].item()
    assert pred_label == true_label

    if analyze(net, inputs, eps, true_label):
        print('verified')
    else:
        print('not verified')
def core_analysis(net_str,true_label,pixel_values,eps,VERBOSE=0):
    start=time.time()

    if net_str == 'fc1':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
    elif net_str == 'fc2':
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
    elif net_str == 'fc3':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif net_str == 'fc4':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 100, 10]).to(DEVICE)
    elif net_str == 'fc5':
        net = FullyConnected(DEVICE, INPUT_SIZE, [400, 200, 100, 100, 10]).to(DEVICE)
    elif net_str == 'conv1':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10], 10).to(DEVICE)
    elif net_str == 'conv2':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)], [100, 10], 10).to(DEVICE)
    elif net_str == 'conv3':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1), (64, 4, 2, 1)], [150, 10], 10).to(DEVICE)
    elif net_str == 'conv4':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)], [100, 100, 10], 10).to(DEVICE)
    elif net_str == 'conv5':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1), (64, 4, 2, 1)], [100, 100, 10], 10).to(DEVICE)

    net.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__),'../mnist_nets/%s.pt' % net_str), map_location=torch.device(DEVICE)))

    inputs = torch.FloatTensor(np.array(pixel_values).reshape((1, 1, INPUT_SIZE, INPUT_SIZE))).to(DEVICE)
    outs = net(inputs)
    pred_label = outs.max(dim=1)[1].item()
    assert pred_label == true_label

    verified=analyze(net, inputs, eps, true_label,verbose=VERBOSE)

    if verified:
        print('verified')
    else:
        print('not verified')

    end = time.time()

    if VERBOSE>0:
        print("time passed: %f s" % (end - start))

        #! TODO remove adverserial example search
        Adv_exmp_found=runPGD(net, inputs, eps, true_label)
        print("Adverserial example found : %s" %Adv_exmp_found)

    return verified,start-end
Exemple #6
0
def main():
    net = Linear().cuda()
    summary(net, (128, ))
    torch.save(net.state_dict(), 'targets/linear.pth')
    net = Conv().cuda()
    summary(net, (3, 256, 256))
    torch.save(net.state_dict(), 'targets/conv.pth')
    net = RNN().cuda()
    print(net.state_dict())
    print(net(torch.randn(16, 32, 256).cuda()).shape)
    torch.save(net.state_dict(), 'targets/rnn.pth')
def main():

    parser = argparse.ArgumentParser(
        description='Neural network verification using DeepZ relaxation')
    parser.add_argument('--net',
                        type=str,
                        choices=[
                            'fc1', 'fc2', 'fc3', 'fc4', 'fc5', 'conv1',
                            'conv2', 'conv3', 'conv4', 'conv5'
                        ],
                        required=True,
                        help='Neural network to verify.')
    parser.add_argument('--spec',
                        type=str,
                        required=True,
                        help='Test case to verify.')
    args = parser.parse_args()

    with open(args.spec, 'r') as f:
        lines = [line[:-1] for line in f.readlines()]
        true_label = int(lines[0])
        pixel_values = [float(line) for line in lines[1:]]
        eps = float(args.spec[:-4].split('/')[-1].split('_')[-1])

    if args.net == 'fc1':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
        net_z = FullyConnectedZ(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
    elif args.net == 'fc2':
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
        net_z = FullyConnectedZ(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
    elif args.net == 'fc3':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
        net_z = FullyConnectedZ(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif args.net == 'fc4':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [100, 100, 100, 10]).to(DEVICE)
        net_z = FullyConnectedZ(DEVICE, INPUT_SIZE,
                                [100, 100, 100, 10]).to(DEVICE)
    elif args.net == 'fc5':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [400, 200, 100, 100, 10]).to(DEVICE)
        net_z = FullyConnectedZ(DEVICE, INPUT_SIZE,
                                [400, 200, 100, 100, 10]).to(DEVICE)
    elif args.net == 'conv1':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10],
                   10).to(DEVICE)
        net_z = ConvZ(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10],
                      10).to(DEVICE)
    elif args.net == 'conv2':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 10], 10).to(DEVICE)
        net_z = ConvZ(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                      [100, 10], 10).to(DEVICE)
    elif args.net == 'conv3':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [150, 10],
                   10).to(DEVICE)
        net_z = ConvZ(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1),
                                           (64, 4, 2, 1)], [150, 10],
                      10).to(DEVICE)
    elif args.net == 'conv4':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 100, 10], 10).to(DEVICE)
        net_z = ConvZ(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                      [100, 100, 10], 10).to(DEVICE)
    elif args.net == 'conv5':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [100, 100, 10],
                   10).to(DEVICE)
        net_z = ConvZ(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1),
                                           (64, 4, 2, 1)], [100, 100, 10],
                      10).to(DEVICE)

    net.load_state_dict(
        torch.load('../mnist_nets/%s.pt' % args.net,
                   map_location=torch.device(DEVICE)))
    net_z.load_state_dict(
        torch.load('../mnist_nets/%s.pt' % args.net,
                   map_location=torch.device(DEVICE)))

    inputs = torch.FloatTensor(pixel_values).view(1, 1, INPUT_SIZE,
                                                  INPUT_SIZE).to(DEVICE)

    low_params = [min(eps, p) for i, p in enumerate(pixel_values)]
    high_params = [min(eps, 1 - p) for i, p in enumerate(pixel_values)]

    outs = net(inputs)
    pred_label = outs.max(dim=1)[1].item()
    assert pred_label == true_label

    if args.net.startswith('fc'):
        lr = 0.1
    else:
        lr = 0.01

    if eps == 0:
        print('verified')
    else:
        if analyze(net_z, inputs, low_params, high_params, true_label, lr,
                   eps):
            print('verified')

        else:
            print('not verified')
Exemple #8
0
        continue
    strnet = net
    if net == 'fc1':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
    elif net == 'fc2':
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
    elif net == 'fc3':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif net == 'fc4':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [100, 100, 100, 10]).to(DEVICE)
    elif net == 'fc5':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [400, 200, 100, 100, 10]).to(DEVICE)
    elif net == 'conv1':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10],
                   10).to(DEVICE)
    elif net == 'conv2':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 10], 10).to(DEVICE)
    elif net == 'conv3':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [150, 10],
                   10).to(DEVICE)
    elif net == 'conv4':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 100, 10], 10).to(DEVICE)
    elif net == 'conv5':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [100, 100, 10],
                   10).to(DEVICE)
Exemple #9
0
def main():
    parser = argparse.ArgumentParser(description='Neural network verification using DeepZ relaxation')
    parser.add_argument('--net',
                        type=str,
                        choices=['fc1', 'fc2', 'fc3', 'fc4', 'fc5', 'conv1', 'conv2', 'conv3', 'conv4', 'conv5'],
                        required=True,
                        help='Neural network to verify.')
    args = parser.parse_args()

    if args.net == 'fc1':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
    elif args.net == 'fc2':
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
    elif args.net == 'fc3':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif args.net == 'fc4':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 100, 10]).to(DEVICE)
    elif args.net == 'fc5':
        net = FullyConnected(DEVICE, INPUT_SIZE, [400, 200, 100, 100, 10]).to(DEVICE)
    elif args.net == 'conv1':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10], 10).to(DEVICE)
    elif args.net == 'conv2':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)], [100, 10], 10).to(DEVICE)
    elif args.net == 'conv3':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1), (64, 4, 2, 1)], [150, 10], 10).to(DEVICE)
    elif args.net == 'conv4':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)], [100, 100, 10], 10).to(DEVICE)
    elif args.net == 'conv5':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1), (64, 4, 2, 1)], [100, 100, 10], 10).to(DEVICE)

    net.load_state_dict(torch.load('../mnist_nets/%s.pt' % args.net, map_location=torch.device(DEVICE)))
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('../data/', train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor()
                                   ])),
        batch_size=1, shuffle=True)
    examples = enumerate(test_loader)
    eps = 0.02
    fnet = LinfPGDAttack(net, epsilon=eps, k=40)
    point = 0
    for batch_idx, (x, y) in examples:
        X_adv = fnet.perturb(x.numpy(), y.numpy())
        X = torch.from_numpy(X_adv)

        if net(X).max(dim=1)[1].item() == y.item():
            out = 'verified'
        else:
            out = 'not verified'
        print(out)

        if analyze(net, X, eps, y.item()):
            pred = 'verified'
        else:
            pred = 'not verified'
        print(pred)
        print('-----------')

        if out == pred:
            point += 1
        if out == 'not verified' and pred == 'verified':
            point -= 2
    print('marks ', point)
Exemple #10
0
def main():
    parser = argparse.ArgumentParser(
        description='Neural network verification using DeepZ relaxation')
    parser.add_argument('--net',
                        type=str,
                        choices=[
                            'fc1', 'fc2', 'fc3', 'fc4', 'fc5', 'conv1',
                            'conv2', 'conv3', 'conv4', 'conv5'
                        ],
                        required=True,
                        help='Neural network to verify.')

    parser.add_argument('--spec',
                        type=str,
                        required=True,
                        help='Test case to verify.')
    args = parser.parse_args()
    with open(args.spec, 'r') as f:
        lines = [line[:-1] for line in f.readlines()]
        true_label = int(lines[0])
        pixel_values = [float(line) for line in lines[1:]]
        eps = float(args.spec[:-4].split('/')[-1].split('_')[-1])

    if args.net == 'fc1':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 10]).to(DEVICE)
    elif args.net == 'fc2':
        net = FullyConnected(DEVICE, INPUT_SIZE, [50, 50, 10]).to(DEVICE)
    elif args.net == 'fc3':
        net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)
    elif args.net == 'fc4':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [100, 100, 100, 10]).to(DEVICE)
    elif args.net == 'fc5':
        net = FullyConnected(DEVICE, INPUT_SIZE,
                             [400, 200, 100, 100, 10]).to(DEVICE)
    elif args.net == 'conv1':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1)], [100, 10],
                   10).to(DEVICE)
    elif args.net == 'conv2':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 10], 10).to(DEVICE)
    elif args.net == 'conv3':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [150, 10],
                   10).to(DEVICE)
    elif args.net == 'conv4':
        net = Conv(DEVICE, INPUT_SIZE, [(32, 4, 2, 1), (64, 4, 2, 1)],
                   [100, 100, 10], 10).to(DEVICE)
    elif args.net == 'conv5':
        net = Conv(DEVICE, INPUT_SIZE, [(16, 3, 1, 1), (32, 4, 2, 1),
                                        (64, 4, 2, 1)], [100, 100, 10],
                   10).to(DEVICE)

    net.load_state_dict(
        torch.load('../mnist_nets/%s.pt' % args.net,
                   map_location=torch.device(DEVICE)), )

    inputs = torch.FloatTensor(pixel_values).view(1, 1, INPUT_SIZE,
                                                  INPUT_SIZE).to(DEVICE)
    outs = net(inputs)
    pred_label = outs.max(dim=1)[1].item()
    assert pred_label == true_label

    try:
        # If there are no learnable params, calling backwards() will throw an exception, which means initial bounds are not verifiable and cannot be optimized further
        if analyze(net, inputs, eps, true_label):
            print('verified')
        else:
            print('not verified')
    except:
        print('not verified')