Ejemplo n.º 1
0
 def generate_polygon(img_caption_df, ann, seg, vals, x_y, polygons):
     val = img_caption_df.loc[img_caption_df['ann_id'] == ann['id'],
                              'dist_from_raw'].values[0]
     vals.append(val)
     poly = np.array(seg).reshape((int(len(seg) / 2), 2))
     poly = Polygon(poly, label=str(val))
     x_y.append(
         (np.mean(poly.get_xy()[:, 0]), np.mean(poly.get_xy()[:, 1])))
     polygons.append(poly)
Ejemplo n.º 2
0
def test_Polygon_close():
    #: Github issue #1018 identified a bug in the Polygon handling
    #: of the closed attribute; the path was not getting closed
    #: when set_xy was used to set the vertices.

    # open set of vertices:
    xy = [[0, 0], [0, 1], [1, 1]]
    # closed set:
    xyclosed = xy + [[0, 0]]

    # start with open path and close it:
    p = Polygon(xy, closed=True)
    assert_array_equal(p.get_xy(), xyclosed)
    p.set_xy(xy)
    assert_array_equal(p.get_xy(), xyclosed)

    # start with closed path and open it:
    p = Polygon(xyclosed, closed=False)
    assert_array_equal(p.get_xy(), xy)
    p.set_xy(xyclosed)
    assert_array_equal(p.get_xy(), xy)

    # start with open path and leave it open:
    p = Polygon(xy, closed=False)
    assert_array_equal(p.get_xy(), xy)
    p.set_xy(xy)
    assert_array_equal(p.get_xy(), xy)

    # start with closed path and leave it closed:
    p = Polygon(xyclosed, closed=True)
    assert_array_equal(p.get_xy(), xyclosed)
    p.set_xy(xyclosed)
    assert_array_equal(p.get_xy(), xyclosed)
Ejemplo n.º 3
0
def test_Polygon_close():
    #: Github issue #1018 identified a bug in the Polygon handling
    #: of the closed attribute; the path was not getting closed
    #: when set_xy was used to set the vertices.

    # open set of vertices:
    xy = [[0, 0], [0, 1], [1, 1]]
    # closed set:
    xyclosed = xy + [[0, 0]]

    # start with open path and close it:
    p = Polygon(xy, closed=True)
    assert_array_equal(p.get_xy(), xyclosed)
    p.set_xy(xy)
    assert_array_equal(p.get_xy(), xyclosed)

    # start with closed path and open it:
    p = Polygon(xyclosed, closed=False)
    assert_array_equal(p.get_xy(), xy)
    p.set_xy(xyclosed)
    assert_array_equal(p.get_xy(), xy)

    # start with open path and leave it open:
    p = Polygon(xy, closed=False)
    assert_array_equal(p.get_xy(), xy)
    p.set_xy(xy)
    assert_array_equal(p.get_xy(), xy)

    # start with closed path and leave it closed:
    p = Polygon(xyclosed, closed=True)
    assert_array_equal(p.get_xy(), xyclosed)
    p.set_xy(xyclosed)
    assert_array_equal(p.get_xy(), xyclosed)
Ejemplo n.º 4
0
    def coco_modify_segs(self):
        for count, im_idx in enumerate(self.D.image_index):
            if count % 100 == 0:
                print 'Image {:d} of {:d}'.format(count,
                                                  len(self.D.image_index))
            seg_filename = self.D.seg_path_from_index(im_idx)
            im = cv2.imread(im_filename)
            if im is None:
                print 'Could not read ', im_filename
                sys.exit(-1)
            seg = self.get_seg(im)

            # get annotations
            c = self.D._COCO
            ann_ids = c.getAnnIds(imgIds=im_idx, iscrowd=False)
            anns = c.loadAnns(ann_ids)
            mask = np.zeros(seg.shape, dtype=np.int)
            for ann in anns:
                if 'segmentation' not in ann:
                    continue
                cat_id = ann['category_id'] + 9
                if type(ann['segmentation']) == list:
                    for s in ann['segmentation']:
                        poly = np.array(s).reshape((len(s) / 2, 2))
                        poly = Polygon(poly)
                        pth = path.Path(poly.get_xy(), closed=True)
                        y, x = np.mgrid[:seg.shape[0], :seg.shape[1]]
                        points = np.transpose((x.ravel(), y.ravel()))
                        m = pth.contains_points(points).reshape(mask.shape)
                        mask[m > 0] = cat_id * m[m > 0]
                else:
                    Tracer()()
            # superimpose mask on seg
            seg[mask > 0] = mask[mask > 0]
            cv2.imwrite(seg_filename, seg)
    def add_polygon(self, vertices, polygon_type):

        self.all_polygons_vertex_circles.append(
            self.curr_polygon_vertex_circles)

        if polygon_type == PolygonType.CLOSED:
            polygon = Polygon(vertices,
                              closed=True,
                              fill=False,
                              edgecolor=self.colors[self.curr_label + 1],
                              linewidth=2)
        elif polygon_type == PolygonType.OPEN:
            polygon = Polygon(vertices,
                              closed=False,
                              fill=False,
                              edgecolor=self.colors[self.curr_label + 1],
                              linewidth=2)
        elif polygon_type == PolygonType.TEXTURE:
            polygon = Polygon(vertices,
                              closed=True,
                              fill=False,
                              edgecolor=self.colors[self.curr_label + 1],
                              linewidth=2,
                              hatch='/')
        elif polygon_type == PolygonType.TEXTURE_WITH_CONTOUR:
            polygon = Polygon(vertices,
                              closed=True,
                              fill=False,
                              edgecolor=self.colors[self.curr_label + 1],
                              linewidth=2,
                              hatch='x')
        elif polygon_type == PolygonType.DIRECTION:
            polygon = Polygon(vertices,
                              closed=False,
                              fill=False,
                              edgecolor=self.colors[self.curr_label + 1],
                              linewidth=2,
                              linestyle='dashed')
        else:
            raise 'polygon_type must be one of enum closed, open, texture'

        xys = polygon.get_xy()
        x0_y0_x1_y1 = np.r_[xys.min(axis=0), xys.max(axis=0)]

        self.axis.add_patch(polygon)

        polygon.set_picker(True)

        self.polygon_list.append(polygon)
        self.polygon_bbox_list.append(x0_y0_x1_y1)
        self.polygon_labels.append(self.curr_label)
        self.polygon_types.append(polygon_type)

        self.curr_polygon_vertices = []
        self.curr_polygon_vertex_circles = []
Ejemplo n.º 6
0
 def add_to_axes(self, ax=None, **kwargs):
     polys = [(c, self.fields[c]) for c in self.fclasses if self.fields[c]['poly']]
     if ax is None:
         fig, ax = plt.subplots(1)
     pgns = []
     for c, f in polys:
         label = f['names']
         pg = Polygon(f['poly'], closed=True, **kwargs)
         pgns.append(pg)
         x, y = pg.get_xy().T
         ax.annotate(c, xy=(np.nanmean(x), np.nanmean(y)))
         ax.add_patch(pg)
 def add_to_axes(self, ax=None, **kwargs):
     polys = [(c, self.fields[c]) for c in self.fclasses if self.fields[c]['poly']]
     if ax is None:
         fig, ax = plt.subplots(1)
     pgns = []
     for c, f in polys:
         label = f['names']
         pg = Polygon(f['poly'], closed=True, **kwargs)
         pgns.append(pg)
         x, y = pg.get_xy().T
         ax.annotate(c, xy=(np.nanmean(x), np.nanmean(y)))
         ax.add_patch(pg)
Ejemplo n.º 8
0
 def add_to_axes(self, ax=None, fill=False, **kwargs):
     polys = [(c, self.fields[c]) for c in self.fclasses
              if self.fields[c]["poly"]]
     if ax is None:
         fig, ax = plt.subplots(1)
     pgns = []
     for c, f in polys:
         label = f["names"]
         if not fill:
             kwargs['facecolor'] = 'none'
         pg = Polygon(f["poly"], closed=True, edgecolor="k", **kwargs)
         pgns.append(pg)
         x, y = pg.get_xy().T
         ax.annotate(c, xy=(np.nanmean(x), np.nanmean(y)))
         ax.add_patch(pg)
Ejemplo n.º 9
0
class Polygon(object):

    def __init__(self, ax, xy, **kwargs):
        self._pol = MplPol(xy, kwargs)
        ax.add_patch(self._pol)

    @property
    def xy(self):
        return self._pol.get_xy()

    @xy.setter
    def xy(self, xy):
        self._pol.set_xy(xy)

    def remove(self):
        self._pol.remove()

    def anim_update(self, *args):
        pass
