コード例 #1
0
def search_space_transition():

    e0 = frame(0)
    e1 = frame(1)
    e2 = frame(2)
    e3 = frame(3)
    e = [e0, e1, e2, e3]

    def get_idx(e_frame, indices):
        e = e_frame
        for idx in indices:
            e = e[idx]
        return e

    def highlight(e_frame, indices, idx, color):
        e = get_idx(e_frame, indices)
        s_fmt = fmt.combine_tikz_strs([e["tikz_str"], fmt.fill_color(color)])
        e['tikz_str'] = s_fmt

    # highlight new modules
    highlight(e1, [0, 2, 0, 0], 0, "light_green_2")
    highlight(e1, [0, 3, 0, 0], 0, "light_green_2")
    highlight(e1, [0, 4, 0, 0], 0, "light_green_2")
    highlight(e2, [0, 1, 0, 0], 0, "light_green_2")

    # highlight new hyperparameters
    highlight(e1, [2, 2, 0], 0, "light_green_2")
    highlight(e1, [2, 3, 0], 0, "light_green_2")
    highlight(e1, [2, 4, 0], 0, "light_green_2")
    highlight(e2, [2, 4, 0], 0, "light_green_2")

    # highlight assigned hyperparameters
    highlight(e1, [2, 5, 0], 0, "light_red_2")
    highlight(e1, [2, 6, 0], 0, "light_red_2")
    highlight(e2, [2, 7, 0], 0, "light_red_2")
    highlight(e3, [2, 0, 0], 0, "light_red_2")
    highlight(e3, [2, 1, 0], 0, "light_red_2")
    highlight(e3, [2, 2, 0], 0, "light_red_2")
    highlight(e3, [2, 3, 0], 0, "light_red_2")
    highlight(e3, [2, 4, 0], 0, "light_red_2")

    # arrange the four frames
    stz.align_tops(e, 0.0)
    stz.distribute_horizontally_with_spacing([e0, e1], frame_spacing)
    stz.distribute_horizontally_with_spacing([e2, e3], frame_spacing)
    stz.distribute_vertically_with_spacing([[e2, e3], [e0, e1]], frame_spacing)

    stz.draw_to_tikz_standalone(e, "deep_architect.tex", name2color)
コード例 #2
0
                                         node_spacing)
stz.distribute_horizontally_with_spacing([eta_c, beta_c], node_spacing)
stz.place_above_and_align_to_the_center(
    [eta_c, beta_c], [alpha_c, theta_c, z_c, w_c], node_spacing)

alpha_l = label_below(alpha_c, "$\\alpha$")
theta_l = label_below(theta_c, "$\\theta$")
z_l = label_below(z_c, "$z$")
w_l = label_below(w_c, "$w$")
eta_l = label_left(eta_c, "$\\eta$")
beta_l = label_right(beta_c, "$\\beta$")

connections = [
    connect_horizontally(alpha_c, theta_c),
    connect_horizontally(theta_c, z_c),
    connect_horizontally(z_c, w_c),
    connect_horizontally(eta_c, beta_c),
    connect_diagonally(beta_c, w_c),
]

p1 = plate([z_c, w_c, z_l, w_l], "$N$")
p2 = plate([theta_c, theta_l, p1], "$M$")
p3 = plate([beta_c, beta_l], "$k$")

e = [
    alpha_c, theta_c, z_c, w_c, eta_c, beta_c, alpha_l, theta_l, z_l, w_l,
    eta_l, beta_l, connections, p1, p2, p3
]

stz.draw_to_tikz_standalone(e, "lda.tex")
コード例 #3
0
ファイル: transformer.py プロジェクト: sidney1994/sane_tikz
    spacing_between_text_and_arrows)
a3 = stz.latex(cs, lst2str(["Output", "Probabilities"]), s_fmt)

# side text
spacing = 0.2
cs = stz.coords_from_bbox_with_fn(bb1, stz.left_center_coords)
a4 = stz.latex([cs[0] - spacing, cs[1]], "$N\\times$",
               fmt.anchor("right_center"))
cs = stz.coords_from_bbox_with_fn(bb2, stz.right_center_coords)
a5 = stz.latex([cs[0] + spacing, cs[1]], "$N\\times$",
               fmt.anchor("left_center"))

s_fn = lambda side: fmt.combine_tikz_strs(
    [fmt.text_width(2.0), fmt.anchor(side + "_center")])
cs = stz.coords_from_bbox_with_fn(sl, stz.left_center_coords)
a6 = stz.latex([cs[0] + 0.3, cs[1]], lst2str(["Positional", "Encoding"]),
               s_fn("right"))
cs = stz.coords_from_bbox_with_fn(sr, stz.right_center_coords)
a7 = stz.latex([cs[0] + 0.2, cs[1]], lst2str(["Positional", "Encoding"]),
               s_fn("left"))

