def main():


    train_x, train_y, valid_x, valid_y, test_x, test_y = get_cifar10('./cifar-10-batches-py/')
    labels = unpickle('./cifar-10-batches-py/batches.meta')['label_names']

    train_x = train_x.astype(np.float32) / 255.0
    valid_x = valid_x.astype(np.float32) / 255.0
    test_x  = test_x.astype(np.float32) / 255.0


    num_epochs = args.epochs
    eta        = args.lr
    batch_size = args.batch_size

    # input 
    x = T.tensor4("x")
    y = T.ivector("y")
    
    # test values
    # x.tag.test_value = np.random.randn(6, 3, 32, 32).astype(np.float32)
    # y.tag.test_value = np.array([1,2,1,4,5]).astype(np.int32)
    # x.tag.test_value = x.tag.test_value / x.tag.test_value.max()

    # import ipdb; ipdb.set_trace()

    # network definition 
    conv1 = BinaryConv2D(input=x, num_filters=50, input_channels=3, size=3, strides=(1,1), padding=1,  name="conv1")
    act1  = Activation(input=conv1.output, activation="relu", name="act1")
    pool1 = Pool2D(input=act1.output, stride=(2,2), name="pool1")
    
    conv2 = BinaryConv2D(input=pool1.output, num_filters=100, input_channels=50, size=3, strides=(1,1), padding=1,  name="conv2")
    act2  = Activation(input=conv2.output, activation="relu", name="act2")
    pool2 = Pool2D(input=act2.output, stride=(2,2), name="pool2")

    conv3 = BinaryConv2D(input=pool2.output, num_filters=200, input_channels=100, size=3, strides=(1,1), padding=1,  name="conv3")
    act3  = Activation(input=conv3.output, activation="relu", name="act3")
    pool3 = Pool2D(input=act3.output, stride=(2,2), name="pool3")

    flat  = Flatten(input=pool3.output)
    fc1   = BinaryDense(input=flat.output, n_in=200*4*4, n_out=500, name="fc1")
    act4  = Activation(input=fc1.output, activation="relu", name="act4")
    fc2   = BinaryDense(input=act4.output, n_in=500, n_out=10, name="fc2")
    softmax  = Activation(input=fc2.output, activation="softmax", name="softmax")

    # loss
    xent     = T.nnet.nnet.categorical_crossentropy(softmax.output, y)
    cost     = xent.mean()

    # errors 
    y_pred   = T.argmax(softmax.output, axis=1)
    errors   = T.mean(T.neq(y, y_pred))

    # updates + clipping (+-1) 
    params   = conv1.params + conv2.params + conv3.params + fc1.params + fc2.params 
    params_bin = conv1.params_bin + conv2.params_bin + conv3.params_bin + fc1.params_bin + fc2.params_bin
    grads    = [T.grad(cost, param) for param in params_bin] # calculate grad w.r.t binary parameters

    updates  = []
    for p,g in zip(params, grads):
        updates.append(
                (p, clip_weights(p - eta*g)) #sgd + clipping update
            )

    # compiling train, predict and test fxns     
    train   = theano.function(
                inputs  = [x,y],
                outputs = cost,
                updates = updates
            )
    predict = theano.function(
                inputs  = [x],
                outputs = y_pred
            )
    test    = theano.function(
                inputs  = [x,y],
                outputs = errors
            )

    # train 
    checkpoint = ModelCheckpoint(folder="snapshots")
    logger = Logger("logs/{}".format(time()))
    for epoch in range(num_epochs):
        
        print "Epoch: ", epoch
        print "LR: ", eta
        epoch_hist = {"loss": []}
        
        t = tqdm(range(0, len(train_x), batch_size))
        for lower in t:
            upper = min(len(train_x), lower + batch_size)
            loss  = train(train_x[lower:upper], train_y[lower:upper].astype(np.int32))     
            t.set_postfix(loss="{:.2f}".format(float(loss)))
            epoch_hist["loss"].append(loss.astype(np.float32))
        
        # epoch loss
        average_loss = sum(epoch_hist["loss"])/len(epoch_hist["loss"])         
        t.set_postfix(loss="{:.2f}".format(float(average_loss)))
        logger.log_scalar(
                tag="Training Loss", 
                value= average_loss,
                step=epoch
                )

        # validation accuracy 
        val_acc  =  1.0 - test(valid_x, valid_y.astype(np.int32))
        print "Validation Accuracy: ", val_acc
        logger.log_scalar(
                tag="Validation Accuracy", 
                value= val_acc,
                step=epoch
                )  
        checkpoint.check(val_acc, params)

    # Report Results on test set (w/ best val acc file)
    best_val_acc_filename = checkpoint.best_val_acc_filename
    print "Using ", best_val_acc_filename, " to calculate best test acc."
    load_model(path=best_val_acc_filename, params=params)
    test_acc = 1.0 - test(test_x, test_y.astype(np.int32))    
    print "Test accuracy: ",test_acc