Ejemplo n.º 10
0
class Roi(object):
    def __init__(self, xy, name, colour, linewidth=None):
        '''
        A ROI is defined by its vertices (xy coords), a name,
        and a colour.

        Input:

          xy       Coordinates of the ROI as a numpy array of
                   shape (2, N).
          name     Name of the ROI, string.
          colour   Colour definition in any format acceptable
                   to matplotlib, ie named (e.g. 'white') or
                   RGBA format (e,g. (1.0, 1.0, 1.0, 1.0)).
        '''
        super(Roi, self).__init__()
        self.polygon = Polygon(xy, lw=linewidth)
        self.polygon.set_facecolor('none')
        self.name = name
        self.set_colour(colour)

    def set_colour(self, colour):
        self.polygon.set_edgecolor(colour)

    def get_colour(self):
        return self.polygon.get_edgecolor()

    def set_name(self, name):
        self.name = name

    def get_name(self):
        return self.name

    def set_xy(self):
        return self.polygon.set_xy()

    def get_xy(self):
        return self.polygon.get_xy()

    def get_polygon(self):
        return self.polygon
Ejemplo n.º 11
0
    def fill_region(self, city=None, taiwan=False, color='tan', *args,
                    **kwargs):
        """
        功能:填充城市板块
        :param city: 需要填充的城市名
        :param taiwan: 因为台湾不包括在中国的shapefile中,所以单独添加
        :param color: 填充的颜色
        :param args:
        :param kwargs:
        :return:
        """
        # 中国大陆
        for area_info, area_shp in zip(self.m.states_info, self.m.states):
            area_proid = area_info['NAME_1']
            if area_proid == city:
                area_ploy = Polygon(area_shp, facecolor=color,
                                    lw=2, *args, **kwargs)
                self.ax.add_patch(area_ploy)
        a = area_ploy.get_xy()
        self.ax.text((np.min(a[:, 0]) + np.max(a[:, 0])) / 2,
                     (np.min(a[:, 1]) + np.max(a[:, 1])) / 2,
                     city, rotation=45, zorder=11,
                     ha='center', va='center')

        # 台湾
        if taiwan:
            for taiwan_info, taiwan_shp in zip(self.m.taiwan_info,
                                               self.m.taiwan):
                taiwan_proid = taiwan_info['NAME_1']
                if taiwan_proid == 'Taiwan':
                    taiwan_poly = Polygon(taiwan_shp, facecolor=color,
                                          edgecolor='r', lw=2,
                                          *args, **kwargs)
                    self.ax.add_patch(taiwan_poly)
        taiwan_center = self.m(120.95, 23.7)
        self.ax.text(taiwan_center[0], taiwan_center[1],
                     'Taiwan', rotation=65, zorder=11,
                     ha='center', va='center', fontsize=15)
Ejemplo n.º 12
0
def display_instances(image,
                      boxes,
                      masks,
                      class_ids,
                      class_names,
                      scores=None,
                      title="",
                      figsize=(16, 16),
                      ax=None):
    """
    boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
    masks: [height, width, num_instances]
    class_ids: [num_instances]
    class_names: list of class names of the dataset
    scores: (optional) confidence scores for each box
    figsize: (optional) the size of the image.
    """
    # Number of instances
    N = boxes.shape[0]
    if not N:
        print("\n*** No instances to display *** \n")
    else:
        assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]

    if not ax:
        _, ax = plt.subplots(1, figsize=figsize)

    # Generate random colors
    colors = random_colors(N)

    # Show area outside image boundaries.
    height, width = image.shape[:2]
    ax.set_ylim(height + 10, -10)
    ax.set_xlim(-10, width + 10)
    ax.axis('off')
    ax.set_title(title)

    masked_image = image.astype(np.uint32).copy()
    contourlist = []
    for i in range(N):
        color = colors[i]

        # Bounding box
        if not np.any(boxes[i]):
            # Skip this instance. Has no bbox. Likely lost in image cropping.
            continue
        y1, x1, y2, x2 = boxes[i]
        p = patches.Rectangle((x1, y1),
                              x2 - x1,
                              y2 - y1,
                              linewidth=2,
                              alpha=0.7,
                              linestyle="dashed",
                              edgecolor=color,
                              facecolor='none')
        # ax.add_patch(p)

        # Label
        class_id = class_ids[i]
        score = scores[i] if scores is not None else None
        label = class_names[class_id]
        x = random.randint(x1, (x1 + x2) // 2)
        caption = "{} {:.3f}".format(label, score) if score else label
        # ax.text(x1, y1 + 8, caption,color='w', size=11, backgroundcolor="none")

        # Mask
        mask = masks[:, :, i]
        masked_image = apply_mask(masked_image, mask, color)

        # Mask Polygon
        # Pad to ensure proper polygons for masks that touch image edges.
        padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2),
                               dtype=np.uint8)
        padded_mask[1:-1, 1:-1] = mask
        contours = find_contours(padded_mask, 0.5)
        # contourlist.append(contours)
        for verts in contours:
            # Subtract the padding and flip (y, x) to (x, y)
            verts = np.fliplr(verts) - 1
            p = Polygon(verts, facecolor="none", edgecolor=color)
            contourlist.append(p.get_xy())
            ax.add_patch(p)
        break

    image1 = cv2.imread(
        '/home/sumedh/Machine Learning/Mask_RCNN/images/cool.jpg', -1)
    resultant = image1.copy()
    image = Image.open(
        '/home/sumedh/Machine Learning/Mask_RCNN/images/cool.jpg')
    blurred_image = image.filter(ImageFilter.GaussianBlur(radius=3))
    blurred_image.show()
    # blur = cv2.GaussianBlur(resultant, (39, 39), 0)
    # blur = cv2.blur(resultant, (5, 5))
    # mask defaulting to black for 3-channel and transparent for 4-channel
    # (of course replace corners with yours)
    a = np.array(contourlist)
    mask4 = np.zeros(image1.shape, dtype=np.uint8)
    roi_corners = np.array(contourlist, dtype=np.int32)
    # fill the ROI so it doesn't get wiped out when the mask is applied
    channel_count = image1.shape[2]  # i.e. 3 or 4 depending on your image
    ignore_mask_color = (255, ) * channel_count
    cv2.fillPoly(mask4, roi_corners, ignore_mask_color)
    # from Masterfool: use cv2.fillConvexPoly if you know it's convex
    # resultant[mask4] = mask4

    # apply the mask
    mask_to_pil = Image.fromarray(np.uint8(255 * mask4))
    background = cv2.bitwise_not(mask4)

    fore_ground_image = cv2.bitwise_and(image1, mask4)
    background_image = cv2.bitwise_and(image1, background)
    blur = cv2.GaussianBlur(background_image, (49, 49), 0)
    blur_original = cv2.GaussianBlur(image1, (49, 49), 0)
    cv2.imwrite("background.png", blur)
    cv2.imwrite("foreground.png", fore_ground_image)
    bck = Image.open("background.png")
    fgk = Image.open("foreground.png")
    blended = Image.blend(bck, fgk, alpha=15.0)
    blended.save("resultant.png")
    cv2.imwrite("mask.png", mask4)
    cv2.imwrite("blur_original.png", blur_original)
    callling()

    #cv2.imwrite("final_image.png",vis)

    # save the result
    # cv2.imwrite('image_masked.png', background)
    # m_image = Image.open('image_masked.png')
    # blurred_image.paste(m_image, box=None, mask=mask_to_pil)
    # blurred_image.show()
    # store_cropped_image(masked_image,contourlist)
    ax.imshow(masked_image.astype(np.uint8))
    plt.show()
    def initialize_brain_labeling_gui(self):

        self.menu = QMenu()
        self.endDraw_Action = self.menu.addAction("Confirm closed contour")
        self.endDrawOpen_Action = self.menu.addAction("Confirm open boundary")
        self.confirmTexture_Action = self.menu.addAction(
            "Confirm textured region without contour")
        self.confirmTextureWithContour_Action = self.menu.addAction(
            "Confirm textured region with contour")
        self.confirmDirectionality_Action = self.menu.addAction(
            "Confirm striated region")

        self.deletePolygon_Action = self.menu.addAction("Delete polygon")
        self.deleteVertex_Action = self.menu.addAction("Delete vertex")
        self.addVertex_Action = self.menu.addAction("Add vertex")

        self.crossReference_Action = self.menu.addAction("Cross reference")

        # A set of high-contrast colors proposed by Green-Armytage
        self.colors = np.loadtxt('100colors.txt', skiprows=1)
        self.label_cmap = ListedColormap(self.colors, name='label_cmap')

        self.curr_label = -1

        self.setupUi(self)

        self.fig = self.canvaswidget.fig
        self.canvas = self.canvaswidget.canvas

        self.canvas.mpl_connect('scroll_event', self.zoom_fun)
        self.bpe_id = self.canvas.mpl_connect('button_press_event',
                                              self.press_fun)
        self.bre_id = self.canvas.mpl_connect('button_release_event',
                                              self.release_fun)
        self.canvas.mpl_connect('motion_notify_event', self.motion_fun)

        self.display_buttons = [
            self.img_radioButton, self.textonmap_radioButton,
            self.dirmap_radioButton, self.labeling_radioButton
        ]
        self.img_radioButton.setChecked(True)

        for b in self.display_buttons:
            b.toggled.connect(self.display_option_changed)

        self.spOnOffSlider.valueChanged.connect(self.display_option_changed)

        self.canvas.mpl_connect('pick_event', self.on_pick)

        ########## Label buttons #############

        self.n_labelbuttons = 0

        self.labelbuttons = []
        self.labeledits = []

        if self.parent_labeling is not None:
            for n in self.parent_labeling['labelnames']:
                self._add_labelbutton(desc=n)
        else:
            for n in self.dm.labelnames:
                self._add_labelbutton(desc=n)

        # self.loadButton.clicked.connect(self.load_callback)
        self.saveButton.clicked.connect(self.save_callback)
        self.newLabelButton.clicked.connect(self.newlabel_callback)
        # self.newLabelButton.clicked.connect(self.sigboost_callback)
        self.quitButton.clicked.connect(self.close)
        self.buttonParams.clicked.connect(self.paramSettings_clicked)

        self.setWindowTitle(self.windowTitle() + ', parent_labeling = %s' %
                            (self.parent_labeling_name))

        # self.statusBar().showMessage()

        self.fig.clear()
        self.fig.set_facecolor('white')

        self.axis = self.fig.add_subplot(111)
        self.axis.axis('off')

        self.axis.imshow(self.masked_img, cmap=plt.cm.Greys_r, aspect='equal')

        if self.curr_labeling['initial_polygons'] is not None:
            for label, typed_polygons in self.curr_labeling[
                    'initial_polygons'].iteritems():
                for polygon_type, vertices in typed_polygons:
                    if polygon_type == PolygonType.CLOSED:
                        polygon = Polygon(vertices,
                                          closed=True,
                                          fill=False,
                                          edgecolor=self.colors[label + 1],
                                          linewidth=2)
                    elif polygon_type == PolygonType.OPEN:
                        polygon = Polygon(vertices,
                                          closed=False,
                                          fill=False,
                                          edgecolor=self.colors[label + 1],
                                          linewidth=2)
                    elif polygon_type == PolygonType.TEXTURE:
                        polygon = Polygon(vertices,
                                          closed=True,
                                          fill=False,
                                          edgecolor=self.colors[label + 1],
                                          linewidth=2,
                                          hatch='/')
                    elif polygon_type == PolygonType.TEXTURE_WITH_CONTOUR:
                        polygon = Polygon(vertices,
                                          closed=True,
                                          fill=False,
                                          edgecolor=self.colors[label + 1],
                                          linewidth=2,
                                          hatch='x')
                    elif polygon_type == PolygonType.DIRECTION:
                        polygon = Polygon(vertices,
                                          closed=False,
                                          fill=False,
                                          edgecolor=self.colors[label + 1],
                                          linewidth=2,
                                          linestyle='dashed')
                    else:
                        raise 'polygon_type must be one of enum closed, open, texture'

                    xys = polygon.get_xy()
                    x0_y0_x1_y1 = np.r_[xys.min(axis=0), xys.max(axis=0)]

                    polygon.set_picker(10.)

                    self.axis.add_patch(polygon)
                    self.polygon_list.append(polygon)
                    self.polygon_bbox_list.append(x0_y0_x1_y1)
                    self.polygon_labels.append(label)
                    self.polygon_types.append(polygon_type)

                    curr_polygon_vertex_circles = []
                    for x, y in vertices:
                        curr_vertex_circle = plt.Circle(
                            (x, y),
                            radius=10,
                            color=self.colors[label + 1],
                            alpha=.8)
                        curr_vertex_circle.set_picker(True)
                        self.axis.add_patch(curr_vertex_circle)
                        curr_polygon_vertex_circles.append(curr_vertex_circle)

                    self.all_polygons_vertex_circles.append(
                        curr_polygon_vertex_circles)

        # self.curr_polygon_vertices = []

        self.fig.subplots_adjust(left=0, bottom=0, right=1, top=1)

        self.newxmin, self.newxmax = self.axis.get_xlim()
        self.newymin, self.newymax = self.axis.get_ylim()

        self.canvas.draw()
        self.show()
