コード例 #1
0
def independent_hyperparameter(name, values_expr, value=None):

    e = stz.ellipse([0, 0], h_width / 2.0, h_height / 2.0, hyperp_s_fmt)
    l = stz.latex(stz.center_coords(e), "\\textbf{%s}" % name)
    fn_cs = stz.coords_from_bbox_with_fn(e, stz.right_center_coords)
    if value is None:
        l_vs = stz.latex(fn_cs, "\\textbf{[%s]}" % (values_expr, ),
                         unassigned_h_s_fmt)
        return [e, l, l_vs]
    else:
        v_cs = stz.coords_from_bbox_with_fn(e, stz.right_center_coords)
        l_v = stz.latex(v_cs, "\\textbf{%s}" % value, assigned_h_s_fmt)
        return [e, l, l_v]
コード例 #2
0
def segment(length, label_str, left_tick_label_str, right_tick_label_str):
    seg = stz.line_segment([0, 0], [length, 0])
    left_tick = stz.centered_vertical_line_segment([0, 0], tick_length)
    right_tick = stz.centered_vertical_line_segment([length, 0], tick_length)
    left_tick_label = stz.latex([0, 0], left_tick_label_str)
    right_tick_label = stz.latex([0, 0], right_tick_label_str)
    stz.place_above_and_align_to_the_center(left_tick_label, left_tick,
                                            tick_label_spacing)
    stz.place_above_and_align_to_the_center(right_tick_label, right_tick,
                                            tick_label_spacing)
    seg_label = stz.latex([-label_spacing, 0], label_str)
    return [
        seg, left_tick, right_tick, left_tick_label, right_tick_label, seg_label
    ]
コード例 #3
0
def output(name):
    x1 = io_long_side / 2.0
    x2 = io_short_side / 2.0
    r = stz.closed_path([[-x1, io_height], [x1, io_height], [x2, 0], [-x2, 0]],
                        output_s_fmt)
    l = stz.latex(stz.center_coords(r), name)
    return [r, l]
コード例 #4
0
def module(module_name,
           input_names,
           output_names,
           hyperp_names,
           p_width_scale=1.0):

    i_lst = [input(s) for s in input_names]
    o_lst = [output(s) for s in output_names]
    m = stz.rectangle([0, 0], [module_width, -module_height], module_s_fmt)
    l = stz.latex(stz.center_coords(m), "\\textbf{%s}" % module_name)

    stz.distribute_horizontally_with_spacing(i_lst, io_spacing)
    stz.translate_bbox_top_left_to_coords(
        i_lst, [module_inner_vertical_spacing, -module_inner_vertical_spacing])
    stz.distribute_horizontally_with_spacing(o_lst, io_spacing)
    stz.translate_bbox_bottom_left_to_coords(o_lst, [
        module_inner_vertical_spacing,
        -module_height + module_inner_vertical_spacing
    ])

    if len(hyperp_names) > 0:
        h_lst = [property(s, p_width_scale) for s in hyperp_names]
        stz.distribute_vertically_with_spacing(h_lst, p_spacing)
        stz.translate_bbox_top_right_to_coords(h_lst, [
            module_width - module_inner_vertical_spacing,
            -module_inner_vertical_spacing - delta_increment
        ])
        return [[m, l], i_lst, o_lst, h_lst]
    else:
        return [[m, l], i_lst, o_lst]
コード例 #5
0
ファイル: transformer.py プロジェクト: sidney1994/sane_tikz
def rectangle_with_text(color_name, s_lst):
    height = 0.1 + per_line_height * len(s_lst)
    s_width = fmt.combine_tikz_strs([fmt.text_width(width), "align=center"])
    r = rectangle(height, color_name)
    cs = stz.center_coords(r)
    s = " \\vspace{-0.05cm} \\linebreak ".join(s_lst)
    t = stz.latex(cs, s, s_width)
    return [r, t]
