示例#1
0
def splot_data(data, mdata, z, label1, label2, sz, grid_size = 100): #, xls, yls, sz):

    cmap_back = ListedColormap(sns.hls_palette(3, l=.4, s=.1))
    cmap_pts = ListedColormap(sns.hls_palette(3, l=.9, s=.9))

    sns.set(style="white")
    sns.set(style="ticks", font_scale=2.0)
    fig, ax = plt.subplots(figsize=(10,6))
    ax.set_aspect('equal')

    # Decorate the plot
    ax.set_xlabel(label1)
    ax.set_ylabel(label2)
    
    # We need grid points and values to make the colormesh plot
    xx = mdata[:, 0].reshape((grid_size, grid_size))
    yy = mdata[:, 1].reshape((grid_size, grid_size))
    zz = z.reshape((grid_size, grid_size))

    plt.pcolormesh(xx, yy, zz, cmap=cmap_back)
    
    # Now draw the points, with bolder colors.
    plt.scatter(data[:, 0], data[:, 1], c=data[:, 2], s=sz, cmap=cmap_pts)

    sns.despine(offset=0.25, trim=True)
示例#2
0
文件: _stats.py 项目: briney/abtools
def get_germline_plot_colors(data, l):
    fams = [d.split('-')[0].replace('/OR', '') for d in data]
    nr_fams = sorted(list(set(fams)))
    num_colors = len(nr_fams)
    rgbs = sns.hls_palette(num_colors)
    rgb_dict = {i[0]: i[1] for i in zip(nr_fams, rgbs)}
    return [rgb_dict[f] for f in fams]
示例#3
0
def main(y_range=[-2, 2]):
    files = glob('*.mat')
    colors = sns.hls_palette(9, l=0.4, s=0.85)
    for f in files:
        handle_data(f, colors, y_range)

    return True
def visualize_hist_pairplot(X,y,selected_feature1,selected_feature2,features,diag_kind):
	"""
	Visualize the pairwise relationships (Histograms and Density Funcions) between classes and respective attributes

	Keyword arguments:
	X -- The feature vectors
	y -- The target vector
	selected_feature1 - First feature
	selected_feature1 - Second feature
	diag_kind -- Type of plot in the diagonal (Histogram or Density Function)
	"""

	#create data
	joint_data=np.column_stack((X,y))
	column_names=features

	#create dataframe
	df=pd.DataFrame(data=joint_data,columns=column_names)

	#plot
	palette = sea.hls_palette()
	splot=sea.pairplot(df, hue="Y", palette={0:palette[2],1:palette[0]},vars=[selected_feature1,selected_feature2],diag_kind=diag_kind)
	splot.fig.suptitle('Pairwise relationship: '+selected_feature1+" vs "+selected_feature2)
	splot.set(xticklabels=[])
	# plt.subplots_adjust(right=0.94, top=0.94)

	#save fig
	output_dir = "img"
	save_fig(output_dir,'{}/{}_{}_hist_pairplot.png'.format(output_dir,selected_feature1,selected_feature2))
示例#5
0
def get_kwargs(grid, kwargs, phenotypes=False):
    """
    Helper function to figure out what denom and palette to use, based on the
    kwargs and the grid being plotted. The optional (default: false) argument
    indicates whether the grid contains phenotypes, as opposed to resources.
    """
    denom = None
    if "denom" in kwargs:
        denom = kwargs["denom"]

    if "palette" in kwargs:
        palette = kwargs["palette"]
        if denom is None:
            denom = len(palette)
    elif "environment" in kwargs or isinstance(grid, EnvironmentFile):
        if "environment" in kwargs:
            env = kwargs["environment"]
        else:
            env = grid

        if phenotypes:
            palette = env.task_palette
            if denom is None:
                denom = len(env.tasks)
        else:
            palette = env.resource_palette
            if denom is None:
                denom = len(env.resources)

    else:
        length = get_pallete_length(grid)
        palette = sns.hls_palette(length, s=1)
        denom = length
        
    return denom, palette
def visualize_task_factors(task_loadings, ax, xticklabels=True, label_size=12,
                           yticklabels=False, pad=0, ymax=None, legend=True):
    """Plot task loadings on one axis"""
    n_measures = len(task_loadings)
    colors = sns.hls_palette(len(task_loadings), l=.4, s=.8)
    for i, (name, DV) in enumerate(task_loadings.iterrows()):
        plot_loadings(ax, abs(DV)+pad, width_scale=1/(n_measures), 
                      colors = [colors[i]], offset=i+.5,
                      kind='line',
                      plot_kws={'label': name, 'alpha': .8})
    # set up yticks
    if ymax:
        ax.set_ylim(top=ymax)
    ytick_locs = ax.yaxis.get_ticklocs()
    new_yticks = np.linspace(0, ytick_locs[-1], 7)
    ax.set_yticks(new_yticks)
    if yticklabels:
        labels = np.round(new_yticks,2)
        replace_dict = {i:'' for i in labels[::2]}
        labels = [replace_dict.get(i, i) for i in labels]
        ax.set_yticklabels(labels)
    # set up x ticks
    xtick_locs = np.arange(0.0, 2*np.pi, 2*np.pi/len(DV))
    ax.set_xticks(xtick_locs)
    ax.set_xticks(xtick_locs+np.pi/len(DV), minor=True)
    if xticklabels:
        labels = task_loadings.columns
        if type(labels[0]) != str:
            labels = ['Fac %s' % str(i) for i in labels]
        scale = 1.2
        size = ax.get_position().expanded(scale, scale)
        ax2=ax.get_figure().add_axes(size,zorder=2)
        max_var_length = max([len(v) for v in labels])
        for i, var in enumerate(labels):
            offset=.3*25/len(labels)**2
            start = (i-offset)*2*np.pi/len(labels)
            end = (i+(1-offset))*2*np.pi/len(labels)
            curve = [
                np.cos(np.linspace(start,end,100)),
                np.sin(np.linspace(start,end,100))
            ]  
            plt.plot(*curve, alpha=0)
            # pad strings to longest length
            num_spaces = (max_var_length-len(var))
            var = ' '*(num_spaces//2) + var + ' '*(num_spaces-num_spaces//2)
            curvetext = CurvedText(
                x = curve[0][::-1],
                y = curve[1][::-1],
                text=var, #'this this is a very, very long text',
                va = 'top',
                axes = ax2,
                fontsize=label_size##calls ax.add_artist in __init__
            )
            ax2.axis('off')
    if legend:
        leg = ax.legend(loc='upper center', bbox_to_anchor=(.5,-.15), frameon=False)
        beautify_legend(leg, colors[:len(task_loadings)])
示例#7
0
def parse_timepoints(tps, args):
    if type(args.timepoints) in [list, tuple]:
        timepoints = [Timepoint(*t) for t in args.timepoints]
    elif args.timepoints is not None:
        timepoints = []
        with open(args.timepoints, 'r') as f:
            for line in f:
                name, order, color = line.strip().split('\t')
                if name in tps:
                    timepoints.append(Timepoint(name, order, color))
    else:
        colors = sns.hls_palette(len(tps), l=0.5, s=0.9)
        timepoints.append(Timepoint('root', 0, colors[0]))
        for i, tp in enumerate(sorted([t for t in tps if t != 'root'])):
            timepoints.append(Timepoint(tp, i + 1, colors[i + 1]))
    return timepoints
def plot_loadings(ax, component_loadings, groups=None, colors=None, 
                  width_scale=1, offset=0, kind='bar', plot_kws=None):
    """Plot component loadings
    
    Args:
        ax: axis to plot on. If a polar axist, a polar bar plot will be created.
            Otherwise, a histogram will be plotted
        component_loadings (array or pandas Series): loadings to plot
        groups (list, optional): ordered list of tuples of the form: 
            [(group_name, list of group members), ...]. If not supplied, all
            elements will be treated as one group
        colors (list, optional): if supplied, specifies the colors for the groups
        width_scale (float): scale of bars. Default is 1, which fills the entire
            plot
        offset (float): offset as a proportion of width. Used to plot multiple
            columns side by side under one factor
        bar_kws (dict): keywords to pass to ax.bar
    """
    if plot_kws is None:
        plot_kws = {}
    N = len(component_loadings)
    if groups is None:
        groups = [('all', [0]*N)]
    if colors is not None:
        assert(len(colors) == len(groups))
    else:
        colors = sns.hls_palette(len(groups), l=.5, s=.8)
    ax.set_xticklabels([''])
    ax.set_yticklabels([''])
    
    width = np.pi/(N/2)*width_scale*np.ones(N)
    theta = np.array([2*np.pi/N*i for i in range(N)]) + width[0]*offset
    radii = component_loadings
    if kind == 'bar':
        bars = ax.bar(theta, radii, width=width, bottom=0.0, **plot_kws)
        for i,r,bar in zip(range(N),radii, bars):
            color_index = sum((np.cumsum([len(g[1]) for g in groups])<i+1))
            bar.set_facecolor(colors[color_index])
    elif kind == 'line':
        if 'linewidth' not in plot_kws.keys():
            plot_kws['linewidth'] = 5
        theta = np.append(theta, theta[0])
        radii = np.append(radii, radii[0])
        lines = ax.plot(theta, radii, color=colors[0], **plot_kws)
    return colors
def plot_cluster_factors(results, c, rotate='oblimin',  ext='png', plot_dir=None):
    """
    Args:
        EFA: EFA_Analysis object
        c: number of components for EFA
        task_sublists: a dictionary whose values are sets of tasks, and 
                        whose keywords are labels for those lists
    """
    # set up variables
    HCA = results.HCA
    EFA = results.EFA
    
    names, cluster_loadings = zip(*HCA.get_cluster_loading(EFA, rotate=rotate).items())
    cluster_DVs = HCA.get_cluster_DVs(inp='EFA%s_%s' % (EFA.get_c(), rotate))
    cluster_loadings = list(zip([cluster_DVs[n] for n in names], cluster_loadings))
    max_loading = max([max(abs(i[1])) for i in cluster_loadings])
    # plot
    colors = sns.hls_palette(len(cluster_loadings))
    ncols = min(5, len(cluster_loadings))
    nrows = ceil(len(cluster_loadings)/ncols)
    f, axes = plt.subplots(nrows, ncols, 
                               figsize=(ncols*10,nrows*(8+nrows)),
                               subplot_kw={'projection': 'polar'})
    axes = f.get_axes()
    for i, (measures, loading) in enumerate(cluster_loadings):
        plot_loadings(axes[i], loading, kind='line', offset=.5,
              plot_kws={'alpha': .8, 'c': colors[i]})
        axes[i].set_title('Cluster %s' % i, y=1.14, fontsize=25)
        # set tick labels
        xtick_locs = np.arange(0.0, 2*np.pi, 2*np.pi/len(loading))
        axes[i].set_xticks(xtick_locs)
        axes[i].set_xticks(xtick_locs+np.pi/len(loading), minor=True)
        if i%(ncols*2)==0 or i%(ncols*2)==(ncols-1):
            axes[i].set_xticklabels(loading.index,  y=.08, minor=True)
            # set ylim
            axes[i].set_ylim(top=max_loading)
    for j in range(i+1, len(axes)):
        axes[j].set_visible(False)
    plt.subplots_adjust(hspace=.5, wspace=.5)
    
    filename = 'polar_factors_EFA%s_%s.%s' % (c, rotate, ext)
    if plot_dir is not None:
        save_figure(f, path.join(plot_dir, filename),
                    {'bbox_inches': 'tight'})
        plt.close()
示例#10
0
文件: _stats.py 项目: briney/abtools
def cdr3_plot(seqs, collection, make_plot, chain, output_dir):
    if not make_plot:
        return None
    max_len = 40 if chain == 'heavy' else 20
    cdr3s = [s['cdr3_len'] for s in seqs if s['cdr3_len'] > 0 and s['cdr3_len'] <= max_len]
    x, y = aggregate(cdr3s, keys=range(1, max_len + 1))
    color = sns.hls_palette(7)[4]
    plot_file = os.path.join(output_dir, '{}_{}_cdr3_lengths.pdf'.format(collection, chain))
    x_title = 'CDR3 Length (AA)'
    y_title = 'Frequency (%)'
    size = (9, 4) if chain == 'heavy' else (6, 4)
    make_barplot([str(i) for i in x], y,
                 color,
                 plot_file,
                 xlabel=x_title,
                 ylabel=y_title,
                 grid=True,
                 size=size,
                 xfontsize=7)
示例#11
0
文件: stats.py 项目: briney/abtools
def cdr3_length_plot(seqs, fig_file=None, max_len=40, chain='heavy', color=None):
    if chain == 'light':
        chain = ['kappa', 'lambda']
    else:
        chain = [chain, ]
    seqs = [s for s in seqs if s['chain'] in chain]
    cdr3s = [s['cdr3_len'] for s in seqs if s['cdr3_len'] > 0 and s['cdr3_len'] <= max_len]
    x, y = _aggregate(cdr3s, keys=range(1, max_len + 1))
    color = color if color is not None else sns.hls_palette(7)[4]
    x_title = 'CDR3 Length (AA)'
    y_title = 'Frequency (%)'
    size = ((max_len / 5.0) + 1, 4)
    _make_barplot([str(i) for i in x], y,
                 color,
                 fig_file=fig_file,
                 xlabel=x_title,
                 ylabel=y_title,
                 grid=True,
                 size=size,
                 xfontsize=7)
示例#12
0
def plot():
    # read the genre submitted from request
    app.genre = request.form['genre']

    # Plot the graph
    fig = plt.figure(figsize=(6, 4))
    plt.style.use('seaborn-darkgrid')
    clr = sns.hls_palette(1, l=.7, s=.6)  # color for each plot
    gspec = gridspec.GridSpec(4, 4)

    top_hist = plt.subplot(gspec[0, 1:])
    side_hist = plt.subplot(gspec[1:, 0])
    lower_right = plt.subplot(gspec[1:, 1:])

    # Variables
    genre = app.genre
    ind = app.data[genre]
    IMDB_rating = app.data[ind]['IMDBrating']
    DB_rating = app.data[ind]['DBrating']

    lower_right.scatter(IMDB_rating, DB_rating, color=clr, s=4)
    top_hist.hist(IMDB_rating, bins=31, color=clr)
    side_hist.hist(DB_rating, bins=31, orientation='horizontal', color=clr)
    side_hist.invert_xaxis()
    identity = np.arange(0, 11)
    lower_right.plot(identity, identity, color='black', linewidth=1)

    for ax in [top_hist, lower_right]:
        ax.set_xlim(1, 10)
    for ax in [side_hist, lower_right]:
        ax.set_ylim(1, 10)
    lower_right.set_xlabel('IMDB Rating(' + genre + ')', fontsize='large')
    side_hist.set_ylabel('Douban Rating(' + genre + ')', fontsize='large')

    plt.tight_layout()

    # Convert plot to html script using mpld3
    app.plot = mpld3.fig_to_html(fig)
    return render_template('plot_page.html',
                           fig_display=app.plot)  #, fig1_display=fig1)
示例#13
0
def draw_graph(graph, hard_assign, args):
    smiles = graph.graph['smiles']
    molecule = Chem.MolFromSmiles(smiles)
    assert molecule is not None
    rdDepictor.Compute2DCoords(molecule)

    palette = np.array(sns.hls_palette(hard_assign.max() + 1))

    atom_index = list(range(len(graph.nodes)))
    undirected_edges = np.array([(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in molecule.GetBonds()])
    non_cut_edges_indices = np.nonzero(hard_assign[undirected_edges[:, 0]] == hard_assign[undirected_edges[:, 1]])[0]
    bond_colors = palette[hard_assign[undirected_edges[non_cut_edges_indices][:, 0]]]
    bond_colors = list(map(tuple, bond_colors))
    atom_colors = list(map(tuple, palette[hard_assign]))

    atom_id_to_color_dict = dict(zip(atom_index, atom_colors))
    non_edge_idx_to_color_dict = dict(zip(non_cut_edges_indices.tolist(), bond_colors))

    if args.svg:
        drawer = rdMolDraw2D.MolDraw2DSVG(1200, 600)
    else:
        drawer = rdMolDraw2D.MolDraw2DCairo(1200, 600)
    drawer.DrawMolecule(
        molecule,
        highlightAtoms=atom_index,
        highlightBonds=non_cut_edges_indices.tolist(),
        highlightAtomColors=atom_id_to_color_dict,
        highlightBondColors=non_edge_idx_to_color_dict,
        highlightAtomRadii=dict(zip(atom_index, [0.1] * len(atom_index)))
    )
    drawer.FinishDrawing()
    if args.svg:
        img = drawer.GetDrawingText().replace('svg:','')
        #================write to files============================
    else:
        txt = drawer.GetDrawingText()
        img = Image.open(io.BytesIO(txt))
        img = np.asarray(img)

    return img
示例#14
0
def features_correlation_analysis(data_excel):
    write_dir = create_results_directory(
        './results/cross_correlation_analysis')
    try:
        del mpl.font_manager.weight_dict['roman']
        mpl.font_manager._rebuild()
    except KeyError:
        pass
    sns.set(style='ticks')
    mpl.rc('font', family='Times New Roman')

    df = pd.read_excel(data_excel, index_col=0, sheet_name='features')
    df_labels = pd.read_excel(data_excel, index_col=0, sheet_name='cutoff')
    working_range = df_labels.iloc[:, -1].values - df_labels.iloc[:, -2].values
    df.insert(loc=df.shape[-1] - 3,
              column='Working Range',
              value=working_range)
    df1 = df[df.iloc[:, -3] == 1].iloc[:, :-3]
    df2 = df[df.iloc[:, -2] == 1].iloc[:, :-3]
    df3 = df[df.iloc[:, -1] == 1].iloc[:, :-3]

    x_store = [
        'CNT Mass Percentage', 'PVA Mass Percentage', 'Thickness nm',
        'Mxene Mass Percentage'
    ]
    mypal = sns.hls_palette(4, l=.3, s=.8)

    for dimension, df in enumerate([df1, df2, df3]):
        df['Mxene Mass Percentage'] = 1 - df.iloc[:, 0] - df.iloc[:, 1]
        for x, color in zip(x_store, mypal):
            plt.close()
            sns.jointplot(x=x,
                          y='Working Range',
                          data=df,
                          alpha=0.3,
                          color=color,
                          stat_func=stat.pearsonr)
            plt.savefig('{}/{}_dim_{}.png'.format(write_dir, x, dimension),
                        bbox_inches='tight')
示例#15
0
	def create_hls_palette(self, name, *args, **kwargs):
		if name in self._palettes:
			print "[SeabornColors::create_palette] ERROR : Palette " + name + " already exists"
			sys.exit(1)
		palette_dir = kwargs.pop("palette_dir", self._default_palette_dir)
		from seaborn import hls_palette
		if palette_dir == "":
			palette_path = self._default_palette_dir + "/" + name + ".pkl"
		else:
			palette_path = palette_dir + "/" + name + ".pkl"
		print "\nColor palette " + name
		print "args"
		print args
		print "kwargs"
		print kwargs
		this_palette = hls_palette(*args, **kwargs)
		this_palette_list = [x for x in this_palette]
		print this_palette_list
		print "Saving to " + palette_path
		pickle.dump(this_palette_list, open(palette_path, 'w'))

		self.load_palette(name, palette_dir=palette_dir)
示例#16
0
def discrete_colors(
    segs,
    palette=DEFAULT_PALETTE,
    h=DEFAULT_H,
    l=DEFAULT_L,
    s=DEFAULT_S,
):
    """Generate discrete colors for segmentations from a palette
    generator. Defaults to perceptually uniform differences with
    high saturation.

    Parameters
    ----------
    segs : list or dict
        Dict or list of segmentations to provide colors for.
    palette : 'husl', 'hls', or str, optional
        Which palette system to use, by default 'husl'. Will
        accept anything allowed by seaborn color_palette function.
    h : float, optional
        Hue value if husl or hls palettes are used, by default 0.01
    l : float, optional
        Lightness if husl or hls palettes are used, by default 0.6
    s : int, optional
        Saturation if husl or hls palettes are used, by default 1

    Returns
    -------
    List or dict
        List or dict with one color per segmentation.
    """
    if palette == 'husl':
        colors = husl_palette(len(segs), h=h, s=s, l=l)
    elif palette == 'hls':
        colors = hls_palette(len(segs), h=h, s=s, l=l)
    else:
        colors = color_palette(n_colors=len(segs), palette=palette)
    if isinstance(segs, dict):
        colors = {k: c for k, c in zip(segs.keys(), colors)}
    return colors
示例#17
0
def labs_to_cmap(labels,
                 return_lut=False,
                 shuffle_colors=False,
                 random_state=None):
    np.random.seed(random_state)
    # Each label has its own index and color
    mtype.check_is_valid_labs(labels)

    labels = np.array(labels)
    uniq_lab_arr = np.unique(labels)
    num_uniq_labs = len(uniq_lab_arr)
    uniq_lab_inds = list(range(num_uniq_labs))

    lab_col_list = list(sns.hls_palette(num_uniq_labs))
    if shuffle_colors:
        np.random.shuffle(lab_col_list)

    lab_cmap = mpl.colors.ListedColormap(lab_col_list)
    # Need to keep track the order of unique labels, so that a labeled
    # legend can be generated.
    # Map unique label indices to unique labels
    uniq_lab_lut = dict(zip(range(num_uniq_labs), uniq_lab_arr))
    # Map unique labels to indices
    uniq_ind_lut = dict(zip(uniq_lab_arr, range(num_uniq_labs)))
    # a list of label indices
    lab_ind_arr = np.array([uniq_ind_lut[x] for x in labels])

    # map unique labels to colors
    # Used to generate legends
    lab_col_lut = dict(
        zip([uniq_lab_lut[i] for i in range(len(uniq_lab_arr))], lab_col_list))
    # norm separates cmap to difference indices
    # https://matplotlib.org/tutorials/colors/colorbar_only.html
    lab_norm = mpl.colors.BoundaryNorm(uniq_lab_inds + [lab_cmap.N],
                                       lab_cmap.N)
    if return_lut:
        return lab_cmap, lab_norm, lab_ind_arr, lab_col_lut, uniq_lab_lut
    else:
        return lab_cmap, lab_norm
def my_roc_curve(prefix,datatype):
    tprs = []
    aucs = []
    mean_fpr = np.linspace(0,1,100)
    sns.set(font_scale = 1.2)
    colors = sns.hls_palette(10)
    fig, ax = plt.subplots()
    fig.set_size_inches(12.0, 8.0)
    for i in range(10):
        testfile = prefix+str(i)+".txt"
        test = ft.read_composite_data(testfile)
        test_shuffled = test.sample(frac=1)
        seq_test,secstr_test,label_test = ft.get_all_seq(test)
        testX2,testY = ft.onehotkey(seq_test,label_test)
        testX1,testY = ft.onehotkey_sec(secstr_test,label_test)
        testY = np_utils.to_categorical(testY,2)
        testY = testY.reshape(-1,2)
        row1,col1 = testX1[0].shape
        row2,col2 = testX2[0].shape
        testX1.shape = (testX1.shape[0],row1,col1)
        testX2.shape = (testX2.shape[0],row2,col2)
        cnn = models.load_model('%d-secstr_seq_denseconcat_60perc_best.h5' % i) #0-yeast_pretrain-merge.h5
        color = colors[i]
        probas_ = cnn.predict([testX1,testX2])
        fpr,tpr,thresholds = roc_curve(testY[:,1],probas_[:,1])
        tprs.append(np.interp(mean_fpr,fpr,tpr))
        roc_auc = auc(fpr,tpr)
        aucs.append(roc_auc)
        plt.plot(fpr, tpr, lw=2,label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))

    plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='k',label='Base line')
    mean_tpr = np.mean(tprs,axis=0)
    mean_auc = auc(mean_fpr,mean_tpr)
    plt.plot(mean_fpr, mean_tpr, color='g', linestyle='--',label='Mean ROC (area = %0.2f)' % mean_auc, lw=2)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(str(datatype) +' Human data 10fold ROC Curve')
    plt.legend(loc="lower right")
    plt.savefig(str(datatype) + '_Human_10fold_ROC_curve.tiff')
