示例#1
0
def generate_plot(datatype, data, ligand_resname, dont_plot_atomtypes=''):
    # Make a list of labels specified by the user not to include in the graph, because they are boring
    # atoms like aliphatic hydrogens with a charge of +0.09
    dont_plot_labels = []
    for l in labels:
        for atomtype in dont_plot_atomtypes.split(','):
            if l.startswith('%s ' % atomtype):
                dont_plot_labels.append(l)
    print >>sys.stderr, 'Will not plot:', ', '.join(dont_plot_labels)

    # Try not to repeat line styles
    pretty_cycler = cycler.cycler(linewidth=[0.3, 0.8, 1.4]) * cycler.cycler('ls', ['-', ':']) * plt.rcParams['axes.prop_cycle']
    plt.rc('axes', prop_cycle=pretty_cycler)

    orig_d = pd.read_csv('%s_%s.csv' % (datatype, ligand_resname))
    print >> sys.stderr, 'Average %s energy: %.2f kcal/mol' % (datatype, np.sum(np.mean(orig_d)))

    # Generate plot
    d = orig_d.drop(dont_plot_labels, axis=1)
    plt.figure(figsize=(8, 4.5))  # 16:9 aspect ratio
    plt.autoscale(tight=True)
    plt.plot(d)
    plt.figtext(0.01, 0.99, 'Average %s energy, including any atom types not plotted: %.2f kcal/mol' % \
                (datatype, np.sum(np.mean(orig_d))), fontsize=5, verticalalignment='top')
    plt.title('%s energy for %s with %s (%s)' % (datatype.capitalize(), args.ligand_resname, args.solute_spec,
                                                 args.namdconf))
    plt.xlabel('Trajectory frame')
    plt.ylabel('Energy (kcal/mol)')

    plt.legend(list(d), fontsize=4, loc='lower left', bbox_to_anchor=(1.0, 0.0))
    plt.savefig('%s.pdf' % prefix)
    print >> sys.stderr, 'Wrote a pretty picture to %s.pdf' % prefix
示例#2
0
def test_by_key_mul():
    input_dict = dict(c=list('rg'), lw=[1, 2, 3])
    cy = cycler(c=input_dict['c']) * cycler(lw=input_dict['lw'])
    res = cy.by_key()
    assert_equal(input_dict['lw'] * len(input_dict['c']),
                 res['lw'])
    yield _by_key_helper, cy
示例#3
0
def test_creation():
    c = cycler(c='rgb')
    yield _cycler_helper, c, 3, ['c'], [['r', 'g', 'b']]
    c = cycler(c=list('rgb'))
    yield _cycler_helper, c, 3, ['c'], [['r', 'g', 'b']]
    c = cycler(cycler(c='rgb'))
    yield _cycler_helper, c, 3, ['c'], [['r', 'g', 'b']]
    def test_continous_color_cycle_line_and_markers(self):
        #TODO: Test that i can also cycle line and markers. I mean the product line_cycler * marker_cycler
        color_cycler = (cycler('color', ['r', 'g', 'b']))
        marker_cycler = (cycler('marker', ['o', '+', 's']))
        line_cycler = (cycler('linestyle', ['-', '--', '-.']))

        style_cycler = (marker_cycler*line_cycler)

        color_cycle_len = len(color_cycler)
        color_cycle = color_cycler()

        used_style_color_combo = set()
        color_style_cycler = utils.ContinousColorCycle(color_cycle, color_cycle_len, style_cycler, used_style_color_combo)
        color_line_cycle = utils.ContinousColorCycle(color_cycle, color_cycle_len, line_cycler, used_style_color_combo)
        color_marker_cycle = utils.ContinousColorCycle(color_cycle, color_cycle_len, marker_cycler,
                                                       used_style_color_combo)

        res = []
        res.append(dict_to_tuple(next(color_style_cycler)))
        res.append(dict_to_tuple(next(color_style_cycler)))
        res.append(dict_to_tuple(next(color_marker_cycle)))
        res.append(dict_to_tuple(next(color_line_cycle)))
        res.append(dict_to_tuple(next(color_style_cycler)))
        res.append(dict_to_tuple(next(color_style_cycler)))
        res = tuple(res)
        print(str(res))
        assert res == ((('color', 'r'), ('linestyle', '-'), ('marker', 'o')),
                       (('color', 'g'), ('linestyle', '-'), ('marker', 'o')),
                       (('color', 'b'), ('marker', 'o')),
                       (('color', 'r'), ('linestyle', '-')),
                       (('color', 'b'), ('linestyle', '-'), ('marker', 'o')),
                       (('color', 'r'), ('linestyle', '--'), ('marker', 'o')))
示例#5
0
def test_invalid_input_forms():
    fig, ax = plt.subplots()

    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle(1)
    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle([1, 2])

    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle('color', 'fish')

    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle('linewidth', 1)
    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle('linewidth', {'1': 1, '2': 2})
    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle(linewidth=1, color='r')

    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle('foobar', [1, 2])
    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle(foobar=[1, 2])

    with pytest.raises((TypeError, ValueError)):
        ax.set_prop_cycle(cycler(foobar=[1, 2]))
    with pytest.raises(ValueError):
        ax.set_prop_cycle(cycler(color='rgb', c='cmy'))
示例#6
0
def update_prop_cycle(linewidth):
    # https://github.com/vega/vega/wiki/Scales#scale-range-literals
    colors = [
        '#1f77b4',
        '#ff7f0e',
        '#2ca02c',
        '#d62728',
        '#9467bd',
        '#8c564b',
        '#e377c2',
        '#7f7f7f',
        '#bcbd22',
        '#17becf',
    ] * 3
    dashes_list = [
        (4, 1),
        (2, 1),
        (4, 1, 2, 1),
        (4, 1, 2, 1, 2, 1),
        (4, 1, 2, 1, 2, 1, 2, 1),
        (8, 1, 4, 1),
    ] * 5
    for i, dashes in enumerate(dashes_list):
        dashes_list[i] = _modify_dashes_by_linewidth(dashes, linewidth)

    plt.rc('axes', prop_cycle=cycler('color', colors) + cycler('dashes', dashes_list))
示例#7
0
def test_marker_cycle():
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_prop_cycle(cycler('color', ['r', 'g', 'y']) +
                      cycler('marker', ['.', '*', 'x']))
    xs = np.arange(10)
    ys = 0.25 * xs + 2
    ax.plot(xs, ys, label='red dot', lw=4, ms=16)
    ys = 0.45 * xs + 3
    ax.plot(xs, ys, label='green star', lw=4, ms=16)
    ys = 0.65 * xs + 4
    ax.plot(xs, ys, label='yellow x', lw=4, ms=16)
    ys = 0.85 * xs + 5
    ax.plot(xs, ys, label='red2 dot', lw=4, ms=16)
    ax.legend(loc='upper left')

    fig = plt.figure()
    ax = fig.add_subplot(111)
    # Test keyword arguments, numpy arrays, and generic iterators
    ax.set_prop_cycle(color=np.array(['r', 'g', 'y']),
                      marker=iter(['.', '*', 'x']))
    xs = np.arange(10)
    ys = 0.25 * xs + 2
    ax.plot(xs, ys, label='red dot', lw=4, ms=16)
    ys = 0.45 * xs + 3
    ax.plot(xs, ys, label='green star', lw=4, ms=16)
    ys = 0.65 * xs + 4
    ax.plot(xs, ys, label='yellow x', lw=4, ms=16)
    ys = 0.85 * xs + 5
    ax.plot(xs, ys, label='red2 dot', lw=4, ms=16)
    ax.legend(loc='upper left')
示例#8
0
def spiral_fermat(x_motor, y_motor, x_start, y_start, x_range, y_range, dr,
                  factor, *, dr_y=None, tilt=0.0):
    '''Absolute fermat spiral scan, centered around (x_start, y_start)

    Parameters
    ----------
    x_motor : object, optional
        any 'setable' object (motor, temp controller, etc.)
    y_motor : object, optional
        any 'setable' object (motor, temp controller, etc.)
    x_start : float
        x center
    y_start : float
        y center
    x_range : float
        x width of spiral
    y_range : float
        y width of spiral
    dr : float
        delta radius along the minor axis of the ellipse.
    dr_y : float, optional
        Delta radius along the major axis of the ellipse, if not specifed defaults to dr
    factor : float
        radius gets divided by this
    tilt : float, optional
        Tilt angle in radians, default 0.0

    Returns
    -------
    cyc : cycler
    '''
    if dr_y is None:
        dr_aspect = 1
    else:
        dr_aspect = dr_y / dr
    
    phi = 137.508 * np.pi / 180.

    half_x = x_range / 2
    half_y = y_range / (2 * dr_aspect)
    tilt_tan = np.tan(tilt + np.pi / 2.)

    x_points, y_points = [], []

    diag = np.sqrt(half_x ** 2 + half_y ** 2)
    num_rings = int((1.5 * diag / (dr / factor)) ** 2)
    for i_ring in range(1, num_rings):
        radius = np.sqrt(i_ring) * dr / factor
        angle = phi * i_ring
        x = radius * np.cos(angle)
        y = radius * np.sin(angle) * dr_aspect

        if ((abs(x - (y / dr_aspect) / tilt_tan) <= half_x) and (abs(y) <= half_y)):
            x_points.append(x_start + x)
            y_points.append(y_start + y)

    cyc = cycler(x_motor, x_points)
    cyc += cycler(y_motor, y_points)
    return cyc
示例#9
0
def test_multiply():
    c1 = cycler('c', 'rgb')
    yield _cycler_helper, 2*c1, 6, ['c'], ['rgb'*2]

    c2 = cycler('ec', c1)
    c3 = c1 * c2

    yield _cycles_equal, 2*c3, c3*2
def richify_line_style(plt):
    plt.style.use('fivethirtyeight')
    plt.rc('axes',
           prop_cycle=(
               cycler('color', ['r', 'r', 'g', 'g', 'g', 'b']) +
               cycler('linestyle', ['-', '--', ':', '-.', '--', '-']) +
               cycler('marker', ['o', 'v', 's', '*', 'o', '*'])
           ))
示例#11
0
def test_cycler_parent_and_parts_fail(hw, children):
    p3x3 = hw.pseudo3x3
    cyc = reduce(operator.add, (cycler(getattr(p3x3, k), range(5))
                                for k in children))
    cyc += cycler(p3x3, range(5))

    with pytest.raises(ValueError):
        merge_cycler(cyc)
示例#12
0
def test_concat():
    a = cycler('a', range(3))
    b = cycler('a', 'abc')
    for con, chn in zip(a.concat(b), chain(a, b)):
        assert_equal(con, chn)

    for con, chn in zip(concat(a, b), chain(a, b)):
        assert_equal(con, chn)
示例#13
0
def test_getitem():
    c1 = cycler('lw', range(15))
    widths = list(range(15))
    for slc in (slice(None, None, None),
                slice(None, None, -1),
                slice(1, 5, None),
                slice(0, 5, 2)):
        yield _cycles_equal, c1[slc], cycler('lw', widths[slc])
示例#14
0
def test_cycler_parent_and_parts_succed(hw, children):
    p3x3 = hw.pseudo3x3
    cyc = reduce(operator.add, (cycler(getattr(p3x3, k), range(5))
                                for k in children))
    cyc += cycler(p3x3, range(5))
    mcyc = merge_cycler(cyc)

    assert mcyc.keys == cyc.keys
    assert mcyc.by_key() == cyc.by_key()
示例#15
0
文件: plots.py 项目: ibell/pdsim
    def add(self,name="plot"):
        page = Plot(self.nb)
        self.nb.AddPage(page,name)
        page.figure.gca().set_prop_cycle(
            cycler('color', ['r', 'g', 'b', 'y', 'm', 'c']) *
            cycler('linestyle', ['-', '--', '-.'])
            )

        return page.figure
示例#16
0
def test_cn():
    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['blue', 'r'])
    assert mcolors.to_hex("C0") == '#0000ff'
    assert mcolors.to_hex("C1") == '#ff0000'

    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['xkcd:blue', 'r'])
    assert mcolors.to_hex("C0") == '#0343df'
    assert mcolors.to_hex("C1") == '#ff0000'
示例#17
0
def test_keychange():
    c1 = cycler('c', 'rgb')
    c2 = cycler('lw', [1, 2, 3])
    c3 = cycler('ec', 'yk')

    c3.change_key('ec', 'edgecolor')
    assert_equal(c3, cycler('edgecolor', c3))

    c = c1 + c2
    c.change_key('lw', 'linewidth')
    # Changing a key in one cycler should have no
    # impact in the original cycler.
    assert_equal(c2, cycler('lw', [1, 2, 3]))
    assert_equal(c, c1 + cycler('linewidth', c2))

    c = (c1 + c2) * c3
    c.change_key('c', 'color')
    assert_equal(c1, cycler('c', 'rgb'))
    assert_equal(c, (cycler('color', c1) + c2) * c3)

    # Perfectly fine, it is a no-op
    c.change_key('color', 'color')
    assert_equal(c, (cycler('color', c1) + c2) * c3)

    # Can't change a key to one that is already in there
    assert_raises(ValueError, Cycler.change_key, c, 'color', 'lw')
    # Can't change a key you don't have
    assert_raises(KeyError, Cycler.change_key, c, 'c', 'foobar')
示例#18
0
    def make_diag_figure(self, xnames, ynames):
        nobj = len(xnames)

        # initialize subplot size
        gs, fig = ftools.gen_gridspec_fig(
            nobj, add_row=False, border=(0.6, 0.6, 0.2, 0.4),
            space=(0.6, 0.35))

        # set up subplot interactions
        gs_geo = gs.get_geometry()

        fgrid_r, fgrid_c = tuple(list(range(n)) for n in gs_geo)
        gs_iter = iproduct(fgrid_r, fgrid_c)

        # set up kwargs for matplotlib errorbar
        # prefer to change color, then marker
        colors_ = ['C{}'.format(cn) for cn in range(10)]
        markers_ = ['o', 'v', 's', 'P', 'X', 'D', 'H']

        # iterate through rows & columns (rows change fastest)
        # which correspond to different quantities
        for (i, (ri, ci)) in enumerate(gs_iter):

            if i >= len(xnames):
                continue

            # choose axis
            ax = fig.add_subplot(gs[ri, ci])
            kwarg_cycler = cycler(marker=markers_) * \
                           cycler(c=colors_)

            xqty = xnames[i]
            yqty = ynames[i]

            # now iterate through results hdulists
            for (j, (result, kwargs)) in enumerate(
                zip(self.results, kwarg_cycler)):

                kwargs['label'] = result[0].header['PLATEIFU']

                ax = self._add_log_offset_plot(
                    j, xqty=xqty, yqty=yqty, ax=ax, **kwargs)

                ax.tick_params(labelsize=5)

            if i == 0:
                handles_, labels_ = ax.get_legend_handles_labels()
                plt.figlegend(
                    handles=handles_, labels=labels_,
                    loc='upper right', prop={'size': 4.})

        fig.suptitle('PCA fitting diagnostics', size=8.)

        return fig
示例#19
0
def test_failures():
    c1 = cycler('c', 'rgb')
    c2 = cycler('c', c1)
    assert_raises(ValueError, add, c1, c2)
    assert_raises(ValueError, iadd, c1, c2)
    assert_raises(ValueError, mul, c1, c2)
    assert_raises(ValueError, imul, c1, c2)

    c3 = cycler('ec', c1)

    assert_raises(ValueError, cycler, 'c', c2 + c3)
示例#20
0
def test_inplace():
    c1 = cycler('c', 'rgb')
    c2 = cycler('lw', range(3))
    c2 += c1
    yield _cycler_helper, c2, 3, ['c', 'lw'], [list('rgb'), range(3)]

    c3 = cycler('c', 'rgb')
    c4 = cycler('lw', range(3))
    c3 *= c4
    target = zip(*product(list('rgb'), range(3)))
    yield (_cycler_helper, c3, 9, ['c', 'lw'], target)
