示例#1
0
def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.

    Usage: Plug this function in Trainer class after loss.backwards() as
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        if (p.requires_grad) and ("bias" not in n) and p.grad != None:
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom=-0.001, top=0.02)  # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([
        Line2D([0], [0], color="c", lw=4),
        Line2D([0], [0], color="b", lw=4),
        Line2D([0], [0], color="k", lw=4)
    ], ['max-gradient', 'mean-gradient', 'zero-gradient'])
示例#2
0
文件: core.py 项目: sq-meng/coilcalc
 def draw_source(self, ax):
     source = self
     neg = np.asarray([1, -1])
     area = Polygon(
         (source.start, source.end, source.end * neg, source.start * neg),
         color=[0.4, 0.4, 0.4, 0.25],
         zorder=0)
     ax.add_artist(area)
     upper_line = Line2D(source.x_span, source.radius, color='r')
     lower_line = Line2D(source.x_span, -source.radius, color='r')
     ax.add_artist(upper_line)
     ax.add_artist(lower_line)
示例#3
0
    def show_current_img(self, filter_id, add_forms=False):
        """Plot current image.

        :param str filter_id: filter ID of image list (e.g. "on")
        """
        ax = self.current_image(filter_id).show_img()
        if add_forms:
            handles = []
            for k, v in six.iteritems(self.lines._forms):
                l = Line2D([v[0], v[2]], [v[1], v[3]],
                           color="#00ff00",
                           label=k)
                handles.append(l)
                ax.add_artist(l)
            for k, v in six.iteritems(self.rects._forms):
                w, h = v[2] - v[0], v[3] - v[1]
                r = Rectangle((v[0], v[1]),
                              w,
                              h,
                              ec="#00ff00",
                              fc="none",
                              label=k)
                ax.add_patch(r)
                handles.append(r)
            ax.legend(handles=handles,
                      loc='best',
                      fancybox=True,
                      framealpha=0.5,
                      fontsize=10).draggable()

        return ax
示例#4
0
    def plot_results(self, show=True):
        """Plots total energy convergence and band gap convergence.

        Args:
            show (bool): Shows plot via matplotlib if True, else saves to file kconv.png .

        """
        # this is the perfect job for plotly
        fig = plt.figure(
            constrained_layout=True,
            figsize=(8, 4),
        )
        spec = gridspec.GridSpec(ncols=2, nrows=1, figure=fig)
        ax1 = fig.add_subplot(spec[0])
        ax2 = fig.add_subplot(spec[1])

        ax1 = self._plot_energy(ax1)
        ax2 = self._plot_gaps(ax2)
        handles = []
        handles.append(
            Line2D([0], [0], color="gold", label="< 1e-4 eV/atom", lw=1.5))
        handles.append(
            Line2D([0], [0], color="orange", label="< 1e-5 eV/atom", lw=1.5))
        handles.append(
            Line2D([0], [0], color="red", label="< 1e-6 eV/atom", lw=1.5))
        ax1.legend(
            handles=handles,
            loc="upper right",
            frameon=True,
            fancybox=True,
            framealpha=0.8,
        )
        if show:
            plt.show()
        else:
            plt.savefig(
                str(self.maindir.joinpath("kconv.png")),
                dpi=300,
                transparent=False,
                facecolor="white",
                bbox_inches="tight",
            )
            logger.info("Results have been saved to kconv.png.")
示例#5
0
    def drawPlot(self):
        ion()
        self.fig = plt.figure()
        # draw cart
        self.axes = self.fig.add_subplot(111, aspect='equal')
        self.box = Rectangle(xy=(self.cart_location - self.cartwidth / 2.0, -self.cartheight), 
                             width=self.cartwidth, height=self.cartheight)
        self.axes.add_artist(self.box)
        self.box.set_clip_box(self.axes.bbox)

        # draw pole
        self.pole = Line2D([self.cart_location, self.cart_location + np.sin(self.pole_angle)], 
                           [0, np.cos(self.pole_angle)], linewidth=3, color='black')
        self.axes.add_artist(self.pole)
        self.pole.set_clip_box(self.axes.bbox)

        # set axes limits
        self.axes.set_xlim(-10, 10)
        self.axes.set_ylim(-0.5, 2)