示例#19
0
def palette(nb, hls=HLS, viridis=VIRIDIS):
    """ Use a smart palette from seaborn, for nb different plots on the same figure.

    - Ref: http://seaborn.pydata.org/generated/seaborn.hls_palette.html#seaborn.hls_palette

    >>> palette(10, hls=True)  # doctest: +ELLIPSIS
    [(0.86..., 0.37..., 0.33...), (0.86...,.65..., 0.33...), (0.78..., 0.86...,.33...), (0.49..., 0.86...,.33...), (0.33..., 0.86...,.46...), (0.33..., 0.86...,.74...), (0.33..., 0.68..., 0.86...) (0.33..., 0.40..., 0.86...) (0.56..., 0.33..., 0.86...) (0.84..., 0.33..., 0.86...)]
    >>> palette(10, hls=False)  # doctest: +ELLIPSIS
    [[0.96..., 0.44..., 0.53...], [0.88..., 0.52..., 0.19...], [0.71..., 0.60..., 0.19...], [0.54..., 0.65..., 0.19...], [0.19..., 0.69..., 0.34...], [0.20..., 0.68..., 0.58...],[0.21..., 0.67..., 0.69...], [0.22..., 0.65..., 0.84...], [0.55..., 0.57..., 0.95...], [0.85..., 0.44..., 0.95...]]
    >>> palette(10, viridis=True)  # doctest: +ELLIPSIS
    [(0.28..., 0.13..., 0.44...), (0.26..., 0.24..., 0.52...), (0.22..., 0.34..., 0.54...), (0.17..., 0.43..., 0.55...), (0.14..., 0.52..., 0.55...), (0.11..., 0.60..., 0.54...), (0.16..., 0.69..., 0.49...), (0.31..., 0.77..., 0.41...), (0.52..., 0.83..., 0.28...), (0.76..., 0.87..., 0.13...)]

    - To visualize:

    >>> sns.palplot(palette(10, hls=True))  # doctest: +SKIP
    >>> sns.palplot(palette(10, hls=False))  # use HUSL by default  # doctest: +SKIP
    >>> sns.palplot(palette(10, viridis=True))  # doctest: +SKIP
    """
    if viridis:
        return sns.color_palette('viridis', nb)
    else:
        return sns.hls_palette(nb + 1)[:nb] if hls else sns.husl_palette(nb + 1)[:nb]
示例#20
0
def model_vs_model_plot(dataset_dir, legend):
    sns.set_style("darkgrid")
    sns.set_context("notebook", font_scale=1.5)
    results = create_results_df(dataset_dir)
    unique_models = set(results.model)
    for first_model in unique_models:
        other_models = unique_models - set([first_model])
        for second_model in other_models:
            corr_tabel = results.loc[(results.model == first_model) | (results.model == second_model)]\
                .pivot(index='sub', columns='model', values='corr')
            corr_tabel['sub'] = corr_tabel.index
            hue = 'sub' if legend else None
            palette = sns.hls_palette(len(set(corr_tabel['sub']))) if legend else None

            g = sns.lmplot(x=first_model, y=second_model, hue=hue, size=6, data=corr_tabel, fit_reg=False,
                           legend=False, palette=palette, scatter_kws={'linewidths': 1, 'edgecolors': 'gray'})

            g.ax.plot([0, 1], [0, 1], linestyle=':', color='tab:gray')
            g.set(xticks=np.arange(0, 1.2, 0.2).squeeze().tolist(), xlim=(0, 1))
            g.set(yticks=np.arange(0.2, 1.2, 0.2).squeeze().tolist(), ylim=(0, 1))
            if legend:
                g.ax.legend(title='Recording', loc='center', bbox_to_anchor=(1.25, 0.5))
            g.savefig(os.path.join(dataset_dir,  first_model + '_vs_' + second_model + '.png'))
示例#21
0
def distributions(df):
    sns.set(style='darkgrid')
    varlist = df.columns.drop(['asset', 'unixtime'])

    fig, axs = plt.subplots(7, 4, figsize=(10, 10))
    for ax, n in zip(axs.flat, range(len(varlist))):
        subset = df[varlist[n]]
        bot = subset.min()
        top = subset.max()
        sns.violinplot(data=subset.values,
                       ax=ax,
                       linewidth=.8,
                       bw=.1,
                       color=sns.hls_palette(len(varlist) + 1, l=.7, s=.5)[n])
        ax.set(xticks=(), yticks=(bot, np.mean([bot, top]), top))
        ax.set_yticklabels(
            [np.round(bot, 2), '', np.round(top, 2)], fontsize=8)
        ax.set_title(varlist[n], fontweight='bold', fontsize=9)

    sns.despine(left=True, bottom=True)
    fig.suptitle(f'Variable Distributions', fontsize=20)
    fig.tight_layout(rect=(0, 0, 1, 0.95))
    fig.savefig('reports/distributions.png')