示例#21
0
def spiral(x_motor, y_motor, x_start, y_start, x_range, y_range, dr, nth, *,
           tilt=0.0):
    '''Spiral scan, centered around (x_start, y_start)

    Parameters
    ----------
    x_motor : object, optional
        any 'setable' object (motor, temp controller, etc.)
    y_motor : object, optional
        any 'setable' object (motor, temp controller, etc.)
    x_start : float
        x center
    y_start : float
        y center
    x_range : float
        x width of spiral
    y_range : float
        y width of spiral
    dr : float
        Delta radius
    nth : float
        Number of theta steps
    tilt : float, optional
        Tilt angle in radians, default 0.0

    Returns
    -------
    cyc : cycler
    '''
    half_x = x_range / 2
    half_y = y_range / 2

    r_max = np.sqrt(half_x ** 2 + half_y ** 2)
    num_ring = 1 + int(r_max / dr)
    tilt_tan = np.tan(tilt + np.pi / 2.)

    x_points, y_points = [], []

    for i_ring in range(1, num_ring + 2):
        radius = i_ring * dr
        angle_step = 2. * np.pi / (i_ring * nth)

        for i_angle in range(int(i_ring * nth)):
            angle = i_angle * angle_step
            x = radius * np.cos(angle)
            y = radius * np.sin(angle)
            if ((abs(x - y / tilt_tan) <= half_x) and (abs(y) <= half_y)):
                x_points.append(x_start + x)
                y_points.append(y_start + y)

    cyc = cycler(x_motor, x_points)
    cyc += cycler(y_motor, y_points)
    return cyc
示例#22
0
def test_eq():
    a = cycler(c='rgb')
    b = cycler(c='rgb')
    yield _eq_test_helper, a, b, True
    yield _eq_test_helper, a, b[::-1], False
    c = cycler(lw=range(3))
    yield _eq_test_helper, a+c, c+a, True
    yield _eq_test_helper, a+c, c+b, True
    yield _eq_test_helper, a*c, c*a, False
    yield _eq_test_helper, a, c, False
    d = cycler(c='ymk')
    yield _eq_test_helper, b, d, False
示例#23
0
def multi_sample_edge(*, edge_list=None, sample_list=None):
    if sample_list is None:
        sample_list = list(SAMPLE_MAP)
    if edge_list is None:
        edge_list = list(EDGE_MAP)
#    edge_list = sorted(edge_list, key=lambda k: EDGE_MAP[k]['start'])
    cy = cycler('edge', edge_list) * cycler('sample_name', sample_list)
    for inp in cy:
        if pass_filter(**inp):
            yield from edge_ascan(**inp)
    yield from bps.abs_set(valve_diag3_close, 1)
    yield from bps.abs_set(valve_mir3_close, 1)
示例#24
0
def test_failures():
    c1 = cycler(c='rgb')
    c2 = cycler(c=c1)
    assert_raises(ValueError, add, c1, c2)
    assert_raises(ValueError, iadd, c1, c2)
    assert_raises(ValueError, mul, c1, c2)
    assert_raises(ValueError, imul, c1, c2)
    assert_raises(TypeError, iadd, c2, 'aardvark')
    assert_raises(TypeError, imul, c2, 'aardvark')

    c3 = cycler(ec=c1)

    assert_raises(ValueError, cycler, c=c2+c3)
示例#25
0
def test_marker_cycle():
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_prop_cycle(cycler("color", ["r", "g", "y"]) + cycler("marker", [".", "*", "x"]))
    xs = np.arange(10)
    ys = 0.25 * xs + 2
    ax.plot(xs, ys, label="red dot", lw=4, ms=16)
    ys = 0.45 * xs + 3
    ax.plot(xs, ys, label="green star", lw=4, ms=16)
    ys = 0.65 * xs + 4
    ax.plot(xs, ys, label="yellow x", lw=4, ms=16)
    ys = 0.85 * xs + 5
    ax.plot(xs, ys, label="red2 dot", lw=4, ms=16)
    ax.legend(loc="upper left")
示例#26
0
def test_cn():
    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['blue', 'r'])
    x11_blue = mcolors.rgb2hex(mcolors.colorConverter.to_rgb('C0'))
    assert x11_blue == '#0000ff'
    red = mcolors.rgb2hex(mcolors.colorConverter.to_rgb('C1'))
    assert red == '#ff0000'

    matplotlib.rcParams['axes.prop_cycle'] = cycler('color',
                                                    ['XKCDblue', 'r'])
    XKCD_blue = mcolors.rgb2hex(mcolors.colorConverter.to_rgb('C0'))
    assert XKCD_blue == '#0343df'
    red = mcolors.rgb2hex(mcolors.colorConverter.to_rgb('C1'))
    assert red == '#ff0000'
示例#27
0
def test_marker_cycle():
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_prop_cycle(cycler('color', ['r', 'g', 'y']) +
                      cycler('marker', ['.', '*', 'x']))
    xs = np.arange(10)
    ys = 0.25 * xs + 2
    ax.plot(xs, ys, label='red dot', lw=4, ms=16)
    ys = 0.45 * xs + 3
    ax.plot(xs, ys, label='green star', lw=4, ms=16)
    ys = 0.65 * xs + 4
    ax.plot(xs, ys, label='yellow x', lw=4, ms=16)
    ys = 0.85 * xs + 5
    ax.plot(xs, ys, label='red2 dot', lw=4, ms=16)
    ax.legend(loc='upper left')
示例#28
0
def test_repr():
    c = cycler('c', 'rgb')
    c2 = cycler('lw', range(3))

    c_sum_rpr = "(cycler('c', ['r', 'g', 'b']) + cycler('lw', [0, 1, 2]))"
    c_prod_rpr = "(cycler('c', ['r', 'g', 'b']) * cycler('lw', [0, 1, 2]))"

    yield _repr_tester_helper, '__repr__', c + c2, c_sum_rpr
    yield _repr_tester_helper, '__repr__', c * c2, c_prod_rpr

    sum_html = "<table><th>'c'</th><th>'lw'</th><tr><td>'r'</td><td>0</td></tr><tr><td>'g'</td><td>1</td></tr><tr><td>'b'</td><td>2</td></tr></table>"
    prod_html = "<table><th>'c'</th><th>'lw'</th><tr><td>'r'</td><td>0</td></tr><tr><td>'r'</td><td>1</td></tr><tr><td>'r'</td><td>2</td></tr><tr><td>'g'</td><td>0</td></tr><tr><td>'g'</td><td>1</td></tr><tr><td>'g'</td><td>2</td></tr><tr><td>'b'</td><td>0</td></tr><tr><td>'b'</td><td>1</td></tr><tr><td>'b'</td><td>2</td></tr></table>"

    yield _repr_tester_helper, '_repr_html_', c + c2, sum_html
    yield _repr_tester_helper, '_repr_html_', c * c2, prod_html
def add_to_rc_defaultParams():
    """
    Adds parameters to rcParams

    :return:
    """
    params_to_add = {'axes.midv_line_cycle': [cycler('linestyle', ['-', '--', '-.', ':']), rcsetup.validate_cycler],
                     'axes.midv_marker_cycle': [cycler('marker', ['o', '+', 's', 'x']), rcsetup.validate_cycler],
                     'legend.midv_ncol': [1, rcsetup.validate_int]}

    rcsetup.defaultParams.update(params_to_add)
    mpl.RcParams.validate = dict((key, converter) for key, (default, converter) in
                    six.iteritems(rcsetup.defaultParams))
    mpl.rcParamsDefault.update({k: v[0] for k, v in params_to_add.items()})
    mpl.rcdefaults()
示例#30
0
def test_fillcycle_basic():
    fig, ax = plt.subplots()
    ax.set_prop_cycle(cycler('c',  ['r', 'g', 'y']) +
                      cycler('hatch', ['xx', 'O', '|-']) +
                      cycler('linestyle', ['-', '--', ':']))
    xs = np.arange(10)
    ys = 0.25 * xs**.5 + 2
    ax.fill(xs, ys, label='red, xx', linewidth=3)
    ys = 0.45 * xs**.5 + 3
    ax.fill(xs, ys, label='green, circle', linewidth=3)
    ys = 0.65 * xs**.5 + 4
    ax.fill(xs, ys, label='yellow, cross', linewidth=3)
    ys = 0.85 * xs**.5 + 5
    ax.fill(xs, ys, label='red2, xx', linewidth=3)
    ax.legend(loc='upper left')
示例#31
0
# print(population[:,1])
#
# x = range(0, 1001)
# y = population[:,1]
# p = plt.plot(x, y, "o")

num_plots = num_species

# Have a look at the colormaps here and decide which one you'd like:
# http://matplotlib.org/1.2.1/examples/pylab_examples/show_colormaps.html
from cycler import cycler

colors = [plt.cm.spectral(i) for i in np.linspace(0, 1, num_species)]

colormap = plt.cm.gist_ncar
plt.gca().set_prop_cycle(cycler('color', colors))

# Plot several different functions...
x = np.arange(num_time + 1)
labels = []
plt.subplot(2, 2, 1)
for i in range(1, num_plots + 1):
    plt.plot(x, trait_N1[:, i - 1])

plt.subplot(2, 2, 2)
for i in range(1, num_plots + 1):
    plt.plot(x, population_N1[:, i - 1])
plt.subplot(2, 2, 3)
sns.distplot(trait_N1[num_time, :], hist=False, rug=True)

plt.subplot(2, 2, 4)
示例#32
0
def set_rcParams_scanpy(fontsize=12, color_map=None, frameon=None):
    """Set matplotlib.rcParams to Scanpy defaults."""

    # dpi options
    rcParams['figure.dpi'] = 100
    rcParams['savefig.dpi'] = 150

    # figure
    rcParams['figure.figsize'] = (4, 4)
    rcParams['figure.subplot.left'] = 0.18
    rcParams['figure.subplot.right'] = 0.96
    rcParams['figure.subplot.bottom'] = 0.15
    rcParams['figure.subplot.top'] = 0.91

    rcParams['lines.linewidth'] = 1.5  # the line width of the frame
    rcParams['lines.markersize'] = 6
    rcParams['lines.markeredgewidth'] = 1

    # font
    rcParams['font.sans-serif'] = [
        'Arial', 'Helvetica', 'DejaVu Sans',
        'Bitstream Vera Sans', 'sans-serif']

    fontsize = fontsize
    labelsize = 0.92 * fontsize

    rcParams['font.size'] = fontsize
    rcParams['legend.fontsize'] = labelsize
    rcParams['axes.titlesize'] = fontsize
    rcParams['axes.labelsize'] = fontsize

    # legend
    rcParams['legend.numpoints'] = 1
    rcParams['legend.scatterpoints'] = 1
    rcParams['legend.handlelength'] = 0.5
    rcParams['legend.handletextpad'] = 0.4

    # color cycle
    rcParams['axes.prop_cycle'] = cycler(color=vega_20)

    # lines
    rcParams['axes.linewidth'] = 0.8
    rcParams['axes.edgecolor'] = 'black'
    rcParams['axes.facecolor'] = 'white'

    # ticks
    rcParams['xtick.color'] = 'k'
    rcParams['ytick.color'] = 'k'
    rcParams['xtick.labelsize'] = fontsize
    rcParams['ytick.labelsize'] = fontsize

    # axes grid
    rcParams['axes.grid'] = True
    rcParams['grid.color'] = '.8'

    # color map
    rcParams['image.cmap'] = rcParams['image.cmap'] if color_map is None else color_map

    # frame
    frameon = True if frameon is None else frameon
    global _frameon
    _frameon = frameon
示例#33
0
  ax.xaxis.set_major_formatter(x_majorFormatter)
  ax.xaxis.set_minor_locator(x_minorLocator)
  
  ax.yaxis.set_major_locator(y_majorLocator)
  ax.yaxis.set_major_formatter(y_majorFormatter)
  ax.yaxis.set_minor_locator(y_minorLocator)
  
  ax.tick_params(which='major',direction='in', length=8, width=1  , bottom=True, top=True, left=True, right=True ,labelbottom=True, labeltop=False, labelleft=True, labelright=False,)
  ax.tick_params(which='minor',direction='in', length=4, width=0.5, bottom=True, top=True, left=True, right=True ,labelbottom=True, labeltop=False, labelleft=True, labelright=False,)
  return

fig,ax=plt.subplots(figsize=(4,4))

leg_text=[]

line_cycler =  ( cycler('color', ['k', 'r', 'g', 'b', 'm', 'c']) *
                 cycler('lw', [0.5,]) * cycler('linestyle', ['-',]))

marker_cycler =  (cycler('marker', ['o', 'd', 's', 'X','x','v','^','<','>','p','+','P']))

data=np.loadtxt('Errors_C1L1.dat')

dxm1=1./data[:,0]

ax.loglog(dxm1,data[:,3]      ,'-o',lw=0.75,label='normals')
ax.loglog(dxm1,data[:,5]      ,'-s',lw=0.75,label='mean curvature')
ax.loglog(dxm1,data[:,11]     ,'-d',lw=0.75,label='membrane force')
ax.loglog(dxm1,3*data[:,0]**1    ,'k:' ,lw=0.9,label='$\Delta X$')
ax.loglog(dxm1,50*data[:,0]**2,'k--',lw=0.9,label='$\Delta X^{2}$')

ax.set_xlabel('$\Delta X^{-1}$')
示例#34
0
        (1. - 0.99 * outdoorFraction['outdoorFraction'].resample('W').mean()) *
        10
    ],
    axis=1)

# disease import:
# rr3.at[pd.to_datetime("2020-02-24"),'diseaseImport'] = 0.9
# ...
# rr3['diseaseImport'].interpolate(inplace=True)

pyplot.close('all')
pyplot.rcParams['figure.figsize'] = [12, 5]

default_cycler = (
    cycler(color=[
        'r', 'g', 'b', 'y', 'purple', 'purple', 'orange', 'cyan', 'brown', 'r',
        'orange'
    ]) + cycler(linestyle=['', '', '', '', '', '', '-', '', '', '-', '']) +
    cycler(marker=['.', '', '', '', '.', '', '', '.', 'o', '', '.']))
pyplot.rc('axes', prop_cycle=default_cycler)
axes = rr3.plot(logy=True, grid=True, legend=None)
axes.set_ylim(0.9, 4000)
axes.set_xlim(pd.to_datetime('2020-02-10'), pd.to_datetime('2020-12-01'))
pyplot.errorbar(rr3.index,
                rr3['nShowingSymptomsCumulative'],
                yerr=3 * np.sqrt(rr3['nShowingSymptomsCumulative']))

# pyplot.axvline(pd.to_datetime('2020-03-10'), color='gray', linestyle=':', lw=1)
# pyplot.axvline(pd.to_datetime('2020-03-17'), color='gray', linestyle=':', lw=1)
# pyplot.axvline(pd.to_datetime('2020-03-22'), color='gray', linestyle=':', lw=1)
# pyplot.axhline(32,color='gray',linestyle='dotted')
from cycler import cycler
import numpy as np
import sliceplots
from matplotlib import pyplot
from collections import defaultdict

line_colors = ["C1", "C2", "C3", "C4", "C5", "C6"]
line_styles = ["-", "--", ":", "-.", (0, (1, 10)), (0, (5, 10))]

cyl = cycler(color=line_colors) + cycler(linestyle=line_styles)

loop_cy_iter = cyl()

STYLE = defaultdict(lambda: next(loop_cy_iter))


class Peak:
    def __init__(self, startidx):
        self.born = self.left = self.right = startidx
        self.died = None

    def get_persistence(self, seq):
        return float("inf") if self.died is None else seq[self.born] - seq[self.died]


def get_persistent_homology(seq):
    peaks = []
    # Maps indices to peaks
    idxtopeak = [None for s in seq]
    # Sequence indices sorted by values
    indices = range(len(seq))