コード例 #6
0
def plate(e_lst, expr):
    top_left_cs, bottom_right_cs = stz.bbox(e_lst)
    r = stz.rectangle_from_additive_resizing(top_left_cs, bottom_right_cs,
                                             2.0 * horizontal_plate_spacing,
                                             2.0 * vertical_plate_spacing)
    l_cs = stz.translate_coords_antidiagonally(
        stz.bbox(r)[1], -plate_label_spacing)
    l = stz.latex(l_cs, expr)
    return [r, l]
コード例 #7
0
ファイル: xlnet.py プロジェクト: sidney1994/sane_tikz
def rectangle_with_latex(width,
                         expr,
                         rectangle_tikz_str="",
                         latex_tikz_str=""):
    r_fmt = fmt.combine_tikz_strs([
        # fmt.alignment("center"),
        fmt.rounded_corners(roundness_in_cm),
        fmt.line_width(line_width),
        rectangle_tikz_str
    ])
    l_fmt = fmt.combine_tikz_strs([latex_tikz_str])
    r = stz.rectangle_from_width_and_height([0, 0], square_side, width, r_fmt)
    l = stz.latex(stz.center_coords(r), expr, l_fmt)
    return [r, l]
コード例 #8
0
def dependent_hyperparameter(name, hyperp_names, fn_expr, value=None):
    e = stz.ellipse([0, 0], h_width / 2.0, h_height / 2.0, hyperp_s_fmt)
    if value is None:
        e["horizontal_radius"] *= 2.1 * e["horizontal_radius"]

    l_cs = stz.center_coords(e)
    if value is None:
        l_cs = stz.translate_coords_horizontally(l_cs, 0.1)
    l = stz.latex(l_cs, "\\textbf{%s}" % name)

    if value is None:
        fn_cs = stz.coords_from_bbox_with_fn(e, stz.right_center_coords)
        l_fn = stz.latex(fn_cs, "\\textbf{fn: %s}" % (fn_expr, ),
                         unassigned_h_s_fmt)

        p = property("x", 0.25, 0.7)
        p_cs = stz.translate_coords_horizontally(
            stz.coords_from_bbox_with_fn(e, stz.left_center_coords), 0.1)
        stz.translate_bbox_left_center_to_coords(p, p_cs)
        return [e, l, l_fn, p]
    else:
        v_cs = stz.coords_from_bbox_with_fn(e, stz.right_center_coords)
        l_v = stz.latex(v_cs, "\\textbf{%s}" % value, assigned_h_s_fmt)
        return [e, l, l_v]
コード例 #9
0
def label_small_node(label, e_small, angle, distance, label_width):
    s_fmt = fmt.combine_tikz_strs([
        fmt.rounded_corners(rectangle_roundness),
        fmt.fill_color_with_no_line('gray!20')
    ])
    s_fmt_t = fmt.fill_color_with_no_line('gray!20')
    r = stz.rectangle_from_width_and_height([0, 0], 1.8 * node_radius,
                                            label_width, s_fmt)
    cs = stz.center_coords(e_small)
    cs_label = stz.coords_on_circle(cs, distance, angle)
    stz.translate_bbox_center_to_coords(r, cs_label)
    l = stz.latex(cs_label, label)
    t = stz.closed_path([
        stz.coords_on_circle(cs, filled_node_radius, angle), cs_label,
        stz.coords_on_circle(cs_label, triangle_radius,
                             angle + 180.0 + triangle_angle_delta)
    ], s_fmt_t)
    return [r, t, l]
