示例#1
0
def getPolicy_NNet(pra, extent, vown, vint=0., nbin=250, ra=-1, bounds=[]):

    # Load network of given pRA into Python
    nnet_file_name = "networks/VertCAS_pra%02d_v4_45HU_200.nnet" % (pra + 1)
    net = NNet(nnet_file_name)

    tauMin, tauMax, hMin, hMax = extent

    # Get safeable bounds
    possAdv = getPossibleAdvisories(pra)

    # Generate a heat map using the bounds for each set
    # Use meshgrid to define array of all inputs to network
    hVec = np.linspace(hMin, hMax, nbin)
    tauVec = np.linspace(tauMin, tauMax, nbin)
    hMesh, tauMesh = np.meshgrid(hVec, tauVec)
    hMesh = hMesh.reshape(nbin**2, 1)
    tauMesh = tauMesh.reshape(nbin**2, 1)
    vownMesh = np.ones((nbin**2, 1)) * vown
    vintMesh = np.ones((nbin**2, 1)) * vint
    netIn = np.concatenate((hMesh, vownMesh, vintMesh, tauMesh), axis=1)

    # Evaluate network on all inputs
    netOut = net.evaluate_network_multiple(netIn)

    # Convert outputs to best advisories
    bestAdv = np.argmax(netOut, axis=1)
    heatMap = np.array([possAdv[i] for i in bestAdv])

    # Highlight SAT points if bounds given
    if len(bounds) > 0:
        for i in range(nbin**2):
            if heatMap[i] == ra and satisfiesBounds(bounds, netIn[i, 0],
                                                    netIn[i, 3]):
                heatMap[i] = 10

    # Reshape the map and flip around the axes for plotting purposes
    heatMap = heatMap.reshape((nbin, nbin)).T[::-1]
    return heatMap