示例#22
0
def task_plot(task_dirs, kind, palette, legend):
    sns.set_style("darkgrid")
    sns.set_context("notebook", font_scale=3.2)
    results = create_results_df(task_dirs)
    order = ['ABSVEL', 'XVEL', 'ABSPOS', 'XPOS']
    palette = sns.hls_palette(len(set(
        results.day.values))) if legend else palette
    hue = 'day' if legend else None
    g = sns.FacetGrid(results,
                      col='task',
                      sharex=False,
                      col_order=order,
                      size=16,
                      aspect=0.6)
    g = (g.map(exp_plot, 'model', 'corr', hue,
               palette=palette).set_titles(col_template="{col_name}"))
    if legend:
        g.add_legend(title='Recording', fontsize='medium', markerscale=1.5)
    g.set_xlabels('')
    g.set_ylabels('Corr. Coeff.')
    parent = Path(task_dirs[0]).parents[1]
    g.savefig(os.path.join(parent, kind + '_model_per_task.png'),
              bbox_inches='tight')
示例#23
0
def plot_tsne_2D(embeddings_df, result_df, n_iter, n_clusters):

    time_start = time.time()

    tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=n_iter)
    tsne_results = tsne.fit_transform(embeddings_df.values)

    print('t-SNE done in {} seconds'.format(time.time() - time_start))

    result_df['tsne-2d-one'] = tsne_results[:, 0]
    result_df['tsne-2d-two'] = tsne_results[:, 1]

    cluster_ids = result_df['topic_cluster'].values
    tsne_data = pd.DataFrame({'topic': cluster_ids})

    plt.figure(figsize=(16, 10))
    sns.scatterplot(x=result_df["tsne-2d-one"],
                    y=result_df["tsne-2d-two"],
                    hue=result_df['topic_cluster'],
                    palette=sns.hls_palette(n_clusters, l=.5, s=.9),
                    data=tsne_data,
                    legend="full",
                    alpha=0.3)
示例#24
0
def plot_dataset(X, y, ax=None, title=None, **params):

    X_t=pd.DataFrame(X)
    y=pd.DataFrame(y)
    XY=pd.concat([X_t,y], axis=1)
    XY.columns = ['x', 'y', 'digit']

    customPalette=sns.hls_palette(10, l=.7, s=1)

    ax.set_xlim(XY['x'].min()*0.98,XY['x'].max()*1.02)
    ax.set_ylim(XY['y'].min()*0.98,XY['y'].max()*1.02)

    for i in range(0,10):
        k=0
        for l, row in XY.loc[XY['digit']==i,:].iterrows():
            if k==0: 
                ax.plot(row['x'], row['y'], '.', color=customPalette[i], label=int(row['digit']))
                ax.legend(numpoints=1, markerscale=3, loc='upper right', bbox_to_anchor=(1.2, 1.0))
                k += 1    
            ax.annotate(int(row['digit']), (row['x'], row['y']), horizontalalignment='center', verticalalignment='center', size=10, color=customPalette[i])
            ax.set_title(title, fontsize=15)
            
    return ax
示例#25
0
def make_sim_plot(counts, bins, median, ofile, output_dir):
    sns.set_style('white')
    bin_mdpt = (bins[1] - bins[0]) / 2
    x = [str(round(b + bin_mdpt, 4)) for b in bins[:-1]]
    y = counts
    # set bar locations and width
    ind = np.arange(len(x))
    width = 0.75
    # plot objects
    fig = plt.figure()
    ax = fig.add_subplot(111)
    # plot aesthetics
    color = sns.hls_palette(7)[4]
    # axis limits, labels and ticks
    ax.set_ylim(0, 1.1 * max(y))
    ax.set_xlim(-width / 2, len(ind))
    ax.set_xticks(ind + width / 2)
    xtick_names = ax.set_xticklabels(x)
    ax.set_xlabel('Similarity index')
    ax.set_ylabel('Frequency')
    # make the plot
    bar = ax.bar(ind, y, width, color='lightgray')
    # plot the median line
    adj_median = ((median - bins[0]) / (bins[-1] - bins[0]) * len(x))
    plt.axvline(x=adj_median, ymin=0, ymax=1, linewidth=1.5, color='r', linestyle='dashed')
    text_xpos = adj_median + 0.2 if adj_median < 6.5 else adj_median - 2.2
    text_ypos = 1.04 * max(y)
    med = round(median, 4)
    med_str = str(med) if len(str(med)) == 6 else str(med) + '0' * (6 - len(str(med)))
    ax.text(text_xpos,
            text_ypos,
            'median: {}'.format(med_str),
            color='r',
            weight='semibold')
    # save the final figure
    plt.savefig(os.path.join(output_dir, ofile))
    plt.close()
示例#26
0
def cdr3_plot(sequences, outfile=None, chain='heavy'):
    max_len = 40 if chain == 'heavy' else 20
    size = (9, 4) if chain == 'heavy' else (6, 4)
    chain = ['kappa', 'lambda'] if chain == 'light' else [
        chain,
    ]
    sequences = [s for s in sequences if s['chain'] in chain]
    cdr3s = [
        s['cdr3_len'] for s in sequences
        if s['cdr3_len'] > 0 and s['cdr3_len'] <= max_len
    ]
    x, y = _aggregate(cdr3s, keys=list(range(1, max_len + 1)))
    color = sns.hls_palette(7)[4]
    x_title = 'CDR3 Length (AA)'
    y_title = 'Frequency (%)'
    _make_barplot([str(i) for i in x],
                  y,
                  color,
                  outfile,
                  xlabel=x_title,
                  ylabel=y_title,
                  grid=True,
                  size=size,
                  xfontsize=7)
示例#27
0
def growth_icu_country(country='Germany', icu_beds=28000, icu_free=0.2):
    #Country ICU capacity, defaults to Germany
    sns.set_palette(sns.hls_palette(8, l=.45, s=.8)) # 8 countries max
    fig, ax = plt.subplots(figsize=(12, 8))
    p_crit = .05

    df_tmp = df.loc[lambda x: (x.country == country) & (x.confirmed > 100)].critical_estimate
    df_tmp.plot(ax=ax)

    x = np.linspace(0, 30, 30)
    pd.Series(index=pd.date_range(df_tmp.index[0], periods=30),
            data=100*p_crit * (1.33) ** x).plot(ax=ax,ls='--', color='k', label='33% daily growth')

    ax.axhline(icu_beds, color='.3', ls='-.', label='Total ICU beds')
    ax.axhline(icu_beds * icu_free, color='.5', ls=':', label='Free ICU beds')
    ax.set(yscale='log',
        title=f'When will {country} run out of ICU beds?',
        ylabel='Expected critical cases (assuming {:.0f}% critical)'.format(100 * p_crit),
    )
    ax.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    ax.legend(bbox_to_anchor=(1.0, 1.0))
    sns.despine()
    ax.annotate(**annotate_kwargs);
    plt.savefig(CURR_DIR+f'/images/growth_{country}_icu_beds.png')
示例#28
0
def stack_barplot(sigCelltypedf, args, plotdf, path):
    import seaborn as sns
    sns.set_context("paper")
    sns.despine()
    plotDF = pd.DataFrame()
    if not args['key_celltype_list']:
        sigCelltypedf = sigCelltypedf.sort_values('p-val', ascending=True)
        celltypes = 0
        for k, v in sigCelltypedf.iterrows():
            if (v['genecluster'] > 20) and (celltypes < 15) and (v['celltype'] in list(plotdf.index)):
                plotDF = plotDF.append(plotdf.loc[v['celltype']])
            celltypes += 1
    else:
        plotDF = plotdf
    #print(plotDF)
    plotDF = plotDF.sort_index(axis=0, ascending=True)
    # Join three palette in case of more than 8 items
    if len(plotDF) > 8:
        d_colors = sns.hls_palette(8, l=.3, s=.7)
        d_colors.pop(3)
        d_colors1 = sns.hls_palette(8, l=.5, s=1)
        d_colors1.pop(3)
        d_colors2 = sns.hls_palette(8, l=.7, s=.8)
        d_colors2.pop(3)
        n_d_colors = []
        for i in range(0,len(d_colors)):
            n_d_colors.append(d_colors[i])
            n_d_colors.append(d_colors1[i])
            n_d_colors.append(d_colors2[i])
            #print(len(n_d_colors))
    else:
        n_d_colors = sns.color_palette("hls", 8)

    # Scaling plotDF for stack bar plotting to 1
    #print(plotDF)
    df1 = plotDF.div(plotDF.sum(axis=0), axis=1)
    # Initialize the vertical-offset for the stacked bar chart.
    y_offset = np.array([0.0] * len(df1.columns))
    ind = np.arange(len(df1.columns))    # the x locations for the groups
    width = 0.5 # the width of the bars: can also be len(x) sequence
    plt.figure(figsize=[10,6])
    handls = []
    i = 0
    for k, v in df1.iterrows():
        h = plt.bar(ind+0.05, list(v), width, bottom=y_offset, color=n_d_colors[i], edgecolor='w')#alpha=0.7
        handls.append(h[0])
        y_offset = y_offset + v
        i += 1
    plt.ylabel('Cell fractions')
    plt.xlabel('Samples')
    plt.title("Relative change in celltype expression")
    plt.ylim(0,1)
    plt.xlim(-0.2,len(df1.columns))
    plt.xticks(ind + width/2. + 0.05, list(df1.columns), rotation=45)
    plt.yticks(np.arange(0, 1, 0.1))
    lgd = plt.legend(tuple(handls[::-1]), tuple(list(df1.index)), loc='center right', bbox_to_anchor=(1.2, 0.5))
    plt.tick_params(direction='out')
    plt.tight_layout()
    plt.savefig(os.path.join(path, 'GCAM_stacks.svg'), bbox_extra_artists=(lgd, ), bbox_inches='tight')
    plt.clf()
    plt.close()
示例#29
0
from scipy.optimize import curve_fit, lsq_linear
from scipy.interpolate import interp1d, interp2d, griddata, Akima1DInterpolator
import matplotlib as mpl
import matplotlib.pyplot as plt
import ROOT as rt
import seaborn as sns

# Set the figure size to span the notebook.
mpl.rcParams['figure.figsize'] = (18, 6)
mpl.style.use('fivethirtyeight')
mpl.rcParams['axes.grid'] = False
mpl.rcParams['axes.facecolor'] = '#ffffff'
mpl.rcParams['figure.facecolor'] = '#ffffff'
mpl.rcParams['image.cmap'] = 'viridis'

colors = sns.hls_palette(10, l=0.3, s=0.8)

mp_name = {}
mp_name[0] = 'Dipole'
mp_name[1] = 'Normal Quadrupole'
mp_name[2] = 'Skew Quadrupole'
mp_name[3] = 'Normal Sextupole'
mp_name[4] = 'Skew Sextupole'
mp_name[5] = 'Normal Octupole'
mp_name[6] = 'Skew Octupole'
mp_name[7] = 'Normal Decupole'
mp_name[8] = 'Skew Decupole'

# Change to the output directory if it's defined.
try:
    os.chdir(os.environ['G2_SHIMMING_OUTPUT_PATH'])
def hex_palette(n):
    pal = sns.hls_palette(n).as_hex()
    return lambda c: pal[c]
                    'Warner Bros.', 'Walt Disney Pictures', 
                    'Dreamworks SKG', 'Paramount Pictures',  
                    'Universal', 'Sony Pictures']
df_top_dstr = df_in_years[df_in_years.Distributor.isin(top_dstr)]
df_genre_pivot = df_top_dstr.pivot_table(
                    values='Worldwide Gross', columns='Major Genre', 
                    index='Distributor', aggfunc=np.mean)

# Step 2: Make the plot
df_genre_pivot.plot(kind='barh', stacked=True);


# In[37]:

# Use a custom palette to avoid repeating the colors.
with sns.hls_palette(12):
    df_genre_pivot.plot(kind='barh', stacked=True);

# Add a label to the x-axis
plt.xlabel('Mean Worldwide Gross');


# ### Exercise:
# 1. Create a stacked bar chart, comparing the __Total Production Budget__ of the films by top distributors over the years of 1991-2010. Using __MPAA Rating__ as hue. <br>_Hint_: You may reuse `df_top_dstr` create above. Start with creating a pivot_table.
# 2. Answer the questions in __Worksheet Problem 4__.

# In[38]:

df_rating_pivot = df_top_dstr.pivot_table(
    values='Production Budget', index='Distributor', 
    columns='MPAA Rating', aggfunc=np.sum)
