예제 #1
0
def visualize_hexarray(hexarray, title, colormap=None, visualize_axes=True, show_hexarray=False):
	axes_border_width = 0.05


	hexarray_width  = hexarray.shape[1]
	hexarray_height = hexarray.shape[0]

	hexarray_x = [w * math.sqrt(3) if not h % 2 else math.sqrt(3) / 2 + w * math.sqrt(3) for h in range(hexarray_height) for w in range(hexarray_width)]
	hexarray_y = [1.5 * h for h in range(hexarray_height) for w in range(hexarray_width)]

	if colormap is not None:
		hexarray = plt.cm.get_cmap(colormap)(hexarray)

	hexarray_colors = np.reshape(hexarray, newshape = (hexarray.shape[0] * hexarray.shape[1], hexarray.shape[2]))


	ax = plt.subplot(aspect='equal')

	ax.axis(
		(min(hexarray_x) - math.sqrt(3) / 2,
		 max(hexarray_x) + math.sqrt(3) / 2,
		 min(hexarray_y) - 1,
		 max(hexarray_y) + 1)
	)

	ax.invert_yaxis()


	patches_list = [RegularPolygon((x, y), numVertices=6, radius=1, color=c) for x, y, c in zip(hexarray_x, hexarray_y, hexarray_colors)]


	patch_collection = PatchCollection(patches_list, match_original=True)
	patch_collection = ax.add_collection(patch_collection)
	ax.axis('off')

	hexarray_fig = f'{title}_hex'
	plt.savefig(f'{hexarray_fig}.png', bbox_inches='tight')
	plt.savefig(f'{hexarray_fig}.pdf', bbox_inches='tight')

	if show_hexarray and not visualize_axes:
		plt.show()


	if visualize_axes:
		patch_collection.remove()


		axes_border_width_x = axes_border_width * hexarray_width
		axes_border_width_y = axes_border_width * hexarray_height

		ax.axis(
			(min(hexarray_x) - (math.sqrt(3) / 2 + axes_border_width_x),
			 max(hexarray_x) + (math.sqrt(3) / 2 + axes_border_width_x),
			 min(hexarray_y) - (1 + axes_border_width_y),
			 max(hexarray_y) + (1 + axes_border_width_y))
		)

		ax.invert_yaxis()


		patch_collection = PatchCollection(patches_list, match_original=True, edgecolor='black')
		ax.add_collection(patch_collection)
		ax.axis('on')

		hexarray_fig = f'{hexarray_fig}_with_axes'
		plt.savefig(f'{hexarray_fig}.png', bbox_inches='tight')
		plt.savefig(f'{hexarray_fig}.pdf', bbox_inches='tight')

		if show_hexarray:
			plt.show()


	plt.close()
