def getPolicy_NNet(pra, extent, vown, vint=0., nbin=250, ra=-1, bounds=[]): # Load network of given pRA into Python nnet_file_name = "networks/VertCAS_pra%02d_v4_45HU_200.nnet" % (pra + 1) net = NNet(nnet_file_name) tauMin, tauMax, hMin, hMax = extent # Get safeable bounds possAdv = getPossibleAdvisories(pra) # Generate a heat map using the bounds for each set # Use meshgrid to define array of all inputs to network hVec = np.linspace(hMin, hMax, nbin) tauVec = np.linspace(tauMin, tauMax, nbin) hMesh, tauMesh = np.meshgrid(hVec, tauVec) hMesh = hMesh.reshape(nbin**2, 1) tauMesh = tauMesh.reshape(nbin**2, 1) vownMesh = np.ones((nbin**2, 1)) * vown vintMesh = np.ones((nbin**2, 1)) * vint netIn = np.concatenate((hMesh, vownMesh, vintMesh, tauMesh), axis=1) # Evaluate network on all inputs netOut = net.evaluate_network_multiple(netIn) # Convert outputs to best advisories bestAdv = np.argmax(netOut, axis=1) heatMap = np.array([possAdv[i] for i in bestAdv]) # Highlight SAT points if bounds given if len(bounds) > 0: for i in range(nbin**2): if heatMap[i] == ra and satisfiesBounds(bounds, netIn[i, 0], netIn[i, 3]): heatMap[i] = 10 # Reshape the map and flip around the axes for plotting purposes heatMap = heatMap.reshape((nbin, nbin)).T[::-1] return heatMap