def plot_state_qsphere(
    state,
    figsize=None,
    ax=None,
    show_state_labels=True,
    show_state_phases=False,
    use_degrees=False,
    *,
    rho=None,
    filename=None,
):
    """Plot the qsphere representation of a quantum state.
    Here, the size of the points is proportional to the probability
    of the corresponding term in the state and the color represents
    the phase.

    Args:
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
        figsize (tuple): Figure size in inches.
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
            the visualization output. If none is specified a new matplotlib
            Figure will be created and used. Additionally, if specified there
            will be no returned Figure since it is redundant.
        show_state_labels (bool): An optional boolean indicating whether to
            show labels for each basis state.
        show_state_phases (bool): An optional boolean indicating whether to
            show the phase for each basis state.
        use_degrees (bool): An optional boolean indicating whether to use
            radians or degrees for the phase values in the plot.

    Returns:
        Figure: A matplotlib figure instance if the ``ax`` kwarg is not set

    Raises:
        MissingOptionalLibraryError: Requires matplotlib.
        VisualizationError: if input is not a valid N-qubit state.

        QiskitError: Input statevector does not have valid dimensions.

    Example:
        .. jupyter-execute::

           from qiskit import QuantumCircuit
           from qiskit.quantum_info import Statevector
           from qiskit.visualization import plot_state_qsphere
           %matplotlib inline

           qc = QuantumCircuit(2)
           qc.h(0)
           qc.cx(0, 1)

           state = Statevector.from_instruction(qc)
           plot_state_qsphere(state)
    """
    if not HAS_MATPLOTLIB:
        raise MissingOptionalLibraryError(
            libname="Matplotlib",
            name="plot_state_qsphere",
            pip_install="pip install matplotlib",
        )

    import matplotlib.gridspec as gridspec
    from matplotlib import pyplot as plt
    from matplotlib.patches import Circle
    from qiskit.visualization.bloch import Arrow3D

    try:
        import seaborn as sns
    except ImportError as ex:
        raise MissingOptionalLibraryError(
            libname="seaborn",
            name="plot_state_qsphere",
            pip_install="pip install seaborn",
        ) from ex
    rho = DensityMatrix(state)
    num = rho.num_qubits
    if num is None:
        raise VisualizationError("Input is not a multi-qubit quantum state.")
    # get the eigenvectors and eigenvalues
    eigvals, eigvecs = linalg.eigh(rho.data)

    if figsize is None:
        figsize = (7, 7)

    if ax is None:
        return_fig = True
        fig = plt.figure(figsize=figsize)
    else:
        return_fig = False
        fig = ax.get_figure()

    gs = gridspec.GridSpec(nrows=3, ncols=3)

    ax = fig.add_subplot(gs[0:3, 0:3], projection="3d")
    ax.axes.set_xlim3d(-1.0, 1.0)
    ax.axes.set_ylim3d(-1.0, 1.0)
    ax.axes.set_zlim3d(-1.0, 1.0)
    ax.axes.grid(False)
    ax.view_init(elev=5, azim=275)

    # Force aspect ratio
    # MPL 3.2 or previous do not have set_box_aspect
    if hasattr(ax.axes, "set_box_aspect"):
        ax.axes.set_box_aspect((1, 1, 1))

    # start the plotting
    # Plot semi-transparent sphere
    u = np.linspace(0, 2 * np.pi, 25)
    v = np.linspace(0, np.pi, 25)
    x = np.outer(np.cos(u), np.sin(v))
    y = np.outer(np.sin(u), np.sin(v))
    z = np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(
        x, y, z, rstride=1, cstride=1, color=plt.rcParams["grid.color"], alpha=0.2, linewidth=0
    )

    # Get rid of the panes
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

    # Get rid of the spines
    ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

    # Get rid of the ticks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # traversing the eigvals/vecs backward as sorted low->high
    for idx in range(eigvals.shape[0] - 1, -1, -1):
        if eigvals[idx] > 0.001:
            # get the max eigenvalue
            state = eigvecs[:, idx]
            loc = np.absolute(state).argmax()
            # remove the global phase from max element
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
            angleset = np.exp(-1j * angles)
            state = angleset * state

            d = num
            for i in range(2 ** num):
                # get x,y,z points
                element = bin(i)[2:].zfill(num)
                weight = element.count("1")
                zvalue = -2 * weight / d + 1
                number_of_divisions = n_choose_k(d, weight)
                weight_order = bit_string_index(element)
                angle = (float(weight) / d) * (np.pi * 2) + (
                    weight_order * 2 * (np.pi / number_of_divisions)
                )

                if (weight > d / 2) or (
                    (weight == d / 2) and (weight_order >= number_of_divisions / 2)
                ):
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)

                xvalue = np.sqrt(1 - zvalue ** 2) * np.cos(angle)
                yvalue = np.sqrt(1 - zvalue ** 2) * np.sin(angle)

                # get prob and angle - prob will be shade and angle color
                prob = np.real(np.dot(state[i], state[i].conj()))
                prob = min(prob, 1)  # See https://github.com/Qiskit/qiskit-terra/issues/4666
                colorstate = phase_to_rgb(state[i])

                alfa = 1
                if yvalue >= 0.1:
                    alfa = 1.0 - yvalue

                if not np.isclose(prob, 0) and show_state_labels:
                    rprime = 1.3
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue ** 2), zvalue)
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
                    zvalue_text = rprime * np.cos(angle_theta)
                    element_text = "$\\vert" + element + "\\rangle$"
                    if show_state_phases:
                        element_angle = (np.angle(state[i]) + (np.pi * 4)) % (np.pi * 2)
                        if use_degrees:
                            element_text += "\n$%.1f^\\circ$" % (element_angle * 180 / np.pi)
                        else:
                            element_angle = pi_check(element_angle, ndigits=3).replace("pi", "\\pi")
                            element_text += "\n$%s$" % (element_angle)
                    ax.text(
                        xvalue_text,
                        yvalue_text,
                        zvalue_text,
                        element_text,
                        ha="center",
                        va="center",
                        size=12,
                    )

                ax.plot(
                    [xvalue],
                    [yvalue],
                    [zvalue],
                    markerfacecolor=colorstate,
                    markeredgecolor=colorstate,
                    marker="o",
                    markersize=np.sqrt(prob) * 30,
                    alpha=alfa,
                )

                a = Arrow3D(
                    [0, xvalue],
                    [0, yvalue],
                    [0, zvalue],
                    mutation_scale=20,
                    alpha=prob,
                    arrowstyle="-",
                    color=colorstate,
                    lw=2,
                )
                ax.add_artist(a)

            # add weight lines
            for weight in range(d + 1):
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
                z = -2 * weight / d + 1
                r = np.sqrt(1 - z ** 2)
                x = r * np.cos(theta)
                y = r * np.sin(theta)
                ax.plot(x, y, z, color=(0.5, 0.5, 0.5), lw=1, ls=":", alpha=0.5)

            # add center point
            ax.plot(
                [0],
                [0],
                [0],
                markerfacecolor=(0.5, 0.5, 0.5),
                markeredgecolor=(0.5, 0.5, 0.5),
                marker="o",
                markersize=3,
                alpha=1,
            )
        else:
            break

    n = 64
    theta = np.ones(n)
    colors = sns.hls_palette(n)

    ax2 = fig.add_subplot(gs[2:, 2:])
    ax2.pie(theta, colors=colors[5 * n // 8 :] + colors[: 5 * n // 8], radius=0.75)
    ax2.add_artist(Circle((0, 0), 0.5, color="white", zorder=1))
    offset = 0.95  # since radius of sphere is one.

    if use_degrees:
        labels = ["Phase\n(Deg)", "0", "90", "180   ", "270"]
    else:
        labels = ["Phase", "$0$", "$\\pi/2$", "$\\pi$", "$3\\pi/2$"]

    ax2.text(0, 0, labels[0], horizontalalignment="center", verticalalignment="center", fontsize=14)
    ax2.text(
        offset, 0, labels[1], horizontalalignment="center", verticalalignment="center", fontsize=14
    )
    ax2.text(
        0, offset, labels[2], horizontalalignment="center", verticalalignment="center", fontsize=14
    )
    ax2.text(
        -offset, 0, labels[3], horizontalalignment="center", verticalalignment="center", fontsize=14
    )
    ax2.text(
        0, -offset, labels[4], horizontalalignment="center", verticalalignment="center", fontsize=14
    )

    if return_fig:
        matplotlib_close_if_inline(fig)
    if filename is None:
        return fig
    else:
        return fig.savefig(filename)
示例#33
0
def hls(n_colors, hue=0.01, lightness=0.6, saturation=0.65):
    return sns.hls_palette(n_colors, h=hue, l=lightness, s=saturation)
示例#34
0
import colorsys
import numpy as np
import seaborn as sns
from matplotlib import colors


width = 1440
height = 1080
n_colors = 24.


# HLS SPECTRUM
# ============
print "Angle step: {:g}".format(360/n_colors)
spectrum_array = sns.hls_palette(n_colors, h=0, l=.5, s=1)

# BEIGETONE
# =========
bgcolor = (90/360., 0.85, 0.15)
RGB = [
    map(lambda x: x*255, colorsys.hls_to_rgb(*bgcolor)),
    (176, 102, 96),
    (202, 143, 66),
    (171, 156, 115),
    (94, 119, 3),
    (106, 125, 142),
]
print np.array(colors.hex2color('#c2ccd6'))*255
print colorsys.rgb_to_hls(*colors.hex2color('#c2ccd6'))
print colorsys.rgb_to_hls(236/255., 236/255., 240/255.)
print np.array(colorsys.rgb_to_hls(*bgcolor))*255
示例#35
0
def plot_rm_corr(data=None, x=None, y=None, subject=None, legend=False,
                 kwargs_facetgrid=dict(height=4, aspect=1)):
    """Plot a repeated measures correlation.

    Parameters
    ----------
    data : :py:class:`pandas.DataFrame`
        Dataframe.
    x, y : string
        Name of columns in ``data`` containing the two dependent variables.
    subject : string
        Name of column in ``data`` containing the subject indicator.
    legend : boolean
        If True, add legend to plot. Legend will show all the unique values in
        ``subject``.
    kwargs_facetgrid : dict
        Optional keyword argument passed to :py:class:`seaborn.FacetGrid`

    Returns
    -------
    g : :py:class:`seaborn.FacetGrid`
        Seaborn FacetGrid.

    See also
    --------
    rm_corr

    Notes
    -----
    Repeated measures correlation [1]_ (rmcorr) is a statistical technique
    for determining the common within-individual association for paired
    measures assessed on two or more occasions for multiple individuals.

    Results have been tested against the
    `rmcorr <https://github.com/cran/rmcorr>` R package. Note that this
    function requires `statsmodels
    <https://www.statsmodels.org/stable/index.html>`_.

    Missing values are automatically removed from the ``data``
    (listwise deletion).

    References
    ----------
    .. [1] Bakdash, J.Z., Marusich, L.R., 2017. Repeated Measures Correlation.
           Front. Psychol. 8, 456. https://doi.org/10.3389/fpsyg.2017.00456

    Examples
    --------
    Default repeated mesures correlation plot

    .. plot::

        >>> import pingouin as pg
        >>> df = pg.read_dataset('rm_corr')
        >>> g = pg.plot_rm_corr(data=df, x='pH', y='PacO2', subject='Subject')

    With some tweakings

    .. plot::

        >>> import pingouin as pg
        >>> import seaborn as sns
        >>> df = pg.read_dataset('rm_corr')
        >>> sns.set(style='darkgrid', font_scale=1.2)
        >>> g = pg.plot_rm_corr(data=df, x='pH', y='PacO2',
        ...                     subject='Subject', legend=True,
        ...                     kwargs_facetgrid=dict(height=4.5, aspect=1.5,
        ...                                           palette='Spectral'))
    """
    # Check that stasmodels is installed
    from pingouin.utils import _is_statsmodels_installed
    _is_statsmodels_installed(raise_error=True)
    from statsmodels.formula.api import ols

    # Safety check (duplicated from pingouin.rm_corr)
    assert isinstance(data, pd.DataFrame), 'Data must be a DataFrame'
    assert x in data.columns, 'The %s column is not in data.' % x
    assert y in data.columns, 'The %s column is not in data.' % y
    assert data[x].dtype.kind in 'bfiu', '%s must be numeric.' % x
    assert data[y].dtype.kind in 'bfiu', '%s must be numeric.' % y
    assert subject in data.columns, 'The %s column is not in data.' % subject
    if data[subject].nunique() < 3:
        raise ValueError('rm_corr requires at least 3 unique subjects.')

    # Remove missing values
    data = data[[x, y, subject]].dropna(axis=0)

    # Calculate rm_corr
    # rmc = pg.rm_corr(data=data, x=x, y=y, subject=subject)

    # Fit ANCOVA model
    # https://patsy.readthedocs.io/en/latest/builtins-reference.html
    # C marks the data as categorical
    # Q allows to quote variable that do not meet Python variable name rule
    # e.g. if variable is "weight.in.kg" or "2A"
    formula = "Q('%s') ~ C(Q('%s')) + Q('%s')" % (y, subject, x)
    model = ols(formula, data=data).fit()

    # Fitted values
    data['pred'] = model.fittedvalues

    # Define color palette
    if 'palette' not in kwargs_facetgrid:
        kwargs_facetgrid['palette'] = sns.hls_palette(data[subject].nunique())

    # Start plot
    g = sns.FacetGrid(data, hue=subject, **kwargs_facetgrid)
    g = g.map(sns.regplot, x, "pred", scatter=False, ci=None, truncate=True)
    g = g.map(sns.scatterplot, x, y)

    if legend:
        g.add_legend()

    return g
示例#36
0
plt.figure(figsize=(25, 10))
sns.heatmap(df.describe().T.drop('count', axis=1), annot=True, cmap='seismic')
plt.title('Basic Statistics (Transposed)')

# - Above Plot / heatmap gives us basic statistical Information. Although I had embedded **describe( )** in heatmap so we could make it more appealing.
#
#
# - Generates descriptive statistics that summarize the central tendency, Inter-Quartile Range / Standard Deviation, dispersion and shape of a dataset’s distribution, excluding NaN values.

# In[9]:

#Correlation

plt.figure(figsize=(10, 10))
sns.pairplot(df, hue='Class', palette=sns.hls_palette(7, l=.3, s=.8))
plt.title('Pair Plot')

# **Correlation plot Observation :**
#
# - Above Plot provides us with pairwise relationships in a dataset.
#
#
# - It's important to note that we've less amount of features so it's quite easy to determine relationship in the Figure but if we'll have more features then pairplot can become more complex and it will be extremely difficult to make any observations out of it.
#
#
# - We've one more way to figure out correlation, Let's see how can we do it , Let us first Label Encode "Sex" and "Class"

# In[ ]:

#Label encode discrete values
def visualize_feature_hist_dist(X,y,selected_feature,features,normalize=False):
	"""
	Visualize the histogram distribution of a feature

	Keyword arguments:
	X -- The feature vectors
	y -- The target vector
	selected_feature -- The desired feature to obtain the histogram
	features -- Vector of feature names (X1 to XN)
	normalize -- Whether to normalize the histogram (Divide by total)
	"""

	#create data
	joint_data=np.column_stack((X,y))
	column_names=features

	#create dataframe
	df=pd.DataFrame(data=joint_data,columns=column_names)
	palette = sea.hls_palette()

	#find number of unique values (groups)
	unique_values=pd.unique(df[[selected_feature]].values.ravel())
	unique_values=map(int, unique_values)
	unique_values.sort()
	n_groups=len(unique_values)

	fig, ax = plt.subplots()
	index = np.arange(n_groups)
	bar_width = 0.35
	opacity = 0.4

	#find values belonging to the positive class and values belonging to the negative class
	positive_class_index=df[df[features[-1]] == 1].index.tolist()
	negative_class_index=df[df[features[-1]] != 1].index.tolist()

	positive_values=df[[selected_feature]].loc[positive_class_index].values.ravel()
	positive_values=map(int, positive_values)

	negative_values=df[[selected_feature]].loc[negative_class_index].values.ravel()
	negative_values=map(int, negative_values)

	#normalize data (divide by total)
	n_positive_labels=n_negative_labels=1
	if normalize==True:
		n_positive_labels=len(y[y==1])
		n_negative_labels=len(y[y!=1])

	#count
	positive_counts=[0]*len(index)
	negative_counts=[0]*len(index)
	for v in xrange(len(unique_values)):
		positive_counts[v]=positive_values.count(v)/n_positive_labels
		negative_counts[v]=negative_values.count(v)/n_negative_labels

	#plot
	plt.bar(index, positive_counts, bar_width,alpha=opacity,color='b',label='Default')			#class 1
	plt.bar(index+bar_width, negative_counts, bar_width,alpha=opacity,color='r',label='Paid')	#class 0

	plt.xlabel(selected_feature)
	plt.ylabel('Frequency')
	if normalize:
		plt.ylabel('Proportion')
	plt.title("Normalized Histogram Distribution of the feature '"+selected_feature+"' grouped by class")
	plt.xticks(index + bar_width, map(str, unique_values) )
	plt.legend()
	plt.tight_layout()

	# plt.show()

	#save fig
	output_dir = "img"
	save_fig(output_dir,'{}/{}_hist_dist.png'.format(output_dir,selected_feature))
示例#38
0
				min_err = err
				min_centers = centers
				min_clusters = clusters

		print num_clusters, min_err # Keep the user updated

	print "Found solution! With weighting function {0}, there are {1} clusters with error {2}. Plotting result now...".format(i_weight, num_clusters, min_err)

	if i_weight == 'lin': validation_scale = float(np.sum(validation_pop))
	elif i_weight == 'sq': validation_scale = np.sum(validation_pop**2)
	elif i_weight == 'sqrt': validation_scale = np.sum(np.sqrt(validation_pop))
	elif i_weight == 'log': validation_scale = np.sum(np.log(validation_pop))
	elif i_weight == 'max': validation_scale = max_pop
	else: validation_scale = 1
	validation_error = error(centers, validation_array[:,:2], validation_pop, avg=True, wt=i_weight, scale=validation_scale, max_pop=max_pop, threshold=thres)

	# Plotting takes a while
	palette = sns.hls_palette(num_clusters)
	plt.figure()
	i = 0
	for key in clusters.keys():
		cur_cluster = clusters[key]
		for pt in cur_cluster:
			plt.plot(pt[0], pt[1],'.',markersize=8, color=palette[i], alpha=pt[2]/max_pop)
		plt.plot(centers[key][0], centers[key][1],'ko', markersize=10)
		i += 1
		
	plt.axis('off')
	plt.annotate('Cost: %0.5f' % validation_error, xy=(15,20), xytext=(15,20), fontsize=30)
	plt.savefig('images/solution_{0}wt_clusters{1}_minpop{2}.png'.format(i_weight,num_clusters,thres))
示例#39
0
def find_width(b, hdu):
    shapelist = pyregion.ShapeList([b])
    m = shapelist.get_mask(hdu=hdu)
    box_bright = hdu.data[m].mean()
    width = 1
    dashed = 1
    for ib, blevel in enumerate(BRIGHT_LEVELS):
        if box_bright >= blevel:
            width = ib + 1
            dashed = 0
    return width, dashed


VMIN, VMAX = -110.0, 0.0
NC = int(VMAX - VMIN) + 1
rgblist = sns.hls_palette(NC)


def find_color(v):
    ic = int(VMAX - v)
    ic = max(0, min(ic, NC - 1))
    r, g, b = rgblist[ic]
    return "#{:01x}{:01x}{:01x}".format(int(16 * r), int(16 * g), int(16 * b))


box_files = glob.glob(os.path.join(REGION_DIR, BOX_PATTERN))

VLIMITS = {"all": [-200.0, 200.0], "slow": [-45.0, 0.0], "fast": [-80.0, -35.0], "ultra": [-150.0, -70.0]}

bar_lists = {"all": [], "slow": [], "fast": [], "ultra": []}
for box_file in box_files:
示例#40
0
earthtone = np.array([
    (73, 56, 41),
    (97, 51, 25),
    (213, 117, 0),
    (64, 79, 36),
    (78, 97, 114),
    (43, 43, 43)])
bgcolor = colorsys.hls_to_rgb(90/360., 15/16., 0/16.)

palette = earthtone/255.
hexcolor = mcl.rgb2hex(bgcolor)

# HSL
n_colors = 32.
print "Angle step: {:g}".format(360/n_colors)
palette_base = sns.hls_palette(n_colors, h=0, l=15/16., s=1/16.)
palette_base = sns.hls_palette(n_colors, h=0, l=3/16., s=12/16.)


fig, (ax) = plt.subplots(1, figsize=(1920/120., 1080/120.), tight_layout=False)
print (fig.get_figwidth(), fig.get_figheight())
fig.suptitle("{:s} - {:s}".format(
    hexcolor, str(np.array(mcl.hex2color(hexcolor))*255)), fontsize=20)
# fig.patch.set_facecolor(palette[-1])

ax.set_axis_bgcolor(earthtone[3]/255.)
ax.set_aspect('equal')
ax.set_xlim(0, n_colors)
ax.set_ylim(0, n_colors)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
    for c, motif_id in zip(numpy.arange(start, end+0.01, (end - start) / len(motifs)),motifs[motifs.keys()[0]].keys()):
        motif_colors[motif_id] = c
    
    # camera setups
    if options.camera is not None:
        camera = {}
        for line in open(options.camera):
            if line[0] != "#":
                parts = [ s.strip() for s in line.split(":")]
                # general and zoomed
                camera[parts[0]]=(parts[1], parts[2])
    
    # load colors per cluster
    #colors = [(1.,0.,0.), (0.,1.,0.), (0.,0.,1.)]*10
    import seaborn as sns
    colors = sns.hls_palette(15, l=.3, s=.8)
    
    # VMD execution template
    template = open("/home/victor/git/PhD-GPCR/PhD-GPCR-2/data/load_script_representatives.tcl").read()
    
    for line in open(options.input):
        protein, drug, folder = line.strip().split()

        # sorted clusters and same color generation always make the same cluster_id, color pair
        representatives_file = os.path.join(folder, "representatives.pdb")
        
        output_folder = os.path.join(options.output_folder, drug, protein)
        create_directory(output_folder)
        
        pdb = parsePDB(representatives_file)
        writePDB(os.path.join(output_folder,"protein.pdb"), pdb.select("protein"), csets = [0])
                        plt.imshow(artificial_image[0].reshape((28, 28)),
                                   cmap='gray')
                    plt.savefig(
                        os.path.join(results_folder, 'Test/{}'.format(epoch)))
                    plt.close()

                    # Create plot of latent space (only if latent dimensions are 2)
                    if FLAGS.latent_dim == 2 and FLAGS.plot_latent:
                        coords = sess.run(z,
                                          feed_dict={
                                              input_batch:
                                              test_images[..., np.newaxis] /
                                              255.
                                          })
                        colormap = ListedColormap(
                            sns.color_palette(sns.hls_palette(10, l=.45,
                                                              s=.8)).as_hex())
                        plt.scatter(coords[:, 0],
                                    coords[:, 1],
                                    c=test_labels,
                                    cmap=colormap)

                        cbar = plt.colorbar()
                        if FLAGS.dataset == 'fashion-mnist':
                            cbar.ax.set_yticklabels([
                                'T-shirt', 'Trouser', 'Pullover', 'Dress',
                                'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
                                'Ankle boot'
                            ])

                        # plt.axis('off')
                        plt.title('Latent space')
示例#43
0
文件: color.py 项目: briney/abtools
def hls(n_colors, hue=0.01, lightness=0.6, saturation=0.65):
    return sns.hls_palette(n_colors, h=hue, l=lightness, s=saturation)
    minErr = np.sqrt(np.sum(bestErr.loc['err2', :]) / len(bestErr.columns))
else:
    bestErr = pd.DataFrame(index=errors.minor_axis, columns=errors.items)
    minErr = 1E10

# Calculate the best fit and plot it
for simIndex in errors.major_axis:
    testRMSE = np.sqrt(
        np.sum(errors.loc[:, simIndex, 'err2'].values) / len(errors.items))
    if testRMSE <= minErr:
        bestErr.loc[:, :] = errors.loc[:, simIndex, :]
        minErr = testRMSE
    else:
        continue
test = 0
mapc = sns.hls_palette(n_colors=8, l=0.5)
colors = itertools.cycle(mapc)
for cond in bestErr.columns:
    curCol = next(colors)
    axTestK.errorbar(bestErr.loc['fSorb', cond],
                     expData.loc[expData.loc[:, 'Salt'] == cond,
                                 'fSorb'].values,
                     yerr=expData.loc[expData.loc[:, 'Salt'] == cond,
                                      'sfSorb'].values,
                     color=curCol,
                     marker='.',
                     ls='none',
                     label=cond)
axTestK.plot([0, 1.0], [0, 1.0], color='k',
             ls='-')  #Plot 1 to 1 line of theoretical vs fitted
def visualize_hist_pairplots(X,y):
	"""
	Visualize the pairwise relationships (Histograms and Density Funcions) between classes and respective attributes

	Keyword arguments:
	X -- The feature vectors
	y -- The target vector
	"""

	joint_data=np.column_stack((X,y))
	df=pd.DataFrame(data=joint_data,columns=["Credit","Gender","Education","Marital Status","Age","X6","X7","X8","X9","X10","X11","X12","X13","X14","X15","X16","X17","X18","X19","X20","X21","X22","X23","Default"])

	palette = sea.hls_palette()

	#histograms	
	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},vars=["Credit","Gender","Education","Marital Status","Age"])
	splot.fig.suptitle('Histograms Distributions and Scatter Plots: Credit, Gender, Education, Marital Status and Age')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()

	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},vars=["X6","X7","X8","X9","X10","X11"])
	splot.fig.suptitle('Histograms Distributions and Scatter Plots: History of Payment')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()

	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},vars=["X12","X13","X14","X15","X16","X17"])
	splot.fig.suptitle('Histograms Distributions and Scatter Plots: Amount of Bill Statements')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()

	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},vars=["X18","X19","X20","X21","X22","X23"])
	splot.fig.suptitle('Histograms Distributions and Scatter Plots: Amount of Previous Payments')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()

	#kdes
	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},diag_kind="kde",vars=["Credit","Gender","Education","Marital Status","Age"])
	splot.fig.suptitle('Univariate Kernel Density Estimations and Scatter Plots: Credit, Gender, Education, Marital Status and Age')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()

	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},diag_kind="kde",vars=["X6","X7","X8","X9","X10","X11"])
	splot.fig.suptitle('Univariate Kernel Density Estimations and Scatter Plots: History of Payment')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()

	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},diag_kind="kde",vars=["X12","X13","X14","X15","X16","X17"])
	splot.fig.suptitle('Univariate Kernel Density Estimations and Scatter Plots: Amount of Bill Statements')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()

	splot=sea.pairplot(df, hue="Default", palette={0:palette[2],1:palette[0]},diag_kind="kde",vars=["X18","X19","X20","X21","X22","X23"])
	splot.fig.suptitle('Univariate Kernel Density Estimations and Scatter Plots: Amount of Previous Payments')
	splot.set(xticklabels=[])
	plt.subplots_adjust(right=0.94, top=0.94)
	plt.show()