class COCO_dataset_generator(object):
    def __init__(self, fig, ax, args):

        self.ax = ax
        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        self.img_dir = args['image_dir']
        self.index = 0
        self.fig = fig
        self.polys = []
        self.zoom_scale, self.points, self.prev, self.submit_p, self.lines, self.circles = 1.2, [], None, None, [], []

        self.zoom_id = fig.canvas.mpl_connect('scroll_event', self.zoom)
        self.click_id = fig.canvas.mpl_connect('button_press_event',
                                               self.onclick)
        self.clickrel_id = fig.canvas.mpl_connect('button_release_event',
                                                  self.onclick_release)
        self.keyboard_id = fig.canvas.mpl_connect('key_press_event',
                                                  self.onkeyboard)

        self.axradio = plt.axes([0.0, 0.0, 0.2, 1])
        self.axbringprev = plt.axes([0.3, 0.05, 0.17, 0.05])
        self.axreset = plt.axes([0.48, 0.05, 0.1, 0.05])
        self.axsubmit = plt.axes([0.59, 0.05, 0.1, 0.05])
        self.axprev = plt.axes([0.7, 0.05, 0.1, 0.05])
        self.axnext = plt.axes([0.81, 0.05, 0.1, 0.05])
        self.b_bringprev = Button(self.axbringprev,
                                  'Bring Previous Annotations')
        self.b_bringprev.on_clicked(self.bring_prev)
        self.b_reset = Button(self.axreset, 'Reset')
        self.b_reset.on_clicked(self.reset)
        self.b_submit = Button(self.axsubmit, 'Submit')
        self.b_submit.on_clicked(self.submit)
        self.b_next = Button(self.axnext, 'Next')
        self.b_next.on_clicked(self.next)
        self.b_prev = Button(self.axprev, 'Prev')
        self.b_prev.on_clicked(self.previous)

        self.button_axes = [
            self.axbringprev, self.axreset, self.axsubmit, self.axprev,
            self.axnext, self.axradio
        ]

        self.existing_polys = []
        self.existing_patches = []
        self.selected_poly = False
        self.objects = []
        self.feedback = args['feedback']

        self.right_click = False

        self.text = ''

        with open(args['class_file'], 'r') as f:
            self.class_names = [x.strip() for x in f.readlines()]

        self.radio = RadioButtons(self.axradio, self.class_names)
        self.class_names = ('BG', ) + tuple(self.class_names)

        self.img_paths = sorted(glob.glob(os.path.join(self.img_dir, '*.jpg')))
        self.img_paths = self.img_paths + sorted(
            glob.glob(os.path.join(self.img_dir, '*.png')))

        while (1):
            if os.path.exists(self.img_paths[self.index][:-3] + 'txt'):
                self.index = self.index + 1
                continue
            else:
                break

        self.checkpoint = self.index
        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        image = plt.imread(self.img_paths[self.index])

        if args['feedback']:

            sys.path.append(args['maskrcnn_dir'])
            from config import Config
            import model as modellib
            from demo import BagsConfig
            from skimage.measure import find_contours
            from visualize_cv2 import random_colors

            config = BagsConfig()

            # Create model object in inference mode.
            model = modellib.MaskRCNN(
                mode="inference",
                model_dir='/'.join(args['weights_path'].split('/')[:-2]),
                config=config)

            # Load weights trained on MS-COCO
            model.load_weights(args['weights_path'], by_name=True)

            r = model.detect([image], verbose=0)[0]

            # Number of instances
            N = r['rois'].shape[0]

            masks = r['masks']

            # Generate random colors
            colors = random_colors(N)

            # Show area outside image boundaries.
            height, width = image.shape[:2]

            class_ids, scores = r['class_ids'], r['scores']

            for i in range(N):
                color = colors[i]

                # Label
                class_id = class_ids[i]
                score = scores[i] if scores is not None else None
                label = self.class_names[class_id]

                # Mask
                mask = masks[:, :, i]

                # Mask Polygon
                # Pad to ensure proper polygons for masks that touch image edges.
                padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2),
                                       dtype=np.uint8)
                padded_mask[1:-1, 1:-1] = mask
                contours = find_contours(padded_mask, 0.5)
                for verts in contours:
                    # Subtract the padding and flip (y, x) to (x, y)

                    verts = np.fliplr(verts) - 1
                    pat = PatchCollection([Polygon(verts, closed=True)],
                                          facecolor='green',
                                          linewidths=0,
                                          alpha=0.6)
                    self.ax.add_collection(pat)
                    self.objects.append(label)
                    self.existing_patches.append(pat)
                    self.existing_polys.append(
                        Polygon(verts,
                                closed=True,
                                alpha=0.25,
                                facecolor='red'))

        self.ax.imshow(image, aspect='auto')
        print("file name : {}".format(self.img_paths[self.index]))
        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def bring_prev(self, event):

        if not self.feedback:

            poly_verts, self.objects = return_info(self.img_paths[self.index -
                                                                  1][:-3] +
                                                   'txt')

            for num in poly_verts:
                self.existing_polys.append(
                    Polygon(num, closed=True, alpha=0.5, facecolor='red'))

                pat = PatchCollection([Polygon(num, closed=True)],
                                      facecolor='green',
                                      linewidths=0,
                                      alpha=0.6)
                self.ax.add_collection(pat)
                self.existing_patches.append(pat)

    def points_to_polygon(self):
        return np.reshape(np.array(self.points),
                          (int(len(self.points) / 2), 2))

    def deactivate_all(self):
        self.fig.canvas.mpl_disconnect(self.zoom_id)
        self.fig.canvas.mpl_disconnect(self.click_id)
        self.fig.canvas.mpl_disconnect(self.clickrel_id)
        self.fig.canvas.mpl_disconnect(self.keyboard_id)

    def onkeyboard(self, event):

        if not event.inaxes:
            return
        elif event.key == 'a':

            if self.selected_poly:
                self.points = self.interactor.get_polygon().xy.flatten()
                self.interactor.deactivate()
                self.right_click = True
                self.selected_poly = False
                self.fig.canvas.mpl_connect(self.click_id, self.onclick)
                self.polygon.color = (0, 255, 0)
                self.fig.canvas.draw()
            else:
                for i, poly in enumerate(self.existing_polys):

                    if poly.get_path().contains_point(
                        (event.xdata, event.ydata)):

                        self.radio.set_active(
                            self.class_names.index(self.objects[i]) - 1)
                        self.polygon = self.existing_polys[i]
                        self.existing_patches[i].set_visible(False)
                        self.fig.canvas.mpl_disconnect(self.click_id)
                        self.ax.add_patch(self.polygon)
                        self.fig.canvas.draw()
                        self.interactor = PolygonInteractor(
                            self.ax, self.polygon)
                        self.selected_poly = True
                        self.existing_polys.pop(i)
                        break

        elif event.key == 'r':

            for i, poly in enumerate(self.existing_polys):
                if poly.get_path().contains_point((event.xdata, event.ydata)):
                    self.existing_patches[i].set_visible(False)
                    self.existing_patches[i].remove()
                    self.existing_patches.pop(i)
                    self.existing_polys.pop(i)
                    break
        self.fig.canvas.draw()

    def next(self, event):

        if len(self.text.split('\n')) > 5:

            print(self.img_paths[self.index][:-3] + 'txt')

            with open(self.img_paths[self.index][:-3] + 'txt',
                      "w") as text_file:
                text_file.write(self.text)

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        if (self.index < len(self.img_paths) - 1):
            self.index += 1
        else:
            exit()

        while (1):
            if os.path.exists(self.img_paths[self.index][:-3] + 'txt'):
                self.index += 1
                continue

            if (self.index < len(self.img_paths) - 1):
                # self.index += 1
                break
            else:
                print("file is end.")
                exit()

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')
        print("file name : {}".format(self.img_paths[self.index]))
        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def reset_all(self):

        self.polys = []
        self.text = ''
        self.points, self.prev, self.submit_p, self.lines, self.circles = [], None, None, [], []

    def previous(self, event):

        if (self.index > self.checkpoint):
            self.index -= 1
        #print (self.img_paths[self.index][:-3]+'txt')
        os.remove(self.img_paths[self.index][:-3] + 'txt')

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def onclick(self, event):

        if not event.inaxes:
            return
        if not any([x.in_axes(event) for x in self.button_axes]):
            if event.button == 1:
                self.points.extend([event.xdata, event.ydata])
                #print (event.xdata, event.ydata)

                circle = plt.Circle((event.xdata, event.ydata),
                                    2.5,
                                    color='black')
                self.ax.add_artist(circle)
                self.circles.append(circle)

                if (len(self.points) < 4):
                    self.r_x = event.xdata
                    self.r_y = event.ydata
            else:
                if len(self.points) > 5:
                    self.right_click = True
                    self.fig.canvas.mpl_disconnect(self.click_id)
                    self.click_id = None
                    self.points.extend([self.points[0], self.points[1]])
                    #self.prev.remove()

            if (len(self.points) > 2):
                line = self.ax.plot([self.points[-4], self.points[-2]],
                                    [self.points[-3], self.points[-1]], 'b--')
                self.lines.append(line)

            self.fig.canvas.draw()

            if len(self.points) > 4:
                if self.prev:
                    self.prev.remove()
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.prev = self.p

                self.fig.canvas.draw()

            #if len(self.points)>4:
            #    print 'AREA OF POLYGON: ', self.find_poly_area(self.points)
            #print event.x, event.y

    def find_poly_area(self):
        coords = self.points_to_polygon()
        x, y = coords[:, 0], coords[:, 1]
        return (0.5 *
                np.abs(np.dot(x, np.roll(y, 1)) -
                       np.dot(y, np.roll(x, 1)))) / 2  #shoelace algorithm

    def onclick_release(self, event):

        if any([x.in_axes(event)
                for x in self.button_axes]) or self.selected_poly:
            return

        elif self.r_x and np.abs(event.xdata - self.r_x) > 10 and np.abs(
                event.ydata -
                self.r_y) > 10:  # 10 pixels limit for rectangle creation
            if len(self.points) < 4:

                self.right_click = True
                self.fig.canvas.mpl_disconnect(self.click_id)
                self.click_id = None
                bbox = [
                    np.min([event.xdata, self.r_x]),
                    np.min([event.ydata, self.r_y]),
                    np.max([event.xdata, self.r_x]),
                    np.max([event.ydata, self.r_y])
                ]
                self.r_x = self.r_y = None

                self.points = [
                    bbox[0], bbox[1], bbox[0], bbox[3], bbox[2], bbox[3],
                    bbox[2], bbox[1], bbox[0], bbox[1]
                ]
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.fig.canvas.draw()

    def zoom(self, event):

        if not event.inaxes:
            return
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        xdata = event.xdata  # get event x location
        ydata = event.ydata  # get event y location

        if event.button == 'down':
            # deal with zoom in
            scale_factor = 1 / self.zoom_scale
        elif event.button == 'up':
            # deal with zoom out
            scale_factor = self.zoom_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print(event.button)

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim(
            [xdata - new_width * (1 - relx), xdata + new_width * (relx)])
        self.ax.set_ylim(
            [ydata - new_height * (1 - rely), ydata + new_height * (rely)])
        self.ax.figure.canvas.draw()

    def reset(self, event):

        if not self.click_id:
            self.click_id = fig.canvas.mpl_connect('button_press_event',
                                                   self.onclick)
        #print (len(self.lines))
        #print (len(self.circles))
        if len(self.points) > 5:
            for line in self.lines:
                line.pop(0).remove()
            for circle in self.circles:
                circle.remove()
            self.lines, self.circles = [], []
            self.p.remove()
            self.prev = self.p = None
            self.points = []
        #print (len(self.lines))
        #print (len(self.circles))

    def print_points(self):

        ret = ''
        for x in self.points:
            ret += '%.2f' % x + ' '
        return ret

    def submit(self, event):

        if not self.right_click:
            print('Right click before submit is a must!!')
        else:

            self.text += self.radio.value_selected + '\n' + '%.2f' % self.find_poly_area(
            ) + '\n' + self.print_points() + '\n\n'
            self.right_click = False
            #print (self.points)

            self.lines, self.circles = [], []
            self.click_id = fig.canvas.mpl_connect('button_press_event',
                                                   self.onclick)

            self.polys.append(
                Polygon(self.points_to_polygon(),
                        closed=True,
                        color=np.random.rand(3),
                        alpha=0.4,
                        fill=True))
            if self.submit_p:
                self.submit_p.remove()
            self.submit_p = PatchCollection(self.polys,
                                            cmap=matplotlib.cm.jet,
                                            alpha=0.4)
            self.ax.add_collection(self.submit_p)
            self.points = []
