def test_last_axis_precision_recall_curve():
    dims = [2, 3, 4, 5]
    for ndim in [1, 2, 3, 4]:
        if ndim == 1: 
            d = 10
        else:
            d = dims[:ndim]
        a = np.random.binomial(n=1, p=.5, size=d).astype('float32')
        p = np.random.random(d)
        prec, recall, _ = tmetrics.last_axis_precision_recall_curve(a, p)
        print 'd: {}'.format(d)
        print prec
        print recall
        print _
def test_precision_recall_curves_all_dims(n_iter=1):
    dims = [(50,), (50, 100), (50, 30, 40), (5, 10, 5, 3)]
    int_types = [T.fvector, T.fmatrix, T.ftensor3, T.ftensor4]
    float_types = [T.fvector, T.fmatrix, T.ftensor3, T.ftensor4]
    for d, int_type, float_type in zip(dims[1:], int_types[1:], float_types[1:]):
        yt = int_type('yt')
        yp = float_type('yp')
        a = float_type('a')
        b = float_type('b')
        gpu_auc = tmetrics.auc(a, b)
        get_auc = theano.function([a, b], gpu_auc)
        p_expr, r_expr, t_expr = tmetrics.precision_recall_curves(yt, yp)
        pr_auc = tmetrics.auc(r_expr, p_expr) 
        f = theano.function([yt, yp], [p_expr, r_expr, t_expr, pr_auc])
        for epoch in xrange(n_iter):
            true = np.random.binomial(n=1, p=.5, size=d).astype('float32')
            predicted = np.random.random((d)).astype('float32')
            precision, recall, thresh, avg = f(true, predicted)
            refp, refr, reft = tmetrics.last_axis_precision_recall_curve(true, predicted)
            try:
                assert np.allclose(precision, refp, equal_nan=True)
                assert np.allclose(recall, refr, equal_nan=True)
                assert np.allclose(thresh, reft, equal_nan=True)
                assert np.allclose(avg, get_auc(refr.astype('float32'), refp.astype('float32')), equal_nan=True)
            except:
                print true
                print predicted
                print 'precision'
                print precision
                print 'ref precision'
                print refp
                print 'recall'
                print recall
                print recall.shape
                print 'ref recall'
                print refr
                print refr.shape
                print thresh
                print reft
                print avg
                print get_auc(refr.astype('float32'), refp.astype('float32'))
                raise