示例#6
0
def plot_path(coord_list, paths):
    figure, ax = plt.subplots()

    for i in range(tot + 1):
        ax.scatter(coord_list[i][0], coord_list[i][1])
        pass

    # generate the random color
    colors = []
    color_letter = ['A', 'B', 'C', 'D', 'E', 'F']
    for i in range(10):
        color_letter.append(str(i))
    for i in range(len(paths)):
        color_str = ''
        for i in range(6):
            color_str = color_str + color_letter[random.randint(0, 15)]
        colors.append('#' + color_str)

    for path in paths:
        # for j in range(len(path)):
        #    path[j] -= 1

        for point_no in range(len(path) - 2):
            # print point_no
            # print  coord_list[path[point_no]][0], coord_list[path[point_no]][1],coord_list[path[point_no+1]][0], coord_list[path[point_no+1]][1]

            line1 = [(coord_list[path[point_no]][0],
                      coord_list[path[point_no]][1]),
                     (coord_list[path[point_no + 1]][0],
                      coord_list[path[point_no + 1]][1])]
            (line1_xs, line1_ys) = zip(*line1)
            ax.add_line(
                Line2D(line1_xs,
                       line1_ys,
                       linewidth=1,
                       color=colors[paths.index(path)]))
        #print path
    plt.plot()
    plt.show()
    #raw_input()
    pass
示例#7
0
def likegrid1d(chains,
               params='all',
               lims=None,
               ticks=None,
               nticks=4,
               nsig=3,
               colors=None,
               nbins1d=30,
               labels=None,
               fig=None,
               size=2,
               aspect=1,
               legend_loc=None,
               linewidth=1,
               param_name_mapping=None,
               param_label_size=None,
               tick_label_size=None,
               titley=1,
               ncol=4,
               axes=None):
    """
    Make a grid of 1-d likelihood contours.
   
    Arguments:
    ----------
   
    chains :
        one or a list of `Chain` objects
       
    default_chain, optional :
        the chain used to get default parameters names, axes limits, and ticks
        either an index into chains or a `Chain` object (default: chains[0])
       
    params, optional :
        list of parameter names which to show
        can also be 'all' or 'common' which does the union/intersection of
        the params in all the chains
       
    lims, optional :
        a dictionary mapping parameter names to (min,max) axes limits
        (default: +/- 4 sigma from default_chain)
       
    ticks, optional :
        a dictionary giving a list of ticks for each parameter
        
    nticks, optional :
        roughly how many x ticks to show. can be dictionary to 
        specify each parameter separately. (default: 4)
       
    fig, optional :
        figure of figure number in which to plot (default: new figure)
    ncol, optional :
        the number of colunms (default: 4)
    axes, optional :
        an array of axes into which to plot. if this is provided, fig and ncol 
        are ignored. must have len(axes) >= len(params). 
       
    size, optional :
        size in inches of one plot (default: 2)
    aspect, optional :
        aspect ratio (default: 1)
       
    colors, optional :
        colors to cycle through for plotting
       
    filled, optional :
        whether to fill in the contours (default: True)
       
    labels, optional :
        list of names for a legend
       
    legend_loc, optional :
        (x,y) location of the legend (coordinates scaled to [0,1])
       
    nbins1d, optional :
        number of bins for 1d plots (default: 30)
       
    nbins2d, optional :
        number of bins for 2d plots (default: 20)
    """
    from matplotlib.pyplot import figure, Line2D
    from matplotlib.ticker import AutoMinorLocator, ScalarFormatter, MaxNLocator

    if type(chains) != list: chains = [chains]

    if params in ['all', 'common']:
        params = sorted(
            reduce(
                lambda x, y: (op.__or__ if params == 'all' else op.__and__)
                (set(x), set(y)), [c.params() for c in chains]))
    elif not isinstance(params, Iterable):
        raise ValueError("params should be iterable or 'all' or 'common'")

    if param_name_mapping is None: param_name_mapping = {}
    nrow = len(params) / ncol + 1
    if axes is None:
        if fig is None: fig = figure(fig) if isinstance(fig, int) else figure()
        if size is not None:
            fig.set_size_inches(size * ncol, size * nrow / aspect)
        fig.subplots_adjust(hspace=0.4, wspace=0.1)
    if colors is None: colors = ['b', 'orange', 'k', 'm', 'cyan']

    if lims is None: lims = {}
    lims = {
        p: (lims[p] if p in lims else (min(
            max(min(c[p]),
                mean(c[p]) - nsig * std(c[p])) for c in chains
            if p in c.params()),
                                       max(
                                           min(max(c[p]),
                                               mean(c[p]) + nsig * std(c[p]))
                                           for c in chains
                                           if p in c.params())))
        for p in params
    }

    n = len(params)
    for (i, p1) in enumerate(
            params, 0 if axes is not None else 2 if labels is not None else 1):
        ax = axes[i] if axes is not None else fig.add_subplot(nrow, ncol, i)
        if ticks is not None and p1 in ticks:
            ax.set_xticks(ticks[p1])
        for (ch, col) in zip(chains, colors):
            if p1 in ch:
                ch.like1d(p1,
                          nbins=nbins1d,
                          color=col,
                          ax=ax,
                          linewidth=linewidth)
        ax.set_yticks([])
        ax.set_xlim(lims[p1])
        ax.set_ylim(0, 1)
        ax.set_title(param_name_mapping.get(p1, p1),
                     size=param_label_size,
                     y=titley)
        ax.tick_params(labelsize=tick_label_size)
        if ticks and p1 in ticks:
            ax.set_xticks(ticks[p1])
        else:
            ax.xaxis.set_major_locator(
                MaxNLocator(nbins=nticks.get(p1, 4) if isinstance(
                    nticks, dict) else nticks))
        ax.xaxis.set_minor_locator(AutoMinorLocator())
        ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))

    if labels is not None:
        fig.legend([Line2D([0], [0], c=c, linewidth=3) for c in colors],
                   labels,
                   fancybox=True,
                   shadow=False,
                   loc=legend_loc if legend_loc is not None else
                   (0, 1 - 1. / nrow))