예제 #3
0
countries = [r['NAME_NL'] for r in map.landen_info]
countries = list(dict.fromkeys(countries))


#iterate through list of countries and make separate basemaps and write them as png
for i in countries:
    
    fig = plt.figure(figsize=(16,8))
    ax  = fig.add_subplot(111)
    
    map = Basemap()

    sh = map.readshapefile("ne_110m_admin_0_countries/ne_110m_admin_0_countries","landen")
    map.fillcontinents(color='lightgrey', alpha=0.5, lake_color='white')
    
    patches = []
    for info, shape in zip(map.landen_info, map.landen):
        if info['NAME_NL'] == str(i):
            patches.append(Polygon(np.array(shape), True))
            print(str(i), patches)
        patchcol = PatchCollection(patches, facecolor= 'g', edgecolor='k', linewidths=1., zorder=2)
        ax.add_collection(patchcol)

         
    naam = f'{i}{".png"}'

    fig.savefig(naam, dpi=100)
    patchcol.remove()
    fig.canvas.draw_idle()
    fig.clf()
예제 #4
0
class COCO_dataset_generator(object):
    def __init__(self, fig, ax, args):

        self.ax = ax
        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        self.img_dir = args['image_dir']
        self.index = 0
        self.fig = fig
        self.polys = []
        self.zoom_scale, self.points, self.prev, self.submit_p, self.lines, self.circles = 1.2, [], None, None, [], []

        self.zoom_id = fig.canvas.mpl_connect('scroll_event', self.zoom)
        self.click_id = fig.canvas.mpl_connect('button_press_event',
                                               self.onclick)
        self.clickrel_id = fig.canvas.mpl_connect('button_release_event',
                                                  self.onclick_release)
        self.keyboard_id = fig.canvas.mpl_connect('key_press_event',
                                                  self.onkeyboard)

        self.axradio = plt.axes([0.0, 0.0, 0.2, 1])
        self.axbringprev = plt.axes([0.3, 0.05, 0.17, 0.05])
        self.axreset = plt.axes([0.48, 0.05, 0.1, 0.05])
        self.axsubmit = plt.axes([0.59, 0.05, 0.1, 0.05])
        self.axprev = plt.axes([0.7, 0.05, 0.1, 0.05])
        self.axnext = plt.axes([0.81, 0.05, 0.1, 0.05])
        self.b_bringprev = Button(self.axbringprev,
                                  'Bring Previous Annotations')
        self.b_bringprev.on_clicked(self.bring_prev)
        self.b_reset = Button(self.axreset, 'Reset')
        self.b_reset.on_clicked(self.reset)
        self.b_submit = Button(self.axsubmit, 'Submit')
        self.b_submit.on_clicked(self.submit)
        self.b_next = Button(self.axnext, 'Next')
        self.b_next.on_clicked(self.next)
        self.b_prev = Button(self.axprev, 'Prev')
        self.b_prev.on_clicked(self.previous)

        self.button_axes = [
            self.axbringprev, self.axreset, self.axsubmit, self.axprev,
            self.axnext, self.axradio
        ]

        self.existing_polys = []
        self.existing_patches = []
        self.selected_poly = False
        self.objects = []
        self.feedback = True  #args['feedback']

        self.right_click = False

        self.text = ''

        with open(args['class_file'], 'r') as f:
            self.class_names = [x.strip() for x in f.readlines()]

        self.radio = RadioButtons(self.axradio, self.class_names)
        self.class_names = ('BG', ) + tuple(self.class_names)

        self.img_paths = sorted(glob.glob(os.path.join(self.img_dir, '*.jpg')))

        if len(self.img_paths) == 0:
            self.img_paths = sorted(
                glob.glob(os.path.join(self.img_dir, '*.png')))
        if os.path.exists(self.img_paths[self.index][:-3] + 'txt'):
            self.index = len(glob.glob(os.path.join(self.img_dir,
                                                    '*.txt'))) - 1
        self.checkpoint = self.index

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')
        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def bring_prev(self, event):

        if not self.feedback:

            poly_verts, self.objects = return_info(self.img_paths[self.index -
                                                                  1][:-3] +
                                                   'txt')

            for num in poly_verts:
                self.existing_polys.append(
                    Polygon(num, closed=True, alpha=0.5, facecolor='red'))

                pat = PatchCollection([Polygon(num, closed=True)],
                                      facecolor='green',
                                      linewidths=0,
                                      alpha=0.6)
                self.ax.add_collection(pat)
                self.existing_patches.append(pat)

    def points_to_polygon(self):
        return np.reshape(np.array(self.points),
                          (int(len(self.points) / 2), 2))

    def deactivate_all(self):
        self.fig.canvas.mpl_disconnect(self.zoom_id)
        self.fig.canvas.mpl_disconnect(self.click_id)
        self.fig.canvas.mpl_disconnect(self.clickrel_id)
        self.fig.canvas.mpl_disconnect(self.keyboard_id)

    def onkeyboard(self, event):

        if not event.inaxes:
            return
        elif event.key == 'a':

            if self.selected_poly:
                self.points = self.interactor.get_polygon().xy.flatten()
                self.interactor.deactivate()
                self.right_click = True
                self.selected_poly = False
                self.fig.canvas.mpl_connect(self.click_id, self.onclick)
                self.polygon.color = (0, 255, 0)
                self.fig.canvas.draw()
            else:
                for i, poly in enumerate(self.existing_polys):

                    if poly.get_path().contains_point(
                        (event.xdata, event.ydata)):

                        self.radio.set_active(
                            self.class_names.index(self.objects[i]) - 1)
                        self.polygon = self.existing_polys[i]
                        self.existing_patches[i].set_visible(False)
                        self.fig.canvas.mpl_disconnect(self.click_id)
                        self.ax.add_patch(self.polygon)
                        self.fig.canvas.draw()
                        self.interactor = PolygonInteractor(
                            self.ax, self.polygon)
                        self.selected_poly = True
                        self.existing_polys.pop(i)
                        break

        elif event.key == 'r':

            for i, poly in enumerate(self.existing_polys):
                if poly.get_path().contains_point((event.xdata, event.ydata)):
                    self.existing_patches[i].set_visible(False)
                    self.existing_patches[i].remove()
                    self.existing_patches.pop(i)
                    self.existing_polys.pop(i)
                    break
        self.fig.canvas.draw()

    def next(self, event):

        if len(self.text.split('\n')) > 5:

            print(self.img_paths[self.index][:-3] + 'txt')

            with open(self.img_paths[self.index][:-3] + 'txt',
                      "w") as text_file:
                text_file.write(self.text)

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        if (self.index < len(self.img_paths) - 1):
            self.index += 1
        else:
            print('all image labeled!, please close current window')
            self.deactivate_all()
            #exit()

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def reset_all(self):

        self.polys = []
        self.text = ''
        self.points, self.prev, self.submit_p, self.lines, self.circles = [], None, None, [], []

    def previous(self, event):

        if (self.index > self.checkpoint):
            self.index -= 1
        #print (self.img_paths[self.index][:-3]+'txt')
        os.remove(self.img_paths[self.index][:-3] + 'txt')

        self.ax.clear()

        self.ax.set_yticklabels([])
        self.ax.set_xticklabels([])

        image = plt.imread(self.img_paths[self.index])
        self.ax.imshow(image, aspect='auto')

        im = Image.open(self.img_paths[self.index])
        width, height = im.size
        im.close()

        self.reset_all()

        self.text += str(self.index) + '\n' + os.path.abspath(self.img_paths[
            self.index]) + '\n' + str(width) + ' ' + str(height) + '\n\n'

    def onclick(self, event):

        if not event.inaxes:
            return
        if not any([x.in_axes(event) for x in self.button_axes]):
            if event.button == 1:
                self.points.extend([event.xdata, event.ydata])
                #print (event.xdata, event.ydata)

                circle = plt.Circle((event.xdata, event.ydata),
                                    2.5,
                                    color='black')
                self.ax.add_artist(circle)
                self.circles.append(circle)

                if (len(self.points) < 4):
                    self.r_x = event.xdata
                    self.r_y = event.ydata
            else:
                if len(self.points) > 5:
                    self.right_click = True
                    self.fig.canvas.mpl_disconnect(self.click_id)
                    self.click_id = None
                    self.points.extend([self.points[0], self.points[1]])
                    #self.prev.remove()

            if (len(self.points) > 2):
                line = self.ax.plot([self.points[-4], self.points[-2]],
                                    [self.points[-3], self.points[-1]], 'b--')
                self.lines.append(line)

            self.fig.canvas.draw()

            if len(self.points) > 4:
                if self.prev:
                    self.prev.remove()
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.prev = self.p

                self.fig.canvas.draw()

            #if len(self.points)>4:
            #    print 'AREA OF POLYGON: ', self.find_poly_area(self.points)
            #print event.x, event.y

    def find_poly_area(self):
        coords = self.points_to_polygon()
        x, y = coords[:, 0], coords[:, 1]
        return (0.5 *
                np.abs(np.dot(x, np.roll(y, 1)) -
                       np.dot(y, np.roll(x, 1)))) / 2  #shoelace algorithm

    def onclick_release(self, event):

        if any([x.in_axes(event)
                for x in self.button_axes]) or self.selected_poly:
            return

        elif self.r_x and np.abs(event.xdata - self.r_x) > 10 and np.abs(
                event.ydata -
                self.r_y) > 10:  # 10 pixels limit for rectangle creation
            if len(self.points) < 4:

                self.right_click = True
                self.fig.canvas.mpl_disconnect(self.click_id)
                self.click_id = None
                bbox = [
                    np.min([event.xdata, self.r_x]),
                    np.min([event.ydata, self.r_y]),
                    np.max([event.xdata, self.r_x]),
                    np.max([event.ydata, self.r_y])
                ]
                self.r_x = self.r_y = None

                self.points = [
                    bbox[0], bbox[1], bbox[0], bbox[3], bbox[2], bbox[3],
                    bbox[2], bbox[1], bbox[0], bbox[1]
                ]
                self.p = PatchCollection(
                    [Polygon(self.points_to_polygon(), closed=True)],
                    facecolor='red',
                    linewidths=0,
                    alpha=0.4)
                self.ax.add_collection(self.p)
                self.fig.canvas.draw()

    def zoom(self, event):

        if not event.inaxes:
            return
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()

        xdata = event.xdata  # get event x location
        ydata = event.ydata  # get event y location

        if event.button == 'down':
            # deal with zoom in
            scale_factor = 1 / self.zoom_scale
        elif event.button == 'up':
            # deal with zoom out
            scale_factor = self.zoom_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print(event.button)

        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

        relx = (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])
        rely = (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])

        self.ax.set_xlim(
            [xdata - new_width * (1 - relx), xdata + new_width * (relx)])
        self.ax.set_ylim(
            [ydata - new_height * (1 - rely), ydata + new_height * (rely)])
        self.ax.figure.canvas.draw()

    def reset(self, event):

        if not self.click_id:
            self.click_id = self.fig.canvas.mpl_connect(
                'button_press_event', self.onclick)
        #print (len(self.lines))
        #print (len(self.circles))
        if len(self.points) > 5:
            for line in self.lines:
                line.pop(0).remove()
            for circle in self.circles:
                circle.remove()
            self.lines, self.circles = [], []
            self.p.remove()
            self.prev = self.p = None
            self.points = []
        #print (len(self.lines))
        #print (len(self.circles))

    def print_points(self):

        ret = ''
        for x in self.points:
            ret += '%.2f' % x + ' '
        return ret

    def submit(self, event):
        if not self.right_click:
            print('Right click before submit is a must!!')
        else:

            self.text += self.radio.value_selected + '\n' + '%.2f' % self.find_poly_area(
            ) + '\n' + self.print_points() + '\n\n'
            self.right_click = False
            #print (self.points)

            self.lines, self.circles = [], []
            self.click_id = self.fig.canvas.mpl_connect(
                'button_press_event', self.onclick)

            self.polys.append(
                Polygon(self.points_to_polygon(),
                        closed=True,
                        color=np.random.rand(3),
                        alpha=0.4,
                        fill=True))
            if self.submit_p:
                self.submit_p.remove()
            self.submit_p = PatchCollection(self.polys,
                                            cmap=matplotlib.cm.jet,
                                            alpha=0.4)
            self.ax.add_collection(self.submit_p)
            self.points = []