Ejemplo n.º 14
0
def vis_one_image(im_name, fout, bboxes, dpi=200, weights=None):
    # masks: (N, 28, 28) ... masks for one frame
    if not len(bboxes):
        return

    im = cv2.imread(im_name)
    H, W, _ = im.shape
    color_list = colormap(rgb=True) / 255

    fig = plt.figure(frameon=False)
    fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im)

    mask_color_id = 0
    if weights is None:
        n_objs = masks.shape[0]
        obj_ids = range(n_objs)
    else:
        obj_ids = np.argsort(weights)

    ws = [0]
    for oid in obj_ids:
        x, y, w, h = bboxes[oid]
        mask = np.zeros([H, W])
        mask[x:x + w, y:y + h] = 1
        mask = mask.astype('uint8')
        if mask.sum() == 0:
            continue

        if weights is not None:
            ws += weights[oid],
        color_mask = color_list[mask_color_id % len(color_list), 0:3]
        mask_color_id += 1

        w_ratio = .4
        for c in range(3):
            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio

        e_down = mask

        e_pil = Image.fromarray(e_down)
        e_pil_up = e_pil.resize((H, W) if TRANS else (W, H), Image.ANTIALIAS)
        e = np.array(e_pil_up)

        _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP,
                                            cv2.CHAIN_APPROX_NONE)

        if len(contour) > 1:
            print('# contour:', len(contour))
        for c in contour:
            if FLIP:
                assert (c.shape[1] == 1
                        and c.shape[2] == 2), print('c.shape:', c.shape)
                for pid in range(c.shape[0]):
                    c[pid][0][0], c[pid][0][1] = c[pid][0][1], c[pid][0][0]
            linewidth = 1.2
            alpha = 0.5
            if oid == obj_ids[-1]:
                # most probable obj
                edgecolor = (1, 0, 0, 1)  # 'r'
            else:
                edgecolor = (1, 1, 1, 1)  # 'w'
            if weights is not None:
                linewidth *= (4**weights[oid])
                alpha /= (4**weights[oid])

            polygon = Polygon(
                c.reshape((-1, 2)),
                fill=True,
                facecolor=(color_mask[0], color_mask[1], color_mask[2], alpha),
                edgecolor=edgecolor,
                linewidth=linewidth,
            )
            xy = polygon.get_xy()

            ax.add_patch(polygon)

    fig.savefig(fout.replace('.jpg', '_{:.3f}.jpg'.format(max(ws))), dpi=dpi)
    plt.close('all')