コード例 #10
0
ファイル: boxes.py プロジェクト: mariemkayagithub/sane_tikz
def bar_plot(data, label_strs):
    cs = [0, 0]
    rs = [
        stz.rectangle_from_width_and_height(cs, rectangle_height, x, r_fmt)
        for x in data
    ]

    stz.distribute_vertically_with_spacing(rs, rectangle_spacing)
    frame_height = (len(data) * rectangle_height +
                    (len(data) - 1) * rectangle_spacing + 2.0 * margin)
    frame = stz.rectangle_from_width_and_height(cs, frame_height, frame_width)
    stz.align_centers_vertically([frame, rs], 0.0)

    labels = []
    for i, r in enumerate(rs):
        l_cs = stz.coords_from_bbox_with_fn(r, stz.left_center_coords)
        l_cs = stz.translate_coords_horizontally(l_cs, -label_spacing)
        l_s = label_strs[i]
        lab = stz.latex(l_cs, l_s, fmt.anchor('right_center'))
        labels.append(lab)

    ticks = get_ticks(frame)
    return [rs, labels, frame, ticks]
コード例 #11
0
def label_below(e, expr):
    cs_ref = stz.coords_from_bbox_with_fn(e, stz.bottom_center_coords)
    cs = stz.translate_coords_vertically(cs_ref, -label_spacing)
    return stz.latex(cs, expr)
コード例 #12
0
ファイル: xlnet.py プロジェクト: sidney1994/sane_tikz
          h_grey_lfmt if i != 2 else "") for i in range(4)
]
m1, m2 = [rct_fn("$\\text{mem}^{(%d)}$" % (i + 1, ), m_fmt) for i in range(2)]
x_out = [sq_fn("$x_3$", x3_fmt)]

row1 = [m1] + xs
row2 = [m2] + hs1

stz.distribute_horizontally_with_spacing([m1] + xs, horizontal_spacing)
stz.distribute_horizontally_with_spacing([m2] + hs1, horizontal_spacing)
stz.distribute_horizontally_with_spacing(hs2, horizontal_spacing)
stz.distribute_vertically_with_spacing([row1, row2, hs2], vertical_spacing)
stz.align_rights([row1, row2, hs2], 0)

stz.place_above_and_align_to_the_center(x_out, hs2[2], vertical_spacing)

legend = stz.latex(
    [0, 0],
    "Factorization order: $3 \\rightarrow 2 \\rightarrow 4 \\rightarrow 1$")
stz.place_below_and_align_to_the_center(legend, xs, legend_spacing)

connections = [
    connect(m1, hs1[2]),
    connect(m2, hs2[2]),
    connect(hs1[2], hs2[2]),
    connect(hs2[2], x_out)
]

e = [row1, row2, hs2, x_out, connections, legend]
stz.draw_to_tikz_standalone(e, "xlnet.tex", name2color)
コード例 #13
0
ファイル: transformer.py プロジェクト: sidney1994/sane_tikz
# the rightmost one
from_cs = stz.coords_from_bbox_with_fn(mmha_an, stz.top_center_coords)
e3 = stz.open_path(
    [from_cs, translated_tri_cs[1], translated_tri_cs[2], tri_cs[2]], s_con)
connections.extend([e1, e2, e3])

arrows = [arrow_in(ier), arrow_in(oer), arrow_out(softmax)]

### annotations
# arrow text
s_fmt = fmt.combine_tikz_strs(
    [fmt.text_width(width), "anchor=north", "align=center"])
cs = stz.translate_coords_vertically(
    stz.coords_from_bbox_with_fn(arrows[0], stz.bottom_center_coords),
    -spacing_between_text_and_arrows)
a1 = stz.latex(cs, "Inputs", s_fmt)

cs = stz.translate_coords_vertically(
    stz.coords_from_bbox_with_fn(arrows[1], stz.bottom_center_coords),
    -spacing_between_text_and_arrows)
a2 = stz.latex(cs, lst2str(["Outputs", "(shifted right)"]), s_fmt)

s_fmt = fmt.combine_tikz_strs(
    [fmt.text_width(width), "anchor=south", "align=center"])

cs = stz.translate_coords_vertically(
    stz.coords_from_bbox_with_fn(arrows[2], stz.top_center_coords),
    spacing_between_text_and_arrows)