예제 #5
0
class CircleRegionManager():
    def __init__(self):
        self.regions = [CircleRegion(self, 0.5, 0.5, 0.1, 0, None)]
        self.set_index(0)
        self.p = None

    def render(self, ax):
        patches = []

        for i in range(len(self.current)):
            patches.append(self.current[i].render(self.index))

        self.cleanup()
        if len(patches) > 0:
            self.p = PatchCollection(patches, alpha=0.4, match_original=True)
            ax.add_collection(self.p)

    def cleanup(self):
        if self.p is not None:
            self.p.remove()
            self.p = None

    def filter_list(self):
        """
        Return the list of regions that are active at this time period
        """
        # Update indices
        for i in range(len(self.regions)):
            self.regions[i].index = i
        # Return
        return [
            r for r in self.regions
            if r.start <= self.index and (r.end is None or r.end > self.index)
        ]

    def get_next_region(self, index):
        """
        If the next region is a direct continuation of the current one
        return it. Otherwise, return None
        """
        if index < len(self.regions) - 1:
            r1 = self.regions[index]
            r2 = self.regions[index + 1]
            if r1.x[1] == r2.x[0] and r1.y[1] == r2.y[0] and r1.r[1] == r2.r[
                    0] and r1.end == r2.start:
                return index + 1
        return None

    def get_prev_region(self, index):
        """
        If the prev region is a direct continuation of the current one
        return it. Otherwise, return None
        """
        if index > 0:
            r1 = self.regions[index - 1]
            r2 = self.regions[index]
            if r1.x[1] == r2.x[0] and r1.y[1] == r2.y[0] and r1.r[1] == r2.r[
                    0] and r1.end == r2.start:
                return index - 1
        return None

    def set_index(self, i):
        self.index = i
        self.current = self.filter_list()

    def get_patch_index(self, x, y):
        for i in range(len(self.current)):
            if self.current[i].contains(x, y, self.index):
                return i
        return None

    def create(self, x, y, r=0.1):
        self.regions.append(CircleRegion(self, x, y, r, self.index))
        self.filter_list()

    def delete(self, current_offset):
        self.current[current_offset].end = self.index
        # Todo: remove zero length entries!
        self.filter_list()

    def get_colors(self, data):
        cartesian = data.get_cartesian(self.index)
        colors = []
        for i in range(len(cartesian)):
            if self.get_patch_index(*cartesian[i]) is None:
                colors.append('r')
            else:
                colors.append('b')
        return np.array(colors)

    def get_classes(self, data, index):
        cartesian = data.get_cartesian(index)
        classes = []
        for i in range(len(cartesian)):
            if self.get_patch_index(*cartesian[i]) is None:
                classes.append(0)
            else:
                classes.append(1)
        return np.array(classes)

    def insert(self, i, x, y, r, start=0, end=None):
        self.regions.insert(i, CircleRegion(self, x, y, r, start, end))
        self.current = self.filter_list()
        return self.regions[i]

    def __getstate__(self):
        return [x.get_data() for x in self.regions]

    def __setstate__(self, state):
        self.regions = []
        for i in range(len(state)):
            r = CircleRegion(self)
            r.set_data(state[i])
            self.regions.append(r)

        self.p = None
        self.set_index(0)
        self.filter_list()