示例#36
0
def plot_iso(isotherms,
             ax=None,
             x_data='pressure',
             y1_data='loading',
             y2_data=None,
             branch="all",
             logx=False,
             logy1=False,
             logy2=False,
             color=True,
             marker=None,
             adsorbent_basis="mass",
             adsorbent_unit="g",
             loading_basis="molar",
             loading_unit="mmol",
             pressure_mode="absolute",
             pressure_unit="bar",
             x_range=(None, None),
             y1_range=(None, None),
             y2_range=(None, None),
             fig_title=None,
             lgd_keys=None,
             lgd_pos='best',
             save_path=None,
             **other_parameters):
    """
    Plot the isotherm(s) provided on a single graph.

    Parameters
    ----------
    isotherms : PointIsotherms or list of Pointisotherms
        An isotherm or iterable of isotherms to be plotted.
    ax : matplotlib axes object, default None
        The axes object where to plot the graph if a new figure is
        not desired.

    x_data : str
        Key of data to plot on the x axis. Defaults to 'pressure'.
    y1_data : tuple
        Key of data to plot on the left y axis. Defaults to 'loading'.
    y2_data : tuple
        Key of data to plot on the right y axis. Defaults to None.
    branch : str
        Which branch to display, adsorption ('ads'), desorption ('des'),
        both ('all') or both with a single legend entry ('all-nol').
    logx : bool
        Whether the graph x axis should be logarithmic.
    logy1 : bool
        Whether the graph y1 axis should be logarithmic.
    logy2 : bool
        Whether the graph y2 axis should be logarithmic.

    color : bool, int, list, optional
        If a boolean, the option controls if the graph is coloured or
        grayscale. Grayscale graphs are usually preferred for publications
        or print media. If an int, it will be the number of colours the
        colourspace is divided into. If a list of matplotlib colour names
        or values, it will be passed directly to the plot function.
    marker : bool, int, list, optional
        Whether the graph should contain different markers.
        Implied ``True`` if color=False. Set both to "True" to
        get both effects at the same time.
        If an int, it will be the number of markers used.
        If a list of matplotlib markers,
        it will be passed directly to the plot function.

    adsorbent_basis : str, optional
        Whether the adsorption is read in terms of either 'per volume'
        or 'per mass'.
    adsorbent_unit : str, optional
        Unit of loading.
    loading_basis : str, optional
        Loading basis.
    loading_unit : str, optional
        Unit of loading.
    pressure_mode : str, optional
        The pressure mode, either absolute pressures or relative in
        the form of p/p0.
    pressure_unit : str, optional
        Unit of pressure.

    x_range : tuple
        Range for data on the x axis. eg: (0, 1). Is applied to each
        isotherm, in the unit/mode/basis requested.
    y1_range : tuple
        Range for data on the regular y axis. eg: (0, 1). Is applied to each
        isotherm, in the unit/mode/basis requested.
    y2_range : tuple
        Range for data on the secondary y axis. eg: (0, 1). Is applied to each
        isotherm, in the unit/mode/basis requested.

    fig_title : str
        Title of the graph. Defaults to none.
    lgd_keys : iterable
        The components of the isotherm which are displayed on the legend. For example
        pass ['material', 'material_batch'] to have the legend labels display only these
        two components. Works with any isotherm properties and with 'branch' and 'key',
        the isotherm branch and the y-axis key respectively.
        Defaults to 'material' and 'adsorbate'.
    lgd_pos : [None, 'best', 'bottom', 'right', 'inner']
        Specify to have the legend position to the bottom, the right of the graph
        or inside the plot area itself. Defaults to 'best'.

    save_path : str, optional
        Whether to save the graph or not.
        If a path is provided, then that is where the graph will be saved.

    Other Parameters
    ----------------
    fig_style : dict
        A dictionary that will be passed into the matplotlib figure()
        function.

    title_style : dict
        A dictionary that will be passed into the matplotlib set_title() function.

    label_style : dict
        A dictionary that will be passed into the matplotlib set_label() function.

    y1_line_style : dict
        A dictionary that will be passed into the matplotlib plot() function.
        Applicable for left axis.

    y2_line_style : dict
        A dictionary that will be passed into the matplotlib plot() function.
        Applicable for right axis.

    tick_style : dict
        A dictionary that will be passed into the matplotlib tick_params() function.

    legend_style : dict
        A dictionary that will be passed into the matplotlib legend() function.

    save_style : dict
        A dictionary that will be passed into the matplotlib savefig() function
        if the saving of the figure is selected.

    Returns
    -------
    axes : matplotlib.axes.Axes or numpy.ndarray of them

    """
    #######################################
    #
    # Initial checks
    # Make iterable if not already
    if not isinstance(isotherms, abc.Iterable):
        isotherms = [isotherms]

    # Check for plot type validity
    if None in [x_data, y1_data]:
        raise ParameterError("Specify a plot type to graph"
                             " e.g. x_data=\'loading\', y1_data=\'pressure\'")

    # Check if required keys are present in isotherms
    def keys(iso):
        ks = ['loading', 'pressure']
        ks.extend(getattr(iso, 'other_keys', []))
        return ks

    if any(x_data not in keys(isotherm) for isotherm in isotherms):
        raise GraphingError(
            "None of the isotherms supplied have {} data".format(x_data))

    if any(y1_data not in keys(isotherm) for isotherm in isotherms):
        raise GraphingError(
            "None of the isotherms supplied have {} data".format(y1_data))

    if y2_data is not None:
        if all(y2_data not in keys(isotherm) for isotherm in isotherms):
            raise GraphingError(
                "None of the isotherms supplied have {} data".format(y2_data))
        elif any(y2_data not in keys(isotherm) for isotherm in isotherms):
            warnings.warn('Some isotherms do not have {} data'.format(y2_data))

    # Store which branches will be displayed
    if branch is None:
        raise ParameterError("Specify a branch to display"
                             " e.g. branch=\'ads\'")
    if branch not in _BRANCH_TYPES:
        raise GraphingError("The supplied branch type is not valid."
                            "Viable types are {}".format(_BRANCH_TYPES))

    ads = False
    des = False
    if branch == 'ads':
        ads = True
    elif branch == 'des':
        des = True
    else:
        ads = True
        des = True

    log_params = dict(logx=logx, logy1=logy1, logy2=logy2)
    range_params = dict(x_range=x_range, y1_range=y1_range, y2_range=y2_range)
    iso_params = dict(
        pressure_mode=pressure_mode,
        pressure_unit=pressure_unit,
        loading_basis=loading_basis,
        loading_unit=loading_unit,
        adsorbent_basis=adsorbent_basis,
        adsorbent_unit=adsorbent_unit,
    )

    #######################################
    #
    # Settings and graph generation

    # Create style dictionaries and get user defined ones
    styles = copy.deepcopy(ISO_STYLES)

    # Overwrite with any user provided styles
    for style in styles:
        new_style = other_parameters.get(style)
        if new_style:
            styles[style].update(new_style)

    #
    # Generate the graph itself
    if ax:
        ax1 = ax
        fig = ax1.get_figure()
    else:
        fig = plt.figure(**styles['fig_style'])
        ax1 = plt.subplot(111)

    # Create empty axes object
    ax2 = None
    # Populate it if required
    if y2_data:
        ax2 = ax1.twinx()

    # Build the name of the axes
    def get_name(key):
        if key == 'pressure':
            if pressure_mode == "absolute":
                text = 'Pressure ($' + pressure_unit + '$)'
            elif pressure_mode == "relative":
                text = "$p/p^0$"
        elif key == 'loading':
            text = 'Loading ($' + loading_unit + '/' + adsorbent_unit + '$)'
        elif key == 'enthalpy':
            text = r'$\Delta_{ads}h$ $(-kJ\/mol^{-1})$'
        else:
            text = key
        return text

    x_label = get_name(x_data)
    y1_label = get_name(y1_data)
    if y2_data:
        y2_label = get_name(y2_data)

    # Get a cycling style for the graph
    if color:
        if isinstance(color, bool):
            colors = [cm.jet(x) for x in numpy.linspace(0, 1, 7)]
        elif isinstance(color, int):
            colors = [cm.jet(x) for x in numpy.linspace(0, 1, color)]
        elif isinstance(color, list):
            colors = color
        else:
            raise ParameterError("Unknown ``color`` parameter type.")

        color_cy = cycler('color', colors)

    else:
        color_cy = cycler('color', ['black', 'grey', 'silver'])

    all_markers = ['o', 's', 'D', 'P', '*', '<', '>', 'X', 'v', '^']
    if marker is None:
        marker = True

    cycle_compose = True
    if isinstance(marker, bool):
        if marker:
            cycle_compose = False
            markers = all_markers
        else:
            markers = []
    elif isinstance(marker, int):
        marker = len(all_markers) if marker > len(all_markers) else marker
        markers = all_markers[:marker]
    elif isinstance(marker, list):
        markers = marker
    else:
        raise ParameterError("Unknown ``marker`` parameter type.")

    y1_marker_cy = cycler('marker', markers)
    y2_marker_cy = cycler('marker', markers[::-1])

    def extend_cycle(cy_1, cy_2):
        l_1 = len(cy_1)
        l_2 = len(cy_2)
        if l_1 == 0:
            return cycle(cy_2)
        if l_2 == 0:
            return cycle(cy_1)
        if l_1 > l_2:
            return cycle(cy_1 + (cy_2 * math.ceil(l_1 / l_2))[:l_1])
        return cycle(cy_2 + (cy_1 * math.ceil(l_2 / l_1))[:l_2])

    if cycle_compose:
        pc_primary = extend_cycle(y1_marker_cy, color_cy)
        pc_secondary = extend_cycle(y2_marker_cy, color_cy)
    else:
        pc_primary = cycle(y1_marker_cy * color_cy)
        pc_secondary = cycle(y2_marker_cy * color_cy)

    # Put grid on plot
    ax1.grid(True, zorder=5)

    # Graph title
    if fig_title is None:
        fig_title = ''
    ax1.set_title(fig_title, **styles['title_style'])

    # Graph legend builder
    def build_label(isotherm, lbl_components, current_branch, key):
        """Build a label for the legend depending on requested parameters."""
        if branch == 'all-nol' and current_branch == 'des':
            return ''
        else:
            if lbl_components is None:
                return isotherm.material + ' ' + convert_chemformula(isotherm)
            text = []
            for selected in lbl_components:
                if selected == 'branch':
                    text.append(current_branch)
                    continue
                elif selected == 'key':
                    text.append(key)
                    continue
                val = getattr(isotherm, selected)
                if val:
                    if selected == 'adsorbate':
                        text.append(convert_chemformula(isotherm))
                    else:
                        text.append(str(val))

            return " ".join(text)

    ###########################################
    #
    # Generic axes graphing function
    #

    def graph_caller(isotherm, current_branch, y1_style, y2_style,
                     **iso_params):
        """Convenience function to call other graphing functions."""

        # Labels and ticks
        ax1.set_xlabel(x_label, **styles['label_style'])
        ax1.set_ylabel(y1_label, **styles['label_style'])
        ax1.tick_params(axis='both', which='major', **styles['tick_style'])

        # Plot line 1
        label = build_label(isotherm, lgd_keys, current_branch, y1_data)

        x_p, y_p = _get_data(isotherm, x_data, current_branch, x_range,
                             **iso_params).align(_get_data(
                                 isotherm, y1_data, current_branch, y1_range,
                                 **iso_params),
                                                 join='inner')

        ax1.plot(x_p, y_p, label=label, **y1_style)

        # Plot line 2 (if applicable)
        if y2_data and y2_data in keys(isotherm):

            x_p, y2_p = _get_data(isotherm, x_data, current_branch, x_range,
                                  **iso_params).align(_get_data(
                                      isotherm, y2_data, current_branch,
                                      y2_range, **iso_params),
                                                      join='inner')

            label = build_label(isotherm, lgd_keys, current_branch, y2_data)
            ax2.set_ylabel(y2_label, **styles['label_style'])
            ax2.tick_params(axis='both', which='major', **styles['tick_style'])
            ax2.plot(x_p, y2_p, label=label, **y2_style)

    #####################################
    #
    # Actual plotting
    #
    # Plot the data
    for isotherm in isotherms:

        # Line styles for the current isotherm
        y1_line_style = next(pc_primary)
        y2_line_style = next(pc_secondary)
        y1_line_style.update(styles['y1_line_style'])
        y2_line_style.update(styles['y2_line_style'])

        # If there's an adsorption branch, plot it
        if ads:
            current_branch = 'ads'
            if isotherm.has_branch(branch=current_branch):

                # Call the plotting function
                graph_caller(isotherm, current_branch, y1_line_style,
                             y2_line_style, **iso_params)

        # Switch to desorption linestyle (dotted, open marker)
        y1_line_style['markerfacecolor'] = 'none'
        y1_line_style['linestyle'] = '--'
        y2_line_style['markerfacecolor'] = 'none'

        # If there's a desorption branch, plot it
        if des:
            current_branch = 'des'
            if isotherm.has_branch(branch=current_branch):

                # Call the plotting function
                graph_caller(isotherm, current_branch, y1_line_style,
                             y2_line_style, **iso_params)

    #####################################
    #
    # Final settings

    _final_styling(fig, ax1, ax2, log_params, range_params, lgd_pos, styles,
                   save_path)

    if ax2:
        return [ax1, ax2]
    return ax1
示例#37
0
    # Other
    'savefig.dpi': 72,
}
color_cycle = ['#348ABD',   # blue
               '#7A68A6',   # purple
               '#A60628',   # red
               '#467821',   # green
               '#CF4457',   # pink
               '#188487',   # turquoise
               '#E24A33']  # orange

if MATPLOTLIB_GE_1_5:
    # This is a dependency of matplotlib, so should be present.
    from cycler import cycler
    astropy_mpl_style_1['axes.prop_cycle'] = cycler('color', color_cycle)

else:
    astropy_mpl_style_1['axes.color_cycle'] = color_cycle


'''
Version 1 astropy plotting style for matplotlib.

This style improves some settings over the matplotlib default.
'''