def main():

    train_x, train_y, valid_x, valid_y, test_x, test_y = get_mnist()

    num_epochs = args.epochs
    eta = args.lr
    batch_size = args.batch_size

    # input
    x = T.matrix("x")
    y = T.ivector("y")

    #x.tag.test_value = np.random.randn(3, 784).astype("float32")
    #y.tag.test_value = np.array([1,2,3])
    #drop_switch.tag.test_value = 0
    #import ipdb; ipdb.set_trace()
    hidden_1 = BinaryDense(input=x, n_in=784, n_out=2048, name="hidden_1")
    act_1 = Activation(input=hidden_1.output, activation="relu", name="act_1")
    hidden_2 = BinaryDense(input=act_1.output,
                           n_in=2048,
                           n_out=2048,
                           name="hidden_2")
    act_2 = Activation(input=hidden_2.output, activation="relu", name="act_2")
    hidden_3 = BinaryDense(input=act_2.output,
                           n_in=2048,
                           n_out=2048,
                           name="hidden_3")
    act_3 = Activation(input=hidden_3.output, activation="relu", name="act_3")
    output = BinaryDense(input=act_3.output,
                         n_in=2048,
                         n_out=10,
                         name="output")
    softmax = Activation(input=output.output,
                         activation="softmax",
                         name="softmax")

    # loss
    xent = T.nnet.nnet.categorical_crossentropy(softmax.output, y)
    cost = xent.mean()

    # errors
    y_pred = T.argmax(softmax.output, axis=1)
    errors = T.mean(T.neq(y, y_pred))

    # updates + clipping (+-1)
    params_bin = hidden_1.params_bin + hidden_2.params_bin + hidden_3.params_bin
    params = hidden_1.params + hidden_2.params + hidden_3.params
    grads = [T.grad(cost, param)
             for param in params_bin]  # calculate grad w.r.t binary parameters
    updates = []
    for p, g in zip(
            params, grads
    ):  # gradient update on full precision weights (NOT binarized wts)
        updates.append((p, clip_weights(p - eta * g))  #sgd + clipping update
                       )

    # compiling train, predict and test fxns
    train = theano.function(inputs=[x, y], outputs=cost, updates=updates)
    predict = theano.function(inputs=[x], outputs=y_pred)
    test = theano.function(inputs=[x, y], outputs=errors)

    # train
    checkpoint = ModelCheckpoint(folder="snapshots")
    logger = Logger("logs/{}".format(time()))
    for epoch in range(num_epochs):

        print "Epoch: ", epoch
        print "LR: ", eta
        epoch_hist = {"loss": []}

        t = tqdm(range(0, len(train_x), batch_size))
        for lower in t:
            upper = min(len(train_x), lower + batch_size)
            loss = train(train_x[lower:upper],
                         train_y[lower:upper].astype(np.int32))
            t.set_postfix(loss="{:.2f}".format(float(loss)))
            epoch_hist["loss"].append(loss.astype(np.float32))

        # epoch loss
        average_loss = sum(epoch_hist["loss"]) / len(epoch_hist["loss"])
        t.set_postfix(loss="{:.2f}".format(float(average_loss)))
        logger.log_scalar(tag="Training Loss", value=average_loss, step=epoch)

        # validation accuracy
        val_acc = 1.0 - test(valid_x, valid_y.astype(np.int32))
        print "Validation Accuracy: ", val_acc
        logger.log_scalar(tag="Validation Accuracy", value=val_acc, step=epoch)
        checkpoint.check(val_acc, params)

    # Report Results on test set
    best_val_acc_filename = checkpoint.best_val_acc_filename
    print "Using ", best_val_acc_filename, " to calculate best test acc."
    load_model(path=best_val_acc_filename, params=params)
    test_acc = 1.0 - test(test_x, test_y.astype(np.int32))
    print "Test accuracy: ", test_acc