annotations = [a1, a2, a3, a4, a5, a6, a7]

# all
e = [
    bb1, bb2, ier, oer, mha1_an, mha2_an, mmha_an, ff1_an, ff2_an, linear,
    softmax, cl, cr, sl, sr, connections, arrows, annotations
]

stz.draw_to_tikz_standalone(e, "transformer.tex", name2color)
コード例 #4
0
e = stz.closed_path(cs)

origin_cs = stz.translate_coords_horizontally(cs[2], -1.0)
x_end_cs = stz.translate_coords_horizontally(origin_cs, x_axis_length)
y_end_cs = stz.translate_coords_vertically(origin_cs, y_axis_length)
x_start_cs = stz.translate_coords_horizontally(origin_cs, -extra_length)
y_start_cs = stz.translate_coords_vertically(origin_cs, -extra_length)
x_label_cs = stz.translate_coords_vertically(x_end_cs, -label_spacing)
y_label_cs = stz.translate_coords_horizontally(y_end_cs, -label_spacing)
origin_label_cs = stz.translate_coords_diagonally(origin_cs, -label_spacing)

axes = [
    stz.line_segment(x_start_cs, x_end_cs, s_fmt),
    stz.line_segment(y_start_cs, y_end_cs, s_fmt)
]

labels = [
    stz.latex([cs[0][0], cs[0][1] + label_spacing], "$C$"),
    stz.latex([cs[1][0] - label_spacing, cs[1][1]], "$B$"),
    stz.latex([cs[2][0], cs[2][1] - a_circle_radius - label_spacing],
              "$A(1, 0)$"),
    stz.latex([cs[3][0], cs[3][1] - label_spacing], "$E$"),
    stz.latex([cs[4][0] + label_spacing, cs[4][1]], "$D$"),
    stz.circle(cs[2], a_circle_radius, f_fmt),
    stz.latex(x_label_cs, "$x$"),
    stz.latex(y_label_cs, "$y$"),
    stz.latex(origin_label_cs, "$O$"),
]

stz.draw_to_tikz_standalone([e, labels, axes], "pentagon.tex")
コード例 #5
0
tick_label_spacing = 0.25
tick_length = 0.25
segment_spacing = 0.6
length_multiplier = 0.8


def segment(length, label_str, left_tick_label_str, right_tick_label_str):
    seg = stz.line_segment([0, 0], [length, 0])
    left_tick = stz.centered_vertical_line_segment([0, 0], tick_length)
    right_tick = stz.centered_vertical_line_segment([length, 0], tick_length)
    left_tick_label = stz.latex([0, 0], left_tick_label_str)
    right_tick_label = stz.latex([0, 0], right_tick_label_str)
    stz.place_above_and_align_to_the_center(left_tick_label, left_tick,
                                            tick_label_spacing)
    stz.place_above_and_align_to_the_center(right_tick_label, right_tick,
                                            tick_label_spacing)
    seg_label = stz.latex([-label_spacing, 0], label_str)
    return [
        seg, left_tick, right_tick, left_tick_label, right_tick_label, seg_label
    ]


segs = [
    segment(length_multiplier * 5, "A", "0", "5"),
    segment(length_multiplier * 1, "B", "2", "3"),
    segment(length_multiplier * 3, "C", "1", "4")
]

stz.distribute_vertically_with_spacing(segs[::-1], segment_spacing)
stz.draw_to_tikz_standalone(segs, "segments.tex")
コード例 #6
0
    connect(nodes[0], nodes[2]),
    connect(nodes[1], nodes[3]),
    connect(nodes[1], nodes[4], "blue"),
    connect(nodes[2], nodes[5]),
    connect(nodes[4], nodes[6], "blue"),
    connect(nodes[4], nodes[7]),
    connect(nodes[5], nodes[8]),
]

nodes[-3][0]["tikz_str"] = fmt.combine_tikz_strs(
    [nodes[-3][0]["tikz_str"],
     fmt.line_and_fill_colors("mygreen", "mygreen")])
nodes[-3][1]["tikz_str"] = fmt.combine_tikz_strs(
    [nodes[-3][0]["tikz_str"], "text=white"])

bb1 = dashed_bbox([nodes[6]])
bb2 = dashed_bbox([bb1, nodes[4], nodes[7]])
bb3 = dashed_bbox([bb2, nodes[1], nodes[3]])
bboxes = [bb1, bb2, bb3]

labels = [
    label_left(nodes[0], "4 < 8"),
    label_left(nodes[1], "4 > 3"),
    label_left(nodes[4], "4 < 6"),
]

name2color = {"mygreen": (2, 129, 0)}