astropy_mpl_style = astropy_mpl_style_1
'''
Most recent version of the astropy plotting style for matplotlib.
示例#38
0
def _set_colors_for_categorical_obs(adata, value_to_plot, palette):
    """
    Sets the adata.uns[value_to_plot + '_colors'] according to the given palette

    Parameters
    ----------
    adata : annData object
    value_to_plot : name of a valid categorical observation
    palette : Palette should be either a valid `matplotlib.pyplot.colormaps()` string,
              a list of colors (in a format that can be understood by matplotlib,
              eg. RGB, RGBS, hex, or a cycler object with key='color'

    Returns
    -------
    None
    """
    from matplotlib.colors import to_hex
    from cycler import Cycler, cycler

    categories = adata.obs[value_to_plot].cat.categories
    # check is palette is a valid matplotlib colormap
    if isinstance(palette, str) and palette in pl.colormaps():
        # this creates a palette from a colormap. E.g. 'Accent, Dark2, tab20'
        cmap = pl.get_cmap(palette)
        colors_list = [
            to_hex(x) for x in cmap(np.linspace(0, 1, len(categories)))
        ]

    else:
        # check if palette is a list and convert it to a cycler, thus
        # it doesnt matter if the list is shorter than the categories length:
        if isinstance(palette, abc.Sequence):
            if len(palette) < len(categories):
                logg.warn(
                    "Length of palette colors is smaller than the number of "
                    "categories (palette length: {}, categories length: {}. "
                    "Some categories will have the same color.".format(
                        len(palette), len(categories)))
            # check that colors are valid
            _color_list = []
            for color in palette:
                if not is_color_like(color):
                    # check if the color is a valid R color and translate it
                    # to a valid hex color value
                    if color in utils.additional_colors:
                        color = utils.additional_colors[color]
                    else:
                        raise ValueError(
                            "The following color value of the given palette is not valid: {}"
                            .format(color))
                _color_list.append(color)

            palette = cycler(color=_color_list)
        if not isinstance(palette, Cycler):
            raise ValueError(
                "Please check that the value of 'palette' is a "
                "valid matplotlib colormap string (eg. Set2), a "
                "list of color names or a cycler with a 'color' key.")
        if 'color' not in palette.keys:
            raise ValueError("Please set the palette key 'color'.")

        cc = palette()
        colors_list = [
            to_hex(next(cc)['color']) for x in range(len(categories))
        ]

    adata.uns[value_to_plot + '_colors'] = colors_list
示例#39
0
with open(op.join(styledir, 'figure-params.yaml'), 'r') as fp:
    fig_params = yaml.load(fp)
    signifcol = fig_params['signif_color']
    cue = fig_params['cue_color']
    msk = fig_params['msk_color']
    hatchlwd = fig_params['hatch_linewidth']
    prp = fig_params['prp']  # purple
    pch = fig_params['pch']  # peach
    cyn = fig_params['cyn']  # cyan
    blu = fig_params['blu']  # blue

# plot style. cannot change cycler via plt.style.use(...) after axes created,
# so we define colors as cycler objects to assign directly to axes.
style_files = ('font-libertine.yaml', 'garnish.yaml')
plt.style.use([op.join(styledir, sf) for sf in style_files])
darks = cycler(color=[blu, prp])
lights = cycler(color=[cyn, pch])

# make DataFrame
df = pd.DataFrame()
df['attn'] = np.array(params['attns'])[cond_mat[:, 0]]
df['spat'] = np.array(params['spatials'])[cond_mat[:, 1]]
df['iden'] = np.array(params['idents'])[cond_mat[:, 2]]
df['lr'] = df['spat'].map(spatial_mapping)
df['mf'] = df['iden'].map(talker_mapping)

# align subject data with listening-difficulty data
lisdiff = {s: l for s, l in zip(subjects, listening_difficulty)}
lisdiff_bool = np.array([lisdiff[s] for s in subj_ord], dtype=bool)
# make boolean group subset vectors
groups = {'ldiff': lisdiff_bool, 'control': np.logical_not(lisdiff_bool)}
示例#40
0
def test_plotratio():
    # histogram creation and manipulation
    from coffea import hist
    # matplotlib
    import matplotlib as mpl
    mpl.use('Agg')
    import matplotlib.pyplot as plt
    
    lepton_kinematics = fill_lepton_kinematics()

    # Add some pseudodata to a pt histogram so we can make a nice data/mc plot
    pthist = lepton_kinematics.sum('eta')
    bin_values = pthist.axis('pt').centers()
    poisson_means = pthist.sum('flavor').values()[()]
    values = np.repeat(bin_values, np.random.poisson(poisson_means))
    pthist.fill(flavor='pseudodata', pt=values)

    # Set nicer labels, by accessing the string bins' label property
    pthist.axis('flavor').index('electron').label = 'e Flavor'
    pthist.axis('flavor').index('muon').label = r'$\mu$ Flavor'
    pthist.axis('flavor').index('pseudodata').label = r'Pseudodata from e/$\mu$'

    # using regular expressions on flavor name to select just the data
    # another method would be to fill a separate data histogram
    import re
    notdata = re.compile('(?!pseudodata)')
    
    # make a nice ratio plot
    plt.rcParams.update({
                        'font.size': 14,
                        'axes.titlesize': 18,
                        'axes.labelsize': 18,
                        'xtick.labelsize': 12,
                        'ytick.labelsize': 12
                        })
    fig, (ax, rax) = plt.subplots(2, 1, figsize=(7,7), gridspec_kw={"height_ratios": (3, 1)}, sharex=True)
    fig.subplots_adjust(hspace=.07)

    # Here is an example of setting up a color cycler to color the various fill patches
    # http://colorbrewer2.org/#type=qualitative&scheme=Paired&n=6
    from cycler import cycler
    colors = ['#a6cee3','#1f78b4','#b2df8a','#33a02c','#fb9a99','#e31a1c']
    ax.set_prop_cycle(cycler(color=colors))

    fill_opts = {
        'edgecolor': (0,0,0,0.3),
        'alpha': 0.8
    }
    error_opts = {
        'label':'Stat. Unc.',
        'hatch':'///',
        'facecolor':'none',
        'edgecolor':(0,0,0,.5),
        'linewidth': 0
    }
    data_err_opts = {
        'linestyle':'none',
        'marker': '.',
        'markersize': 10.,
        'color':'k',
        'elinewidth': 1,
        'emarker': '_'
    }

    hist.plot1d(pthist[notdata],
                overlay="flavor",
                ax=ax,
                clear=False,
                stack=True,
                line_opts=None,
                fill_opts=fill_opts,
                error_opts=error_opts
                )
    hist.plot1d(pthist['pseudodata'],
                overlay="flavor",
                ax=ax,
                clear=False,
                error_opts=data_err_opts
                )

    ax.autoscale(axis='x', tight=True)
    ax.set_ylim(0, None)
    ax.set_xlabel(None)
    leg = ax.legend()

    hist.plotratio(pthist['pseudodata'].sum("flavor"), pthist[notdata].sum("flavor"),
                   ax=rax,
                   error_opts=data_err_opts,
                   denom_fill_opts={},
                   guide_opts={},
                   unc='num'
                   )
    rax.set_ylabel('Ratio')
    rax.set_ylim(0,2)

    coffee = plt.text(0., 1., u"☕",
                      fontsize=28,
                      horizontalalignment='left',
                      verticalalignment='bottom',
                      transform=ax.transAxes
                      )
    lumi = plt.text(1., 1., r"1 fb$^{-1}$ (?? TeV)",
                    fontsize=16,
                    horizontalalignment='right',
                    verticalalignment='bottom',
                    transform=ax.transAxes
                    )
示例#41
0
            density.set_bandwidth(0.1)
            ys = density(logxs)

            areas[h].set_xy(list(zip(ys, xs)) + [(0, xs[-1]), (0, xs[0])])
            dist_shdp[h].set_data(ys, xs)

        # sns.heatmap(shdp.state[:, 0:1].T, ax=ax2, cbar=False)
        all_states, state_counts = np.unique(shdp.state, return_counts=True)
        # print("all states: ", all_states)
        # print("state counts: ", state_counts)
        # print("PI: ", shdp.PI)
        trans_shdp.set_data(shdp.PI.copy())
        text.set_text("MCMC iteration {0}".format(t))
        return line_shdp + dist_shdp + [trans_shdp, text] + areas

    cycle = cycler('color', colors)
    fig = plt.figure(figsize=(14, 8), facecolor='w')

    ax1 = plt.subplot2grid((15, 20), (0, 0), colspan=13, rowspan=5)
    plt.gca().set_prop_cycle(cycle)

    ax1.set_title("Simulated data")
    # ax1.set_yscale("log")

    ax1.plot(data)
    ax1.set_ylabel("$f(t)$")
    ax1.set_xticklabels([])
    ax1.set_ylim([vmin, vmax])
    # ax1.set_xlim([0, 288])
    ax1.grid()
示例#42
0
"""
A module to customize Matplotlib parameters
"""

from matplotlib import pyplot as plt
from cycler import cycler

plt.figure(figsize=[10, 7])

# plt.style.use(['seaborn-deep'])
custom_cycler = (cycler(color=[
    '#003262', '#C4820F', '#55A868', '#d62728', '#9467bd', '#CCB974',
    '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#1f77b4'
]))
plt.rc('axes', prop_cycle=custom_cycler)

plt.rc('lines',
       linewidth=3,
       dashed_pattern=[4, 3],
       dashdot_pattern=[4, 2, 1, 2],
       dotted_pattern=[1, 2])

plt.rc('font', family='sans-serif', size=15)

plt.rc('mathtext', fontset='cm')

plt.rc(
    'axes',
    linewidth=2,
    titlesize=25,
    labelsize=25,
示例#43
0
def colorfunction(nfiles, function, printed, *options):
    solid = (0, ())
    loosely_dotted = (0, (1, 10))
    dotted = (0, (1, 5))
    densely_dotted = (0, (1, 1))

    loosely_dashed = (0, (5, 10))
    dashed = (0, (5, 5))
    densely_dashed = (0, (5, 1))

    loosely_dashdotted = (0, (3, 10, 1, 10))
    dashdotted = (0, (3, 5, 1, 5))
    densely_dashdotted = (0, (3, 1, 1, 1))

    loosely_dashdotdotted = (0, (3, 10, 1, 10, 1, 10))
    dashdotdotted = (0, (3, 5, 1, 5, 1, 5))
    densely_dashdotdotted = (0, (3, 1, 1, 1, 1, 1))

    color = []
    style = []
    estyle = []
    for argument in options:
        if argument == 'blue':
            option = 'blue'
        if argument == 'bluegreen':
            option = 'bluegreen'
        if argument == 'green':
            option = 'green'
        if argument == 'gold':
            option = 'gold'
        if argument == 'brown':
            option = 'brown'
        if argument == 'rose':
            option = 'rose'
        if argument == 'purple':
            option = 'purple'

    if printed == 'c':
        estyle = [solid] * nfiles
        if function == 'fccolorblind':
            color = fccolorblind(nfiles)
        elif function == 'rbscale':
            color = rbscale(nfiles)
        elif function == 'rainbow':
            color = rainbow(nfiles)
        elif function == 'huescale':
            color = huescale(nfiles, *options)

    if (printed == 'b&w'):
        color = ['black'] * nfiles
    if (printed == 'blue'):
        color = ['#4477AA'] * nfiles
    if (printed == 'red'):
        color = ['#CC6677'] * nfiles
    if (printed == 'yellow'):
        color = ['#DDCC77'] * nfiles
    if (printed == 'green'):
        color = ['#117733'] * nfiles
    if (printed == 'b&w' or printed == 'blue' or printed == 'red'
            or printed == 'yellow' or printed == 'green'):
        style = [
            solid, dashed, dotted, dashdotted, dashdotdotted, densely_dashed,
            densely_dotted, densely_dashdotted, densely_dashdotdotted,
            loosely_dashed, loosely_dotted, loosely_dashdotted,
            loosely_dashdotdotted
        ]
        if (nfiles > 13):
            print('CAREFUL : MAXIMUM 13 DIFFERENT LINESTYLES')
            style = style * int(np.ceil(nfiles / 13))
        estyle = style[0:nfiles]
        default_cycler = (cycler(color=color) + cycler(linestyle=estyle))
        plt.rc('axes', prop_cycle=default_cycler)

    return (color, estyle)
示例#44
0
# Initialize the figure
fig, axs = plt.subplots(9, 9)
plt.style.use('seaborn-darkgrid')

# create an aid dict
AID_list = [
    'AID_1345083', 'AID_624255', 'AID_449739', 'AID_995', 'AID_938', 'AID_628',
    'AID_596', 'AID_893', 'AID_894'
]

# create a color palette
from cycler import cycler
from matplotlib.colors import ListedColormap
cmap = ListedColormap(sns.color_palette())
colors = cmap.colors
custom_cycler = (cycler(color=colors[:5]) + cycler(lw=[1, 1, 1, 1, 1]))
#set color pallettes for all of these
for ax in axs.flatten().tolist():
    ax.set_prop_cycle(custom_cycler)
# multiple line plot
num = 0
for _, row in df.iterrows():
    ax_row = AID_list.index(row['AID'])
    ax_col = int(row['Iteration Number'])
    #get prec rec from df row
    rec = row['rec_array'].tolist()
    prec = row['prec_array'].tolist()
    recs, precs = zip(*sorted(zip(rec, prec)))
    #sort the recall so that we can view it on y axis
    axs[ax_row, ax_col].plot(precs, recs, label=row['Classifier'])
    # Find the right spot on the plot
示例#45
0
def set_lines_marker_style(style='both', omit_markers=False):
    """
    Set the rcParams to either 'both', 'bw' or 'color'.

    The function will change matplotlib's rcParams key
    "axes.prop_cycle" accordingly.
    """
    if style.lower() in ['color', 'c', 'col']:
        if omit_markers:
            cl = cycler(color=cols)
        else:
            cl = cycler(color=cols, marker=markers)

    elif style.lower() in ['bw', 'blackwhite', 'blacknwhite', 'black']:
        if omit_markers:
            cl = (cycler('color', ['k']) * cycler(dashes=dashs))
        else:
            cl = (cycler('color', ['k']) * cycler(dashes=dashs) +
                  cycler(marker=markers))
    elif style.lower() in ['both', 'egal', 'dunno']:
        if omit_markers:
            cl = (cycler(color=cols) + cycler(dashes=dashs))
        else:
            cl = (cycler(color=cols) + cycler(dashes=dashs) +
                  cycler(marker=markers))
    else:
        print('Unknown style {}. Properties not set.'.format(style))

    mpl.rcParams['axes.prop_cycle'] = cl
示例#46
0
# run ParameterSweep directly upon initialization
AUTORUN_SWEEP = True

# enable/disable the CENTRAL_DISPATCH system
DISPATCH_ENABLED = True

# For parallel processing ----------------------------------------------------------------------------------------------
# store processing pool once generated
POOL = None
# number of cores to be used by default in methods that enable parallel processing
NUM_CPUS = 1

# Select multiprocessing library
# Options:  'multiprocessing'
#           'pathos'
MULTIPROC = 'multiprocessing'

# Matplotlib options ---------------------------------------------------------------------------------------------------
# set custom matplotlib color cycle
mpl.rcParams['axes.prop_cycle'] = cycler(color=[
    "#016E82", "#333795", "#2E5EAC", "#4498D3", "#CD85B9", "#45C3D1",
    "#AA1D3F", "#F47752", "#19B35A", "#EDE83B", "#ABD379", "#F9E6BE"
])

# set matplotlib defaults
mpl.rcParams['font.family'] = "sans-serif"
mpl.rcParams['font.sans-serif'] = "Arial"
mpl.rcParams['figure.dpi'] = 150
mpl.rcParams['font.size'] = 11
示例#47
0
def generate_validator_testcases(valid):
    validation_tests = (
        {
            'validator':
            validate_bool,
            'success':
            (*((_, True)
               for _ in ('t', 'y', 'yes', 'on', 'true', '1', 1, True)),
             *((_, False)
               for _ in ('f', 'n', 'no', 'off', 'false', '0', 0, False))),
            'fail': ((_, ValueError) for _ in (
                'aardvark',
                2,
                -1,
                [],
            ))
        },
        {
            'validator':
            validate_stringlist,
            'success': (
                ('', []),
                ('a,b', ['a', 'b']),
                ('aardvark', ['aardvark']),
                ('aardvark, ', ['aardvark']),
                ('aardvark, ,', ['aardvark']),
                (['a', 'b'], ['a', 'b']),
                (('a', 'b'), ['a', 'b']),
                (iter(['a', 'b']), ['a', 'b']),
                (np.array(['a', 'b']), ['a', 'b']),
                ((1, 2), ['1', '2']),
                (np.array([1, 2]), ['1', '2']),
            ),
            'fail': (
                (set(), ValueError),
                (1, ValueError),
            )
        },
        {
            'validator':
            _listify_validator(validate_int, n=2),
            'success':
            ((_, [1, 2])
             for _ in ('1, 2', [1.5, 2.5], [1, 2], (1, 2), np.array((1, 2)))),
            'fail':
            ((_, ValueError) for _ in ('aardvark', ('a', 1), (1, 2, 3)))
        },
        {
            'validator':
            _listify_validator(validate_float, n=2),
            'success':
            ((_, [1.5, 2.5])
             for _ in ('1.5, 2.5', [1.5, 2.5], [1.5, 2.5], (1.5, 2.5),
                       np.array((1.5, 2.5)))),
            'fail':
            ((_, ValueError) for _ in ('aardvark', ('a', 1), (1, 2, 3)))
        },
        {
            'validator':
            validate_cycler,
            'success': (
                ('cycler("color", "rgb")', cycler("color", 'rgb')),
                (cycler('linestyle',
                        ['-', '--']), cycler('linestyle', ['-', '--'])),
                ("""(cycler("color", ["r", "g", "b"]) +
                          cycler("mew", [2, 3, 5]))""",
                 (cycler("color", 'rgb') +
                  cycler("markeredgewidth", [2, 3, 5]))),
                ("cycler(c='rgb', lw=[1, 2, 3])",
                 cycler('color', 'rgb') + cycler('linewidth', [1, 2, 3])),
                ("cycler('c', 'rgb') * cycler('linestyle', ['-', '--'])",
                 (cycler('color', 'rgb') * cycler('linestyle', ['-', '--']))),
                (cycler('ls', ['-', '--']), cycler('linestyle', ['-', '--'])),
                (cycler(mew=[2, 5]), cycler('markeredgewidth', [2, 5])),
            ),
            # This is *so* incredibly important: validate_cycler() eval's
            # an arbitrary string! I think I have it locked down enough,
            # and that is what this is testing.
            # TODO: Note that these tests are actually insufficient, as it may
            # be that they raised errors, but still did an action prior to
            # raising the exception. We should devise some additional tests
            # for that...
            'fail': (
                (4, ValueError),  # Gotta be a string or Cycler object
                ('cycler("bleh, [])', ValueError),  # syntax error
                ('Cycler("linewidth", [1, 2, 3])',
                 ValueError),  # only 'cycler()' function is allowed
                ('1 + 2', ValueError),  # doesn't produce a Cycler object
                ('os.system("echo Gotcha")', ValueError),  # os not available
                ('import os', ValueError),  # should not be able to import
                ('def badjuju(a): return a; badjuju(cycler("color", "rgb"))',
                 ValueError),  # Should not be able to define anything
                # even if it does return a cycler
                ('cycler("waka", [1, 2, 3])', ValueError),  # not a property
                ('cycler(c=[1, 2, 3])', ValueError),  # invalid values
                ("cycler(lw=['a', 'b', 'c'])", ValueError),  # invalid values
                (cycler('waka', [1, 3, 5]), ValueError),  # not a property
                (cycler('color', ['C1', 'r', 'g']), ValueError)  # no CN
            )
        },
        {
            'validator':
            validate_hatch,
            'success':
            (('--|', '--|'), ('\\oO', '\\oO'), ('/+*/.x', '/+*/.x'), ('', '')),
            'fail': (('--_', ValueError), (8, ValueError), ('X', ValueError)),
        },
        {
            'validator':
            validate_colorlist,
            'success': (
                ('r,g,b', ['r', 'g', 'b']),
                (['r', 'g', 'b'], ['r', 'g', 'b']),
                ('r, ,', ['r']),
                (['', 'g', 'blue'], ['g', 'blue']),
                ([np.array([1, 0, 0]),
                  np.array([0, 1, 0])], np.array([[1, 0, 0], [0, 1, 0]])),
                (np.array([[1, 0, 0], [0, 1,
                                       0]]), np.array([[1, 0, 0], [0, 1, 0]])),
            ),
            'fail': (('fish', ValueError), ),
        },
        {
            'validator':
            validate_color,
            'success': (
                ('None', 'none'),
                ('none', 'none'),
                ('AABBCC', '#AABBCC'),  # RGB hex code
                ('AABBCC00', '#AABBCC00'),  # RGBA hex code
                ('tab:blue', 'tab:blue'),  # named color
                ('C12', 'C12'),  # color from cycle
                ('(0, 1, 0)', (0.0, 1.0, 0.0)),  # RGB tuple
                ((0, 1, 0), (0, 1, 0)),  # non-string version
                ('(0, 1, 0, 1)', (0.0, 1.0, 0.0, 1.0)),  # RGBA tuple
                ((0, 1, 0, 1), (0, 1, 0, 1)),  # non-string version
            ),
            'fail': (
                ('tab:veryblue', ValueError),  # invalid name
                ('(0, 1)', ValueError),  # tuple with length < 3
                ('(0, 1, 0, 1, 0)', ValueError),  # tuple with length > 4
                ('(0, 1, none)', ValueError),  # cannot cast none to float
                ('(0, 1, "0.5")', ValueError),  # last one not a float
            ),
        },
        {
            'validator':
            validate_hist_bins,
            'success':
            (('auto', 'auto'), ('fd', 'fd'), ('10', 10), ('1, 2, 3', [1, 2,
                                                                      3]),
             ([1, 2, 3], [1, 2, 3]), (np.arange(15), np.arange(15))),
            'fail': (('aardvark', ValueError), )
        },
        {
            'validator':
            validate_markevery,
            'success':
            ((None, None), (1, 1), (0.1, 0.1), ((1, 1), (1, 1)),
             ((0.1, 0.1), (0.1, 0.1)), ([1, 2, 3], [1, 2, 3]),
             (slice(2), slice(None, 2, None)), (slice(1, 2, 3), slice(1, 2,
                                                                      3))),
            'fail':
            (((1, 2, 3), TypeError), ([1, 2,
                                       0.3], TypeError), (['a', 2,
                                                           3], TypeError),
             ([1, 2, 'a'], TypeError), ((0.1, 0.2, 0.3), TypeError),
             ((0.1, 2, 3), TypeError), ((1, 0.2, 0.3), TypeError),
             ((1, 0.1), TypeError), ((0.1, 1), TypeError),
             (('abc'), TypeError), ((1, 'a'), TypeError),
             ((0.1, 'b'), TypeError), (('a', 1), TypeError), (('a', 0.1),
                                                              TypeError),
             ('abc', TypeError), ('a', TypeError), (object(), TypeError))
        },
        {
            'validator':
            _validate_linestyle,
            'success': (
                ('-', '-'),
                ('solid', 'solid'),
                ('--', '--'),
                ('dashed', 'dashed'),
                ('-.', '-.'),
                ('dashdot', 'dashdot'),
                (':', ':'),
                ('dotted', 'dotted'),
                ('', ''),
                (' ', ' '),
                ('None', 'none'),
                ('none', 'none'),
                ('DoTtEd', 'dotted'),  # case-insensitive
                ('1, 3', (0, (1, 3))),
                ([1.23, 456], (0, [1.23, 456.0])),
                ([1, 2, 3, 4], (0, [1.0, 2.0, 3.0, 4.0])),
                ((0, [1, 2]), (0, [1, 2])),
                ((-1, [1, 2]), (-1, [1, 2])),
            ),
            'fail': (
                ('aardvark', ValueError),  # not a valid string
                (b'dotted', ValueError),
                ('dotted'.encode('utf-16'), ValueError),
                ([1, 2, 3], ValueError),  # sequence with odd length
                (1.23, ValueError),  # not a sequence
                (("a", [1, 2]), ValueError),  # wrong explicit offset
                ((1, [1, 2, 3]), ValueError),  # odd length sequence
                (([1, 2], 1), ValueError),  # inverted offset/onoff
            )
        },
    )

    for validator_dict in validation_tests:
        validator = validator_dict['validator']
        if valid:
            for arg, target in validator_dict['success']:
                yield validator, arg, target
        else:
            for arg, error_type in validator_dict['fail']:
                yield validator, arg, error_type
示例#48
0
文件: _utils.py 项目: waldoPHD/scanpy
def default_palette(palette=None):
    if palette is None: return rcParams['axes.prop_cycle']
    elif not isinstance(palette, Cycler): return cycler(color=palette)
    else: return palette
示例#49
0
        print(sty)
        sty.update(plot_kwargs)
        print(sty)
        ret = plot_func(ax, edges, top, bottoms=bottoms, label=label, **sty)
        bottoms = top
        arts[label] = ret
    ax.legend(fontsize=10)
    return arts


# set up histogram function to fixed bins
edges = np.linspace(-3, 3, 20, endpoint=True)
hist_func = partial(np.histogram, bins=edges)

# set up style cycles
color_cycle = cycler(facecolor=plt.rcParams['axes.prop_cycle'][:4])
label_cycle = cycler(label=['set {n}'.format(n=n) for n in range(4)])
hatch_cycle = cycler(hatch=['/', '*', '+', '|'])

# Fixing random state for reproducibility
np.random.seed(19680801)

stack_data = np.random.randn(4, 12250)
dict_data = OrderedDict(zip((c['label'] for c in label_cycle), stack_data))

###############################################################################
# Work with plain arrays

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4.5), tight_layout=True)
arts = stack_hist(ax1,
                  stack_data,
示例#50
0
def paper_plot(fontsize=9, font='paper'):
    """ Initialize the settings of the plot, including font, fontsize, etc..
    Also refer to the changes in
    https://matplotlib.org/users/dflt_style_changes.html

    fontsize: fontsize for legends and labels.
    font: font for legends and labels, 'paper' uses Times New Roman, 'default'
    uses default, a tuple of (family, font, ...) customizes font.
    """
    # Set locale for unicode signs (e.g., minus sign).
    try:
        locale.setlocale(locale.LC_ALL, 'C.UTF-8')
    except locale.Error:
        try:
            locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
        except locale.Error:
            pass

    # Clear font cache (in case of switching host machines).
    # On Mac OS, cache directory and config directory is the same, avoid remove
    # the rc file.
    for e in os.listdir(matplotlib.get_cachedir()):
        if str(e) != 'matplotlibrc':
            fe = os.path.join(matplotlib.get_cachedir(), e)
            try:
                os.remove(fe)
            except OSError:
                shutil.rmtree(fe, ignore_errors=True)

    if font == 'paper':
        matplotlib.rcParams['font.family'] = 'serif'
        matplotlib.rcParams['font.serif'] = ['Times New Roman']
        matplotlib.rcParams[
            'mathtext.fontset'] = 'stix'  # to blend well with Times
        matplotlib.rcParams['mathtext.rm'] = 'serif'
    elif font == 'default':
        pass
    else:
        if not isinstance(font, (tuple, list)) or len(font) < 2:
            raise ValueError('[format] font must be a tuple of (family, font)')
        matplotlib.rcParams['font.family'] = font[0]
        matplotlib.rcParams['font.{}'.format(font[0])] = list(font[1:])
        matplotlib.rcParams['mathtext.rm'] = font[0]

    matplotlib.rcParams['font.size'] = fontsize

    # Use TrueType fonts.
    matplotlib.rcParams['ps.fonttype'] = 42
    matplotlib.rcParams['pdf.fonttype'] = 42

    matplotlib.rcParams['legend.loc'] = 'upper right'
    matplotlib.rcParams['legend.fontsize'] = fontsize
    matplotlib.rcParams['legend.fancybox'] = False
    matplotlib.rcParams['legend.shadow'] = False
    matplotlib.rcParams['legend.numpoints'] = 2
    matplotlib.rcParams['legend.scatterpoints'] = 3
    matplotlib.rcParams['legend.borderpad'] = 0.4
    try:
        matplotlib.rcParams['legend.facecolor'] = 'inherit'
        matplotlib.rcParams['legend.edgecolor'] = 'inherit'
    except KeyError:
        assert __mpl_version__ < (2, 0)  # Changed from 2.0
    try:
        matplotlib.rcParams['legend.framealpha'] = 1.0
    except KeyError:
        assert __mpl_version__ < (1, 5)  # Changed from 1.5
    matplotlib.rcParams['axes.linewidth'] = 1.0
    matplotlib.rcParams['axes.facecolor'] = 'w'
    matplotlib.rcParams['axes.edgecolor'] = 'k'
    matplotlib.rcParams['axes.labelsize'] = fontsize
    matplotlib.rcParams['axes.axisbelow'] = True
    try:
        matplotlib.rcParams['axes.prop_cycle'] = cycler('color', COLOR_SET)
    except KeyError:
        assert __mpl_version__ < (1, 5)  # Changed from 1.5
        matplotlib.rcParams['axes.color_cycle'] = COLOR_SET
    matplotlib.rcParams['xtick.labelsize'] = fontsize
    matplotlib.rcParams['ytick.labelsize'] = fontsize
    matplotlib.rcParams['grid.linestyle'] = ':'
    matplotlib.rcParams['grid.linewidth'] = 0.5
    matplotlib.rcParams['grid.alpha'] = 1.0
    matplotlib.rcParams['grid.color'] = 'k'
    matplotlib.rcParams['lines.linewidth'] = 1.0
    try:
        matplotlib.rcParams['lines.color'] = 'C0'
    except ValueError:
        assert __mpl_version__ < (2, 0)  # Changed from 2.0
    matplotlib.rcParams['lines.markeredgewidth'] = 0.5
    matplotlib.rcParams['lines.markersize'] = 4
    try:
        matplotlib.rcParams['lines.dashed_pattern'] = [4, 4]
        matplotlib.rcParams['lines.dashdot_pattern'] = [4, 2, 1, 2]
        matplotlib.rcParams['lines.dotted_pattern'] = [1, 3]
    except KeyError:
        assert __mpl_version__ < (2, 0)  # Changed from 2.0
    matplotlib.rcParams['patch.linewidth'] = 0.5
    try:
        matplotlib.rcParams['patch.facecolor'] = 'C0'
        matplotlib.rcParams['patch.force_edgecolor'] = True
    except ValueError:
        assert __mpl_version__ < (2, 0)  # Changed from 2.0
    matplotlib.rcParams['patch.edgecolor'] = 'k'
    try:
        matplotlib.rcParams['hatch.linewidth'] = 0.5
        matplotlib.rcParams['hatch.color'] = 'k'
    except KeyError:
        assert __mpl_version__ < (2, 0)  # Changed from 2.0
    try:
        matplotlib.rcParams['errorbar.capsize'] = 3
    except KeyError:
        assert __mpl_version__ < (1, 5)  # Changed from 1.5
    matplotlib.rcParams['xtick.direction'] = 'out'
    matplotlib.rcParams['ytick.direction'] = 'out'
    matplotlib.rcParams['xtick.major.width'] = 0.8
    matplotlib.rcParams['xtick.minor.width'] = 0.6
    matplotlib.rcParams['ytick.major.width'] = 0.8
    matplotlib.rcParams['ytick.minor.width'] = 0.6
    try:
        matplotlib.rcParams['xtick.top'] = False
        matplotlib.rcParams['ytick.right'] = False
    except KeyError:
        assert __mpl_version__ < (2, 0)  # Changed from 2.0
示例#51
0
    #    prof2[i] = csRecon[profileX1, yVal]
    #    prof3[i] = mlemRecon[profileX1, yVal]
    #    prof4[i] = sirtRecon[profileX1, yVal]
    prof1[i] = image[yVal, profileX1]
    prof2[i] = csRecon[yVal, profileX1]
    prof3[i] = mlemRecon[yVal, profileX1]
    prof4[i] = sirtRecon[yVal, profileX1]
    prof5[i] = radialRecon[yVal, profileX1]

fig, ax = plt.subplots(figsize=(8, 5))

plt.rc('lines', linewidth=4)
plt.rc(
    'axes',
    prop_cycle=(
        cycler('color', ['r', 'g', 'b', 'y', 'm']) +
        #                           cycler('linestyle', ['-', '--', ':', '-.', '-'])))
        #                           cycler('linestyle', ['solid', 'densely dashed', 'densely dashdotted', 'densely dashdotdotted', 'dotted'])))
        cycler('linestyle', [(0, ()), (0, (3, 1)), (0, (3, 1, 1, 1)),
                             (0, (3, 1, 1, 1, 1, 1)), (0, (1, 2))])))

plt.rc('xtick', labelsize=fontsize)
plt.rc('ytick', labelsize=fontsize)

xVals1 = profileY1
#xVals3 = np.arange(0, len(sirtSSIMS))*plotIncrement

keysProf = []
profImg, = plt.plot(xVals1, prof1, label='Original', linewidth=3)
profCS, = plt.plot(xVals1, prof2, label='CS', linewidth=3)
profMLEM, = plt.plot(xVals1, prof3, label='Finite MLEM', linewidth=3)
示例#52
0
class SaveLogic(GenericLogic):
    """
    A general class which saves all kinds of data in a general sense.
    """

    _modclass = 'savelogic'
    _modtype = 'logic'

    _win_data_dir = ConfigOption('win_data_directory', 'C:/Data/')
    _unix_data_dir = ConfigOption('unix_data_directory', 'Data')
    log_into_daily_directory = ConfigOption('log_into_daily_directory',
                                            False,
                                            missing='warn')

    # Matplotlib style definition for saving plots
    mpl_qd_style = {
        'axes.prop_cycle':
        cycler('color', [
            '#1f17f4', '#ffa40e', '#ff3487', '#008b00', '#17becf', '#850085'
        ]) + cycler('marker', ['o', 's', '^', 'v', 'D', 'd']),
        'axes.edgecolor':
        '0.3',
        'xtick.color':
        '0.3',
        'ytick.color':
        '0.3',
        'axes.labelcolor':
        'black',
        'font.size':
        '14',
        'lines.linewidth':
        '2',
        'figure.figsize':
        '12, 6',
        'lines.markeredgewidth':
        '0',
        'lines.markersize':
        '5',
        'axes.spines.right':
        True,
        'axes.spines.top':
        True,
        'xtick.minor.visible':
        True,
        'ytick.minor.visible':
        True,
        'savefig.dpi':
        '180'
    }

    _additional_parameters = {}

    def __init__(self, config, **kwargs):
        super().__init__(config=config, **kwargs)

        # locking for thread safety
        self.lock = Mutex()

        # name of active POI, default to empty string
        self.active_poi_name = ''

        # Some default variables concerning the operating system:
        self.os_system = None

        # Chech which operation system is used and include a case if the
        # directory was not found in the config:
        if sys.platform in ('linux', 'darwin'):
            self.os_system = 'unix'
            self.data_dir = self._unix_data_dir
        elif 'win32' in sys.platform or 'AMD64' in sys.platform:
            self.os_system = 'win'
            self.data_dir = self._win_data_dir
        else:
            raise Exception('Identify the operating system.')

        # Expand environment variables in the data_dir path (e.g. $HOME)
        self.data_dir = os.path.expandvars(self.data_dir)

        # start logging into daily directory?
        if not isinstance(self.log_into_daily_directory, bool):
            self.log.warning(
                'log entry in configuration is not a '
                'boolean. Falling back to default setting: False.')
            self.log_into_daily_directory = False

        self._daily_loghandler = None

    def on_activate(self):
        """ Definition, configuration and initialisation of the SaveLogic.
        """
        if self.log_into_daily_directory:
            # adds a log handler for logging into daily directory
            self._daily_loghandler = DailyLogHandler(
                '%Y%m%d-%Hh%Mm%Ss-qudi.log', self)
            self._daily_loghandler.setFormatter(
                logging.Formatter(
                    '%(asctime)s %(name)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S'))
            self._daily_loghandler.setLevel(logging.DEBUG)
            logging.getLogger().addHandler(self._daily_loghandler)
        else:
            self._daily_loghandler = None

    def on_deactivate(self):
        if self._daily_loghandler is not None:
            # removes the log handler logging into the daily directory
            logging.getLogger().removeHandler(self._daily_loghandler)

    @property
    def dailylog(self):
        """
        Returns the daily log handler.
        """
        return self._daily_loghandler

    def dailylog_set_level(self, level):
        """
        Sets the log level of the daily log handler

        @param level int: log level, see logging
        """
        self._daily_loghandler.setLevel(level)

    def save_data(self,
                  data,
                  filepath=None,
                  parameters=None,
                  filename=None,
                  filelabel=None,
                  timestamp=None,
                  filetype='text',
                  fmt='%.15e',
                  delimiter='\t',
                  plotfig=None):
        """
        General save routine for data.

        @param dictionary data: Dictionary containing the data to be saved. The keys should be
                                strings containing the data header/description. The corresponding
                                items are one or more 1D arrays or one 2D array containing the data
                                (list or numpy.ndarray). Example:

                                    data = {'Frequency (MHz)': [1,2,4,5,6]}
                                    data = {'Frequency': [1, 2, 4], 'Counts': [234, 894, 743, 423]}
                                    data = {'Frequency (MHz),Counts':[[1,234], [2,894],...[30,504]]}

        @param string filepath: optional, the path to the directory, where the data will be saved.
                                If the specified path does not exist yet, the saving routine will
                                try to create it.
                                If no path is passed (default filepath=None) the saving routine will
                                create a directory by the name of the calling module inside the
                                daily data directory.
                                If no calling module can be inferred and/or the requested path can
                                not be created the data will be saved in a subfolder of the daily
                                data directory called UNSPECIFIED
        @param dictionary parameters: optional, a dictionary with all parameters you want to save in
                                      the header of the created file.
        @parem string filename: optional, if you really want to fix your own filename. If passed,
                                the whole file will have the name

                                    <filename>

                                If nothing is specified the save logic will generate a filename
                                either based on the module name from which this method was called,
                                or it will use the passed filelabel if that is speficied.
                                You also need to specify the ending of the filename!
        @parem string filelabel: optional, if filelabel is set and no filename was specified, the
                                 savelogic will create a name which looks like

                                     YYYY-MM-DD_HHh-MMm-SSs_<filelabel>.dat

                                 The timestamp will be created at runtime if no user defined
                                 timestamp was passed.
        @param datetime timestamp: optional, a datetime.datetime object. You can create this object
                                   with datetime.datetime.now() in the calling module if you want to
                                   fix the timestamp for the filename. Be careful when passing a
                                   filename and a timestamp, because then the timestamp will be
                                   ignored.
        @param string filetype: optional, the file format the data should be saved in. Valid inputs
                                are 'text', 'xml' and 'npz'. Default is 'text'.
        @param string or list of strings fmt: optional, format specifier for saved data. See python
                                              documentation for
                                              "Format Specification Mini-Language". If you want for
                                              example save a float in scientific notation with 6
                                              decimals this would look like '%.6e'. For saving
                                              integers you could use '%d', '%s' for strings.
                                              The default is '%.15e' for numbers and '%s' for str.
                                              If len(data) > 1 you should pass a list of format
                                              specifiers; one for each item in the data dict. If
                                              only one specifier is passed but the data arrays have
                                              different data types this can lead to strange
                                              behaviour or failure to save right away.
        @param string delimiter: optional, insert here the delimiter, like '\n' for new line, '\t'
                                 for tab, ',' for a comma ect.

        1D data
        =======
        1D data should be passed in a dictionary where the data trace should be assigned to one
        identifier like

            {'<identifier>':[list of values]}
            {'Numbers of counts':[1.4, 4.2, 5, 2.0, 5.9 , ... , 9.5, 6.4]}

        You can also pass as much 1D arrays as you want:

            {'Frequency (MHz)':list1, 'signal':list2, 'correlations': list3, ...}

        2D data
        =======
        2D data should be passed in a dictionary where the matrix like data should be assigned to
        one identifier like

            {'<identifier>':[[1,2,3],[4,5,6],[7,8,9]]}

        which will result in:
            <identifier>
            1   2   3
            4   5   6
            7   8   9


        YOU ARE RESPONSIBLE FOR THE IDENTIFIER! DO NOT FORGET THE UNITS FOR THE SAVED TIME
        TRACE/MATRIX.
        """
        start_time = time.time()
        # Create timestamp if none is present
        if timestamp is None:
            timestamp = datetime.datetime.now()

        # Try to cast data array into numpy.ndarray if it is not already one
        # Also collect information on arrays in the process and do sanity checks
        found_1d = False
        found_2d = False
        multiple_dtypes = False
        arr_length = []
        arr_dtype = []
        max_row_num = 0
        max_line_num = 0
        for keyname in data:
            # Cast into numpy array
            if not isinstance(data[keyname], np.ndarray):
                try:
                    data[keyname] = np.array(data[keyname])
                except:
                    self.log.error(
                        'Casting data array of type "{0}" into numpy.ndarray failed. '
                        'Could not save data.'.format(type(data[keyname])))
                    return -1

            # determine dimensions
            if data[keyname].ndim < 3:
                length = data[keyname].shape[0]
                arr_length.append(length)
                if length > max_line_num:
                    max_line_num = length
                if data[keyname].ndim == 2:
                    found_2d = True
                    width = data[keyname].shape[1]
                    if max_row_num < width:
                        max_row_num = width
                else:
                    found_1d = True
                    max_row_num += 1
            else:
                self.log.error(
                    'Found data array with dimension >2. Unable to save data.')
                return -1

            # determine array data types
            if len(arr_dtype) > 0:
                if arr_dtype[-1] != data[keyname].dtype:
                    multiple_dtypes = True
            arr_dtype.append(data[keyname].dtype)

        # Raise error if data contains a mixture of 1D and 2D arrays
        if found_2d and found_1d:
            self.log.error(
                'Passed data dictionary contains 1D AND 2D arrays. This is not allowed. '
                'Either fit all data arrays into a single 2D array or pass multiple 1D '
                'arrays only. Saving data failed!')
            return -1

        # try to trace back the functioncall to the class which was calling it.
        try:
            frm = inspect.stack()[1]
            # this will get the object, which called the save_data function.
            mod = inspect.getmodule(frm[0])
            # that will extract the name of the class.
            module_name = mod.__name__.split('.')[-1]
        except:
            # Sometimes it is not possible to get the object which called the save_data function
            # (such as when calling this from the console).
            module_name = 'UNSPECIFIED'

        # determine proper file path
        if filepath is None:
            filepath = self.get_path_for_module(module_name)
        elif not os.path.exists(filepath):
            os.makedirs(filepath)
            self.log.info(
                'Custom filepath does not exist. Created directory "{0}"'
                ''.format(filepath))

        # create filelabel if none has been passed
        if filelabel is None:
            filelabel = module_name
        if self.active_poi_name != '':
            filelabel = self.active_poi_name.replace(' ',
                                                     '_') + '_' + filelabel

        # determine proper unique filename to save if none has been passed
        if filename is None:
            filename = timestamp.strftime('%Y%m%d-%H%M-%S' + '_' + filelabel +
                                          '.dat')

        # Check format specifier.
        if not isinstance(fmt, str) and len(fmt) != len(data):
            self.log.error(
                'Length of list of format specifiers and number of data items differs. '
                'Saving not possible. Please pass exactly as many format specifiers as '
                'data arrays.')
            return -1

        # Create header string for the file
        header = 'Saved Data from the class {0} on {1}.\n' \
                 ''.format(module_name, timestamp.strftime('%d.%m.%Y at %Hh%Mm%Ss'))
        header += '\nParameters:\n===========\n\n'
        # Include the active POI name (if not empty) as a parameter in the header
        if self.active_poi_name != '':
            header += 'Measured at POI: {0}\n'.format(self.active_poi_name)
        # add the parameters if specified:
        if parameters is not None:
            # check whether the format for the parameters have a dict type:
            if isinstance(parameters, dict):
                if isinstance(self._additional_parameters, dict):
                    parameters = {**self._additional_parameters, **parameters}
                for entry, param in parameters.items():
                    if isinstance(param, float):
                        header += '{0}: {1:.16e}\n'.format(entry, param)
                    else:
                        header += '{0}: {1}\n'.format(entry, param)
            # make a hardcore string conversion and try to save the parameters directly:
            else:
                self.log.error(
                    'The parameters are not passed as a dictionary! The SaveLogic will '
                    'try to save the parameters nevertheless.')
                header += 'not specified parameters: {0}\n'.format(parameters)
        header += '\nData:\n=====\n'

        # write data to file
        # FIXME: Implement other file formats
        # write to textfile
        if filetype == 'text':
            # Reshape data if multiple 1D arrays have been passed to this method.
            # If a 2D array has been passed, reformat the specifier
            if len(data) != 1:
                identifier_str = ''
                if multiple_dtypes:
                    field_dtypes = list(
                        zip([
                            'f{0:d}'.format(i) for i in range(len(arr_dtype))
                        ], arr_dtype))
                    new_array = np.empty(max_line_num, dtype=field_dtypes)
                    for i, keyname in enumerate(data):
                        identifier_str += keyname + delimiter
                        field = 'f{0:d}'.format(i)
                        length = data[keyname].size
                        new_array[field][:length] = data[keyname]
                        if length < max_line_num:
                            if isinstance(data[keyname][0], str):
                                new_array[field][length:] = 'nan'
                            else:
                                new_array[field][length:] = np.nan
                else:
                    new_array = np.empty([max_line_num, max_row_num],
                                         arr_dtype[0])
                    for i, keyname in enumerate(data):
                        identifier_str += keyname + delimiter
                        length = data[keyname].size
                        new_array[:length, i] = data[keyname]
                        if length < max_line_num:
                            if isinstance(data[keyname][0], str):
                                new_array[length:, i] = 'nan'
                            else:
                                new_array[length:, i] = np.nan
                # discard old data array and use new one
                data = {identifier_str: new_array}
            elif found_2d:
                keyname = list(data.keys())[0]
                identifier_str = keyname.replace(', ', delimiter).replace(
                    ',', delimiter)
                data[identifier_str] = data.pop(keyname)
            else:
                identifier_str = list(data)[0]
            header += list(data)[0]
            self.save_array_as_text(data=data[identifier_str],
                                    filename=filename,
                                    filepath=filepath,
                                    fmt=fmt,
                                    header=header,
                                    delimiter=delimiter,
                                    comments='#',
                                    append=False)
        # write npz file and save parameters in textfile
        elif filetype == 'npz':
            header += str(list(data.keys()))[1:-1]
            np.savez_compressed(filepath + '/' + filename[:-4], **data)
            self.save_array_as_text(data=[],
                                    filename=filename[:-4] + '_params.dat',
                                    filepath=filepath,
                                    fmt=fmt,
                                    header=header,
                                    delimiter=delimiter,
                                    comments='#',
                                    append=False)
        else:
            self.log.error(
                'Only saving of data as textfile and npz-file is implemented. Filetype "{0}" is not '
                'supported yet. Saving as textfile.'.format(filetype))
            self.save_array_as_text(data=data[identifier_str],
                                    filename=filename,
                                    filepath=filepath,
                                    fmt=fmt,
                                    header=header,
                                    delimiter=delimiter,
                                    comments='#',
                                    append=False)

        #--------------------------------------------------------------------------------------------
        # Save thumbnail figure of plot
        if plotfig is not None:
            # create Metadata
            metadata = dict()
            metadata['Title'] = 'Image produced by qudi: ' + module_name
            metadata['Author'] = 'qudi - Software Suite'
            metadata[
                'Subject'] = 'Find more information on: https://github.com/Ulm-IQO/qudi'
            metadata[
                'Keywords'] = 'Python 3, Qt, experiment control, automation, measurement, software, framework, modular'
            metadata['Producer'] = 'qudi - Software Suite'
            if timestamp is not None:
                metadata['CreationDate'] = timestamp
                metadata['ModDate'] = timestamp
            else:
                metadata['CreationDate'] = time
                metadata['ModDate'] = time

            # determine the PDF-Filename
            fig_fname_vector = os.path.join(filepath,
                                            filename)[:-4] + '_fig.pdf'

            # Create the PdfPages object to which we will save the pages:
            # The with statement makes sure that the PdfPages object is closed properly at
            # the end of the block, even if an Exception occurs.
            with PdfPages(fig_fname_vector) as pdf:
                pdf.savefig(plotfig, bbox_inches='tight', pad_inches=0.05)

                # We can also set the file's metadata via the PdfPages object:
                pdf_metadata = pdf.infodict()
                for x in metadata:
                    pdf_metadata[x] = metadata[x]

            # determine the PNG-Filename and save the plain PNG
            fig_fname_image = os.path.join(filepath,
                                           filename)[:-4] + '_fig.png'
            plotfig.savefig(fig_fname_image,
                            bbox_inches='tight',
                            pad_inches=0.05)

            # Use Pillow (an fork for PIL) to attach metadata to the PNG
            png_image = Image.open(fig_fname_image)
            png_metadata = PngImagePlugin.PngInfo()

            # PIL can only handle Strings, so let's convert our times
            metadata['CreationDate'] = metadata['CreationDate'].strftime(
                '%Y%m%d-%H%M-%S')
            metadata['ModDate'] = metadata['ModDate'].strftime(
                '%Y%m%d-%H%M-%S')

            for x in metadata:
                # make sure every value of the metadata is a string
                if not isinstance(metadata[x], str):
                    metadata[x] = str(metadata[x])

                # add the metadata to the picture
                png_metadata.add_text(x, metadata[x])

            # save the picture again, this time including the metadata
            png_image.save(fig_fname_image, "png", pnginfo=png_metadata)

            # close matplotlib figure
            plt.close(plotfig)
            self.log.debug(
                'Time needed to save data: {0:.2f}s'.format(time.time() -
                                                            start_time))
            #----------------------------------------------------------------------------------

    def save_array_as_text(self,
                           data,
                           filename,
                           filepath='',
                           fmt='%.15e',
                           header='',
                           delimiter='\t',
                           comments='#',
                           append=False):
        """
        An Independent method, which can save a 1D or 2D numpy.ndarray as textfile.
        Can append to files.
        """
        # write to file. Append if requested.
        if append:
            with open(os.path.join(filepath, filename), 'ab') as file:
                np.savetxt(file,
                           data,
                           fmt=fmt,
                           delimiter=delimiter,
                           header=header,
                           comments=comments)
        else:
            with open(os.path.join(filepath, filename), 'wb') as file:
                np.savetxt(file,
                           data,
                           fmt=fmt,
                           delimiter=delimiter,
                           header=header,
                           comments=comments)
        return

    def get_daily_directory(self):
        """
        Creates the daily directory.

          @return string: path to the daily directory.

        If the daily directory does not exits in the specified <root_dir> path
        in the config file, then it is created according to the following scheme:

            <root_dir>\<year>\<month>\<yearmonthday>

        and the filepath is returned. There should be always a filepath
        returned.
        """

        # First check if the directory exists and if not then the default
        # directory is taken.
        if not os.path.exists(self.data_dir):
            # Check if the default directory does exist. If yes, there is
            # no need to create it, since it will overwrite the existing
            # data there.
            if not os.path.exists(self.data_dir):
                os.makedirs(self.data_dir)
                self.log.warning(
                    'The specified Data Directory in the '
                    'config file does not exist. Using default for '
                    '{0} system instead. The directory {1} was '
                    'created'.format(self.os_system, self.data_dir))

        # That is now the current directory:
        current_dir = os.path.join(self.data_dir, time.strftime("%Y"),
                                   time.strftime("%m"))

        folder_exists = False  # Flag to indicate that the folder does not exist.
        if os.path.exists(current_dir):

            # Get only the folders without the files there:
            folderlist = [
                d for d in os.listdir(current_dir)
                if os.path.isdir(os.path.join(current_dir, d))
            ]
            # Search if there is a folder which starts with the current date:
            for entry in folderlist:
                if time.strftime("%Y%m%d") in (entry[:2]):
                    current_dir = os.path.join(current_dir, str(entry))
                    folder_exists = True
                    break

        if not folder_exists:
            current_dir = os.path.join(current_dir, time.strftime("%Y%m%d"))
            self.log.info('Creating directory for today\'s data in \n'
                          '{0}'.format(current_dir))

            # The exist_ok=True is necessary here to prevent Error 17 "File Exists"
            # Details at http://stackoverflow.com/questions/12468022/python-fileexists-error-when-making-directory
            os.makedirs(current_dir, exist_ok=True)

        return current_dir

    def get_path_for_module(self, module_name):
        """
        Method that creates a path for 'module_name' where data are stored.

        @param string module_name: Specify the folder, which should be created in the daily
                                   directory. The module_name can be e.g. 'Confocal'.
        @return string: absolute path to the module name
        """
        dir_path = os.path.join(self.get_daily_directory(), module_name)

        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        return dir_path

    def get_additional_parameters(self):
        """ Method that return the additional parameters dictionary securely """
        return self._additional_parameters.copy()

    def update_additional_parameters(self, *args, **kwargs):
        """
        Method to update one or multiple additional parameters

        @param dict args: Optional single positional argument holding parameters in a dict to
                          update additional parameters from.
        @param kwargs: Optional keyword arguments to be added to additional parameters
        """
        if len(args) == 0:
            param_dict = kwargs
        elif len(args) == 1 and isinstance(args[0], dict):
            param_dict = args[0]
            param_dict.update(kwargs)
        else:
            raise TypeError(
                '"update_additional_parameters" takes exactly 0 or 1 positional '
                'argument of type dict.')

        for key in param_dict.keys():
            param_dict[key] = netobtain(param_dict[key])
        self._additional_parameters.update(param_dict)
        return

    def remove_additional_parameter(self, key):
        """
        remove parameter from additional parameters

        @param str key: The additional parameters key/name to delete
        """
        self._additional_parameters.pop(key, None)
        return
def showSurveyStatistics(simulatedSurvey,
                         pdfFile=None,
                         pngFile=None,
                         usekde=False):
    """
    Produce a plot with the survey statistics.

    Parameters
    ----------

    simulatedSurvey : Object containing the simulated survey.

    Keywords
    --------

    pdfFile : string
        Name of optional PDF file in which to save the plot.
    pngFile : string
        Name of optional PNG file in which to save the plot.
    usekde  : boolean
        If true use kernel density estimates to show the distribution of survey quantities instead of
        histograms.
    """
    try:
        _ = simulatedSurvey.observedParallaxes.shape
    except AttributeError:
        stderr.write("You have not generated the observations yet!\n")
        return

    parLimitPlot = 50.0
    plxSnrLim = 5.0

    positiveParallaxes = (simulatedSurvey.observedParallaxes > 0.0)
    goodParallaxes = (simulatedSurvey.observedParallaxes /
                      simulatedSurvey.parallaxErrors >= plxSnrLim)
    estimatedAbsMags = (
        simulatedSurvey.observedMagnitudes[positiveParallaxes] +
        5.0 * np.log10(simulatedSurvey.observedParallaxes[positiveParallaxes])
        - 10.0)
    relParErr = (simulatedSurvey.parallaxErrors[positiveParallaxes] /
                 simulatedSurvey.observedParallaxes[positiveParallaxes])
    deltaAbsMag = estimatedAbsMags - simulatedSurvey.absoluteMagnitudes[
        positiveParallaxes]

    useagab(usetex=False, fontfam='sans')
    fig = plt.figure(figsize=(27, 12))

    axA = fig.add_subplot(2, 3, 1)
    apply_tufte(axA, withgrid=False)
    axA.set_prop_cycle(cycler('color', get_distinct(3)))

    minPMinThird = np.power(simulatedSurvey.minParallax, -3.0)
    maxPMinThird = np.power(parLimitPlot, -3.0)
    x = np.linspace(simulatedSurvey.minParallax,
                    np.min([parLimitPlot, simulatedSurvey.maxParallax]), 1001)
    axA.plot(x,
             3.0 * np.power(x, -4.0) / (minPMinThird - maxPMinThird),
             '--',
             label='model',
             lw=3)

    if usekde:
        scatter = rse(simulatedSurvey.trueParallaxes)
        bw = 1.06 * scatter * simulatedSurvey.numberOfStarsInSurvey**(-0.2)
        kde = KernelDensity(bandwidth=bw)
        kde.fit(simulatedSurvey.trueParallaxes[:, None])
        samples = np.linspace(simulatedSurvey.trueParallaxes.min(),
                              simulatedSurvey.trueParallaxes.max(), 200)[:,
                                                                         None]
        logdens = kde.score_samples(samples)
        axA.plot(samples, np.exp(logdens), '-', lw=3, label='true')
    else:
        axA.hist(simulatedSurvey.trueParallaxes,
                 bins='auto',
                 density=True,
                 histtype='step',
                 lw=3,
                 label='true')

    if usekde:
        scatter = rse(simulatedSurvey.observedParallaxes)
        bw = 1.06 * scatter * simulatedSurvey.numberOfStarsInSurvey**(-0.2)
        kde = KernelDensity(bandwidth=bw)
        kde.fit(simulatedSurvey.observedParallaxes[:, None])
        samples = np.linspace(simulatedSurvey.observedParallaxes.min(),
                              simulatedSurvey.observedParallaxes.max(),
                              200)[:, None]
        logdens = kde.score_samples(samples)
        axA.plot(samples, np.exp(logdens), '-', lw=3, label='observed')
    else:
        axA.hist(simulatedSurvey.observedParallaxes,
                 bins='auto',
                 density=True,
                 histtype='step',
                 lw=3,
                 label='observed')

    axA.set_xlabel(r'$\varpi$,  $\varpi_\mathrm{true}$ [mas]')
    axA.set_ylabel(r'$p(\varpi)$, $p(\varpi_\mathrm{true})$')
    leg = axA.legend(loc='best', handlelength=1.0)
    for t in leg.get_texts():
        t.set_fontsize(14)
    axA.text(0.025,
             0.9,
             'a',
             horizontalalignment='center',
             verticalalignment='center',
             transform=axA.transAxes,
             weight='bold',
             fontsize=30)

    axB = fig.add_subplot(2, 3, 2)
    apply_tufte(axB, withgrid=False)
    axB.set_prop_cycle(cycler('color', get_distinct(3)))

    m = np.linspace(simulatedSurvey.observedMagnitudes.min(),
                    simulatedSurvey.observedMagnitudes.max(), 1000)
    axB.plot(m,
             np.exp(simulatedSurvey.apparentMagnitude_lpdf(m)),
             '--',
             lw=3,
             label='model')

    if usekde:
        scatter = rse(simulatedSurvey.apparentMagnitudes)
        bw = 1.06 * scatter * simulatedSurvey.numberOfStarsInSurvey**(-0.2)
        kde = KernelDensity(bandwidth=bw)
        kde.fit(simulatedSurvey.apparentMagnitudes[:, None])
        samples = np.linspace(simulatedSurvey.apparentMagnitudes.min(),
                              simulatedSurvey.apparentMagnitudes.max(),
                              200)[:, None]
        logdens = kde.score_samples(samples)
        axB.plot(samples, np.exp(logdens), '-', label='true', lw=3)
    else:
        axB.hist(simulatedSurvey.apparentMagnitudes,
                 bins='auto',
                 density=True,
                 histtype='step',
                 lw=3,
                 label='true')

    if usekde:
        scatter = rse(simulatedSurvey.observedMagnitudes)
        bw = 1.06 * scatter * simulatedSurvey.numberOfStarsInSurvey**(-0.2)
        kde = KernelDensity(bandwidth=bw)
        kde.fit(simulatedSurvey.observedMagnitudes[:, None])
        samples = np.linspace(simulatedSurvey.observedMagnitudes.min(),
                              simulatedSurvey.observedMagnitudes.max(),
                              200)[:, None]
        logdens = kde.score_samples(samples)
        axB.plot(samples, np.exp(logdens), '-', label='observed', lw=3)
    else:
        axB.hist(simulatedSurvey.observedMagnitudes,
                 bins='auto',
                 density=True,
                 histtype='step',
                 lw=3,
                 label='observed')

    axB.set_xlabel("$m$, $m_\mathrm{true}$")
    axB.set_ylabel("$p(m)$, $p(m_\mathrm{true})$")
    leg = axB.legend(loc=(0.03, 0.55), handlelength=1.0)
    for t in leg.get_texts():
        t.set_fontsize(14)
    axB.text(0.025,
             0.9,
             'b',
             horizontalalignment='center',
             verticalalignment='center',
             transform=axB.transAxes,
             weight='bold',
             fontsize=30)

    axC = fig.add_subplot(2, 3, 3)
    apply_tufte(axC, withgrid=False)
    axC.set_prop_cycle(cycler('color', get_distinct(3)))

    x = np.linspace(simulatedSurvey.absoluteMagnitudes.min(),
                    simulatedSurvey.absoluteMagnitudes.max(), 300)
    axC.plot(x,
             norm.pdf(x,
                      loc=simulatedSurvey.meanAbsoluteMagnitude,
                      scale=simulatedSurvey.stddevAbsoluteMagnitude),
             '--',
             lw=3,
             label='model')

    if usekde:
        scatter = rse(simulatedSurvey.absoluteMagnitudes)
        bw = 1.06 * scatter * simulatedSurvey.numberOfStarsInSurvey**(-0.2)
        kde = KernelDensity(bandwidth=bw)
        kde.fit(simulatedSurvey.absoluteMagnitudes[:, None])
        samples = np.linspace(simulatedSurvey.absoluteMagnitudes.min(),
                              simulatedSurvey.absoluteMagnitudes.max(),
                              200)[:, None]
        logdens = kde.score_samples(samples)
        axC.plot(samples, np.exp(logdens), '-', label='true', lw=3)
    else:
        axC.hist(simulatedSurvey.absoluteMagnitudes,
                 bins='auto',
                 density=True,
                 histtype='step',
                 lw=3,
                 label='true')

    if (simulatedSurvey.absoluteMagnitudes[goodParallaxes].size >= 3):
        if usekde:
            scatter = rse(simulatedSurvey.absoluteMagnitudes[goodParallaxes])
            bw = 1.06 * scatter * simulatedSurvey.absoluteMagnitudes[
                goodParallaxes].size**(-0.2)
            kde = KernelDensity(bandwidth=bw)
            kde.fit(simulatedSurvey.absoluteMagnitudes[goodParallaxes][:,
                                                                       None])
            samples = np.linspace(
                simulatedSurvey.absoluteMagnitudes[goodParallaxes].min(),
                simulatedSurvey.absoluteMagnitudes[goodParallaxes].max(),
                200)[:, None]
            logdens = kde.score_samples(samples)
            axC.plot(
                samples,
                np.exp(logdens),
                '-',
                label=r'$\varpi/\sigma_\varpi\geq{0:.1f}$'.format(plxSnrLim),
                lw=3)
        else:
            axC.hist(
                simulatedSurvey.absoluteMagnitudes[goodParallaxes],
                bins='auto',
                density=True,
                histtype='step',
                lw=3,
                label=r'$\varpi/\sigma_\varpi\geq{0:.1f}$'.format(plxSnrLim))

    axC.set_xlabel("$M$")
    axC.set_ylabel("$p(M)$")
    leg = axC.legend(loc=(0.03, 0.55), handlelength=1.0)
    for t in leg.get_texts():
        t.set_fontsize(14)
    axC.text(0.025,
             0.9,
             'c',
             horizontalalignment='center',
             verticalalignment='center',
             transform=axC.transAxes,
             weight='bold',
             fontsize=30)

    axD = fig.add_subplot(2, 3, 4)
    apply_tufte(axD, withgrid=False)
    axD.set_prop_cycle(cycler('color', get_distinct(3)))
    axD.plot(simulatedSurvey.trueParallaxesNoLim,
             simulatedSurvey.observedParallaxesNoLim -
             simulatedSurvey.trueParallaxesNoLim,
             'k,',
             label=r'$m_\mathrm{lim}=\infty$')
    axD.plot(simulatedSurvey.trueParallaxes,
             simulatedSurvey.observedParallaxes -
             simulatedSurvey.trueParallaxes,
             '.',
             label=r'$m_\mathrm{{lim}}={0}$'.format(
                 simulatedSurvey.apparentMagnitudeLimit))
    axD.plot(simulatedSurvey.trueParallaxes[positiveParallaxes],
             simulatedSurvey.observedParallaxes[positiveParallaxes] -
             simulatedSurvey.trueParallaxes[positiveParallaxes],
             '.',
             label=r'$\varpi>0$')
    axD.plot(simulatedSurvey.trueParallaxes[goodParallaxes],
             simulatedSurvey.observedParallaxes[goodParallaxes] -
             simulatedSurvey.trueParallaxes[goodParallaxes],
             'o',
             label=r'$\varpi/\sigma_\varpi\geq{0:.1f}$'.format(plxSnrLim))
    axD.set_xlabel(r"$\varpi_\mathrm{true}$ [mas]")
    axD.set_ylabel("$\\varpi-\\varpi_\\mathrm{true}$ [mas]")
    leg = axD.legend(loc='best', handlelength=0.5, ncol=2)
    for t in leg.get_texts():
        t.set_fontsize(14)
    axD.text(0.025,
             0.9,
             'd',
             horizontalalignment='center',
             verticalalignment='center',
             transform=axD.transAxes,
             weight='bold',
             fontsize=30)

    axE = fig.add_subplot(2, 3, 5)
    apply_tufte(axE, withgrid=False)
    axE.set_prop_cycle(cycler('color', get_distinct(3)))
    axE.plot(simulatedSurvey.trueParallaxesNoLim,
             simulatedSurvey.absoluteMagnitudesNoLim,
             'k,',
             label=r'$m_\mathrm{lim}=\infty$')
    axE.plot(simulatedSurvey.trueParallaxes,
             simulatedSurvey.absoluteMagnitudes,
             '.',
             label=r'$m_\mathrm{{lim}}={0}$'.format(
                 simulatedSurvey.apparentMagnitudeLimit))
    axE.plot(simulatedSurvey.trueParallaxes[positiveParallaxes],
             simulatedSurvey.absoluteMagnitudes[positiveParallaxes],
             '.',
             label=r'$\varpi>0$')
    axE.plot(simulatedSurvey.trueParallaxes[goodParallaxes],
             simulatedSurvey.absoluteMagnitudes[goodParallaxes],
             'o',
             label=r'$\varpi/\sigma_\varpi\geq{0:.1f}$'.format(plxSnrLim))
    axE.set_xlabel(r"$\varpi_\mathrm{true}$ [mas]")
    axE.set_ylabel("$M_\\mathrm{true}$")
    axE.axhline(y=simulatedSurvey.meanAbsoluteMagnitude)
    leg = axE.legend(loc='best', handlelength=0.5, ncol=2)
    for t in leg.get_texts():
        t.set_fontsize(14)
    axE.text(0.025,
             0.9,
             'e',
             horizontalalignment='center',
             verticalalignment='center',
             transform=axE.transAxes,
             weight='bold',
             fontsize=30)

    plt.suptitle(
        "Simulated survey statistics: $N_\\mathrm{{stars}}={0}$, ".format(
            simulatedSurvey.numberOfStars) +
        "$m_\\mathrm{{lim}}={0}$, ".format(
            simulatedSurvey.apparentMagnitudeLimit) +
        "$N_\\mathrm{{survey}}={0}$, ".format(
            simulatedSurvey.numberOfStarsInSurvey) +
        "${0}\\leq\\varpi\\leq{1}$, ".format(simulatedSurvey.minParallax,
                                             simulatedSurvey.maxParallax) +
        "$\\mu_M={0}$, ".format(simulatedSurvey.meanAbsoluteMagnitude) +
        "$\\sigma_M={0:.2f}$".format(simulatedSurvey.stddevAbsoluteMagnitude))

    if pdfFile is not None:
        plt.savefig(pdfFile)
    if pngFile is not None:
        plt.savefig(pngFile)
    if (pdfFile is None and pngFile is None):
        plt.show()
示例#54
0
colors_muted['lightgreen'] = '#AAB71B'
colors_muted['green'] = '#408020'
colors_muted['darkgreen'] = '#007030'
colors_muted['cyan'] = '#40A787'
colors_muted['lightblue'] = '#008797'
colors_muted['blue'] = '#2060A7'
colors_muted['purple'] = '#53379B'
colors_muted['magenta'] = '#873770'
colors_muted['pink'] = '#D03050'
colors_muted['white'] = '#FFFFFF'
colors_muted['gray'] = '#A0A0A0'
colors_muted['black'] = '#000000'

colors = colors_muted
plt.rcParams['axes.prop_cycle'] = cycler(color=[
    colors['blue'], colors['red'], colors['lightgreen'], colors['orange'],
    colors['cyan'], colors['magenta']
])

plt.rcParams['savefig.format'] = 'png'
plt.rcParams['savefig.dpi'] = 200.0
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['font.size'] = 5.0
plt.rcParams['lines.linewidth'] = 0.8
plt.rcParams['lines.markersize'] = 4
plt.rcParams['axes.xmargin'] = 0.0
plt.rcParams['axes.ymargin'] = 0.0
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False
plt.rcParams['legend.borderpad'] = 0.0
def setStyle(palette='default', bigPlot=False):
    '''
    A function to set the plotting style.
    The function receives the colour palette name and whether it is
    a big plot or not. The latter sets the fonts and marker to be bigger in case it is a big plot.
    The available colour palettes are as follows:

    - classic (default): A classic colourful palette with strong colours and contrast.
    - modified classic: Similar to the classic, with slightly different colours.
    - autumn: A slightly darker autumn style colour palette.
    - purples: A pseudo sequential purple colour palette (not great for contrast).
    - greens: A pseudo sequential green colour palette (not great for contrast).

    To use the function, simply call it before plotting anything.

    Parameters
    ----------
    palette: str
    bigPlot: bool

    Raises
    ------
    KeyError if provided palette does not exist.
    '''

    COLORS = dict()
    COLORS['classic'] = ['#ba2c54', '#5B90DC', '#FFAB44', '#0C9FB3', '#57271B', '#3B507D',
                         '#794D88', '#FD6989', '#8A978E', '#3B507D', '#D8153C', '#cc9214']
    COLORS['modified classic'] = ['#D6088F', '#424D9C', '#178084', '#AF99DA', '#F58D46', '#634B5B',
                                  '#0C9FB3', '#7C438A', '#328cd6', '#8D0F25', '#8A978E', '#ffcb3d']
    COLORS['autumn'] = ['#A9434D', '#4E615D', '#3C8DAB', '#A4657A', '#424D9C', '#DC575A',
                        '#1D2D38', '#634B5B', '#56276D', '#577580', '#134663', '#196096']
    COLORS['purples'] = ['#a57bb7', '#343D80', '#EA60BF', '#B7308E', '#E099C3', '#7C438A',
                         '#AF99DA', '#4D428E', '#56276D', '#CC4B93', '#DC4E76', '#5C4AE4']
    COLORS['greens'] = ['#268F92', '#abc14d', '#8A978E', '#0C9FB3', '#BDA962', '#B0CB9E',
                        '#769168', '#5E93A5', '#178084', '#B7BBAD', '#163317', '#76A63F']

    COLORS['default'] = COLORS['classic']

    MARKERS = ['o', 's', 'v', '^', '*', 'P', 'd', 'X', 'p', '<', '>', 'h']
    LINES = [(0, ()),  # solid
             (0, (1, 1)),  # densely dotted
             (0, (3, 1, 1, 1)),  # densely dashdotted
             (0, (5, 5)),  # dashed
             (0, (3, 1, 1, 1, 1, 1)),  # densely dashdotdotted
             (0, (5, 1)),  # desnely dashed
             (0, (1, 5)),  # dotted
             (0, (3, 5, 1, 5)),  # dashdotted
             (0, (3, 5, 1, 5, 1, 5)),  # dashdotdotted
             (0, (5, 10)),  # loosely dashed
             (0, (1, 10)),  # loosely dotted
             (0, (3, 10, 1, 10)),  # loosely dashdotted
             ]

    if palette not in COLORS.keys():
        raise KeyError('palette must be one of {}'.format(', '.join(COLORS)))

    fontsize = {'default': 15, 'bigPlot': 30}
    markersize = {'default': 8, 'bigPlot': 18}
    plotSize = 'default'
    if bigPlot:
        plotSize = 'bigPlot'

    plt.rc('lines', linewidth=2, markersize=markersize[plotSize])
    plt.rc('axes', prop_cycle=(
        cycler(color=COLORS[palette])
        + cycler(linestyle=LINES)
        + cycler(marker=MARKERS))
    )
    plt.rc(
        'axes',
        titlesize=fontsize[plotSize],
        labelsize=fontsize[plotSize],
        labelpad=5,
        grid=True,
        axisbelow=True
    )
    plt.rc('xtick', labelsize=fontsize[plotSize])
    plt.rc('ytick', labelsize=fontsize[plotSize])
    plt.rc('legend', loc='best', shadow=False, fontsize='medium')
    plt.rc('font', family='serif', size=fontsize[plotSize])

    return
        std = np.std(vals, axis=0) / np.sqrt(vals.shape[0])
        ax1 = plt.plot(batches, mean, label="p={}".format(p))
        ax1_col = ax1[0].get_color()
        plt.fill_between(batches,
                         mean - 2 * std,
                         mean + 2 * std,
                         alpha=0.15,
                         color=ax1_col)
    plt.legend()
    plt.xlabel("Checkpoint ($T$)")
    plt.ylabel("R_(CUCB2)({}) - R_(UCB)({})".format(T, T))


n = 9  # Number of colors
new_colors = [plt.get_cmap('Set1')(1. * i / n) for i in range(n)]
linestyle_cycler = cycler('linestyle',
                          ['-', '--', ':', '-.', '-', '--', ':', '-.', '-'])
plt.rc('axes', prop_cycle=(cycler('color', new_colors) + linestyle_cycler))
plt.rc('lines', linewidth=2)

if len(sys.argv) == 1:
    filename = 'COMP_20190825_033627_batch_results.pickle'
else:
    filename = sys.argv[1]

SUMMARIZE = False

print("Opening file %s..." % filename)
with open(filename, 'rb') as f:
    results = pickle.load(f)
print("Done.\n")
示例#57
0
def huescale(number_of_plots, *option):
    #SET DEFAULT VALUES IF NO OPTIONAL ARGUMENT
    hue = 'None'
    #OPTIONS
    for argument in option:
        if argument == 'blue':
            hue = 'blue'
        if argument == 'bluegreen':
            hue = 'bluegreen'
        if argument == 'green':
            hue = 'green'
        if argument == 'gold':
            hue = 'gold'
        if argument == 'brown':
            hue = 'brown'
        if argument == 'rose':
            hue = 'rose'
        if argument == 'purple':
            hue = 'purple'

    if (number_of_plots <= 3):
        if hue == 'blue':
            colorscheme = ['#114477', '#4477AA', '#77AADD']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
        if hue == 'bluegreen':
            colorscheme = ['#117777', '#44AAAA', '#77CCCC']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
        if hue == 'green':
            colorscheme = ['#117744', '#44AA77', '#88CCAA']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
        if hue == 'gold':
            colorscheme = ['#777711', '#AAAA44', '#DDDD77']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
        if hue == 'brown':
            colorscheme = ['#774411', '#AA7744', '#DDAA77']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
        if hue == 'rose':
            colorscheme = ['#771122', '#AA4455', '#DD7788']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
        if hue == 'purple':
            colorscheme = ['#771155', '#AA4488', '#CC99BB']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
        if hue == 'None':
            colorscheme = ['#D95F0E', '#FEC44F', '#FFF7BC']
            plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)

    prop_cycle = plt.rcParams['axes.prop_cycle']
    clist = prop_cycle.by_key()['color']

    if (number_of_plots == 4):
        colorscheme = ['#CC4C02', '#FB9A29', '#FED98E', '#FFFBD5']
        plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
    if (number_of_plots == 5):
        colorscheme = ['#993404', '#D95F0E', '#FB9A29', '#FED98E', '#FFFBD5']
        plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
    if (number_of_plots == 6):
        colorscheme = [
            '#993404', '#D95F0E', '#FB9A29', '#FEC44F', '#FEE391', '#FFFBD5'
        ]
        plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
    if (number_of_plots == 7):
        colorscheme = [
            '#8C2D04', '#CC4C02', '#EC7014', '#FB9A29', '#FEC44F', '#FEE391',
            '#FFFBD5'
        ]
        plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
    if (number_of_plots == 8):
        colorscheme = [
            '#8C2D04', '#CC4C02', '#EC7014', '#FB9A29', '#FEC44F', '#FEE391',
            '#FFF7BC', '#FFFFE5'
        ]
        plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
    if (number_of_plots == 9):
        colorscheme = [
            '#662506', '#993404', '#CC4C02', '#EC7014', '#FB9A29', '#FEC44F',
            '#FEE391', '#FFF7BC', '#FFFFE5'
        ]
        plt.rcParams['axes.prop_cycle'] = cycler('color', colorscheme)
    if (number_of_plots > 3 and number_of_plots <= 9 and hue != 'None'):
        print("ONLY OCHERSCALE FOR MORE THAN 3 PLOTS")
    if (number_of_plots > 9):
        colorscheme = clist
        print(
            "OUT OF RANGE[1-9] : COLORBLIND MODE DEACTIVATED ---> DEFAULT MODE"
        )
    return (colorscheme)
示例#58
0
def set_rcParams_scvelo(fontsize=12, color_map=None, frameon=None):
    """Set matplotlib.rcParams to scvelo defaults."""

    # dpi options (mpl default: 100, 100)
    rcParams['figure.dpi'] = 100
    rcParams['savefig.dpi'] = 150

    # figure (mpl default: 0.125, 0.96, 0.15, 0.91)
    rcParams['figure.figsize'] = (6, 4)
    rcParams['figure.subplot.left'] = 0.18
    rcParams['figure.subplot.right'] = 0.96
    rcParams['figure.subplot.bottom'] = 0.15
    rcParams['figure.subplot.top'] = 0.91

    # lines (defaults:  1.5, 6, 1)
    rcParams['lines.linewidth'] = 1.5  # the line width of the frame
    rcParams['lines.markersize'] = 6
    rcParams['lines.markeredgewidth'] = 1

    # font
    rcParams['font.sans-serif'] = \
        ['Arial', 'Helvetica', 'DejaVu Sans',
         'Bitstream Vera Sans', 'sans-serif']

    fontsize = fontsize
    labelsize = 0.92 * fontsize

    # fonsizes (mpl default: 10, medium, large, medium)
    rcParams['font.size'] = fontsize
    rcParams['legend.fontsize'] = labelsize
    rcParams['axes.titlesize'] = fontsize
    rcParams['axes.labelsize'] = labelsize

    # legend (mpl default: 1, 1, 2, 0.8)
    rcParams['legend.numpoints'] = 1
    rcParams['legend.scatterpoints'] = 1
    rcParams['legend.handlelength'] = 0.5
    rcParams['legend.handletextpad'] = 0.4

    # color cycle
    rcParams['axes.prop_cycle'] = cycler(color=vega_10)

    # axes
    rcParams['axes.linewidth'] = 0.8
    rcParams['axes.edgecolor'] = 'black'
    rcParams['axes.facecolor'] = 'white'

    # ticks (mpl default: k, k, medium, medium)
    rcParams['xtick.color'] = 'k'
    rcParams['ytick.color'] = 'k'
    rcParams['xtick.labelsize'] = labelsize
    rcParams['ytick.labelsize'] = labelsize

    # axes grid (mpl default: False, #b0b0b0)
    rcParams['axes.grid'] = False
    rcParams['grid.color'] = '.8'

    # color map
    rcParams['image.cmap'] = 'RdBu_r' if color_map is None else color_map

    # frame (mpl default: True)
    frameon = False if frameon is None else frameon
    global _frameon
    _frameon = frameon
示例#59
0
__all__ = ['ex', 'graph', 'circuit', 'tn', 'peo', 'sim_costs', 'sum_flops', 'step_flops', 'max_mem', 'SEED',
           'EDGE_IDX_FOR_SEED', 'EDGE_IDX_FOR_SEED_JLSE', 'sim_profile', 'step_sim_time', 'plot_with_filter',
           'get_log_flops_vs_matmul', 'time_vs_flops_plot']

# Cell
import sys
import numpy as np
import matplotlib.pyplot as plt

import qtensor as qt
from cartesian_explorer import Explorer

# Cell
import matplotlib as mpl
from cycler import cycler
mpl.rcParams['axes.prop_cycle'] = cycler(color=['#db503d', '#02C6E0'])

# Cell
ex = Explorer()

# Cell
@ex.provider
def graph(n, d, seed):
    return qt.toolbox.random_graph(nodes=n, degree=d, seed=seed)

@ex.provider
def circuit(graph, edge_idx, p, composer_type='cone'):
    gamma, beta = [.1]*p, [.3]*p
    if composer_type=='cylinder':
        comp = qt.OldQtreeQAOAComposer(graph, gamma=gamma, beta=beta)
    if composer_type=='cone':
    def __init__(self,
                 hourly_data,
                 label,
                 yearly_ax,
                 monthly_ax,
                 daily_ax,
                 agg_by_day=None,
                 agg_by_month=None,
                 style_cycle=None):
        '''Class to manage 3-levels of aggregated temperature

        Parameters
        ----------
        hourly_data : DataFrame
            Tempreture measured hourly

        label : str
            The name of this data set_a

        yearly_ax : Axes
            The axes to plot 'year' scale data (aggregated by month) to

        monthly_ax : Axes
            The axes to plot 'month' scale data (aggregated by day) to

        daily_ax : Axes
            The axes to plot 'day' scale data (un-aggregated hourly) to

        agg_by_day : DataFrame, optional

            Data already aggregated by day.  This is just to save
            computation, will be computed if not provided.

        agg_by_month : DataFrame, optional

            Data already aggregated by month.  This is just to save
            computation, will be computed if not provided.

        style_cycle : Cycler, optional
            Style to use for plotting

        '''
        # data
        self.data_by_hour = hourly_data
        if agg_by_day is None:
            agg_by_day = aggregate_by_day(hourly_data)
        self.data_by_day = agg_by_day
        if agg_by_month is None:
            agg_by_month = aggregate_by_month(hourly_data)
        self.data_by_month = agg_by_month
        # style
        if style_cycle is None:
            style_cycle = (
                (cycler('marker',
                        ['o', 's', '^', '*', 'x', 'v', '8', 'D', 'H', '<']) +
                 cycler('color', [
                     '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
                     '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
                 ])))
        self.style_cycle = style_cycle()
        # axes
        self.yearly_ax = yearly_ax
        self.monthly_ax = monthly_ax
        self.daily_ax = daily_ax
        # name
        self.label = label
        # these will be used for book keeping
        self.daily_artists = {}
        self.daily_index = {}
        self.hourly_artiists = {}
        # artists
        self.yearly_art = plot_aggregated_errorbar(self.yearly_ax,
                                                   self.data_by_month,
                                                   self.label,
                                                   picker=5,
                                                   **next(self.style_cycle))

        # pick methods
        self.y_cid = self.yearly_ax.figure.canvas.mpl_connect(
            'pick_event', self._yearly_on_pick)
        self.y_cid = self.yearly_ax.figure.canvas.mpl_connect(
            'pick_event', self._monthly_on_pick)
        self.y_cid = self.yearly_ax.figure.canvas.mpl_connect(
            'pick_event', self._daily_on_pick)