示例#8
0
def likegrid(chains,
             params=None,
             lims=None,
             ticks=None,
             nticks=5,
             default_chain=0,
             spacing=0.05,
             xtick_rotation=30,
             colors=None,
             filled=True,
             nbins1d=30,
             nbins2d=20,
             labels=None,
             fig=None,
             size=2,
             legend_loc=None,
             param_name_mapping=None,
             param_label_size=None):
    """
    Make a grid (aka "triangle plot") of 1- and 2-d likelihood contours. 
    
    Parameters
    ----------
    
    chains : 
        one or a list of `Chain` objects
        
    default_chain, optional :
        the chain used to get default parameters names, axes limits, and ticks 
        either an index into chains or a `Chain` object (default: chains[0])
        
    params, optional : 
        list of parameter names which to show 
        (default: all parameters from default_chain)
        
    lims, optional :
        a dictionary mapping parameter names to (min,max) axes limits
        (default: +/- 4 sigma from default_chain)
        
    ticks, optional :
        a dictionary mapping parameter names to list of [ticks]
        (default: automatically picks `nticks`)
    nticks, optional :
        roughly how many ticks per axes (default: 5)
        
    xtick_rotation, optional :
        numbers of degrees to rotate the xticks by (default: 30)
    spacing, optional :
        space in between plots as a fraction of figure width (default: 0.05)
    fig, optional :
        figure of figure number in which to plot (default: figure(0))
        
    size, optional :
        size in inches of one plot (default: 2)
        
    colors, optional : 
        colors to cycle through for plotting
        
    filled, optional :
        whether to fill in the contours (default: True)
        
    labels, optional :
        list of names for a legend
        
    legend_loc, optional :
        (x,y) location of the legend (coordinates scaled to [0,1]) 
        
    nbins1d, optional : 
        number (or len(chains) length list) of bins for 1d plots (default: 30)
        
    nbins2d, optional :
        number (or len(chains) length list) of bins for 2d plots (default: 20)
    """
    from matplotlib.pyplot import figure, Line2D, xticks
    from matplotlib.ticker import MaxNLocator
    fig = figure(0) if fig is None else (
        figure(fig) if isinstance(fig, int) else fig)
    if type(chains) != list: chains = [chains]
    if params == None:
        params = sorted(
            reduce(lambda x, y: set(x) & set(y), [c.params() for c in chains]))
    if param_name_mapping is None: param_name_mapping = {}
    if size is not None: fig.set_size_inches(*([size * len(params)] * 2))
    if colors is None: colors = ['b', 'orange', 'k', 'm', 'cyan']
    if not isinstance(nbins2d, list): nbins2d = [nbins2d] * len(chains)
    if not isinstance(nbins1d, list): nbins1d = [nbins1d] * len(chains)
    fig.subplots_adjust(hspace=spacing, wspace=spacing)

    c = chains[default_chain] if isinstance(default_chain,
                                            int) else default_chain
    lims = dict(
        {
            p: (max(min(c[p]),
                    mean(c[p]) - 4 * std(c[p])),
                min(max(c[p]),
                    mean(c[p]) + 4 * std(c[p])))
            for p in params
        }, **(lims if lims is not None else {}))
    if ticks is None: ticks = {}
    if isinstance(nticks, int): nticks = {p: nticks for p in params}

    n = len(params)
    for (i, p1) in enumerate(params):
        for (j, p2) in enumerate(params):
            if (i <= j):
                ax = fig.add_subplot(n, n, j * n + i + 1)
                ax.xaxis.set_major_locator(MaxNLocator(nticks.get(p1, 5)))
                ax.yaxis.set_major_locator(MaxNLocator(nticks.get(p2, 5)))
                ax.set_xlim(*lims[p1])
                if (i == j):
                    for (ch, col, nbins) in zip(chains, colors, nbins1d):
                        if p1 in ch:
                            ch.like1d(p1, nbins=nbins, color=col, ax=ax)
                    ax.set_yticks([])

                elif (i < j):
                    for (ch, col, nbins) in zip(chains, colors, nbins2d):
                        if p1 in ch and p2 in ch:
                            ch.like2d(p1,
                                      p2,
                                      filled=filled,
                                      nbins=nbins,
                                      color=col,
                                      ax=ax)
                    if p2 in ticks: ax.set_yticks(ticks[p2])
                    ax.set_ylim(*lims[p2])

                if i == 0:
                    ax.set_ylabel(param_name_mapping.get(p2, p2),
                                  size=param_label_size)
                else:
                    ax.set_yticklabels([])

                if j == n - 1:
                    ax.set_xlabel(param_name_mapping.get(p1, p1),
                                  size=param_label_size)
                    xticks(rotation=xtick_rotation)
                else:
                    ax.set_xticklabels([])

    if labels is not None:
        fig.legend([Line2D([0], [0], c=c, lw=2) for c in colors],
                   labels,
                   fancybox=True,
                   shadow=False,
                   loc=legend_loc)
                                  alpha=.5)