예제 #6
0
def polygon_phases_tune_junction(lp, args, nmode=None):
    """Plot the phase differences ccw around each polygon in the network for many glats as junction coupling
    is increased. Do this for the nth mode as the scaling parameter evolves
    (typically the bond strength between particles that are nearer than some threshold).

    Example usage:
    python gyro_lattice_class.py -polygon_phases_tune_junction -LT hexjunction2triads -N 1 -OmKspec union0p000in0p100 -alph 0.1 -periodic
    python gyro_lattice_class.py -polygon_phases_tune_junction -LT spindle -N 2 -OmKspec union0p000in0p100 -alph 0.0 -periodic -aratio 0.1
    python gyro_lattice_class.py -polygon_phases_tune_junction -LT spindle -N 4 -OmKspec union0p000in0p100 -alph 0.0 -periodic -aratio 0.1

    # for making lattices
    python ./build/make_lattice.py -LT spindle -N 4 -periodic -skip_gyroDOS -aratio 0.1

    Parameters
    ----------
    lp
    args
    nmode : int, int list, or None

    Returns
    -------

    """
    cmap = lecmaps.ensure_cmap('bbr0')
    nkvals = 50
    if lp['LatticeTop'] in ['hexjunctiontriad', 'spindle', 'hexjunction2triads']:
        kvals = -np.unique(np.round(np.logspace(-1, 1., nkvals), 2))  # [::-1]
        dist_thres = lp['OmKspec'].split('union')[-1].split('in')[-1]
        lpmaster = copy.deepcopy(lp)
        lat = lattice_class.Lattice(lp)
        lat.load()
        if nmode is None:
            todo = np.arange(len(lat.xy[:, 0]))
        elif type(nmode) == int:
            todo = [nmode]
        else:
            todo = nmode

        ##########################################################################
        # First collect eigenvalue flow
        eigvals, first = [], True
        for (kval, dmyi) in zip(kvals, np.arange(len(kvals))):
            lp = copy.deepcopy(lpmaster)
            lp['OmKspec'] = 'union' + sf.float2pstr(kval, ndigits=3) + 'in' + dist_thres
            # lat = lattice_class.Lattice(lp)
            glat = GyroLattice(lat, lp)
            eigval, eigvect = glat.eig_vals_vects(attribute=True)
            if first:
                eigvals = np.zeros((len(kvals), len(eigval)), dtype=float)

            eigvals[dmyi, :] = np.imag(eigval)
            first = False

        ##########################################################################

        lp = copy.deepcopy(lpmaster)
        glat = GyroLattice(lat, lp)
        # add meshfn without OmKspecunion part
        mfe = glat.lp['meshfn_exten']
        if mfe[0:13] == '_OmKspecunion':
            meshfnextenstr = mfe.split(mfe.split('_')[1])[-1]
        else:
            raise RuntimeError('Handle this case here -- should be easy: split meshfn_exten to pop OmKspec out')

        for ii in todo[::-1]:
            modefn = dio.prepdir(lat.lp['meshfn']) + 'glat_mode_phases_scaling_tune_junction_mode{0:05d}'.format(ii) +\
                     meshfnextenstr + '_nkvals{0:04}'.format(nkvals) + '.mov'
            globmodefn = glob.glob(modefn)
            if not globmodefn:
                modedir = dio.prepdir(lat.lp['meshfn']) + 'glat_mode_phases_tune_junction_mode{0:05d}'.format(ii) + \
                          meshfnextenstr + '/'
                dio.ensure_dir(modedir)
                previous_ev = None
                first = True
                dmyi = 0
                for kval in kvals:
                    lp = copy.deepcopy(lpmaster)
                    lp['OmKspec'] = 'union' + sf.float2pstr(kval, ndigits=3) + 'in' + dist_thres
                    # lat = lattice_class.Lattice(lp)
                    glat = GyroLattice(lat, lp)
                    eigval, eigvect = glat.eig_vals_vects(attribute=True)

                    # plot the nth mode
                    # fig, DOS_ax, eax = leplt.initialize_eigvect_DOS_header_plot(eigval, glat.lattice.xy,
                    #                                                             sim_type='gyro', cbar_nticks=2,
                    #                                                             cbar_tickfmt='%0.3f')
                    fig, dos_ax, eax, ax1, cbar_ax = \
                        leplt.initialize_eigvect_DOS_header_twinplot(eigval, glat.lattice.xy, sim_type='gyro',
                                                                     ax0_pos=[0.0, 0.10, 0.45, 0.55],
                                                                     ax1_pos=[0.65, 0.15, 0.3, 0.60],
                                                                     header_pos=[0.1, 0.78, 0.4, 0.20],
                                                                     xlabel_pad=8, fontsize=8)

                    cax = plt.axes([0.455, 0.10, 0.02, 0.55])

                    # Get the theta that minimizes the difference between the present and previous eigenvalue
                    # IN ORDER TO CONNECT MODES PROPERLY
                    if previous_ev is not None:
                        realxy = np.real(previous_ev)
                        thetas, eigvects = gdh.phase_fix_nth_gyro(glat.eigvect, ngyro=0, basis='XY')
                        # only look at neighboring modes
                        # (presumes sufficient resolution to disallow simultaneous crossings)
                        mmin = max(modenum - 2, 0)
                        mmax = min(modenum + 2, len(eigvects))
                        modenum = gdh.phase_difference_minimum(eigvects[mmin:mmax], realxy, basis='XY')
                        modenum += mmin
                        # print 'thetas = ', thetas
                        # if theta < 1e-9:
                        #     print 'problem with theta'
                        #     sys.exit()
                    else:
                        thetas, eigvects = gdh.phase_fix_nth_gyro(glat.eigvect, ngyro=0, basis='XY')
                        modenum = ii

                    # Plot the lattice with bonds
                    glat.lattice.plot_BW_lat(fig=fig, ax=eax, save=False, close=False, axis_off=False, title='')
                    # plot excitation
                    fig, [scat_fg, scat_fg2, pp, f_mark, lines_12_st], cw_ccw = \
                        leplt.construct_eigvect_DOS_plot(glat.lattice.xy, fig, dos_ax, eax, eigval, eigvects,
                                                         modenum, 'gyro', glat.lattice.NL, glat.lattice.KL,
                                                         marker_num=0, color_scheme='default', sub_lattice=-1,
                                                         amplify=1., title='')
                    # Plot the polygons colored by phase
                    polys = glat.lattice.get_polygons()
                    patches, colors = [], []
                    for poly in polys:
                        addv = np.array([0., 0.])
                        # build up positions, taking care of periodic boundaries
                        xys = np.zeros_like(glat.lattice.xy[poly], dtype=float)
                        xys[0] = glat.lattice.xy[poly[0]]
                        for (site, qq) in zip(poly[1:], range(len(poly) - 1)):
                            if latfns.bond_is_periodic(poly[qq], site, glat.lattice.BL):
                                toadd = latfns.get_periodic_vector(poly[qq], site,
                                                                   glat.lattice.PVx, glat.lattice.PVy,
                                                                   glat.lattice.NL, glat.lattice.KL)
                                if np.shape(toadd)[0] > 1:
                                    raise RuntimeError('Handle the case of multiple periodic bonds between ii jj here')
                                else:
                                    addv += toadd[0]
                            xys[qq + 1] = glat.lattice.xy[site] + addv
                            print 'site, poly[qq - 1] = ', (site, poly[qq])
                            print 'addv = ', addv

                        xys = np.array(xys)
                        polygon = Polygon(xys, True)
                        patches.append(polygon)

                        # Check the polygon
                        # plt.close('all')
                        # plt.plot(xys[:, 0], xys[:, 1], 'b-')
                        # plt.show()

                        # Get mean phase difference in this polygon
                        # Use weighted arithmetic mean of (cos(angle), sin(angle)), then take the arctangent.
                        yinds = 2 * np.array(poly) + 1
                        xinds = 2 * np.array(poly)
                        weights = glatfns.calc_magevecs_full(eigvect[modenum])
                        # To take mean, follow
                        # https://en.wikipedia.org/wiki/Mean_of_circular_quantities#Mean_of_angles
                        # with weights from
                        # https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Mathematical_definition
                        # First take differences in angles
                        phis = np.arctan2(np.real(eigvects[modenum, yinds]), np.real(eigvects[modenum, xinds]))
                        print 'phis = ', phis
                        phis = np.mod(phis, np.pi * 2)
                        print 'phis = ', phis
                        # Now convert to vectors, take mean of both x and y components. Then grab atan2(y,x) of result.
                        xx, yy = np.mean(np.cos(np.diff(phis))), np.mean(np.sin(np.diff(phis)))
                        dphi = np.arctan2(yy, xx)
                        print 'dphi = ', dphi
                        colors.append(dphi)

                    # sys.exit()
                    pp = PatchCollection(patches, alpha=0.4, cmap=cmap)
                    pp.set_array(np.array(colors))
                    eax.add_collection(pp)
                    pp.set_clim([-np.pi, np.pi])
                    cbar = fig.colorbar(pp, cax=cax)

                    # Store this current eigvector as 'previous_ev'
                    previous_ev = eigvects[modenum]

                    # Plot where in evolution we are tracking
                    ngyros = int(np.shape(eigvals)[1] * 0.5)
                    halfev = eigvals[:, ngyros:]
                    for row in halfev.T:
                        ax1.loglog(np.abs(kvals), row, 'b-')

                    trackmark = ax1.plot(np.abs(kval), np.abs(np.imag(eigval))[modenum], 'ro')
                    ax1.set_xlabel(r"vertex coupling, $\Omega_k'$")
                    ax1.set_ylabel(r"frequency, $\omega$")
                    eax.xaxis.set_ticks([])
                    eax.yaxis.set_ticks([])
                    cbar.set_ticks([-np.pi, 0, np.pi])
                    cbar.set_ticklabels([r'-$\pi$', 0, r'$\pi$'])
                    cbar.set_label(r'phase, $\Delta \phi$')

                    dos_ax.set_xlim(xmin=0)
                    plt.savefig(modedir + 'DOS_' + '{0:05}'.format(dmyi) + '.png', dpi=200)

                    # remove plotted excitation
                    scat_fg.remove()
                    scat_fg2.remove()
                    pp.remove()
                    if f_mark is not None:
                        f_mark.remove()
                    lines_12_st.remove()
                    eax.cla()

                    dmyi += 1
                    first = False

                # Make movie
                imgname = modedir + 'DOS_'
                movname = modedir[:-1] + '_nkvals{0:04}'.format(nkvals)
                lemov.make_movie(imgname, movname, indexsz='05', framerate=5, rm_images=True, save_into_subdir=True,
                                 imgdir=modedir)