def track_clusters_3d(outfile, coords, end_frames, labels, title=None):
    """
    Show clusters of trajectories in 3d.

    Parameters
    ----------
    outfile : str
    coords : array, shape (n_tracks, track_length, 2)
        The x and y coordinates of the trajectories.
    end_frames :
    labels : array, shape (n_tracks,)
        Cluster label for each trajectory. The label -1 is used when a
        trajectory does not belong to any cluster.

    Returns
    -------
    None
    """
    if not (len(coords) == len(labels) == len(end_frames)):
        raise ValueError('number of trajectories must match number of labels')

    labels_uniq, labels_cnt = np.unique(labels, return_counts=True)
    colors = sns.hls_palette(len(labels_uniq))

    with PdfPages(outfile) as pdf:
        # set up figure and 3d axis
        fig, ax = plt.subplots(tight_layout=True, figsize=(16, 9),
                               subplot_kw={'projection': '3d'})
        if title:
            ax.set_title(title)

        # plot trajectories in each cluster
        for l, c in zip(labels_uniq, colors):
            indices = np.flatnonzero(labels == l)
            for idx, (xy, end) in enumerate(zip(coords[indices], end_frames[indices])):
                ax.plot(xy[:, 0], xy[:, 1], np.arange(end+1-len(xy), end+1), '-o',
                        zdir='y', lw=0.7, ms=1.5, color=c)

        # show legend
        ax.legend(handles=[Patch(color=c, label=str(cnt)) for c, cnt in zip(colors, labels_cnt)])

        # mimic video frames
        xy = (coords[:, :, 0].min(), coords[:, :, 1].min())
        w = coords[:, :, 0].max() - coords[:, :, 0].min()
        h = coords[:, :, 1].max() - coords[:, :, 1].min()
        for z in range(0, end_frames.max(), 50):
            rect = Rectangle(xy, w, h, fill=False, color='black',
                             alpha=0.3, lw=0.3)
            ax.add_patch(rect)
            art3d.pathpatch_2d_to_3d(rect, z=z, zdir='y')

        # set axis limits
        pad = 0.5
        ax.set_xlim3d(coords[:, :, 0].min()-pad, coords[:, :, 0].max()+pad)
        ax.set_ylim3d(0, end_frames.max()+1)
        ax.set_zlim3d(coords[:, :, 1].min()-pad, coords[:, :, 1].max()+pad)

        ax.set_xlabel('Video width')
        ax.set_ylabel('Video frames')
        ax.set_zlabel('Video height')

        ax.view_init(elev=20, azim=12)
        _set_axes_equal(ax)

        pdf.savefig()
        plt.close(fig)