a3 = stz.latex(cs, lst2str(["Output", "Probabilities"]), s_fmt)

# side text
コード例 #14
0
e = stz.closed_path(cs)

origin_cs = stz.translate_coords_horizontally(cs[2], -1.0)
x_end_cs = stz.translate_coords_horizontally(origin_cs, x_axis_length)
y_end_cs = stz.translate_coords_vertically(origin_cs, y_axis_length)
x_start_cs = stz.translate_coords_horizontally(origin_cs, -extra_length)
y_start_cs = stz.translate_coords_vertically(origin_cs, -extra_length)
x_label_cs = stz.translate_coords_vertically(x_end_cs, -label_spacing)
y_label_cs = stz.translate_coords_horizontally(y_end_cs, -label_spacing)
origin_label_cs = stz.translate_coords_diagonally(origin_cs, -label_spacing)

axes = [
    stz.line_segment(x_start_cs, x_end_cs, s_fmt),
    stz.line_segment(y_start_cs, y_end_cs, s_fmt)
]

labels = [
    stz.latex([cs[0][0], cs[0][1] + label_spacing], "$C$"),
    stz.latex([cs[1][0] - label_spacing, cs[1][1]], "$B$"),
    stz.latex([cs[2][0], cs[2][1] - a_circle_radius - label_spacing],
              "$A(1, 0)$"),
    stz.latex([cs[3][0], cs[3][1] - label_spacing], "$E$"),
    stz.latex([cs[4][0] + label_spacing, cs[4][1]], "$D$"),
    stz.circle(cs[2], a_circle_radius, f_fmt),
    stz.latex(x_label_cs, "$x$"),
    stz.latex(y_label_cs, "$y$"),
    stz.latex(origin_label_cs, "$O$"),
]

stz.draw_to_tikz_standalone([e, labels, axes], "pentagon.tex")
コード例 #15
0
def box_with_text(s):
    s_fmt = fmt.rounded_corners(box_roundness)
    return [
        stz.rectangle([0, 0], [box_width, -box_height], s_fmt),
        stz.latex([box_width / 2.0, -box_height / 2.0], s)
    ]
コード例 #16
0
def property(name, width_scale=1.0, height_scale=1.0):
    e = stz.ellipse([0, 0], width_scale * p_width / 2.0,
                    height_scale * p_height / 2.0, property_s_fmt)
    l = stz.latex(stz.center_coords(e), name)
    return [e, l]
コード例 #17
0
import sane_tikz as stz
import formatting as fmt

node_radius = 0.30
vertical_node_spacing = 1.4 * node_radius
first_level_horizontal_node_spacing = 1.8
arrow_angle = 30.0
bbox_spacing = 0.1
label_spacing = 0.4
line_width = 1.2 * fmt.standard_line_width

s_lw = fmt.line_width(line_width)

fn = lambda expr: [
    stz.circle([0, 0], node_radius, s_lw),
    stz.latex([0, 0], expr)
]


def place(e, lst):
    delta = 0.0
    for i, sign in enumerate(lst):
        delta += sign * (node_radius + first_level_horizontal_node_spacing /
                         (2 * (i + 1.0)))
    stz.translate_horizontally(e, delta)


def connect(e_from, e_to, color_name="black"):
    s_fmt = fmt.combine_tikz_strs(
        [fmt.arrow_heads("end"),
         fmt.line_color(color_name), s_lw])
コード例 #18
0
def label_right(e, expr):
    cs_ref = stz.coords_from_bbox_with_fn(e, stz.right_center_coords)
    cs = stz.translate_coords_horizontally(cs_ref, label_spacing)
    return stz.latex(cs, expr)
コード例 #19
0
def label_left(e, expr):
    cs = stz.coords_from_bbox_with_fn(e, stz.left_center_coords)
    cs = stz.translate_coords(cs, -label_spacing, 0.1)
    return stz.latex(cs, "\\scriptsize{%s}" % expr)