labels = ["%iM €" % (s / 1e6) for s in reference_caps]
handler_map = make_handler_map_to_scale_circles_as_in(ax)
legend = fig.legend(handles,
                    labels,
                    loc="upper left",
                    bbox_to_anchor=(1., 0.4),
                    frameon=False,
                    handler_map=handler_map)
fig.add_artist(legend)

# legend transmission capacitues
handles, labels = [], []
reference_caps = [20, -10]
handles = [
    Line2D([0], [0], color=c, linewidth=abs(s) * 1e6 * branch_scale)
    for s, c in zip(reference_caps, ['cadetblue', 'indianred'])
]
labels = ['%iM €' % s for s in reference_caps]

legend = fig.legend(
    handles,
    labels,
    loc="lower left",
    bbox_to_anchor=(1, .1),
    frameon=False,
)
fig.artists.append(legend)

fig.canvas.draw()
fig.tight_layout()
示例#10
0
    # plt.plot(imus[0].timestamp, imus[0].euler_y, 'r:')
    # plt.plot(imus[0].timestamp, imus[0].euler_z, 'r--')
    # plt.plot(emg_1_timestamp, emg_1_values, 'm-')
    # plt.plot(emg_2_timestamp, emg_2_values, 'm:')
    # plt.plot(imu_2_z_up_timestamp, imu_2_z_up_values, 'r.', label='extension')
    # plt.plot(imu_2_z_zero_timestamp, imu_2_z_zero_values, 'g.', label='stop')
    # plt.plot(imu_2_z_low_timestamp, imu_2_z_low_values, 'b.', label='flexion')

    # plt.title(filename)
    plt.xlabel('Time [s]')
    plt.ylabel('Angle [rad] / Class')
    plt.ylim((-1.2, 1.2))
    plt.xlim((565, 585))
    # plt.legend()
    legend_elements = [
        Line2D([0], [0], color='k', label='Stimulation'),
        Line2D([0], [0], color='b', label='Flexion', marker='o'),
        Line2D([0], [0], color='g', label='Stop', marker='o'),
        Line2D([0], [0], color='r', label='Extension', marker='o'),
        # Line2D([0], [0], color='c', label='Prediction', marker='o'),
    ]
    plt.legend(handles=legend_elements)
    plt.savefig('graph.svg')
    plt.show()