Ejemplo n.º 15
0
class dispASDF(pyasdf.ASDFDataSet):
	"""An class for analyzing dispersion curves based on ASDF database
	"""
	def set_poly(self,lst,minlon,minlat,maxlon,maxlat):
		"""Define the polygon for study region and max(min) longitude/latitude for map view
		Parameters:  lst  -- list of (lon, lat) defining the polygon for study region
					 minlon, minlat, maxlon, maxlat -- minimum and miximum coordinates for the input age model
		"""
		self.poly = Polygon(lst)
		self.perimeter = self.poly.get_path()
		self.minlon = minlon; self.minlat = minlat; self.maxlon = maxlon; self.maxlat = maxlat		
		return
	
	def point_in(self):
		"""test if a point of given longitude and latitude is in the polygon of study region
		"""
		return
	
	def path_in(self):
		"""test if the line connecting 2 given points is in the polygon
		"""
		return
	
	def read_age_mdl(self):
		"""read in crustal age model for the oceans
		"""
		# dset = Dataset('/projects/howa1663/Code/ToolKit/Models/Age_Ocean_Crust/age.3.2.nc','r')
		dset = Dataset('./age.3.2.nc','r') # 2-minute resolution
		longitude = dset.variables['x'][:]
		longitude[longitude<0] += 360.
		latitude = dset.variables['y'][:]
		z = dset.variables['z'][:] # masked array
		mask = dset.variables['z'][:].mask
		data = dset.variables['z'][:].data / 100.
		data[mask] = 9999.
		data = data[(latitude >= self.minlat)*(latitude <= self.maxlat),:]
		data = data[:,(longitude >= self.minlon)*(longitude <= self.maxlon)]
		longitude = longitude[(longitude >= self.minlon)*(longitude <= self.maxlon)]
		latitude = latitude[(latitude >= self.minlat)*(latitude <= self.maxlat)]
		self.age_data = data; self.age_lon = longitude; self.age_lat = latitude
		return
	
	def read_topo_mdl(self):
		"""Read in topography model, return the function for calculating topography at any given point.
		"""
		# etopo1 = Dataset('/projects/howa1663/Code/ToolKit/Models/ETOPO1/ETOPO1_Ice_g_gmt4.grd', 'r')
		etopo1 = Dataset('/work2/wang/Code/ToolKit/ETOPO1_Ice_g_gmt4.grd','r')
		lons = etopo1.variables["x"][:]
		west = lons<0 # mask array with negetive longitudes
		west = 360.*west*np.ones(len(lons))
		lons = lons+west
		lats = etopo1.variables["y"][:]
		z = etopo1.variables["z"][:]
		etopoz = z[(lats>=(self.minlat))*(lats<=(self.maxlat)), :]
		etopoz = etopoz[:, (lons>=self.minlon)*(lons<=self.maxlon)]
		lats = lats[(lats>=self.minlat)*(lats<=self.maxlat)]
		lons = lons[(lons>=self.minlon)*(lons<=self.maxlon)]
		etopox, etopoy = np.meshgrid(lats, lons, indexing='ij')
		etopox = etopox.flatten()
		etopoy = etopoy.flatten()
		points = np.vstack((etopox,etopoy)).T
		etopoz = etopoz.flatten()
		f = LinearNDInterpolator(points,etopoz)
		return f

	def read_stations(self,stafile,source='CIEI',chans=['BHZ', 'BHE', 'BHN']):
		"""Read in stations from station list file
		
		"""
		with open(stafile, 'r') as f:
			Sta = []
			site = obspy.core.inventory.util.Site(name='01')
			creation_date = obspy.core.utcdatetime.UTCDateTime(0)
			inv = obspy.core.inventory.inventory.Inventory(networks=[], source=source)
			total_number_of_channels = len(chans)
			for lines in f.readlines():
				lines = lines.split()
				stacode = lines[0]
				lon = float(lines[1])
				lat = float(lines[2])
				if not self.perimeter.contains_point((lon,lat)): continue;
				netcode = lines[3]
				netsta = netcode+'.'+stacode
				if Sta.__contains__(netsta):
					index = Sta.index(netsta)
					if abs(self[index].lon-lon) >0.01 and abs(self[index].lat-lat) >0.01:
						raise ValueError('Incompatible Station Location:' + netsta+' in Station List!')
					else:
						print('Warning: Repeated Station:' +netsta+' in Station List!')
						continue
				Sta.append(netsta)
				channels = []
				if lon > 180.:
					lon -= 360.
				for chan in chans:
					channel = obspy.core.inventory.channel.Channel(code=chan, location_code='01', latitude=lat, longitude=lon,
							elevation=0.0, depth=0.0)
					channels.append(channel)
				station = obspy.core.inventory.station.Station(code=stacode, latitude=lat, longitude=lon, elevation=0.0,
						site=site, channels=channels, total_number_of_channels = total_number_of_channels, creation_date = creation_date)
				network = obspy.core.inventory.network.Network(code=netcode, stations=[station])
				networks = [network]
				inv += obspy.core.inventory.inventory.Inventory(networks=networks, source=source)
		print('Writing obspy inventory to ASDF dataset')
		self.add_stationxml(inv)
		print('End writing obspy inventory to ASDF dataset')
		return
	
	def get_ages(self, lons,lats):
		""" calculate ages for a give list of points on the ocean floor
		Parameters:
					lons  -- longitude vector
					lats  -- latitude vector
		"""
		xx, yy = np.meshgrid(self.age_lon, self.age_lat) # xx for longitude, yy for latitude
		xx = xx.reshape(xx.size) #nearest
		yy = yy.reshape(yy.size)
		x = np.column_stack((xx,yy))
		y = self.age_data.reshape(self.age_data.size)
		f = NearestNDInterpolator(x,y,rescale=False)
		ages = f(np.column_stack((lons,lats)))
		# mask = self.age_data > 180
		#f = interp2d(xx[~mask],yy[~mask],self.age_data[~mask],kind='cubic') # cubic spline interpolation
		# ages = f(lons,lats)
		return ages
			
	def read_paths(self,disp_dir='/work3/wang/JdF/FTAN',res=3.5):
		"""read in dispersion curve measurements from files, calculate ages along great circle path, read waveforms
		Paramenters:  disp_dir -- directory for the dispersion measurement results
					  res      -- resolution for the great circle sampling points
		"""
		yy, xx = np.meshgrid(self.age_lon, self.age_lat)
		xx = xx.reshape(xx.size)
		yy = yy.reshape(yy.size)
		x = np.vstack((xx,yy)).T
		y = self.age_data.reshape(self.age_data.size)
		f = NearestNDInterpolator(x,y,rescale=False) # nearest-neighbour interpolation
		f_topo = self.read_topo_mdl()
		staLst = self.waveforms.list()
		print('Reading dispersion curves & averaging crustal age along paths')
		for staid1 in staLst:
			lat1 = self.waveforms[staid1].coordinates['latitude']
			lon1 = self.waveforms[staid1].coordinates['longitude']
			netcode1, stacode1 = staid1.split('.')
			if lon1 < 0.:
				lon1 += 360.
			for staid2 in staLst:
				if staid1 >= staid2: continue;
				lat2 = self.waveforms[staid2].coordinates['latitude']
				lon2 = self.waveforms[staid2].coordinates['longitude']
				if lon2 < 0.:
					lon2 += 360.
				gc_path, dist = get_gc_path(lon1,lat1,lon2,lat2,res) # the great circle path travels beyond the study region
				if not self.perimeter.contains_path(Path(gc_path)):
					print("Station "+staid1+" not in study region! Will be discarded!")
					del self.waveforms[staid1]
					continue; # the great circle path travels beyond the study region	
				netcode2, stacode2 = staid2.split('.')
				gc_points = np.array(gc_path)
				ages = f(gc_points[:,::-1])
				depths = f_topo(gc_points[:,::-1])
				if ages.max() > 300: continue; # the paths went out the model bound
				#age_avg = 1/((1./ages).mean())
				age_avg = ages.mean()
				depth_avg = depths.mean()
				disp_file = disp_dir+'/'+stacode1+'/'+'COR_'+stacode1+'_'+stacode2+'.SAC_2_DISP.1'
				snr_file = disp_dir+'/'+stacode1+'/'+'COR_'+stacode1+'_'+stacode2+'.SAC_2_amp_snr'
				if not os.path.isfile(disp_file): continue;
				arr = np.loadtxt(disp_file)
				arr_snr = np.loadtxt(snr_file)
				snrp_vec = arr_snr[:,2] # snr for positive lag signal
				snrn_vec = arr_snr[:,4] # snr for negative lag signal
				transp = np.vstack((snrp_vec,snrn_vec)).max(axis=0) > 3 # mean snr > 5
				# transp = np.vstack((snrp_vec,snrn_vec)).min(axis=0) > 5 # both positive & negative lag have snr > 5
				if (transp*1).max() == 0:
					continue
				per_vec = arr[transp,2]
				grv_vec = arr[transp,3]
				phv_vec = arr[transp,4]
				d_g = grv_vec[1:]-grv_vec[:-1]
				try:
					d_g[per_vec[1:]<6.] = 1. # don't consider period less than 5. sec
					index = np.where(d_g < -0.2)[0][0] # find out where the group velocity starts to drop larger than 0.2
					per_vec = per_vec[:index+1] # only keep results before the group velocity starts to drop
					grv_vec = grv_vec[:index+1]
					phv_vec = phv_vec[:index+1]
					snrp_vec = snrp_vec[:index+1]
					snrn_vec = snrn_vec[:index+1]
				except:
					pass
				mask = per_vec*phv_vec > dist # interstation distance larger than one wavelength
				try:
					ind1 = np.where(mask)[0][0] # 1st True
					if ind1 == 0:
						continue
				except:
					ind1 = per_vec.size
				disp_arr = np.vstack((per_vec[:ind1],grv_vec[:ind1],phv_vec[:ind1],snrp_vec[:ind1],snrn_vec[:ind1]))
				# xcorr_file = disp_dir+'/'+stacode1+'/'+'COR_'+stacode1+'_'+stacode2+'.SAC'
				# if not os.path.isfile(xcorr_file): continue;
				# tr = obspy.core.read(xcorr_file)[0]
				# xcorr_header = {'stacode1': '', 'stacode2': '', 'npts': 12345, 'b': 12345, 'e': 12345, \
				# 				'delta': 12345, 'dist': 12345, 'stackday': 0}
				# xcorr_header['b'] = tr.stats.sac.b
				# xcorr_header['e'] = tr.stats.sac.e
				# xcorr_header['stacode1'] = stacode1
				# xcorr_header['stacode2'] = stacode2
				# xcorr_header['npts'] = tr.stats.npts
				# xcorr_header['delta'] = tr.stats.delta
				# xcorr_header['stackday'] = tr.stats.sac.user0
				# try:
				# 	xcorr_header['dist'] = tr.stats.sac.dist
				# except AttributeError:
				# 	xcorr_header['dist'] = dist
				staid_aux = netcode1+'/'+stacode1+'/'+netcode2+'/'+stacode2
				parameters1 = {'age_avg': age_avg, 'depth_avg': depth_avg,'dist':dist, 'L_gc':len(gc_path)}
				ages_path = np.concatenate((gc_points,ages.reshape(-1,1)),axis=1)
				self.add_auxiliary_data(data=ages_path, data_type='AgeGc', path=staid_aux, parameters=parameters1)
				parameters2 = {'T': 0, 'grV': 1, 'phV': 2, 'snr_p': 3, 'snr_n': 4,'dist': dist, 'Np': ind1}
				self.add_auxiliary_data(data=disp_arr, data_type='DISPArray', path=staid_aux, parameters=parameters2)
				#self.add_auxiliary_data(data=tr.data, data_type='NoiseXcorr', path=staid_aux, parameters=xcorr_header)
		print('End of reading dispersion curves')
		return
	
	def intp_disp(self,pers,verbose=False):
		"""interpolate the dispersion curves to a given period band, QC was applied during this process
		Parameter:  pers -- period array
		"""
		if pers.size == 0:
			pers = np.append( np.arange(6.)*2.+6., np.arange(4.)*3.+18.)
		self.pers = pers
		staLst = self.waveforms.list()
		for staid1 in staLst:
			netcode1, stacode1 = staid1.split('.')
			for staid2 in staLst:
				netcode2, stacode2 = staid2.split('.')
				if staid1 >= staid2: continue
				try:
					subdset = self.auxiliary_data['DISPArray'][netcode1][stacode1][netcode2][stacode2]
				except:
					continue
				data = subdset.data.value
				index = subdset.parameters
				dist = index['dist']
				if verbose:
					print('Interpolating dispersion curve for '+ netcode1+'.'+stacode1+'_'+netcode2+'.'+stacode2)
				outindex = { 'T': 0, 'grV': 1, 'phV': 2,  'snr_p': 3, 'snr_n': 4, 'dist': dist, 'Np': pers.size }
				Np = int(index['Np'])
				if Np < 5:
					warnings.warn('Not enough datapoints for: '+ netcode1+'.'+stacode1+'_'+netcode2+'.'+stacode2, UserWarning, stacklevel=1)
					continue
				obsT = data[index['T']][:Np]
				grV = np.interp(pers, obsT, data[index['grV']][:Np] )
				phV = np.interp(pers, obsT, data[index['phV']][:Np] )
				inbound = (pers > obsT[0])*(pers < obsT[-1])
				if grV[inbound].size == pers[inbound].size and phV[inbound].size == pers[inbound].size:
					interpdata = np.append(pers[inbound], grV[inbound])
					interpdata = np.append(interpdata, phV[inbound])
				else:
					continue
				snr_p = np.interp(pers, obsT, data[index['snr_p']][:Np] )
				snr_n = np.interp(pers, obsT, data[index['snr_n']][:Np] )
				interpdata = np.append(interpdata, snr_p[inbound])
				interpdata = np.append(interpdata, snr_n[inbound])
				interpdata = interpdata.reshape(5, pers[inbound].size)
				staid_aux = netcode1+'/'+stacode1+'/'+netcode2+'/'+stacode2
				self.add_auxiliary_data(data=interpdata, data_type='DISPinterp', path=staid_aux, parameters=outindex)
		return
		
	
	def get_basemap(self,model='age'):
		"""get basemap from given model, use ocean crustal age model or etopo model
		Parameters:   model -- use which type of data for basemap, 'age' or 'etopo'
		"""
		if model == 'age':
			distEW, az, baz=obspy.geodetics.gps2dist_azimuth(self.minlat, self.minlon, self.minlat, self.maxlon) # distance is in m
			distNS, az, baz=obspy.geodetics.gps2dist_azimuth(self.minlat, self.minlon, self.maxlat+2., self.minlon) # distance is in m
			m = Basemap(width=distEW, height=distNS, rsphere=(6378137.00,6356752.3142), resolution='i', projection='lcc',\
						lat_1=self.minlat-1, lat_2=self.maxlat+1, lon_0=(self.minlon+self.maxlon)/2, lat_0=(self.minlat+self.maxlat)/2)
			m.drawparallels(np.arange(-80.0,80.0,5.0), linewidth=1, dashes=[2,2], labels=[1,0,0,0], fontsize=15)
			m.drawmeridians(np.arange(-170.0,170.0,5.0), linewidth=1, dashes=[2,2], labels=[0,0,1,0], fontsize=15)
			m.drawcoastlines(linewidth=1.0)
			m.drawcountries(linewidth=1.0)
			m.drawstates(linewidth=1.0)
			m.readshapefile('./Plates/PB2002_boundaries', name='PB2002_boundaries', drawbounds=True, \
						linewidth=1, color='orange') # draw plate boundary on basemap
			x, y = m(*np.meshgrid(self.age_lon,self.age_lat))
			data = np.ma.masked_array(self.age_data, self.age_data > 9000)
			img = m.pcolormesh(x, y, data, shading='gouraud', cmap='jet_r', vmin=0, vmax=11,alpha=0.5) # cmap possible choices: "jet","Spectral"
			m.drawmapboundary(fill_color="white")
			cbar = m.colorbar(img,location='bottom',pad="3%")
			cbar.set_alpha(1)
			cbar.draw_all()
			cbar.set_label('Age (Ma)')
			# plt.show()
		elif model == 'etopo':
			code
		else:
			raise ValueError('Only age or etopo model can be used for constructing basemap')
					
		return m
	
	def plot_stations(self,ppoly=False):
		"""Plot stations on basemap
			basemap -- 'age' or 'etopo'
			ppoly   -- flag for if showing the boundary polygon
		"""
		staLst = self.waveforms.list()
		m = self.get_basemap()
		for staid in staLst:
			lat = self.waveforms[staid].coordinates['latitude']
			lon = self.waveforms[staid].coordinates['longitude']
			lat_x, lat_y = m(lon,lat)
			m.plot(lat_x,lat_y,'^',color='olive')
		if ppoly:
			point_arr = self.poly.get_xy() 
			xx, yy = m(point_arr[:,0],point_arr[:,1])
			cord_arr = np.vstack((xx,yy)).T
			poly = Polygon(cord_arr,facecolor='none', edgecolor='k')
			plt.gca().add_patch(poly)
		plt.title('Station map',fontsize=16)
		plt.show()
		return
	
	def get_vel_age(self, c0,c1,c2,vmin,vmax):
		""" Calculate 2-D velocity map based on the three coefficients
		"""
		age_data = self.age_data
		age_lon = self.age_lon
		age_lat = self.age_lat
		llon, llat = np.meshgrid(age_lon, age_lat,indexing='ij')
		ccords = np.dstack((llon,llat)).reshape(llon.size,2)
		ind_arr = self.perimeter.contains_points(ccords).reshape(llon.shape[0],llon.shape[1]) # select the data points inside the pre-set polygon
		vel_arr = (c0+ c1*np.sqrt(age_data) + c2*age_data)*np.transpose(ind_arr.astype(float))
		distEW, az, baz=obspy.geodetics.gps2dist_azimuth(self.minlat, self.minlon, self.minlat, self.maxlon) # distance is in m
		distNS, az, baz=obspy.geodetics.gps2dist_azimuth(self.minlat, self.minlon, self.maxlat+2., self.minlon) # distance is in m
		m = Basemap(width=distEW, height=distNS, rsphere=(6378137.00,6356752.3142), resolution='i', projection='lcc',\
					lat_1=self.minlat-1, lat_2=self.maxlat+1, lon_0=(self.minlon+self.maxlon)/2, lat_0=(self.minlat+self.maxlat)/2)
		m.drawparallels(np.arange(-80.0,80.0,2.0), linewidth=1, dashes=[2,2], labels=[1,0,0,0], fontsize=15)
		m.drawmeridians(np.arange(-170.0,170.0,2.0), linewidth=1, dashes=[2,2], labels=[0,0,1,0], fontsize=15)
		m.drawcoastlines(linewidth=1.0)
		m.drawcountries(linewidth=1.0)
		m.drawstates(linewidth=1.0)
		m.readshapefile('./Plates/PB2002_boundaries', name='PB2002_boundaries', drawbounds=True, \
						linewidth=1, color='orange') # draw plate boundary on basemap
		x, y = m(*np.meshgrid(age_lon,age_lat))
		data = np.ma.masked_array(vel_arr, vel_arr < 0.1)
		img = m.pcolormesh(x, y, data, shading='gouraud', cmap='jet_r', vmin=vmin, vmax=vmax) # cmap possible choices: "jet","Spectral"
		m.drawmapboundary(fill_color="white")
		cbar = m.colorbar(img,location='bottom',pad="3%")
		#cbar.solids.set_eddgecolor("face")
		cbar.set_label('Vel (km/s)')
		# plt.title('Age-dependent velocity map',fontsize=16)
		plt.show()
		return
	
	def count_path(self):
		"""get number of paths at all periods for all stations
		"""
		return

	def fit_Harmon(self,period,vel_type='phase'):
		"""find the 3 coefficients which best fit the Harmon velocity-age relationship. [Ye et al, GGG, 2013, eq(1)]
		Parameters:  period   -- the period of the group or phase velocity to analyse for
					 vel_type -- type of velocity to invert for. 'group' or 'phase'
		"""
		try:
			subset = self.auxiliary_data.FitResult[str_per][vel_type]
			warnings.warn("Funciton fit_Harmon has been run for this period and velocity type!")
			return
		except:
			pass
		staLst = self.waveforms.list()
		dist_lst = []
		int_dis_lst = []
		ages_lst = []
		age_avg_lst = []
		V_lst = []
		netcode1_lst = []# used to save which stations were left after final fitting model
		stacode1_lst = []
		netcode2_lst = []
		stacode2_lst = [] # used to save which stations were left after final fitting model
		for staid1 in staLst:
			netcode1, stacode1 = staid1.split('.')
			for staid2 in staLst:
				if staid1 >= staid2: continue;
				netcode2, stacode2 = staid2.split('.')
				try:
					T_V = self.auxiliary_data['DISPinterp'][netcode1][stacode1][netcode2][stacode2].data.value # T, grV, phV
				except:
					continue
				ind_T = np.where(T_V[0,:]==period)[0]
				if ind_T.size != 1: continue;
				if vel_type == 'group':
					ind_V = 1
				elif vel_type == 'phase':
					ind_V = 2
				else:
					raise AttributeError('velocity type can only be group or phase')
				try:
					subset_age = self.auxiliary_data['AgeGc'][netcode1][stacode1][netcode2][stacode2]
				except:
					print((stacode1+'_'+stacode2+'pair has interpolated dispersion curve but dont have age along path'))
					continue
				if T_V[3,ind_T] < 5. or T_V[4,ind_T] < 5.: # snr_p or snr_n smaller than 5.
					continue
				d_T = min(period/3, 2.)
				# time_delay = get_ph_misfit(period,1./(period+d_T),1./(period-d_T),stacode1,stacode2,T_V[1,ind_T])
				# if np.abs(time_delay) > 1. or np.abs(time_delay / period) > 0.2:
					# continue
				if T_V[2,ind_T]<T_V[1,ind_T]: #phase velocity smaller than group velocity
					continue
				V_lst.append(T_V[ind_V,ind_T])
				dist = subset_age.parameters['dist']
				dist_lst.append(dist)
				inter_dist = dist / (subset_age.parameters['L_gc']-1)
				int_dis_lst.append(inter_dist)
				ages = subset_age.data.value[:,2]
				age_avg_lst.append(subset_age.parameters['age_avg'])
				ages_lst.append(ages)
				netcode1_lst.append(netcode1)
				stacode1_lst.append(stacode1)
				netcode2_lst.append(netcode2)
				stacode2_lst.append(stacode2)
		if not len(ages_lst)==len(V_lst) & len(V_lst)==len(int_dis_lst):
			raise AttributeError('The number of inter-station paths are incompatible for inverting c0, c1 & c2')
		dist = np.array(dist_lst)
		d_dist = np.array(int_dis_lst)
		age_avgs = np.array(age_avg_lst)
		V = np.array(V_lst).reshape(len(V_lst))
		
		fits = np.polyfit(np.sqrt(age_avgs),V,2)
		p = np.poly1d(fits)
		predict_V = p(np.sqrt(age_avgs))
		diffs = np.abs(predict_V-V)
		diffs_mean = np.mean(diffs)
		diffs_std = np.std(diffs)
		ind = np.abs(diffs-diffs_mean) < 3*diffs_std # discard those datapoints which are far away from a crude model.
		ages_lst_final = list(compress(ages_lst,ind))
		params = (dist[ind], d_dist[ind], V[ind], ages_lst_final)
		# params = (dist, d_dist, V, ages_lst)
		cranges = (slice(0.5,4,0.1), slice(-0.5,0.5,0.05), slice(-1.,1,0.05))
		resbrute = optimize.brute(sq_misfit,cranges,args=params,full_output=True,finish=None)
		print("Age-depedent coefficients found for period {} sec.".format(period))
		c0, c1, c2 = resbrute[0]
		data_out = np.vstack((age_avgs[ind],V[ind])).T # [-1,2] array, storing average age and velocity
		netcode1_lst_final = list(compress(netcode1_lst,ind))
		netcode2_lst_final = list(compress(netcode2_lst,ind))
		stacode1_lst_final = list(compress(stacode1_lst,ind))
		stacode2_lst_final = list(compress(stacode2_lst,ind))
		netcode1_arr = np.chararray(len(netcode1_lst_final),itemsize=2); netcode2_arr = np.chararray(len(netcode2_lst_final),itemsize=2)
		stacode1_arr = np.chararray(len(stacode1_lst_final),itemsize=5); stacode2_arr = np.chararray(len(stacode2_lst_final),itemsize=5)
		netcode1_arr[:] = netcode1_lst_final[:]
		stacode1_arr[:] = stacode1_lst_final[:]
		netcode2_arr[:] = netcode2_lst_final[:]
		stacode2_arr[:] = stacode2_lst_final[:]
		# data_out = np.vstack((age_avgs,V)).T # [-1,2] array, storing average age and velocity
		# netcode1_arr = np.chararray(len(netcode1_lst),itemsize=2); netcode2_arr = np.chararray(len(netcode2_lst),itemsize=2)
		# stacode1_arr = np.chararray(len(stacode1_lst),itemsize=5); stacode2_arr = np.chararray(len(stacode2_lst),itemsize=5)
		# netcode1_arr[:] = netcode1_lst[:]
		# stacode1_arr[:] = stacode1_lst[:]
		# netcode2_arr[:] = netcode2_lst[:]
		# stacode2_arr[:] = stacode2_lst[:]
		
		stas_arr_final = np.vstack((netcode1_arr,stacode1_arr,netcode2_arr,stacode2_arr)).T
		# plt.plot(age_avgs[ind],V[ind],'r.')
		# t0 = np.linspace(0,10,100)
		# plt.plot(t0,p(np.sqrt(t0)), 'b-')
		# plt.plot(t0,c0+c1*np.sqrt(t0)+c2*t0, 'g-')
		# plt.show()
		para_aux = str(period).zfill(2)+'/'+vel_type
		parameters = {'c0':c0, 'c1':c1, 'c2':c2}
		self.add_auxiliary_data(data=data_out, data_type='FitResult', path=para_aux, parameters=parameters)
		out_index = {'netcode1':0, 'stacode1':1, 'netcode2':2, 'stacode2':3}
		self.add_auxiliary_data(data=stas_arr_final, data_type='FinalStas', path=para_aux, parameters=out_index)
		return resbrute
	
	def plot_vel_age(self,period,vel_type='phase'):
		"""Plot velocity vs. oceanic crust age, both from model and measurement
		"""
		str_per = str(period).zfill(2)
		age_vel = self.auxiliary_data.FitResult[str_per][vel_type].data.value
		codes = self.auxiliary_data.FinalStas[str_per][vel_type].data.value
		c0 = self.auxiliary_data.FitResult[str_per][vel_type].parameters['c0']
		c1 = self.auxiliary_data.FitResult[str_per][vel_type].parameters['c1']
		c2 = self.auxiliary_data.FitResult[str_per][vel_type].parameters['c2']
		plt.plot(age_vel[:,0], age_vel[:,1], 'r.')
		t0 = np.linspace(0,10,50)
		plt.plot(t0,c0+c1*np.sqrt(t0)+c2*t0, 'b-')
		# for i in range(len(codes)):
		# 	plt.text(age_vel[i,0],age_vel[i,1],codes[i,1]+'_'+codes[i,3])
		plt.title(str(period)+' sec '+vel_type+' velocity vs. oceanic age')
		plt.xlim(xmin=0.)
		plt.xlabel('age (ma)')
		plt.ylabel('km/s')
		plt.show()
		return
	
	def plot_age_topo(self,period,vel_type='phase'):
		"""Plot age vs. ocean depth averaged along path.
		"""
		str_per = str(period).zfill(2)
		codes = self.auxiliary_data.FinalStas[str_per][vel_type].data.value
		ages = np.array([])
		depths = np.array([])
		for code in codes:
			age_avg = self.auxiliary_data['AgeGc'][code[0]][code[1]][code[2]][code[3]].parameters['age_avg']
			depth_avg = self.auxiliary_data['AgeGc'][code[0]][code[1]][code[2]][code[3]].parameters['depth_avg']
			ages = np.append(ages,age_avg)
			depths = np.append(depths,depth_avg)
		plt.plot(ages, depths, 'r.')
		plt.title('Oceanic depth vs. age for '+str(period)+' sec paths')
		plt.ylabel('Depth (m)')
		plt.xlabel('Age (ma)')
		plt.show()
		return
	
	def plot_vel_topo(self, period, vel_type='phase'):
		""" Plot velocity vs. ocean depth averaged along paths.
		"""
		str_per = str(period).zfill(2)
		codes = self.auxiliary_data.FinalStas[str_per][vel_type].data.value
		depths = np.array([])
		for code in codes:	
			depth_avg = self.auxiliary_data['AgeGc'][code[0]][code[1]][code[2]][code[3]].parameters['depth_avg']
			depths = np.append(depths,depth_avg)
		age_vel = self.auxiliary_data.FitResult[str_per][vel_type].data.value
		plt.plot(depths, age_vel[:,1], 'r.')
		plt.title(str(period)+' sec '+vel_type+' velocity vs. oceanic depth')
		plt.xlabel('Depth (m)')
		plt.ylabel('km/s')
		plt.show()
		return


	def plot_all_vel(self,pers=np.array([6,8,10,14,18,24,27]),vel_type='phase'):
		""" Plot the interpolated dispersion result in the same plot for a certain type of velocity
		"""
		colors = ['red','green','wheat','blue','orange','black','cyan']
		i = 0
		for period in pers:
			str_per = str(period).zfill(2)
			age_vel = self.auxiliary_data.FitResult[str_per][vel_type].data.value
			c0 = self.auxiliary_data.FitResult[str_per][vel_type].parameters['c0']
			c1 = self.auxiliary_data.FitResult[str_per][vel_type].parameters['c1']
			c2 = self.auxiliary_data.FitResult[str_per][vel_type].parameters['c2']
			plt.plot(age_vel[:,0], age_vel[:,1], '.',color=colors[i],label=str(period)+' sec')
			t0 = np.linspace(0,np.max(age_vel[:,0]),100)
			plt.plot(t0,c0+c1*np.sqrt(t0)+c2*t0,color=colors[i])
			i += 1
		plt.legend(loc='best',fontsize=16)
		plt.title(vel_type+' velocity vs. oceanic age',fontsize=16)
		plt.xlabel('age (ma)',fontsize=16)
		plt.ylabel('km/s',fontsize=16)
		plt.xlim(xmin=0.)
		plt.xticks(fontsize=18)
		plt.yticks(fontsize=18)
		plt.show()
		return

	def plot_paths(self,period,vel_type='phase'):
		""" Plot all the paths for dispersion curves used for a given period. The paths are great circle paths
		"""
		str_per = str(period).zfill(2)
		codes = self.auxiliary_data.FinalStas[str_per][vel_type].data.value
		m = self.get_basemap()
		for code in codes:
			lat1 = self.waveforms[code[0]+'.'+code[1]].coordinates['latitude']
			lon1 = self.waveforms[code[0]+'.'+code[1]].coordinates['longitude']
			lat2 = self.waveforms[code[2]+'.'+code[3]].coordinates['latitude']
			lon2 = self.waveforms[code[2]+'.'+code[3]].coordinates['longitude']
			gc_path, _ = get_gc_path(lon1,lat1,lon2,lat2,3)
			gc_points = np.array(gc_path)
			path_x, path_y = m(gc_points[:,0], gc_points[:,1])
			#m.plot(path_x,path_y,color='black',linewidth=0.5)
			m.plot(path_x[0],path_y[0],'^',color='olive')
			m.plot(path_x[-1],path_y[-1],'^',color='olive')
		plt.title(str(period)+' sec')
		plt.show()
		return
	
	def plot_age_model(lons, lats, resolution='i', cpt='/projects/howa1663/Code/ToolKit/Models/Age_Ocean_Crust/age1.cpt'):
		""" Not finished
		"""
		#mycm=pycpt.load.gmtColormap(cpt)
		try:
			etopo1 = Dataset('/work2/wang/Code/ToolKit/ETOPO1_Ice_g_gmt4.grd','r') # read in the etopo1 file which was used as the basemap
			llons = etopo1.variables["x"][:]
			west = llons<0 # mask array with negetive longitudes
			west = 360.*west*np.ones(len(llons))
			llons = llons+west
			llats = etopo1.variables["y"][:]
			zz = etopo1.variables["z"][:]
			etopoz = zz[(llats>(lats[0]-2))*(llats<(lats[1]+2)), :]
			etopoz = etopoz[:, (llons>(lons[0]-2))*(llons<(lons[1]+2))]
			llats = llats[(llats>(lats[0]-2))*(llats<(lats[1]+2))]
			llons = llons[(llons>(lons[0]-2))*(llons<(lons[1]+2))]
			etopoZ = m.transform_scalar(etopoz, llons-360*(llons>180)*np.ones(len(llons)), llats, etopoz.shape[0], etopoz.shape[1]) # tranform the altitude grid into the projected coordinate
			ls = LightSource(azdeg=315, altdeg=45)
			m.imshow(ls.hillshade(etopoZ, vert_exag=0.05),cmap='gray')
		except IOError:
			print("Couldn't read etopo data or color map file! Check file directory!")
		
		return
	
	def write_2_sta_in(self, outdir="/work3/wang/JdF/Input_4_Ray", channel='ZZ', pers = np.array([]), outpfx="2_sta_in_",cut=0.8):
		""" Write the FTAN measurements as input files for 2-station tomography
		Parameters: outdir -- directory for the generated input files
		            outpfx -- prefix for the name of generated files
		            cut    -- cut those lines in input file whose velocity differ more than this threshold from the mean or median
		"""
		if not os.path.isdir(outdir):
			os.makedirs(outdir)
		if pers.size == 0:
			pers=np.append( np.arange(18.)*2.+6.)
		staLst = self.waveforms.list()
		ph_f_lst = []
		gr_f_lst = []
		for prd in pers:
			ph_name = outdir+"/"+outpfx+"%g"%(prd)+"_"+channel+"_ph.lst_tmp"
			gr_name = outdir+"/"+outpfx+"%g"%(prd)+"_"+channel+"_gr.lst_tmp"
			ph_f = open(ph_name, 'w')
			gr_f = open(gr_name, 'w')
			ph_f_lst.append(ph_f)
			gr_f_lst.append(gr_f)
			# to be written on the input file
			# N(id), lat, lon1, lat2, lon2, pvel(gvel), 1.(weight), (strings doesn't really affect to result)staid1, staid2, 1, 1
		i = -1
		for staid1 in staLst:
			netcode1, stacode1 = staid1.split('.')
			lat1 = self.waveforms[staid1].coordinates['latitude']
			lon1 = self.waveforms[staid1].coordinates['longitude']
			if lon1<0: lon1 += 360.
			for staid2 in staLst:
				if staid1 >= staid2: continue;
				netcode2, stacode2 = staid2.split('.')
				i = i + 1
				try:
					itp_arr = self.auxiliary_data['DISPinterp'][netcode1][stacode1][netcode2][stacode2].data.value
				except:
					continue
				if itp_arr.shape[1] == 0: continue	
				lat2 = self.waveforms[staid2].coordinates['latitude']
				lon2 = self.waveforms[staid2].coordinates['longitude']
				dist, az, baz = obspy.geodetics.gps2dist_azimuth(lat1, lon1, lat2, lon2) # distance is in m
				dist = dist/1000.
				if lon2 < 0: lon2 += 360.
				for iper in range(pers.size):
					per = pers[iper]
					ind_per = np.where(itp_arr[0,:] == per)[0]
					if not ind_per.size == 1:
						continue
					pvel = itp_arr[2,ind_per][0]
					gvel = itp_arr[1,ind_per][0]
					# if dist < 2.*per*pvel: continue
					# inbound=data[index['inbound']][ind_per]
					# quality control
					if pvel < 0 or gvel < 0 or pvel > 5 or gvel > 5: continue
					if max(np.isnan([pvel, gvel])) != False: continue # skip if parameters in dispersion curve is nan
					# if inbound != 1.: continue
					fph = ph_f_lst[iper]
					fgr = gr_f_lst[iper]
					fph.writelines("%d %g %g %g %g %g 1. %s %s 1 1 \n" %(i, lat1, lon1, lat2, lon2, pvel, staid1, staid2))
					fgr.writelines("%d %g %g %g %g %g 1. %s %s 1 1 \n" %(i, lat1, lon1, lat2, lon2, gvel, staid1, staid2))
		for iper in range(pers.size):
			fph = ph_f_lst[iper]
			fgr = gr_f_lst[iper]
			fph.close()
			fgr.close()
		from itertools import compress
		for prd in pers:
			in_ph_name = outdir+"/"+outpfx+"%g"%(prd)+"_"+channel+"_ph.lst_tmp"
			in_gr_name = outdir+"/"+outpfx+"%g"%(prd)+"_"+channel+"_gr.lst_tmp"
			fn_ph_name = outdir+"/"+outpfx+"%g"%(prd)+"_"+channel+"_ph.lst"
			fn_gr_name = outdir+"/"+outpfx+"%g"%(prd)+"_"+channel+"_gr.lst"
			with open(in_ph_name) as ph_f_tmp:
				ph_V_arr = np.array(ph_f_tmp.read().split()[5::11],dtype='f')
			ph_ind = np.logical_and(abs(ph_V_arr-ph_V_arr.mean()) < cut, abs(ph_V_arr-np.median(ph_V_arr)) < cut)
			ph_lines_in = [line.rstrip('\n') for line in open(in_ph_name)]
			ph_lines_fn = list(compress(ph_lines_in, ph_ind))
			ph_fn = open (fn_ph_name,'w')
			ph_fn.writelines('%s\n'%item for item in ph_lines_fn)
			ph_fn.close()
			with open(in_gr_name) as gr_f_tmp:
				gr_V_arr = np.array(gr_f_tmp.read().split()[5::11],dtype='f')
			gr_ind = np.logical_and(abs(gr_V_arr-gr_V_arr.mean()) < cut, abs(gr_V_arr-np.median(gr_V_arr)) < cut)
			gr_lines_in = [line.rstrip('\n') for line in open(in_gr_name)]
			gr_lines_fn = list(compress(gr_lines_in, gr_ind))
			gr_fn = open (fn_gr_name,'w')
			gr_fn.writelines('%s\n'%item for item in gr_lines_fn)
			gr_fn.close()
		print('End of Generating Misha Tomography Input File!')
		return