def movie_cherns_varyloc(ccoll, title='Chern number calculation for varied positions',
                         filename='chern_varyloc', rootdir=None, exten='.png', max_boxfrac=None, max_boxsize=None,
                         xlabel=None, ylabel=None, step=0.5, fracsteps=False, framerate=3):
    """Plot the chern as a function of space for each haldane_lattice examined

    Parameters
    ----------
    ccoll : ChernCollection instance
        The collection of varyloc chern calcs to make into a movie
    title : str
        title of the movie
    filename : str
        the name of the files to save
    rootdir : str or None
        The cproot directory to use (usually self.cp['rootdir'])
    exten : str (.png, .jpg, etc)
        file type extension
    max_boxfrac : float
        Fraction of spatial extent of the sample to use as maximum bound for kitaev sum
    max_boxsize : float or None
        If None, uses max_boxfrac * spatial extent of the sample asmax_boxsize
    xlabel : str
        label for x axis
    ylabel : str
        label for y axis
    step : float (default=1.0)
        how far apart to sample kregion vertices in varyloc
    fracsteps : bool
    framerate : int
        The framerate at which to save the movie
    max_boxfrac : float
        Fraction of spatial extent of the sample to use as maximum bound for kitaev sum
    max_boxsize : float or None
        If None, uses max_boxfrac * spatial extent of the sample asmax_boxsize
    """
    rad = 1.0
    divgmap = cmaps.diverging_cmap(250, 10, l=30)

    # plot it
    for hlat_name in ccoll.cherns:
        hlat = ccoll.cherns[hlat_name][0].haldane_lattice
        if hlat.lp['shape'] == 'square':
            # get extent of the network from Bounding box
            Radius = np.abs(hlat.lp['BBox'][0, 0])
        else:
            # todo: allow different geometries
            pass

        # Initialize the figure
        h_mm = 90
        w_mm = 120
        # To get space between subplots, figure out how far away ksize region needs to be, based on first chern
        # Compare max ksize to be used with spatial extent of the lattice. If comparable, make hspace large.
        # Otherwise, use defaults
        ksize = ccoll.cherns[hlat_name][0].chern_finsize[:, 2]
        cgll = ccoll.cherns[hlat_name][0].haldane_lattice.lattice
        maxsz = max(np.max(cgll.xy[:, 0]) - np.min(cgll.xy[:, 0]),
                    np.max(cgll.xy[:, 1]) - np.min(cgll.xy[:, 1]))
        if max_boxsize is not None:
            ksize = ksize[ksize < max_boxsize]
        else:
            if max_boxfrac is not None:
                max_boxsize = max_boxfrac * maxsz
                ksize = ksize[ksize < max_boxsize]
            else:
                ksize = ksize
                max_boxsize = np.max(ksize)
        if max_boxsize > 0.9 * maxsz:
            center0_frac = 0.3
            center2_frac = 0.75
        elif max_boxsize > 0.65 * maxsz:
            center0_frac = 0.35
            center2_frac = 0.72
        elif max_boxsize > 0.55 * maxsz:
            center0_frac = 0.375
            center2_frac = 0.71
        else:
            center0_frac = 0.4
            center2_frac = 0.7

        fig, ax = initialize_1p5panelcbar_fig(Wfig=w_mm, Hfig=h_mm, wsfrac=0.4, wssfrac=0.4,
                                              center0_frac=center0_frac, center2_frac=center2_frac)

        # dimensions of video in pixels
        final_h = 720
        final_w = 960
        actual_dpi = final_h / (float(h_mm) / 25.4)

        # Add the network to the figure
        hlat = ccoll.cherns[hlat_name][0].haldane_lattice
        netvis.movie_plot_2D(hlat.lattice.xy, hlat.lattice.BL, 0 * hlat.lattice.BL[:, 0],
                             None, None, ax=ax[0], fig=fig, axcb=None,
                             xlimv='auto', ylimv='auto', climv=0.1, colorz=False, ptcolor=None, figsize='auto',
                             colormap='BlueBlackRed', bgcolor='#ffffff', axis_off=True, axis_equal=True,
                             lw=0.2)

        # Add title
        if title is not None:
            ax[0].annotate(title, xy=(0.5, .95), xycoords='figure fraction',
                           horizontalalignment='center', verticalalignment='center')
        if xlabel is not None:
            ax[0].set_xlabel(xlabel)
        if ylabel is not None:
            ax[0].set_xlabel(ylabel)

        # Position colorbar
        sm = plt.cm.ScalarMappable(cmap=divgmap, norm=plt.Normalize(vmin=-1, vmax=1))
        # fake up the array of the scalar mappable.
        sm._A = []
        cbar = plt.colorbar(sm, cax=ax[1], orientation='horizontal', ticks=[-1, 0, 1])
        ax[1].set_xlabel(r'$\nu$')
        ax[1].xaxis.set_label_position("top")
        ax[2].axis('off')

        # Add patches (rectangles from cherns at each site) to the figure
        print 'Opening hlat_name = ', hlat_name
        done = False
        ind = 0
        while done is False:
            rectps = []
            colorL = []
            for chernii in ccoll.cherns[hlat_name]:
                # Grab small, medium, and large circles
                ksize = chernii.chern_finsize[:, 2]
                if max_boxsize is not None:
                    ksize = ksize[ksize < max_boxsize]
                else:
                    if max_boxfrac is not None:
                        cgll = chernii.haldane_lattice.lattice
                        maxsz = max(np.max(cgll.xy[:, 0]) - np.min(cgll.xy[:, 0]),
                                    np.max(cgll.xy[:, 1]) - np.min(cgll.xy[:, 1]))
                        max_boxsize = max_boxfrac * maxsz
                        ksize = ksize[ksize < max_boxsize]
                    else:
                        ksize = ksize
                        max_boxsize = np.max(ksize)

                # print 'ksize =  ', ksize
                # print 'max_boxsize =  ', max_boxsize

                xx = float(chernii.cp['poly_offset'].split('/')[0])
                yy = float(chernii.cp['poly_offset'].split('/')[1])
                nu = chernii.chern_finsize[:, -1]
                rad = step
                rect = plt.Rectangle((xx-rad*0.5, yy-rad*0.5), rad, rad, ec="none")
                colorL.append(nu[ind])
                rectps.append(rect)

            p = PatchCollection(rectps, cmap=divgmap, alpha=1.0, edgecolors='none')
            p.set_array(np.array(np.array(colorL)))
            p.set_clim([-1., 1.])

            # Add the patches of nu calculations for each site probed
            ax[0].add_collection(p)

            # Draw the kitaev cartoon in second axis with size ksize[ind]
            polygon1, polygon2, polygon3 = kfns.get_kitaev_polygons(ccoll.cp['shape'], ccoll.cp['regalph'],
                                                                    ccoll.cp['regbeta'], ccoll.cp['reggamma'],
                                                                    ksize[ind])
            patchlist = []
            patchlist.append(patches.Polygon(polygon1, color='r'))
            patchlist.append(patches.Polygon(polygon2, color='g'))
            patchlist.append(patches.Polygon(polygon3, color='b'))
            polypatches = PatchCollection(patchlist, cmap=cm.jet, alpha=0.4, zorder=99, linewidths=0.4)
            colors = np.linspace(0, 1, 3)[::-1]
            polypatches.set_array(np.array(colors))
            ax[2].add_collection(polypatches)
            ax[2].set_xlim(ax[0].get_xlim())
            ax[2].set_ylim(ax[0].get_ylim())

            # Save the plot
            # make index string
            indstr = '_{0:06d}'.format(ind)
            hlat_cmesh = kfns.get_cmeshfn(ccoll.cherns[hlat_name][0].haldane_lattice.lp, rootdir=rootdir)
            specstr = '_Nks' + '{0:03d}'.format(len(ksize)) + '_step' + sf.float2pstr(step) \
                      + '_maxbsz' + sf.float2pstr(max_boxsize)
            outdir = hlat_cmesh + '_' + hlat.lp['LatticeTop'] + '_varyloc_stills' + specstr + '/'
            fnout = outdir + filename + specstr + indstr + exten
            print 'saving figure: ' + fnout
            le.ensure_dir(outdir)
            fig.savefig(fnout, dpi=actual_dpi*2)

            # Save at lower res after antialiasing
            f_img = Image.open(fnout)
            f_img.resize((final_w, final_h), Image.ANTIALIAS).save(fnout)

            # clear patches
            p.remove()
            polypatches.remove()
            # del p

            # Update index
            ind += 1
            if ind == len(ksize):
                done = True

        # Turn into movie
        imgname = outdir + filename + specstr
        movname = hlat_cmesh + filename + specstr + '_mov'
        subprocess.call(['./ffmpeg', '-framerate', str(int(framerate)), '-i', imgname + '_%06d' + exten, movname + '.mov',
                         '-vcodec', 'libx264', '-profile:v', 'main', '-crf', '12', '-threads', '0',
                         '-r', '30', '-pix_fmt', 'yuv420p'])
예제 #8
0
import matplotlib.patches as patches
from matplotlib.collections import PatchCollection
import wx


def idx2x(idx):
    return timeview.x_to_num(gpx[timeview.xaxis][idx])


# patches
mypatches = []
mypatches.append(
    patches.Rectangle((idx2x(1200), 0),
                      idx2x(2400) - idx2x(1200), 25.0))
p = PatchCollection(mypatches, alpha=0.4)
timeview.ax1.add_collection(p)
# text
txt = timeview.ax1.text(idx2x(2400), 12, 'end', fontsize=12)
# annotations
annotation = timeview.ax1.annotate('local max',
                                   xy=(idx2x(1200), 10),
                                   xytext=(idx2x(1200), 20),
                                   arrowprops=dict(facecolor='black',
                                                   shrink=0.05),
                                   fontsize=12)
sh.upd()
cmd = raw_input('Press enter to exit script:')
p.remove()
txt.remove()
annotation.remove()
sh.upd()