if dash_plot:

    app_dash = dash.Dash()

    app_dash.layout = html.Div(children=[
        html.Label('Data to graph:'),
示例#11
0
# %%
events = LogDf['name'].unique()
events = list(events)
events.remove(np.nan)
events = ['TRIAL_ENTRY_EVENT','FRAME_EVENT','FRAME_INF_EVENT']
events = ['REACH_LEFT_ON','REACH_RIGHT_ON','REWARD_COLLECTED_EVENT']
events = np.array(events)
fig, axes = plt.subplots()
colors = dict(zip(events,sns.color_palette(palette='tab20',n_colors=events.shape[0])))
for i,event in enumerate(events):
    Df = LogDf.groupby('name').get_group(event)
    for j,row in Df.iterrows():
        axes.plot([row['t']/1000,row['t']/1000],[0+i,1+i],color=colors[event])

from matplotlib.pyplot import Line2D
handles = [Line2D([],[],color=color) for event,color in colors.items()]
plt.legend(handles, [event for event,color in colors.items()],loc='upper left', bbox_to_anchor=(1.0, 1.0 ))
fig.tight_layout()
# %%
LogDf.groupby('name').get_group('FRAME_EVENT').shape[0]
LogDf.groupby('name').get_group('FRAME_INF_EVENT').shape[0]
# -> works, assign idex
nFrames = LogDf.groupby('name').get_group('FRAME_INF_EVENT').shape[0]
offset = dFF.shape[1] - nFrames
LogDf.loc[LogDf['name'] == 'FRAME_INF_EVENT','var'] = np.arange(offset, nFrames+offset)

# %% test slicing
w = (-3000, 3000)
event = 'REACH_RIGHT_ON'
times = LogDf.groupby('name').get_group(event)['t']
ix_ts = []
示例#12
0
            for j in range(ncols):
                plot_data(i, j, df_files_mwa[i-1, j], df_files_hera37[i-1, j],
                          df_files_hera331[i-1, j])

        # Axes labels
        fig.text(0.01, 0.5, ylabels[stat], rotation='vertical',
                 horizontalalignment='left', verticalalignment='center')
        fig.text(0.5, 0.01, 'Frequency [MHz]', horizontalalignment='center',
                 verticalalignment='bottom')
        fig.text(0.5, 0.99, 'Ionized Fraction', horizontalalignment='center',
                 verticalalignment='top')

        # Legend
        # Legend parameters
        handlers = [
            Line2D([], [], linestyle=':', color='black', linewidth=1),
            Line2D([], [], linestyle='--', color='black', linewidth=1),
            Line2D([], [], linestyle='-', color='black', linewidth=1),
            Patch(color='0.85'),
            Patch(color='0.7'),
            Patch(color='0.55')
        ]
        labels = ['MWA Phase I Core', 'HERA37', 'HERA331',
                  'MWA Phase I Core Error', 'HERA37 Error', 'HERA331 Error']

        plt.figlegend(handles=handlers, labels=labels, loc=(0.6, 0.8),
                      ncol=1, fontsize='medium')

        # Tidy up
        gs.tight_layout(fig, rect=[0.02, 0.02, 0.99, 0.98])
        gs.update(wspace=0, hspace=0)
           ncol=2)

# legend generator capacities
reference_caps = [500e6, 100e6]
scale = 1 / bus_scale / projected_area_factor(ax)**2
handles = make_legend_circles_for(reference_caps,
                                  scale=scale,
                                  facecolor="w",
                                  edgecolor='grey',
                                  alpha=.5)
labels = ["%iM €" % (s / 1e6) for s in reference_caps]

# append legend transmission capacities
reference_caps = [20, 10]
handles += [
    Line2D([0], [0], color='cadetblue', linewidth=abs(s) * 1e6 * branch_scale)
    for s in reference_caps
]
labels += ['%iM €' % s for s in reference_caps]

handler_map = make_handler_map_to_scale_circles_as_in(ax)
legend = fig.legend(handles,
                    labels,
                    loc="lower left",
                    bbox_to_anchor=(0., 0.42),
                    title='Revenue',
                    ncol=2,
                    handler_map=handler_map)
fig.add_artist(legend)

fig.canvas.draw()
示例#14
0
                                  alpha=.5)