コード例 #20
0
def frame(frame_idx):
    assert frame_idx >= 0 and frame_idx <= 3
    c1 = conv2d(1)
    o = optional(1)
    r1 = repeat(1)
    r2 = repeat(2)
    cc = concat(1)
    c2 = conv2d(2)
    c3 = conv2d(3)
    c4 = conv2d(4)
    d = dropout(1)

    stz.distribute_horizontally_with_spacing([r1, r2],
                                             horizontal_module_spacing)
    stz.distribute_horizontally_with_spacing([c2, [c3, c4]],
                                             horizontal_module_spacing)

    modules = []
    if frame_idx == 0:
        stz.distribute_vertically_with_spacing([cc, [r1, r2], o, c1],
                                               vertical_module_spacing)

        stz.align_centers_horizontally([cc, [r1, r2], o, c1], 0)
        modules.extend([c1, o, r1, r2, cc])

    else:
        stz.distribute_vertically_with_spacing([c4, c3],
                                               vertical_module_spacing)
        stz.distribute_horizontally_with_spacing([c2, [c3, c4]],
                                                 horizontal_module_spacing)
        stz.align_centers_vertically([[c3, c4], c2], 0)

        if frame_idx == 1:
            stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], o, c1],
                                                   vertical_module_spacing)
            stz.align_centers_horizontally([cc, [c2, c3, c4], o, c1], 0)
            modules.extend([c1, o, c2, c3, c4, cc])

        else:
            stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], d, c1],
                                                   vertical_module_spacing)
            stz.align_centers_horizontally([cc, [c2, c3, c4], d, c1], 0)
            modules.extend([c1, d, c2, c3, c4, cc])

    module_connections = []
    if frame_idx == 0:
        module_connections.extend([
            connect_modules(c1, o, 0, 0),
            connect_modules(o, r1, 0, 0),
            connect_modules(o, r2, 0, 0),
            connect_modules(r1, cc, 0, 0),
            connect_modules(r2, cc, 0, 1),
        ])

    else:
        if frame_idx == 1:
            module_connections.extend([
                connect_modules(c1, o, 0, 0),
                connect_modules(o, c2, 0, 0),
                connect_modules(o, c3, 0, 0),
            ])
        else:
            module_connections.extend([
                connect_modules(c1, d, 0, 0),
                connect_modules(d, c2, 0, 0),
                connect_modules(d, c3, 0, 0),
            ])

        module_connections.extend([
            connect_modules(c3, c4, 0, 0),
            connect_modules(c2, cc, 0, 0),
            connect_modules(c4, cc, 0, 1),
        ])

    # # hyperparameters
    if frame_idx <= 1:
        h_o = independent_hyperparameter("IH-2", "0, 1")
    else:
        h_o = independent_hyperparameter("IH-2", "0, 1", "1")

    if frame_idx <= 0:
        h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x")
        h_r2 = independent_hyperparameter("IH-3", "1, 2, 4")
    else:
        h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x", "2")
        h_r2 = independent_hyperparameter("IH-3", "1, 2, 4", "1")

    if frame_idx <= 2:
        h_c1 = independent_hyperparameter("IH-1", "64, 128")
        h_c2 = independent_hyperparameter("IH-4", "64, 128")
        h_c3 = independent_hyperparameter("IH-5", "64, 128")
        h_c4 = independent_hyperparameter("IH-6", "64, 128")
        h_d = independent_hyperparameter("IH-7", "0.25, 0.5")
    else:
        h_c1 = independent_hyperparameter("IH-1", "64, 128", "64")
        h_c2 = independent_hyperparameter("IH-4", "64, 128", "128")
        h_c3 = independent_hyperparameter("IH-5", "64, 128", "128")
        h_c4 = independent_hyperparameter("IH-6", "64, 128", "64")
        h_d = independent_hyperparameter("IH-7", "0.25, 0.5", "0.5")

    def place_hyperp_right_of(h, m):
        y_p = stz.center_coords(m[3])[1]
        stz.align_centers_vertically([h], y_p)
        stz.place_to_the_right(h, m, spacing_between_module_and_hyperp)

    hyperparameters = []
    place_hyperp_right_of(h_c1, c1)
    if frame_idx in [0, 1]:
        place_hyperp_right_of(h_o, o)
        hyperparameters.append(h_o)

    if frame_idx == 0:
        place_hyperp_right_of(h_r1, r2)
        stz.place_above_and_align_to_the_right(h_r2, h_r1, 0.8)
        hyperparameters.extend([h_r1, h_r2, h_c1])
    else:
        place_hyperp_right_of(h_c1, c1)
        place_hyperp_right_of(h_c3, c3)
        place_hyperp_right_of(h_c4, c4)
        stz.place_below(h_c2, h_c1, 3.0)
        hyperparameters.extend([h_c1, h_c2, h_c3, h_c4])

        if frame_idx in [2, 3]:
            place_hyperp_right_of(h_d, d)
            hyperparameters.extend([h_d])

    unreachable_hyperps = []
    if frame_idx == 1:
        stz.distribute_vertically_with_spacing([h_r1, h_r2], 0.2)
        unreachable_hyperps.extend([h_r1, h_r2])
    if frame_idx >= 2:
        stz.distribute_vertically_with_spacing([h_o, h_r1, h_r2], 0.2)
        unreachable_hyperps.extend([h_r1, h_r2, h_o])
    hyperparameters.extend(unreachable_hyperps)

    cs_fn = lambda e: stz.coords_from_bbox_with_fn(e, stz.left_center_coords)
    if frame_idx == 0:

        stz.translate_bbox_left_center_to_coords(h_r2, cs_fn([h_o, h_r1]))
    elif frame_idx == 1:
        stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_o, h_c3]))
    else:
        stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_d, h_c3]))

    hyperp_connections = [
        connect_hyperp_to_module(h_c1, c1, 0),
    ]
    if frame_idx in [0, 1]:
        hyperp_connections.extend([connect_hyperp_to_module(h_o, o, 0)])
    if frame_idx == 0:
        hyperp_connections.extend([
            connect_hyperp_to_module(h_r1, r2, 0),
            connect_hyperp_to_module(h_r2, r1, 0),
            connect_hyperp_to_hyperp(h_r2, h_r1)
        ])
    else:
        hyperp_connections.extend([
            connect_hyperp_to_module(h_c2, c2, 0),
            connect_hyperp_to_module(h_c3, c3, 0),
            connect_hyperp_to_module(h_c4, c4, 0),
        ])
        if frame_idx in [2, 3]:
            hyperp_connections.append(connect_hyperp_to_module(h_d, d, 0))

    f = stz.rectangle_from_width_and_height([0, 0], frame_height, frame_width,
                                            frame_s_fmt)
    e = [modules, module_connections, hyperparameters, hyperp_connections]
    stz.translate_bbox_center_to_coords(
        f, stz.translate_coords_horizontally(stz.center_coords(e), 0.8))
    if len(unreachable_hyperps) > 0:
        stz.translate_bbox_bottom_right_to_coords(unreachable_hyperps,
                                                  stz.bbox(e)[1])

    # frame id
    s = ["a", "b", "c", "d"][frame_idx]
    label = [stz.latex([0, 0], "\\Huge \\textbf %s" % s)]
    stz.translate_bbox_top_left_to_coords(
        label,
        stz.translate_coords_antidiagonally(
            stz.coords_from_bbox_with_fn(f, stz.top_left_coords), 0.6))

    return e + [f, label]
コード例 #21
0
def node(label):
    return [stz.circle([0, 0], node_radius), stz.latex([0, 0], label)]