Ejemplo n.º 1
0
def computeQUI(distSXY,
               eps=1e-7,
               DEBUG=False,
               IPmethod="GIS",
               maxiter=100000,
               maxiter2=100000):
    '''
    Compute an optimizer Q

    distSXY : A joint distribution of three variables (as a dit.Distribution).
    eps     : The precision of the outer loop.  The precision of the inner loop will be eps / (20 |S|).
    DEBUG   : Print output for debugging.

    The computation is carried out using computeQUI_numpy
    '''
    # Prepare distributions
    QSXYd = distSXY.copy()  # make a copy, to not overwrite the argument
    QSXYd.set_rv_names('SXY')
    QSXYd.make_dense()  # the set of outcomes should be Cartesian

    # collect state spaces / cardinalities
    suppS = QSXYd.alphabet[0]
    nS = len(suppS)
    suppX = QSXYd.alphabet[1]
    nX = len(suppX)
    suppY = QSXYd.alphabet[2]
    nY = len(suppY)
    # to do: take relevant subset of suppS

    samplespace = QSXYd.outcomes
    QSXYa = QSXYd.pmf.reshape(nS, nX, nY)

    PS = QSXYd.marginal('S').pmf.reshape(
        nS, 1)  # PS is a column vector, for convenience

    PXgSa = numpy.array(
        list(map(lambda x: x.pmf,
                 QSXYd.condition_on('S', rvs='X')[1]))).transpose()
    PYgSa = numpy.array(
        list(map(lambda x: x.pmf,
                 QSXYd.condition_on('S', rvs='Y')[1]))).transpose()

    # print(1e-6 * numpy.ones((nX, nY)) / (nX * nY) + (1 - 1e-6) * QSXYd.marginal('X').pmf.reshape(nX, 1) * QSXYd.marginal('Y').pmf.reshape(1, nY))
    QSXYa = computeQUI_numpy(PXgSa,
                             PYgSa,
                             PS,
                             eps=eps,
                             DEBUG=DEBUG,
                             IPmethod=IPmethod,
                             maxiter=maxiter,
                             maxiter2=maxiter2).reshape(-1)

    QSXYd = dit.Distribution(samplespace, QSXYa)
    QSXYd.set_rv_names('SXY')
    return QSXYd
Ejemplo n.º 2
0
CI = dit.shannon.mutual_information(d, 'S', 'XY') - UIX - UIY - SI
CIQ = dit.shannon.mutual_information(Q, 'S', 'XY') - UIX - UIY - SI
print("PID(R=", SI, ", U0=", UIX, ", U1=", UIY, ", S=", CI, ")", sep='')
print(Q)

# pid
print("\nRunning dit pid_broja...")
start_time = time.time()
pid = dit.algorithms.pid_broja(d, ['X', 'Y'], 'S')
print("--- %s seconds ---" % (time.time() - start_time))
print(pid)

# Example d12 without using dit
print("\nRunning admUI without dit on example d12...")
ns = 2
nx = 3
ny = 2
P = np.array([
    0.0869196091623, 0, 0.0218631235533, 0, 0.133963681059, 0, 0,
    0.164924698739, 0.429533105427, 0, 0, 0.16279578206
])
Psxy = np.reshape(P, [ns, nx, ny])
Psx = np.sum(Psxy, axis=2)
Psy = np.sum(Psxy, axis=1)
PS = np.sum(Psx, axis=1)
PXgS = np.divide(np.transpose(Psx), PS)
PYgS = np.divide(np.transpose(Psy), PS)
PS = PS.reshape((-1, 1))
Q = computeQUI_numpy(PXgS, PYgS, PS)
print("p=\n", Q)