labels = ["%i GW" % (s / 1e3) for s in reference_caps]
handler_map = make_handler_map_to_scale_circles_as_in(axes[0])
legend = fig.legend(
    handles,
    labels,
    loc="upper left",
    bbox_to_anchor=(.5, 0),
    frameon=False,  # edgecolor='w',
    title='Generation Capacity',
    handler_map=handler_map)
fig.add_artist(legend)

# legend AC / DC
handles = [
    Line2D([0], [0], color=c, linewidth=5)
    for c in ['rosybrown', 'darkseagreen']
]
labels = ['AC', 'DC']

legend = fig.legend(handles,
                    labels,
                    loc="upper left",
                    bbox_to_anchor=(0.8, 0),
                    frameon=False,
                    title='Transmission Type')
fig.artists.append(legend)

# legend transmission capacitues
handles, labels = [], []
reference_caps = [10, 5]
示例#15
0
def draw_network_with_reactions(network,omit=[],arrowsize=15,font_size='small',arrowstyle='->',database_file='hanford.dat',do_legend=True,
            node_colors=node_colors,namechanges={},font_color='k',node_alpha=0.8,node_size=None,edge_color=None,markers={'Reaction':'*'},pos=None,
            width=None,connectionstyle=None,**kwargs):
    to_draw=network.copy()
    
    for p in network.nodes:
        if network.nodes[p]['kind'] in omit or p in omit or p in ['HRimm','Tracer']:
            to_draw.remove_node(p)
        elif network.nodes[p]['kind'] == 'surf_complex':
            for cplx in network.nodes[p]['complexes']:
                # to_draw=nx.compose(get_reaction_from_database(cplx,'surf_complex',filename=database_file),to_draw)
                to_draw.add_edge(p.strip('>'),cplx.strip('>'))
                to_draw.nodes[cplx.strip('>')]['kind']='Sorption Reaction'
                to_draw.add_edge(network.nodes[p]['mineral'],cplx.strip('>'))
                to_draw.remove_node(p)
        elif network.nodes[p]['kind'] not in ['primary','immobile','implicit','sorbed']:
            to_draw=nx.compose(get_reaction_from_database(p,network.nodes[p]['kind'],filename=database_file),to_draw)
    
    for react in network.edges:
        if network.edges[react]['name'] not in to_draw.nodes and network.edges[react]['name'] not in omit:
            e=network.edges[react]['reaction']
            for species in e['reactant_pools']:
                to_draw.add_edge(species,e['name'])
            for species in e['product_pools']:
                to_draw.add_edge(e['name'],species)
            to_draw.nodes[e['name']]['kind']=e['reactiontype'].capitalize().replace('Som','SOM') + ' Reaction'
            to_draw.nodes[e['name']]['reactiontype']=e['reactiontype']
            
    # Get rid of nodes added from database reactions that we want removed
    to_draw.remove_nodes_from([node for node in to_draw.nodes if to_draw.nodes('kind')[node] is None])
    # Get rid of all the original reactions
    to_draw.remove_edges_from(network.edges)

            
    if pos is None:
        pos=nx.drawing.nx_agraph.graphviz_layout(to_draw,prog='dot')
    from numpy import array
    nodecats=array(categorize_nodes(to_draw)  )
    nodecolors=array([node_colors[nodecat] for nodecat in nodecats])
    
    # Non-reactions:
    nonreactions=array(['Reaction' not in to_draw.nodes[n]['kind'] for n in to_draw.nodes])
    minerals=array(['mineral' in to_draw.nodes[n]['kind'] for n in to_draw.nodes])
    nx.draw_networkx_nodes(to_draw,pos=pos,nodelist=array(to_draw.nodes())[nonreactions&~minerals].tolist(),node_color=nodecolors[nonreactions&~minerals],node_size=node_size,node_shape='o',alpha=node_alpha,**kwargs)
    nx.draw_networkx_nodes(to_draw,pos=pos,nodelist=array(to_draw.nodes())[minerals].tolist(),node_color=nodecolors[minerals],node_shape=markers.get('mineral','o'),alpha=node_alpha,node_size=node_size,**kwargs)
        
    nx.draw_networkx_labels(to_draw,pos=pos,labels={n:namechanges.get(n,n) for n in to_draw.nodes},font_size=font_size,font_color=font_color,**kwargs)
    
    reactionnodes=array(to_draw.nodes())[~nonreactions].tolist()
    nx.draw_networkx_nodes(to_draw,pos=pos,nodelist=reactionnodes,node_shape=markers.get('Reaction','*'),
                node_color=nodecolors[~nonreactions],alpha=node_alpha,node_size=node_size,**kwargs)
    
    nx.draw_networkx_edges(to_draw,pos=pos,connectionstyle=connectionstyle,arrowsize=arrowsize,arrowstyle=arrowstyle,edge_color=edge_color,width=width,**kwargs)
        
    
    
    if do_legend:
        from matplotlib.pyplot import legend,Line2D
        legend_handles=[]
        legend_labels=[]
        
        for num,node in enumerate(to_draw.nodes):
            if namechanges.get(nodecats[num],nodecats[num]) not in legend_labels:
                legend_labels.append(namechanges.get(nodecats[num],nodecats[num]))
                if 'Reaction' in to_draw.nodes[node]['kind']:
                    legend_handles.append(Line2D([0],[0],ls='None',marker=markers.get('Reaction','*'),ms=15.0,color=nodecolors[num]))
                elif 'mineral' in to_draw.nodes[node]['kind']:
                    legend_handles.append(Line2D([0],[0],ls='None',marker=markers.get('mineral','o'),ms=15.0,color=nodecolors[num]))
                else:
                    legend_handles.append(Line2D([0],[0],ls='None',marker='o',ms=15.0,color=nodecolors[num]))
                
        legend(handles=legend_handles,labels=legend_labels,fontsize='large',title='Component types',title_fontsize='large',labelspacing=1.0,ncol=2)
    
    return to_draw,pos
示例#16
0
                                  alpha=.5)
labels = ["%iM €" % (s / 1e6) for s in reference_caps]
handler_map = make_handler_map_to_scale_circles_as_in(ax)
legend = ax.legend(handles,
                   labels,
                   loc="lower center",
                   bbox_to_anchor=(0.3, 1),
                   frameon=False,
                   handler_map=handler_map)
ax.add_artist(legend)

# legend transmission capacitues
reference_caps = [100e6, 50e6]
handles, labels = [], []
handles = [
    Line2D([0], [0], color='grey', linewidth=s * branch_scale / branch_sum)
    for s in reference_caps
]
labels = ['%iM €' % (s / 1e6) for s in reference_caps]

legend = ax.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(.7, 1),
    frameon=False,
)
ax.artists.append(legend)

fig.canvas.draw()
fig.tight_layout()