e = [nodes, connections, bboxes, labels]
stz.draw_to_tikz_standalone(e, "tree.tex", name2color)
コード例 #7
0
ファイル: boxes.py プロジェクト: mariemkayagithub/sane_tikz
    cs = [0, 0]
    rs = [
        stz.rectangle_from_width_and_height(cs, rectangle_height, x, r_fmt)
        for x in data
    ]

    stz.distribute_vertically_with_spacing(rs, rectangle_spacing)
    frame_height = (len(data) * rectangle_height +
                    (len(data) - 1) * rectangle_spacing + 2.0 * margin)
    frame = stz.rectangle_from_width_and_height(cs, frame_height, frame_width)
    stz.align_centers_vertically([frame, rs], 0.0)

    labels = []
    for i, r in enumerate(rs):
        l_cs = stz.coords_from_bbox_with_fn(r, stz.left_center_coords)
        l_cs = stz.translate_coords_horizontally(l_cs, -label_spacing)
        l_s = label_strs[i]
        lab = stz.latex(l_cs, l_s, fmt.anchor('right_center'))
        labels.append(lab)

    ticks = get_ticks(frame)
    return [rs, labels, frame, ticks]


e_large = bar_plot(large_data, large_label_strs)
e_small = bar_plot(small_data, small_label_strs)
stz.distribute_horizontally_with_spacing([e_small, e_large], frame_spacing)
stz.align_bottoms([e_large, e_small], 0)

stz.draw_to_tikz_standalone([e_small, e_large], "boxes.tex")
コード例 #8
0
    "E",
    "F",
]


def box_with_text(s):
    s_fmt = fmt.rounded_corners(box_roundness)
    return [
        stz.rectangle([0, 0], [box_width, -box_height], s_fmt),
        stz.latex([box_width / 2.0, -box_height / 2.0], s)
    ]


def arrow():
    s_fmt = fmt.fill_color_with_no_line("lightgray")
    a = stz.arrow(shaft_width, shaft_height, head_width, head_height, -90.0,
                  s_fmt)
    return a


boxes = [box_with_text(s) for s in lst]
stz.distribute_vertically_with_spacing(boxes[::-1], box_spacing)

arrows = [arrow() for _ in range(len(lst) - 1)]
for i in range(len(boxes) - 1):
    to_cs = stz.center_coords(boxes[i:i + 2])
    from_cs = stz.center_coords(arrows[i])
    stz.translate_to_coords(arrows[i], from_cs, to_cs)

stz.draw_to_tikz_standalone([boxes, arrows], "flowchart.tex")
コード例 #9
0
ファイル: xlnet.py プロジェクト: sidney1994/sane_tikz
          h_grey_lfmt if i != 2 else "") for i in range(4)
]
m1, m2 = [rct_fn("$\\text{mem}^{(%d)}$" % (i + 1, ), m_fmt) for i in range(2)]
x_out = [sq_fn("$x_3$", x3_fmt)]

row1 = [m1] + xs
row2 = [m2] + hs1

stz.distribute_horizontally_with_spacing([m1] + xs, horizontal_spacing)
stz.distribute_horizontally_with_spacing([m2] + hs1, horizontal_spacing)
stz.distribute_horizontally_with_spacing(hs2, horizontal_spacing)
stz.distribute_vertically_with_spacing([row1, row2, hs2], vertical_spacing)
stz.align_rights([row1, row2, hs2], 0)

stz.place_above_and_align_to_the_center(x_out, hs2[2], vertical_spacing)

legend = stz.latex(
    [0, 0],
    "Factorization order: $3 \\rightarrow 2 \\rightarrow 4 \\rightarrow 1$")
stz.place_below_and_align_to_the_center(legend, xs, legend_spacing)

connections = [
    connect(m1, hs1[2]),
    connect(m2, hs2[2]),
    connect(hs1[2], hs2[2]),
    connect(hs2[2], x_out)
]

e = [row1, row2, hs2, x_out, connections, legend]
stz.draw_to_tikz_standalone(e, "xlnet.tex", name2color)
コード例 #10
0
    small_nodes.extend([
        small_node_relative(x1, x5, 0.6),
        small_node_relative(x2, x5, 0.6),
        small_node_relative(x2, x4, 0.6),
        small_node_relative(x3, x4, 0.6),
        small_node_relative(x1, x4, 0.6),
    ])

    labels = [
        label_small_node("$f_{rp}(C_1, p_5)$", small_nodes[-5], 135.0, hd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]


e = [
    figure_a(),
    figure_b(),
    figure_c(),
    figure_d(),
    figure_e(),
    figure_f(),
]
stz.distribute_vertically_with_spacing(e[:3][::-1], 2.0)
stz.distribute_vertically_with_spacing(e[3:][::-1], 2.0)
stz.distribute_horizontally_with_spacing([e[:3], e[3:]], 2.0)

stz.draw_to_tikz_standalone(e, "factor_graphs.tex")