def random_transform(x, y=None,
                     rotation_range=0.,
                     width_shift_range=0.,
                     height_shift_range=0.,
                     shear_range=0.,
                     zoom_range=0.,
                     channel_shift_range=0.,
                     fill_mode='nearest',
                     cval=0.,
                     cval_mask=0.,
                     horizontal_flip=0.,  # probability
                     vertical_flip=0.,  # probability
                     rescale=None,
                     spline_warp=False,
                     warp_sigma=0.1,
                     warp_grid_size=3,
                     crop_size=None,
                     return_optical_flow=False,
                     nclasses=None,
                     gamma=0.,
                     gain=1.,
                     chan_idx=3,  # No batch yet: (s, 0, 1, c)
                     rows_idx=1,  # No batch yet: (s, 0, 1, c)
                     cols_idx=2,  # No batch yet: (s, 0, 1, c)
                     void_label=None,
                     prescale=1.0):
    '''Random Transform.

    A function to perform data augmentation of images and masks during
    the training  (on-the-fly). Based on [RandomTransform1]_.

    Parameters
    ----------
    x: array of floats
        An image.
    y: array of int
        An array with labels.
    rotation_range: int
        Degrees of rotation (0 to 180).
    width_shift_range: float
        The maximum amount the image can be shifted horizontally (in
        percentage).
    height_shift_range: float
        The maximum amount the image can be shifted vertically (in
        percentage).
    shear_range: float
        The shear intensity (shear angle in radians).
    zoom_range: float or list of floats
        The amout of zoom. If set to a scalar z, the zoom range will be
        randomly picked in the range [1-z, 1+z].
    channel_shift_range: float
        The shift range for each channel.
    fill_mode: string
        Some transformations can return pixels that are outside of the
        boundaries of the original image. The points outside the
        boundaries are filled according to the given mode (`constant`,
        `nearest`, `reflect` or `wrap`). Default: `nearest`.
    cval: int
        Value used to fill the points of the image outside the boundaries when
        fill_mode is `constant`. Default: 0.
    cval_mask: int
        Value used to fill the points of the mask outside the boundaries when
        fill_mode is `constant`. Default: 0.
    horizontal_flip: float
        The probability to randomly flip the images (and masks)
        horizontally. Default: 0.
    vertical_flip: bool
        The probability to randomly flip the images (and masks)
        vertically. Default: 0.
    rescale: float
        The rescaling factor. If None or 0, no rescaling is applied, otherwise
        the data is multiplied by the value provided (before applying
        any other transformation).
    spline_warp: bool
        Whether to apply spline warping.
    warp_sigma: float
        The sigma of the gaussians used for spline warping.
    warp_grid_size: int
        The grid size of the spline warping.
    crop_size: tuple
        The size of crop to be applied to images and masks (after any
        other transformation).
    return_optical_flow: bool
        If not False a dense optical flow will be concatenated to the
        end of the channel axis of the image. If True, angle and
        magnitude will be returned, if set to 'rbg' an RGB representation
        will be returned instead. Default: False.
    nclasses: int
        The number of classes of the dataset.
    gamma: float
        Controls gamma in Gamma correction.
    gain: float
        Controls gain in Gamma correction.
    chan_idx: int
        The index of the channel axis.
    rows_idx: int
        The index of the rows of the image.
    cols_idx: int
        The index of the cols of the image.
    void_label: int
        The index of the void label, if any. Used for padding.

    References
    ----------
    .. [RandomTransform1]
       https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
    '''
    # Set this to a dir, if you want to save augmented images samples
    save_to_dir = None  # "./"

    if rescale:
        raise NotImplementedError()

    # Do not modify the original images
    x = x.copy()
    if y is not None and len(y) > 0:
        y = y[..., None]  # Add extra dim to y to simplify computation
        y = y.copy()

    # Scale images
    if prescale != 1.0:
        import skimage.transform
        x = [skimage.transform.rescale(x_image, prescale, order=1,
                                       preserve_range=True) for x_image in x]
        x = np.stack(x, 0)
        y = [skimage.transform.rescale(y_image, prescale, order=0,
                                       preserve_range=True) for y_image in y]
        y = np.stack(y, 0)

    # listify zoom range
    if np.isscalar(zoom_range):
        if zoom_range > 1.:
            raise RuntimeError('Zoom range should be between 0 and 1. '
                               'Received: ', zoom_range)
        zoom_range = [1 - zoom_range, 1 - zoom_range]
    elif len(zoom_range) == 2:
        if any(el > 1. for el in zoom_range):
            raise RuntimeError('Zoom range should be between 0 and 1. '
                               'Received: ', zoom_range)
        zoom_range = [1-el for el in zoom_range]
    else:
        raise Exception('zoom_range should be a float or '
                        'a tuple or list of two floats. '
                        'Received arg: ', zoom_range)

    # Channel shift
    if channel_shift_range != 0:
        x = random_channel_shift(x, channel_shift_range, rows_idx, cols_idx,
                                 chan_idx)

    # Gamma correction
    if gamma > 0:
        scale = float(1)
        x = ((x / scale) ** gamma) * scale * gain

    # Affine transformations (zoom, rotation, shift, ..)
    if (rotation_range or height_shift_range or width_shift_range or
            shear_range or zoom_range != [1, 1]):

        # --> Rotation
        if rotation_range:
            theta = np.pi / 180 * np.random.uniform(-rotation_range,
                                                    rotation_range)
        else:
            theta = 0
        rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                    [np.sin(theta), np.cos(theta), 0],
                                    [0, 0, 1]])
        # --> Shift/Translation
        if height_shift_range:
            tx = (np.random.uniform(-height_shift_range, height_shift_range) *
                  x.shape[rows_idx])
        else:
            tx = 0
        if width_shift_range:
            ty = (np.random.uniform(-width_shift_range, width_shift_range) *
                  x.shape[cols_idx])
        else:
            ty = 0
        translation_matrix = np.array([[1, 0, tx],
                                       [0, 1, ty],
                                       [0, 0, 1]])
        # --> Shear
        if shear_range:
            shear = np.random.uniform(-shear_range, shear_range)
        else:
            shear = 0
        shear_matrix = np.array([[1, -np.sin(shear), 0],
                                 [0, np.cos(shear), 0],
                                 [0, 0, 1]])
        # --> Zoom
        if zoom_range == [1, 1]:
            zx, zy = 1, 1
        else:
            zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
        # print ('ZX: {} ZY: {}'.format(zx, zy))
        zoom_matrix = np.array([[zx, 0, 0],
                                [0, zy, 0],
                                [0, 0, 1]])
        # Use a composition of homographies to generate the final transform
        # that has to be applied
        transform_matrix = np.dot(np.dot(np.dot(rotation_matrix,
                                                translation_matrix),
                                         shear_matrix), zoom_matrix)
        h, w = x.shape[rows_idx], x.shape[cols_idx]
        transform_matrix = transform_matrix_offset_center(transform_matrix,
                                                          h, w)
        # Apply all the transformations together
        # print ('Before: X size: {} Y size: {}'.format(x.shape, y.shape))
        x = apply_transform(x, transform_matrix, fill_mode=fill_mode,
                            cval=cval, order=1, rows_idx=rows_idx,
                            cols_idx=cols_idx)
        if y is not None and len(y) > 0:
            y = apply_transform(y, transform_matrix, fill_mode=fill_mode,
                                cval=cval_mask, order=0, rows_idx=rows_idx,
                                cols_idx=cols_idx)
        # print ('After: X size: {} Y size: {}'.format(x.shape, y.shape))

    # Horizontal flip
    if np.random.random() < horizontal_flip:  # 0 = disabled
        x = flip_axis(x, cols_idx)
        if y is not None and len(y) > 0:
            y = flip_axis(y, cols_idx)

    # Vertical flip
    if np.random.random() < vertical_flip:  # 0 = disabled
        x = flip_axis(x, rows_idx)
        if y is not None and len(y) > 0:
            y = flip_axis(y, rows_idx)

    # Spline warp
    if spline_warp:
        import SimpleITK as sitk
        warp_field = gen_warp_field(shape=(x.shape[rows_idx],
                                           x.shape[cols_idx]),
                                    sigma=warp_sigma,
                                    grid_size=warp_grid_size)
        x = apply_warp(x, warp_field,
                       interpolator=sitk.sitkLinear,
                       fill_mode=fill_mode,
                       fill_constant=cval,
                       rows_idx=rows_idx, cols_idx=cols_idx)
        if y is not None and len(y) > 0:
            y = np.round(apply_warp(y, warp_field,
                                    interpolator=sitk.sitkNearestNeighbor,
                                    fill_mode=fill_mode,
                                    fill_constant=cval_mask,
                                    rows_idx=rows_idx, cols_idx=cols_idx))

    # Save augmented images
    if save_to_dir:
        import seaborn as sns
        cmap = sns.hls_palette(nclasses)
        fname = 'data_augm_{}.png'.format(np.random.randint(1e4))
        if y is not None and len(y) > 0:
            save_img2(x, y, os.path.join(save_to_dir, fname),
                      cmap, void_label, rows_idx, cols_idx, chan_idx)
        else:
            scipy.misc.toimage(x[0]).save(fname)

    # Crop
    # Expects axes with shape (..., 0, 1)
    # TODO: Add center crop
    if crop_size:
        # Reshape to (..., 0, 1)
        pattern = [el for el in range(x.ndim) if el != rows_idx and
                   el != cols_idx] + [rows_idx, cols_idx]
        inv_pattern = [pattern.index(el) for el in range(x.ndim)]
        x = x.transpose(pattern)

        crop = list(crop_size)
        pad = [0, 0]
        h, w = x.shape[-2:]

        # Compute amounts
        if crop[0] < h:
            # Do random crop
            top = np.random.randint(h - crop[0])
        else:
            # Set pad and reset crop
            pad[0] = crop[0] - h
            top, crop[0] = 0, h
        if crop[1] < w:
            # Do random crop
            left = np.random.randint(w - crop[1])
        else:
            # Set pad and reset crop
            pad[1] = crop[1] - w
            left, crop[1] = 0, w

        # Cropping
        x = x[..., top:top+crop[0], left:left+crop[1]]
        if y is not None and len(y) > 0:
            y = y.transpose(pattern)
            y = y[..., top:top+crop[0], left:left+crop[1]]
        # Padding
        if pad != [0, 0]:
            pad_pattern = ((0, 0),) * (x.ndim - 2) + (
                (pad[0]//2, pad[0] - pad[0]//2),
                (pad[1]//2, pad[1] - pad[1]//2))
            x = np.pad(x, pad_pattern, 'constant')
            y = np.pad(y, pad_pattern, 'constant', constant_values=void_label)

        x = x.transpose(inv_pattern)
        if y is not None and len(y) > 0:
            y = y.transpose(inv_pattern)

    if return_optical_flow:
        flow = optical_flow(x, rows_idx, cols_idx, chan_idx,
                            return_rgb=return_optical_flow == 'rgb')
        x = np.concatenate((x, flow), axis=chan_idx)

    # # Save augmented images
    # if save_to_dir:
    #     import seaborn as sns
    #     cmap = sns.hls_palette(nclasses)
    #     fname = 'data_augm_{}.png'.format(np.random.randint(1e4))
    #     if y is not None and len(y) > 0:
    #         save_img2(x, y, os.path.join(save_to_dir, fname),
    #                   cmap, void_label, rows_idx, cols_idx, chan_idx)
    #     else:
    #         scipy.misc.toimage(x[0]).save(fname)

    # Undo extra dim
    if y is not None and len(y) > 0:
        y = y[..., 0]

    return x, y
def draw_graph_centrality2(G,
                           Subsets=[],
                           h=15,
                           v=10,
                           deltax=0,
                           deltay=0,
                           fontsize=18,
                           k=0.2,
                           arrows=False,
                           node_alpha=0.3,
                           l_alpha=1,
                           node_color='blue',
                           centrality=nx.degree_centrality,
                           font_color='black',
                           threshold=0.01,
                           multi=3000,
                           edge_color='olive',
                           colstart=0.2,
                           coldark=0.5):

    from pylab import rcParams
    import matplotlib.pyplot as plt
    from matplotlib import colors as mcolors

    colors = dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS)
    node_dict = centrality(G)
    subnodes = dict(
        {x: node_dict[x]
         for x in node_dict if node_dict[x] >= threshold})
    #print(subnodes)
    x, y = rcParams['figure.figsize']
    rcParams['figure.figsize'] = h, v

    ax = plt.subplot()
    ax.set_xticks([])
    ax.set_yticks([])
    #G = G.subgraph(subnodes)
    glob_col = sns.hls_palette(len(G), h=colstart, l=coldark)[0]
    pos = nx.spring_layout(G, k=k)
    labelpos = dict({k: (pos[k][0] + deltax, pos[k][1] + deltay) for k in pos})
    #print(labelpos)
    #print(pos)
    if l_alpha <= 1:
        nx.draw_networkx_labels(G,
                                labelpos,
                                font_size=fontsize,
                                alpha=l_alpha,
                                font_color=font_color)
    sub_color = 0
    if Subsets != []:
        i = 0
        colpalette = sns.hls_palette(len(Subsets), h=colstart, l=coldark)
        #print(colpalette)
        for Sub in Subsets:
            sublist = dict({x: subnodes[x] for x in subnodes if x in Sub})
            #print(sublist)
            #sub_col = list(colors.values())[np.random.randint(20,100)]
            sub_col = colpalette[i]
            #print(i, sub_col, sublist.keys())
            #print(i, sub_col)
            nx.draw_networkx_nodes(
                G,
                pos,
                alpha=node_alpha,
                node_color=[sub_col],
                nodelist=[x for x in sublist.keys()],
                node_size=[v * multi for v in sublist.values()])
            i += 1
    else:
        nx.draw_networkx_nodes(
            G,
            pos,
            alpha=node_alpha,
            node_color=glob_col,
            nodelist=subnodes.keys(),
            node_size=[v * multi for v in subnodes.values()])
        True

    nx.draw_networkx_edges(G,
                           pos,
                           alpha=0.1,
                           arrows=arrows,
                           edge_color=edge_color)

    rcParams['figure.figsize'] = x, y
    return
示例#49
0
def cluster_map_plot(dmat_file,
                     big_tribe,
                     tribe_groups_dir,
                     raw_wav_dir,
                     savefig=None):
    """
    Wrapper on seaborn.clustermap to allow for coloring of rows/columns
    by multiplet
    :param dmat_file:
    :param big_tribe:
    :param tribe_groups_dir:
    :return:
    """
    # Make list of temp files which were actually used in the clustering
    # There were actually fewer than templates for some reason...?
    # XXX TODO May be worth using SAC directories instead?
    big_tribe.sort()
    raw_wav_files = glob('%s/*' % raw_wav_dir)
    raw_wav_files.sort()
    all_wavs = [wav.split('/')[-1].split('.')[0] for wav in raw_wav_files]
    names = [t.name for t in big_tribe if t.name in all_wavs]
    wavs = [
        wav for wav in raw_wav_files
        if wav.split('/')[-1].split('.')[0] in names
    ]
    new_tribe = Tribe()
    new_tribe.templates = [temp for temp in big_tribe if temp.name in names]
    print('Processing temps')
    temp_list = [template.name for tmp, template in zip(wavs, new_tribe)]
    matrix = np.load(dmat_file)  # Take absolute value? NO
    dist_vec = squareform(matrix)
    Z = linkage(dist_vec)
    df_mat = pd.DataFrame(matrix)
    tribes = glob('{}/*.tgz'.format(tribe_groups_dir))
    grp_inds = []
    grp_nos = []
    for tribe in tribes:
        grp_nos.append(tribe.split('_')[-2])
        trb = Tribe().read(tribe)
        names = [temp.name for temp in trb]
        inds = []
        for i, nm in enumerate(temp_list):
            if nm in names:
                inds.append(i)
        grp_inds.append(tuple(inds))
    # Create a categorical palette to identify the networks
    multiplet_pal = sns.hls_palette(len(grp_inds))
    multiplet_lut = dict(zip(tuple(grp_inds), multiplet_pal))
    # Convert the palette to vectors that will be drawn on the side of the matrix
    temp_colors = {}
    temp_inds = np.arange(0, len(temp_list), 1)
    for i in temp_inds:
        for key in multiplet_lut.keys():
            if i in key:
                temp_colors[i] = multiplet_lut[key]
                break
    template_colors = pd.Series(temp_inds, index=temp_inds,
                                name='Multiplet').map(temp_colors)
    cmg = sns.clustermap(
        df_mat,
        method='single',
        cmap='vlag_r',
        vmin=0.4,
        vmax=1.4,  #row_colors=template_colors,
        col_colors=template_colors,
        row_linkage=Z,
        col_linkage=Z,
        yticklabels=False,
        xticklabels=False,
        cbar_kws={'label': '1 - CCC'},
        figsize=(12, 12))
    if not savefig:
        plt.show()
    else:
        cmg.savefig(savefig, dpi=500)
    return cmg
示例#50
0
def hex_palette(n):
    pal = sns.hls_palette(n).as_hex()
    return lambda c: pal[c]
示例#51
0
    def plot_profiles(self, profiles, legend=None, spacing=0.1, color_palette=None):
        """
        Plot given profiles as bar plot.
        :param profiles: Profiles to plot.
        :param legend: (optional) Legend for plot.
        :param spacing: Space between bars relative to their width.
        :return: Figure object.
        """
        # get number of plots
        plot_n = len(profiles)

        # construct legend
        try:
            # start with empty legend
            new_legend = []
            # try to iterate but don't make legend longer than number of plots
            for entry in legend:
                if len(new_legend) > plot_n:
                    break
                else:
                    new_legend = new_legend + [entry]
            legend = new_legend
        except TypeError:
            # set to empty list
            legend = []
        # generate missing legend entries
        for idx in range(len(legend)+1, plot_n+1, 1):
            legend = legend + ['profile {}'.format(idx)]

        # construct color palette
        try:
            # start with empty palette
            new_color_palette = []
            # try to iterate but don't make palette longer than number of plots
            for entry in color_palette:
                if len(new_color_palette) > plot_n:
                    break
                else:
                    new_color_palette = new_color_palette + [entry]
            color_palette = new_color_palette
        except TypeError:
            # set to empty list
            color_palette = []
        # generate missing color entries
        for color in sns.hls_palette(plot_n-len(color_palette), l=.4, s=.9):
            color_palette = color_palette + [color]

        # dimensions
        bar_width = 1/plot_n
        plot_width = max([len(p) for p in profiles])*(1+spacing) - spacing

        # plot
        fig, ax = plt.subplots(1, 1, figsize=(15, 10))
        for idx, p in enumerate(profiles):
            ax.bar((1+spacing)*self.tone_indices + idx*bar_width, p, bar_width, color=color_palette[idx])
        ax.set_xticks((1+spacing)*self.tone_indices+0.5)
        ax.set_xticklabels(self.tone_names)
        ax.set_xlim(0, plot_width)
        ax.legend(legend, loc='best')

        return fig, ax
示例#52
0
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
from scipy.optimize import curve_fit

sns.set_style('white')
sns.set_palette(sns.hls_palette(6, h=0.5, l=0.4, s=0.5))
font = {'size': 18}
plt.rc('font', **font)

num_dict = {
    '0': '$_{0}$',
    '1': '$_{1}$',
    '2': '$_{2}$',
    '3': '$_{3}$',
    '4': '$_{4}$',
    '5': '$_{5}$',
    '6': '$_{6}$',
    '7': '$_{7}$',
    '8': '$_{8}$',
    '9': '$_{9}$',
    '+': '$^{+}$'
}
# usage
SUB = str.maketrans(num_dict)
# print("H2SO4+".translate(SUB))

avg_dist = 2.5432441257325
dproton = avg_dist
    def __init__(self, grid, resources, size, name, tasks):
        """
        Arguments:
        - grid: a 2d array of sets of strings representing which resources
          are present in which cells.
        
        - resources: a list of strings representing all of the resources 
          found anywhere in the world
        
        - size: a tuple containing the x and y dimensions of the grid
        
        - name: a string representing the name of the file this came from
        
        - tasks: a list of strings representing all of the tasks rewarded by
          the environment file (extracted from reactions), in order of 
          appearance.
        """
        self.grid = grid
        self.resources = resources
        self.size = size
        self.name = name.split("/")[-1] #Extract filename from path
        self.name = self.name[:-4] #Remove file extension
        self.tasks = tasks
        
        if len(self.resources) == 1:
            #Yay, we can make the background nice and simple!
            self.resource_palette = sns.dark_palette("black", 1)
            #If we're only using two colors, it's better for them to be red
            #and yellow than red and green
            self.task_palette = sns.hls_palette(max(len(tasks), 4), s=1)
        
        elif set(self.tasks) == set(self.resources):
            #Yay, the tasks and resources correspond to each other!
            #If we're only using two colors, it's better for them to be red
            #and yellow than red and green
            self.resource_palette = sns.hls_palette(max(len(tasks), 4), s=1)
            self.task_palette = sns.hls_palette(max(len(tasks), 4), s=1)
        
        elif len(self.tasks) == 1:
            #Yay, we can display the tasks easily!
            self.task_palette = sns.dark_palette("black", 1)
            #If we're only using two colors, it's better for them to be red
            #and yellow than red and green
            self.resource_palette = sns.hls_palette(max(len(tasks), 4), s=1)

        elif set(self.tasks).issubset(set(self.resources)):
            self.task_palette = sns.hls_palette(max(len(tasks), 4), s=1)
            self.resource_palette = sns.hls_palette(max(len(tasks), 4), s=1)
            self.resource_palette += sns.color_palette("bone", 
                                                    len(resources)-len(tasks))

        elif set(self.tasks).issuperset(set(self.resources)):
            self.task_palette = sns.hls_palette(max(len(tasks), 4), s=1)
            self.resource_palette = sns.hls_palette(max(len(resources), 4), s=1)
        else:
            #Since we don't know anything in particular about the tasks
            #and resources, let's be color-blind friendly and keep the
            #background simple. The user probably wants to assign new
            #palettes based on their data.
            self.task_palette = sns.color_palette("colorblind", len(tasks))
            self.resource_palette = sns.color_palette("bone", len(resources)) 
#create list of continents
continents_list = [
    african_countries, asian_countries, european_countries,
    north_american_countries, south_american_countries, oceanian_countries
]
"""
PLOT A GRAPH FOR EACH CONTINENT WHICH COMPARES EACH COUNTRY'S DATA FROM 2010-2014
2010 AS THE LIGHTEST COLOR AND 2014 AS THE DARKEST TO SHOW EACH COUNTRY'S PROGRESS
OVER THE PAST 5 YEARS
"""

sns.set_context("notebook", font_scale=.9)
sns.set_style("whitegrid")
palette_2013 = sns.color_palette("hls", 5)
palette_2014 = sns.hls_palette(5, l=.3, s=.8)
palette_2012 = sns.hls_palette(5, l=.5, s=.6)
palette_2011 = sns.hls_palette(5, l=.7, s=.5)
palette_2010 = sns.hls_palette(5, l=.8, s=.4)


#function for plotting the graph using stripplot
def plot_continent_graph(continent):
    ax = sns.stripplot(x='Country Name',
                       y='2010',
                       hue='Country Name',
                       data=continent,
                       palette=palette_2010,
                       jitter=True)
    ax = sns.stripplot(x='Country Name',
                       y='2011',
def plot_silhouette(results, inp='data', labels=None, axes=None,
                    size=4.6,  dpi=300,  ext='png', plot_dir=None):
    HCA = results.HCA
    clustering = HCA.results[inp]
    name = inp
    sample_scores, avg_score = silhouette_analysis(clustering, labels)
    # raw clustering for comparison
    raw_clustering = HCA.results['data']
    _, raw_avg_score = silhouette_analysis(raw_clustering, labels)
    
    if labels is None:
        labels = clustering['labels']
    n_clusters = len(np.unique(labels))
    colors = sns.hls_palette(n_clusters)
    if axes is None:
        fig, (ax, ax2) =  plt.subplots(1, 2, figsize=(size, size*.375))
    else:
        ax, ax2 = axes
    y_lower = 5
    ax.grid(False)
    ax2.grid(linewidth=size/10)
    cluster_names = HCA.get_cluster_names(inp=inp)
    for i in range(n_clusters):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = sample_scores[labels == i+1]
        # skip "clusters" with one value
        if len(ith_cluster_silhouette_values) == 1:
            continue
        ith_cluster_silhouette_values.sort()
        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        # update y range and plot
        y_upper = y_lower + size_cluster_i
        ax.fill_betweenx(np.arange(y_lower, y_upper),
                          0, ith_cluster_silhouette_values,
                          alpha=0.7, color=colors[i],
                          linewidth=size/10)
        # Label the silhouette plots with their cluster numbers at the middle
        ax.text(-0.02, y_lower + 0.25 * size_cluster_i, cluster_names[i], fontsize=size/1.7, ha='right')
        # Compute the new y_lower for next plot
        y_lower = y_upper + 5  # 10 for the 0 samples
    ax.axvline(x=avg_score, color="red", linestyle="--", linewidth=size*.1)
    ax.set_xlabel('Silhouette score', fontsize=size, labelpad=5)
    ax.set_ylabel('Cluster Separated DVs', fontsize=size)
    ax.tick_params(pad=size/4, length=size/4, labelsize=size*.8, width=size/10,
                   left=False, labelleft=False, bottom=True)
    ax.set_title('Dynamic tree cut', fontsize=size*1.2, y=1.02)
    ax.set_xlim(-1, 1)
    # plot silhouettes for constant thresholds
    _, scores, _ = get_constant_height_labels(clustering)
    ax2.plot(*zip(*scores), 'o', color='b', 
             markeredgecolor='white', markeredgewidth=size*.1, markersize=size*.5, 
             label='Fixed Height Cut')
    # plot the dynamic tree cut point
    ax2.plot(n_clusters, avg_score, 'o', color ='r', 
             markeredgecolor='white', markeredgewidth=size*.1, markersize=size*.75, 
             label='EFA Dynamic Cut')
    ax2.plot(n_clusters, raw_avg_score, 'o', color ='k', 
             markeredgecolor='white', markeredgewidth=size*.1, markersize=size*.75, 
             label='Raw Dynamic Cut')
    ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax2.set_xlabel('Number of clusters', fontsize=size)
    ax2.set_ylabel('Average Silhouette Score', fontsize=size)
    ax2.set_title('Single cut height', fontsize=size*1.2, y=1.02)
    ax2.tick_params(labelsize=size*.8, pad=size/4, length=size/4, width=size/10, bottom=True)
    ax2.legend(loc='center right', fontsize=size*.8)
    plt.subplots_adjust(wspace=.3)
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, 
                                         'silhouette_analysis_%s.%s' % (name, ext)),
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
示例#56
0
def main(argv):
    dataSetFile = ""
    nClusters = 3
    dataPointLimits = 0

    try:
        helpmsg = 'The script can visualise up to 10 clusters in 2D or 3D\n' + 'Usage: kmeans.py -i <dataSetFile> -k <nClusters> -l <dataPointLimit>'
        opts, args = getopt.getopt(argv[1:], "hi:k:l:", [
            'input=',
            'kClusters=',
            'limit=',
        ])
    except getopt.GetoptError as err:
        print(str(err))
        print(helpmsg)
        sys.exit(2)
    for opt, arg in opts:
        if opt == '-h':
            print(helpmsg)
            sys.exit()
        elif opt in ("-i", "--input"):
            dataSetFile = arg
        elif opt in ("-k", "--kClusters"):
            nClusters = int(arg)
        elif opt in ("-l", "--limit"):
            dataPointLimits = int(arg)
    if len(dataSetFile) < 1:
        print("No input dataset specified.")
        print(helpmsg)
        sys.exit()

    print('Dataset: "', dataSetFile)
    print('Number of clusters: ', str(nClusters))
    print('Process files limit: ', str(dataPointLimits))

    # read data from file
    # features:
    # 0: ['Cassette', 'CD', 'CD'....]
    # 1: [1984, 1984, 2000         ....]
    # ...
    with open(dataSetFile) as f:
        rowCount, featureCount = [int(x)
                                  for x in next(f).split()]  # read first line
        featureListInfo = f.readline().split(',')[:-1]
        featureList = [[
            i, featureListInfo[i].strip().split(' ')[0],
            featureListInfo[i].strip().split(' ')[1] == "number"
        ] for i in range(len(featureListInfo))]

        if dataPointLimits != 0 and rowCount > dataPointLimits:
            rowCount = dataPointLimits
        features = [[] for y in range(featureCount)]
        readCount = 0
        for line in f:  # read rest of lines
            if readCount >= rowCount:
                break
            readCount = readCount + 1
            datapoint = line.split(',')
            for i in range(len(features)):
                features[i].append(datapoint[i])

    unchosenFeatures = featureList
    chosenFeatures = []
    userInput = print(
        "Input feature number to choose or 'q' to finish choosing.")
    while len(unchosenFeatures) > 0:
        print('Available features: ', [
            str(i) + ":" + unchosenFeatures[i][1]
            for i in range(0, len(unchosenFeatures))
        ])
        userInput = input("Choice:")
        try:
            choice = int(userInput)
            if choice < len(unchosenFeatures):
                chosenFeatures.append(unchosenFeatures[choice])
                unchosenFeatures.pop(choice)
            else:
                print('Index out of range')
        except ValueError:
            print('Features chosen: ', [
                str(i) + ":" + chosenFeatures[i][1]
                for i in range(0, len(chosenFeatures))
            ])
            break
    if len(chosenFeatures) == 0:
        print('No features chosen. Exiting.')
        sys.exit(1)

    # encode labels
    encoders = [None] * featureCount
    encodedLabels = [None] * featureCount
    # encodedLabels:
    # 0: [3, [0, 1, 0....]]
    # 1: [4, [0, 0, 1....]]
    # ...
    for i in range(0, len(chosenFeatures)):
        featureIndex = chosenFeatures[i][0]
        encoders[featureIndex] = LabelEncoder()
        encoders[featureIndex].fit(features[featureIndex])
        encodedLabels[featureIndex] = encoders[featureIndex].fit_transform(
            features[featureIndex])

    # normalize data
    arr = np.empty([len(chosenFeatures), rowCount])

    for i in range(len(chosenFeatures)):
        arr[i] = encodedLabels[chosenFeatures[i][0]]

    #arr = arr.reshape(-1, 1)
    plottable_X = arr
    arr = np.transpose(arr)

    #  scaler = StandardScaler().fit(arr)
    #  standardized_X = scaler.fit_transform(arr)

    #  normalizer = Normalizer().fit(standardized_X)
    #  normalized_X = normalizer.fit_transform(arr)

    #plottable_X = np.transpose(normalized_X)

    km = KMeans(n_clusters=nClusters,
                init='random',
                n_init=10,
                max_iter=300,
                tol=1e-04,
                random_state=0)
    y_km = km.fit_predict(arr)

    markers = [
        'o', 'v', '^', '<', '>', '8', 's', 'p', 'h', 'H', 'D', 'd', 'P', 'X'
    ]

    colors = sns.hls_palette(10, l=.5, s=1.0)
    random.seed(124)
    random.shuffle(colors)

    # plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_title(label='Clusters:' + str(nClusters) + ' DataPoints:' +
                 str(rowCount),
                 pad=20.)

    if len(chosenFeatures) == 3:
        ax.set_xlabel(chosenFeatures[0][1], labelpad=20.)
        ax.set_ylabel(chosenFeatures[1][1], labelpad=20.)
        ax.set_zlabel(chosenFeatures[2][1], labelpad=20.)

        ax.tick_params(direction='out', grid_color='r', grid_alpha=0.5)

        ax.xaxis.set_tick_params(rotation=90)
        ax.yaxis.set_tick_params(rotation=90)

        xTickValues = list(set(encodedLabels[chosenFeatures[0][0]]))
        yTickValues = list(set(encodedLabels[chosenFeatures[1][0]]))
        zTickValues = list(set(encodedLabels[chosenFeatures[2][0]]))
        # X TICKS
        ax.set_xticks(takespread(xTickValues, 25))
        ax.set_xticklabels(encoders[chosenFeatures[0][0]].inverse_transform(
            takespread(xTickValues, 25)))

        # Y TICKS
        ax.set_yticks(takespread(yTickValues, 25))
        ax.set_yticklabels(encoders[chosenFeatures[1][0]].inverse_transform(
            takespread(yTickValues, 25)))

        # Z TICKS
        ax.set_zticks(takespread(zTickValues, 25))
        ax.set_zticklabels(encoders[chosenFeatures[2][0]].inverse_transform(
            takespread(zTickValues, 25)))

        # count the occurrences of each point
        c = Counter(zip(plottable_X[0], plottable_X[1], plottable_X[2]))
        # create a list of the sizes, here multiplied by 10 for scale
        density = [
            10 + ((2000) / rowCount) * c[(xx, yy, zz)] for xx, yy, zz in zip(
                plottable_X[0], plottable_X[1], plottable_X[2])
        ]

        for i in range(0, nClusters):
            color = colors[i % 10]
            marker = markers[i % len(markers)]
            ax.scatter(xs=plottable_X[0, y_km == i],
                       ys=plottable_X[1, y_km == i],
                       zs=plottable_X[2, y_km == i],
                       c=color,
                       marker=marker,
                       edgecolor='black',
                       label='cluster ' + str(i),
                       s=density)

    elif len(chosenFeatures) == 2:
        ax.set_xlabel(chosenFeatures[0][1], labelpad=20.)
        ax.set_ylabel(chosenFeatures[1][1], labelpad=20.)

        ax.tick_params(direction='out', grid_color='r', grid_alpha=0.5)

        ax.xaxis.set_tick_params(rotation=90)
        ax.yaxis.set_tick_params(rotation=90)

        xTickValues = list(set(encodedLabels[chosenFeatures[0][0]]))
        yTickValues = list(set(encodedLabels[chosenFeatures[1][0]]))
        # X TICKS
        ax.set_xticks(takespread(xTickValues, 25))
        ax.set_xticklabels(encoders[chosenFeatures[0][0]].inverse_transform(
            takespread(xTickValues, 25)))

        # Y TICKS
        ax.set_yticks(takespread(yTickValues, 25))
        ax.set_yticklabels(encoders[chosenFeatures[1][0]].inverse_transform(
            takespread(yTickValues, 25)))

        # count the occurrences of each point
        c = Counter(zip(plottable_X[0], plottable_X[1]))
        # create a list of the sizes, here multiplied by 10 for scale
        density = [
            10 + ((2000) / rowCount) * c[(xx, yy)]
            for xx, yy in zip(plottable_X[0], plottable_X[1])
        ]

        for i in range(0, nClusters):
            color = colors[i % 10]
            marker = markers[i % len(markers)]
            ax.scatter(xs=plottable_X[0, y_km == i],
                       ys=plottable_X[1, y_km == i],
                       c=color,
                       marker=marker,
                       edgecolor='black',
                       label='cluster ' + str(i),
                       s=50)

    # plot the centroids
    if len(chosenFeatures) == 3:
        ax.scatter(xs=km.cluster_centers_[:, 0],
                   ys=km.cluster_centers_[:, 1],
                   zs=km.cluster_centers_[:, 2],
                   s=350,
                   marker='*',
                   c='red',
                   edgecolor='black',
                   label='centroids')
    elif len(chosenFeatures) == 2:
        ax.scatter(xs=km.cluster_centers_[:, 0],
                   ys=km.cluster_centers_[:, 1],
                   s=350,
                   marker='*',
                   c='red',
                   edgecolor='black',
                   label='centroids')
    plt.legend()
    plt.grid()
    plt.show()
示例#57
0
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

current_palette = sns.color_palette()
#默认颜色
#sns.palplot(current_palette)
# 环形颜色
#sns.palplot(sns.color_palette("hls", 16)) 
#调整深度和饱和度
sns.palplot(sns.hls_palette(8, l=.3, s=.6))
plt.show()
示例#58
0
def plot_logs(experiments,
              smooth_factor=0,
              share_legend=True,
              ignore_metrics=None,
              pretty_names=False,
              include_metrics=None):
    """A function which will plot experiment histories for comparison viewing / analysis

    Args:
        experiments (list, Experiment): Experiment(s) to plot
        smooth_factor (float): A non-negative float representing the magnitude of gaussian smoothing to apply (zero for
        none)
        share_legend (bool): Whether to have one legend across all graphs (true) or one legend per graph (false)
        pretty_names (bool): Whether to modify the metric names in graph titles (true) or leave them alone (false)
        ignore_metrics (set): Any keys to ignore during plotting
        include_metrics (set): A whitelist of keys to include during plotting. If None then all will be included.
    Returns:
        The handle of the pyplot figure
    """
    experiments = to_list(experiments)

    ignore_keys = ignore_metrics or set()
    ignore_keys = to_set(ignore_keys)
    ignore_keys |= {'epoch', 'progress', 'total_train_steps'}
    include_keys = to_set(include_metrics) if include_metrics else None
    # TODO: epoch should be indicated on the axis (top x axis?)
    # TODO: figure out how ignore_metrics should interact with mode

    max_time = 0
    metric_keys = set()
    for experiment in experiments:
        history = experiment.history
        for mode, metrics in history.items():
            for key, value in metrics.items():
                if value.keys():
                    max_time = max(max_time, max(value.keys()))
                if key in ignore_keys:
                    continue
                if include_keys and key not in include_keys:
                    ignore_keys.add(key)
                    continue
                if any(
                        map(lambda x: isinstance(x[1], np.ndarray),
                            value.items())):
                    ignore_keys.add(key)
                    continue  # TODO: nd array not currently supported. maybe in future visualize as heat map?
                metric_keys.add("{}: {}".format(mode, key))
    metric_list = sorted(
        list(metric_keys))  # Sort the metrics alphabetically for consistency
    num_metrics = len(metric_list)
    num_experiments = len(experiments)

    if num_metrics == 0:
        return plt.subplots(111)[0]

    # map the metrics into an n x n grid, then remove any extra rows. Final grid will be m x n with m <= n
    num_cols = math.ceil(math.sqrt(num_metrics))
    metric_grid_location = {
        key: (idx // num_cols, idx % num_cols)
        for (idx, key) in enumerate(metric_list)
    }
    num_rows = math.ceil(num_metrics / num_cols)

    sns.set_context('paper')
    fig, axs = plt.subplots(num_rows,
                            num_cols,
                            sharex='all',
                            figsize=(4 * num_cols, 2.8 * num_rows))

    # If only one row, need to re-format the axs object for consistency. Likewise for columns
    if num_rows == 1:
        axs = [axs]
        if num_cols == 1:
            axs = [axs]

    for metric in metric_grid_location.keys():
        axis = axs[metric_grid_location[metric][0]][
            metric_grid_location[metric][1]]
        axis.set_title(
            metric if not pretty_names else prettify_metric_name(metric))
        axis.ticklabel_format(axis='y', style='sci', scilimits=(-2, 3))
        axis.grid(linestyle='--')
        axis.spines['top'].set_visible(False)
        axis.spines['right'].set_visible(False)
        axis.spines['bottom'].set_visible(False)
        axis.spines['left'].set_visible(False)
        axis.tick_params(bottom=False, left=False)

    for i in range(num_cols):
        axs[num_rows - 1][i].set_xlabel('Steps')

    # some of the columns in the last row might be unused, so disable them
    last_column_idx = num_cols - (num_rows * num_cols - num_metrics) - 1
    for i in range(last_column_idx + 1, num_cols):
        axs[num_rows - 1][i].axis('off')
        axs[num_rows - 2][i].set_xlabel('Steps')
        axs[num_rows - 2][i].xaxis.set_tick_params(which='both',
                                                   labelbottom=True)

    colors = sns.hls_palette(
        n_colors=num_experiments,
        s=0.95) if num_experiments > 10 else sns.color_palette("colorblind")

    handles = []
    labels = []
    bar_counter = defaultdict(lambda: 0)
    for (color_idx, experiment) in enumerate(experiments):
        labels.append(experiment.name)
        metrics = {
            "{}: {}".format(mode, key): val
            for mode, sub in experiment.history.items()
            for key, val in sub.items() if key not in ignore_keys
        }
        for (idx, (metric, value)) in enumerate(metrics.items()):
            data = np.array(list(value.items()))
            if len(data) == 1:
                y = data[0][1]
                if isinstance(y, str):
                    vals = [float(x) for x in re.findall(r'\d+\.?\d+', y)]
                    if len(vals) == 1:
                        y = vals[0]
                width = max(10, max_time // 10)
                x = max_time // 2 + (2 * (bar_counter[metric] % 2) -
                                     1) * width * math.ceil(
                                         bar_counter[metric] / 2)
                ln = axs[metric_grid_location[metric][0]][
                    metric_grid_location[metric][1]].bar(
                        x=x,
                        height=y,
                        color=colors[color_idx],
                        label=experiment.name,
                        width=width)
                bar_counter[metric] += 1
            else:
                y = data[:, 1] if smooth_factor == 0 else gaussian_filter1d(
                    data[:, 1], sigma=smooth_factor)
                ln = axs[metric_grid_location[metric][0]][
                    metric_grid_location[metric][1]].plot(
                        data[:, 0],
                        y,
                        color=colors[color_idx],
                        label=experiment.name,
                        linewidth=1.5)
            if idx == 0:
                handles.append(ln[0])

    plt.tight_layout()

    if len(labels) > 1 or labels[0]:
        if share_legend and num_rows > 1:
            if last_column_idx == num_cols - 1:
                fig.subplots_adjust(bottom=0.15)
                fig.legend(handles,
                           labels,
                           loc='lower center',
                           ncol=num_cols + 1)
            else:
                axs[num_rows - 1][last_column_idx + 1].legend(handles,
                                                              labels,
                                                              loc='center',
                                                              fontsize='large')
        else:
            for i in range(num_rows):
                for j in range(num_cols):
                    if i == num_rows - 1 and j > last_column_idx:
                        break
                    axs[i][j].legend(loc='best', fontsize='small')
    return fig
示例#59
0
np.random.seed(13)

# this example does 3 channel colour, so it has 3 weights
n = 8
som = np.random.rand(n,n,3)
plt.imshow(som)

#num_samples = 6#30

# TODO: plot the colours by themselves to see what they look like (seaborn?)
#data = np.zeros((num_samples,3))
#data = np.random.rand(num_samples,3)
#data = np.array(sns.color_palette())
palette = sns.color_palette()
#palette = sns.hls_palette(16, l=.3, s=.8)
palette = sns.hls_palette(8, l=.3, s=.8)
num_samples = len(palette)
data = np.array(palette)

#sns.palplot(data)

print(data.shape)

num_iterations = 20000 #200000

# learning rate
const_lr = .1#.001
lr = .001
map_radius = 5

lr_i = .5 # initial learning rate