Ejemplo n.º 16
0
class CreatePolygon:
    def __init__(self):
        self.circle_list = []
        self.x0 = None
        self.y0 = None
        self.fig = plt.figure()
        self.ax = plt.subplot(111)
        self.ax.set_xlim(-20, 20)
        self.ax.set_ylim(-20, 20)
        self.ax.set_title(
            'Bot' + u'ó' +
            'n izq. para poner puntos y moverlos, derecho para cerrar y terminar'
        )
        self.cidpress = self.fig.canvas.mpl_connect('button_press_event',
                                                    self.on_press)
        self.cidrelease = self.fig.canvas.mpl_connect('button_release_event',
                                                      self.on_release)
        self.cidmove = self.fig.canvas.mpl_connect('motion_notify_event',
                                                   self.on_move)
        self.press_event = None
        self.current_circle = None
        self.points = None
        self.poly = None

    def on_press(self, event):
        if event.button == 3:  #Cerrar y desconectar
            self.poly.set_closed(True)
            self.fig.canvas.draw()
            self.fig.canvas.mpl_disconnect(self.cidpress)
            self.fig.canvas.mpl_disconnect(self.cidrelease)
            self.fig.canvas.mpl_disconnect(self.cidmove)
            self.points = [list(circle.center) for circle in self.circle_list]
            return self.points

        x0, y0 = int(event.xdata), int(event.ydata)
        #Buscamos el circulo que contiene al evento, si lo hay
        for circle in self.circle_list:
            contains, attr = circle.contains(event)
            if contains:
                self.press_event = event
                self.current_circle = circle
                self.x0, self.y0 = self.current_circle.center
                return
        #Si no hemos encontrado ningun circulo:
        c = Circle((x0, y0), 0.5)
        self.ax.add_patch(c)
        self.circle_list.append(c)
        self.current_circle = None
        num_circles = len(self.circle_list)
        if num_circles == 1:
            self.fig.canvas.draw()
        else:
            self.points = [list(circle.center) for circle in self.circle_list]
            if self.poly == None:
                self.poly = Polygon(np.array(self.points),
                                    fill=False,
                                    closed=False)
                self.ax.add_patch(self.poly)
                self.fig.canvas.draw()
            else:
                self.poly.set_xy(np.array(self.points))
                print self.poly.get_xy()
                self.fig.canvas.draw()
                self.fig.canvas.draw()

    def on_release(self, event):
        self.press_event = None
        self.current_circle = None

    def on_move(self, event):
        if (self.press_event is None or event.inaxes != self.press_event.inaxes
                or self.current_circle == None):
            return
        dx = event.xdata - self.press_event.xdata
        dy = event.ydata - self.press_event.ydata
        self.current_circle.center = int(self.x0 + dx), int(self.y0 + dy)
        self.points = [list(circle.center) for circle in self.circle_list]
        self.poly.set_xy(np.array(self.points))
        self.fig.canvas.draw()
