def on_key(event):
    global fileList, targetList, fig, imgNum, brushSize
    global maskDir, maskImg
    global prevImg,   thisImg,   nextImg
    global prevAxImg, thisAxImg, nextAxImg
    global prevTarget, thisTarget, nextTarget
    global prevMin,   thisMin,   nextMin
    global prevMax,   thisMax,   nextMax
    global prevLabel, thisLabel, nextLabel

    # Handle brush sizing
    if event.key == '1':
        brushSize = 1
    elif event.key == '2':
        brushSize = 2
    elif event.key == '3':
        brushSize = 3
    elif event.key == '4':
        brushSize = 4
    elif event.key == '5':
        brushSize = 5
    elif event.key == '6':
        brushSize = 6

    # Increment the image number
    if event.key == 'right' or event.key == 'left':
        if event.key == 'right':
            #Advance to the next image
            imgNum += 1

            # Read in the new files
            prevImg = thisImg
            thisImg = nextImg
            nextImg = AstroImage(fileList[(imgNum + 1) % len(fileList)])

            # Update target info
            prevTarget = thisTarget
            thisTarget = nextTarget
            nextTarget = targetList[(imgNum + 1) % len(fileList)]

            # Compute new image display minima
            prevMin = thisMin
            thisMin = nextMin
            nextMin = np.median(nextImg.arr) - 0.25*np.std(nextImg.arr)

            # Compute new image display maxima
            prevMax = thisMax
            thisMax = nextMax
            nextMax = np.median(nextImg.arr) + 2*np.std(nextImg.arr)

        if event.key == 'left':
            #Move back to the previous image
            imgNum -= 1

            # Read in the new files
            nextImg = thisImg
            thisImg = prevImg
            prevImg = AstroImage(fileList[(imgNum - 1) % len(fileList)])

            # Update target info
            nextTarget = thisTarget
            thisTarget = prevTarget
            prevTarget = targetList[(imgNum - 1) % len(fileList)]

            # Compute new image display minima
            nextMin = thisMin
            thisMin = prevMin
            prevMin = np.median(prevImg.arr) - 0.25*np.std(prevImg.arr)

            # Compute new image display maxima
            nextMax = thisMax
            thisMax = prevMax
            prevMax = np.median(prevImg.arr) + 2*np.std(prevImg.arr)

        #*******************************
        # Update the displayed mask
        #*******************************

        # Check which mask files might be usable...
        prevMaskFile = os.path.join(maskDir,
            os.path.basename(prevImg.filename))
        thisMaskFile = os.path.join(maskDir,
            os.path.basename(thisImg.filename))
        nextMaskFile = os.path.join(maskDir,
            os.path.basename(nextImg.filename))
        if os.path.isfile(thisMaskFile):
            # If the mask for this file exists, use it
            print('using this mask: ',os.path.basename(thisMaskFile))
            maskImg = AstroImage(thisMaskFile)
        elif os.path.isfile(prevMaskFile) and (prevTarget == thisTarget):
            # Otherwise check for the mask for the previous file
            print('using previous mask: ',os.path.basename(prevMaskFile))
            maskImg = AstroImage(prevMaskFile)
        elif os.path.isfile(nextMaskFile) and (nextTarget == thisTarget):
            # Then check for the mask of the next file
            print('using next mask: ',os.path.basename(nextMaskFile))
            maskImg = AstroImage(nextMaskFile)
        else:
            # If none of those files exist, build a blank slate
            # Build a mask template (0 = not masked, 1 = masked)
            maskImg       = thisImg.copy()
            maskImg.filename = thisMaskFile
            maskImg.arr   = maskImg.arr.astype(np.int16) * np.int16(0)
            maskImg.dtype = np.byte
            maskImg.header['BITPIX'] = 16

        # Update contour plot (clear old lines redo contouring)
        axarr[1].collections = []
        axarr[1].contour(xx, yy, maskImg.arr, levels=[0.5], colors='white', alpha = 0.2)

        # Reassign image display limits
        prevAxImg.set_clim(vmin = prevMin, vmax = prevMax)
        thisAxImg.set_clim(vmin = thisMin, vmax = thisMax)
        nextAxImg.set_clim(vmin = nextMin, vmax = nextMax)

        # Display the new images
        prevAxImg.set_data(prevImg.arr)
        thisAxImg.set_data(thisImg.arr)
        nextAxImg.set_data(nextImg.arr)

        # Update the annotation
        axList = fig.get_axes()
        axList[1].set_title(os.path.basename(thisImg.filename))

        prevStr   = (str(prevImg.header['OBJECT']) + '\n' +
                     str(prevImg.header['FILTNME3'] + '\n' +
                     str(prevImg.header['POLPOS'])))
        thisStr   = (str(thisImg.header['OBJECT']) + '\n' +
                     str(thisImg.header['FILTNME3'] + '\n' +
                     str(thisImg.header['POLPOS'])))
        nextStr   = (str(nextImg.header['OBJECT']) + '\n' +
                     str(nextImg.header['FILTNME3'] + '\n' +
                     str(nextImg.header['POLPOS'])))
        prevLabel.set_text(prevStr)
        thisLabel.set_text(thisStr)
        nextLabel.set_text(nextStr)

        # Update the display
        fig.canvas.draw()

    # Save the generated mask
    if event.key == 'enter':
        # Make sure the header has the right values
        maskImg.header = thisImg.header

        # Write the mask to disk
        print('Writing mask for file ', os.path.basename(maskImg.filename))
        maskImg.write()

    # Clear out the mask values
    if event.key == 'backspace':
        # Clear out the mask array
        maskImg.arr = maskImg.arr * np.byte(0)

        # Update contour plot (clear old lines redo contouring)
        axarr[1].collections = []
        axarr[1].contour(xx, yy, maskImg.arr, levels=[0.5], colors='white', alpha = 0.2)

        # Update the display
        fig.canvas.draw()
# Perhaps they're getting treated as local variables
# because they are modified elsewhere???


# Test if a mask has already been generated for this images
maskFile = os.path.join(maskDir, os.path.basename(thisImg.filename))
if os.path.isfile(maskFile):
    # If the mask file exists, use it
    maskImg = AstroImage(maskFile)
else:
    # If the mask file does not exist, build a blank slate
    # Build a mask template (0 = not masked, 1 = masked)
    maskImg       = thisImg.copy()
    maskImg.filename = maskFile
    maskImg.arr   = maskImg.arr.astype(np.int16) * np.int16(0)
    maskImg.dtype = np.byte
    maskImg.header['BITPIX'] = 16

# Generate 2D X and Y position maps
maskShape = maskImg.arr.shape
grids     = np.mgrid[0:maskShape[0], 0:maskShape[1]]
xx        = grids[1]
yy        = grids[0]

# Build the image displays
# Start by preparing a 1x3 plotting area
fig, axarr = plt.subplots(1, 3, sharey=True)

# Compute image count scaling
prevMin = np.median(prevImg.arr) - 0.25*np.std(prevImg.arr)
prevMax = np.median(prevImg.arr) + 2*np.std(prevImg.arr)