Ejemplo n.º 17
0
 def plot_cn_map(self, *data, title=None):
     lat_min = 0
     lat_max = 60
     lon_min = 80
     lon_max = 150
     font = FontProperties(size=20)
     tw = [
         '基隆市',
         '台北市',
         '桃园县',
         '宜兰县',
         '新竹县',
         '苗栗县',
         '台中县',
         '莲花县',
         '金门县',
         '南投县',
         '台中市',
         '彰化县',
         '云林县',
         '嘉义县',
         '台东县',
         '凤山县',
         '诏安县',
         '台南县',
         '南澳县',
         '台南市',
         '屏东县',
         '高雄市',
         '台北县',
     ]
     special = {
         '大庸市': '张家界市',
         '株州市': '株洲市',
         '浑江市': '白城市',
         '巢湖市': '合肥市',
         '莱芜市': '济南市',
         '崇明県': '上海市',
         '丽江纳西族自治县': '丽江市',
         '达川市': '达州市',
         '巴州': '巴音郭楞蒙古自治州',
         '叶鲁番市': '吐鲁番地区',
         '阿勒泰市': '伊犁州',
         '烏海市': '乌海市',
         '沙湾县': '塔城市'
     }
     fig, ax = plt.subplots()
     fig.set_size_inches(20, 16)
     cn_map = Basemap(projection='lcc',
                      width=5000000,
                      height=5000000,
                      lat_0=36,
                      lon_0=102,
                      llcrnrlon=lon_min,
                      llcrnrlat=lat_min,
                      urcrnrlon=lon_max,
                      urcrnrlat=lat_max,
                      resolution='i',
                      ax=ax)
     cn_map.readshapefile('City/CN_city',
                          'cities',
                          drawbounds=True,
                          antialiased=3)
     for info, shape in zip(cn_map.cities_info, cn_map.cities):
         city = info['NAME'].strip('\x00')
         color = '#ffffff'
         if city in tw:
             color = '#ff7b69'
         else:
             if city in special.keys():
                 city = special[city]
             for p_key in data[0].keys():
                 if p_key == city:
                     color = self.coloring(data[0][p_key])
                     break
                 elif isinstance(data[0][p_key], dict):
                     if '市' in city:
                         search = china_region.search(city=city)
                     elif '县' or '区' in city:
                         search = china_region.search(county=city)
                     else:
                         search = china_region.search(city=city)
                     sc = ''
                     if len(search) > 0:
                         sc = search['city']
                     for c in data[0][p_key].keys():
                         if sc == '巴音郭楞蒙古自治州':
                             sc = '巴州'
                         if ((c in city) or (c in sc)) and len(c) > 0:
                             color = self.coloring(data[0][p_key][c])
                             break
                 if color != '#ffffff':
                     break
         if color == '#ffffff':
             # print('Not match', city)
             color = '#f0f0f0'
         poly = Polygon(shape, facecolor=color, edgecolor=color, label=city)
         ax.add_patch(poly)
         coord_list = poly.get_xy(
         )  # print out city names on map for debugging propores
         # print(len(coord_list),city)
         # mid = len(coord_list)//2
         # x = (coord_list[0][0] + coord_list[mid][0])/2
         # y = (coord_list[0][1] + coord_list[mid][1])/2
         # ax.text(s=city, x=x, y=y, fontsize=8)
         #add water maker
     cn_map.drawcoastlines(
         color='black',
         linewidth=0.4,
     )
     cn_map.drawparallels(np.arange(lat_min, lat_max, 10),
                          labels=[1, 0, 0, 1])
     cn_map.drawmeridians(np.arange(lon_min, lon_max, 10),
                          labels=[0, 0, 0, 1])
     ax.legend(self.handles,
               self.legend_labels,
               bbox_to_anchor=(0.5, -0.11),
               loc='lower center',
               ncol=5,
               prop={'size': 16})
     plt.title(title + '\n(Update: ' + self.update_time + ")",
               fontproperties=font)
     plt.text(1.2,
              0.5,
              'By Ruiyan Ma',
              fontsize=20,
              color='gray',
              va='bottom',
              alpha=0.5)
     plt.show()
     fig.savefig(title + '.PNG')
     plt.close()
     return