Ejemplo n.º 1
0
def figure_b():
    x1 = node("$x_1$")
    x2 = node("$x_2$")
    x3 = node("$x_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    c1 = small_node_relative(x1, x2)
    c2 = small_node_relative(x2, x3)
    c3 = small_node([0, 0])
    stz.place_to_the_left_and_align_to_the_center(
        c3, x1, horizontal_node_spacing / 2.0 - filled_node_radius)
    small_nodes = [c1, c2, c3]
    small_nodes.extend([small_node([0, 0]) for _ in range(3)])
    stz.distribute_centers_horizontally_with_spacing(
        small_nodes[-3:], horizontal_node_spacing + 2.0 * node_radius)
    stz.place_below_and_align_to_the_center(small_nodes[-3:], nodes,
                                            vertical_node_spacing)

    x1_x2 = connect_nodes(x1, x2)
    x2_x3 = connect_nodes(x2, x3)
    x1_c3 = connect_small_node_to_node(c3, x1)
    connections = [x1_x2, x2_x3, x1_c3]
    connections.extend([
        connect_small_node_to_node(small_nodes[-3 + i], nodes[i])
        for i in range(3)
    ])

    # for qs
    q1 = node("$q_1$")
    q2 = node("$q_2$")
    q3 = node("$q_3$")
    stz.distribute_horizontally_with_spacing([q1, q2, q3],
                                             horizontal_node_spacing)
    nodes.extend([q1, q2, q3])
    stz.place_above_and_align_to_the_center([q1, q2, q3], [x1, x2, x3],
                                            vertical_node_spacing)

    qc1 = small_node_relative(q1, q2)
    qc2 = small_node_relative(q2, q3)
    qc3 = small_node([0, 0])
    stz.place_to_the_left_and_align_to_the_center(
        qc3, q1, horizontal_node_spacing / 2.0 - filled_node_radius)
    small_nodes.extend([qc1, qc2, qc3])

    q1_q2 = connect_nodes(q1, q2)
    q2_q3 = connect_nodes(q2, q3)
    q1_qc3 = connect_small_node_to_node(qc3, q1)
    connections.extend([q1_q2, q2_q3, q1_qc3])
    connections.extend(
        [connect_nodes(e1, e2) for e1, e2 in zip([x1, x2, x3], [q1, q2, q3])])

    labels = [
        label_small_node("$p(q_1)$", qc3, -90.0, vd, rectangle_width),
        label_small_node("$p(q_2| q_3)$", qc2, -90.0, vd, rectangle_width),
    ]

    return [nodes, small_nodes, connections, labels]
Ejemplo n.º 2
0
def segment(length, label_str, left_tick_label_str, right_tick_label_str):
    seg = stz.line_segment([0, 0], [length, 0])
    left_tick = stz.centered_vertical_line_segment([0, 0], tick_length)
    right_tick = stz.centered_vertical_line_segment([length, 0], tick_length)
    left_tick_label = stz.latex([0, 0], left_tick_label_str)
    right_tick_label = stz.latex([0, 0], right_tick_label_str)
    stz.place_above_and_align_to_the_center(left_tick_label, left_tick,
                                            tick_label_spacing)
    stz.place_above_and_align_to_the_center(right_tick_label, right_tick,
                                            tick_label_spacing)
    seg_label = stz.latex([-label_spacing, 0], label_str)
    return [
        seg, left_tick, right_tick, left_tick_label, right_tick_label, seg_label
    ]
Ejemplo n.º 3
0
def figure_e():
    x1 = node("$T_1$")
    x2 = node("$T_2$")
    x3 = node("$T_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    c1 = small_node_relative(x1, x2)
    c2 = small_node_relative(x2, x3)
    c3 = small_node([0, 0])
    stz.place_to_the_left_and_align_to_the_center(
        c3, x1, horizontal_node_spacing / 2.0 - filled_node_radius)
    small_nodes = [c1, c2, c3]

    connections = [
        connect_nodes(x1, x2),
        connect_nodes(x2, x3),
        connect_small_node_to_node(c3, x1),
    ]

    x4 = node("$l_4$")
    x5 = node("$l_5$")
    stz.distribute_horizontally_with_spacing([x5, x4], horizontal_node_spacing)
    stz.place_above_and_align_to_the_center([x5, x4], [x2],
                                            vertical_node_spacing)
    nodes.extend([x4, x5])

    connections.extend([
        connect_nodes(x1, x5),
        connect_nodes(x2, x5),
        connect_nodes(x2, x4),
        connect_nodes(x3, x4),
        connect_nodes(x1, x4),
    ])

    small_nodes.extend([
        small_node_relative(x1, x5, 0.6),
        small_node_relative(x2, x5, 0.6),
        small_node_relative(x2, x4, 0.6),
        small_node_relative(x3, x4, 0.6),
        small_node_relative(x1, x4, 0.6),
    ])

    labels = [
        label_small_node("$f_{br}(T_1, l_5)$", small_nodes[-5], 135.0, hd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]
Ejemplo n.º 4
0
def figure_c():
    x1 = node("$x_1$")
    x2 = node("$x_2$")
    x3 = node("$x_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    c1 = small_node_relative(x1, x2)
    c2 = small_node_relative(x2, x3)
    small_nodes = [c1, c2]
    small_nodes.extend([small_node([0, 0]) for _ in range(3)])
    stz.distribute_centers_horizontally_with_spacing(
        small_nodes[-3:], horizontal_node_spacing + 2.0 * node_radius)
    stz.place_below_and_align_to_the_center(small_nodes[-3:], nodes,
                                            vertical_node_spacing)

    x1_x2 = connect_nodes(x1, x2)
    x2_x3 = connect_nodes(x2, x3)
    connections = [x1_x2, x2_x3]
    connections.extend([
        connect_small_node_to_node(small_nodes[-3 + i], nodes[i])
        for i in range(3)
    ])

    u1 = node("$u_1$")
    u2 = node("$u_2$")
    nodes.extend([u1, u2])
    for u, c in zip([u1, u2], [c1, c2]):
        cx = small_node([0, 0])
        stz.place_above_and_align_to_the_center(
            u, c, vertical_node_spacing - node_radius)
        stz.place_above_and_align_to_the_center(
            cx, u, vertical_node_spacing - node_radius)
        connections.extend([
            connect_small_node_to_node(c, u),
            connect_small_node_to_node(cx, u)
        ])
        small_nodes.append(cx)

    labels = [
        label_small_node("$J_x(x_1)$", small_nodes[2], 0.0, hd,
                         rectangle_width),
        label_small_node("$J_u(u_1)$", small_nodes[5], 0.0, hd,
                         rectangle_width),
        label_small_node("$p(x_3| x_2, u_2)$", small_nodes[1], -90.0, vd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]
Ejemplo n.º 5
0
def figure_f():
    x1 = node("$C_1$")
    x2 = node("$C_2$")
    x3 = node("$C_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    small_nodes = []
    connections = []

    x4 = node("$p_4$")
    x5 = node("$p_5$")
    stz.distribute_horizontally_with_spacing([x5, x4], horizontal_node_spacing)
    stz.place_above_and_align_to_the_center([x5, x4], [x2],
                                            vertical_node_spacing)
    nodes.extend([x4, x5])

    connections.extend([
        connect_nodes(x1, x5),
        connect_nodes(x2, x5),
        connect_nodes(x2, x4),
        connect_nodes(x3, x4),
        connect_nodes(x1, x4),
    ])

    small_nodes.extend([
        small_node_relative(x1, x5, 0.6),
        small_node_relative(x2, x5, 0.6),
        small_node_relative(x2, x4, 0.6),
        small_node_relative(x3, x4, 0.6),
        small_node_relative(x1, x4, 0.6),
    ])

    labels = [
        label_small_node("$f_{rp}(C_1, p_5)$", small_nodes[-5], 135.0, hd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]
Ejemplo n.º 6
0

alpha_c = stz.circle([0, 0], node_radius)
theta_c = stz.circle([0, 0], node_radius)
z_c = stz.circle([0, 0], node_radius)
w_c = stz.circle([0, 0], node_radius)
eta_c = stz.circle([0, 0], node_radius)
beta_c = stz.circle([0, 0], node_radius)

w_c["tikz_str"] = fmt.combine_tikz_strs(
    [w_c["tikz_str"], fmt.fill_color("gray")])

stz.distribute_horizontally_with_spacing([alpha_c, theta_c, z_c, w_c],
                                         node_spacing)
stz.distribute_horizontally_with_spacing([eta_c, beta_c], node_spacing)
stz.place_above_and_align_to_the_center(
    [eta_c, beta_c], [alpha_c, theta_c, z_c, w_c], node_spacing)

alpha_l = label_below(alpha_c, "$\\alpha$")
theta_l = label_below(theta_c, "$\\theta$")
z_l = label_below(z_c, "$z$")
w_l = label_below(w_c, "$w$")
eta_l = label_left(eta_c, "$\\eta$")
beta_l = label_right(beta_c, "$\\beta$")

connections = [
    connect_horizontally(alpha_c, theta_c),
    connect_horizontally(theta_c, z_c),
    connect_horizontally(z_c, w_c),
    connect_horizontally(eta_c, beta_c),
    connect_diagonally(beta_c, w_c),
]
Ejemplo n.º 7
0
def arrow_out(e):
    from_cs = stz.coords_from_bbox_with_fn(e, stz.top_center_coords)
    to_cs = stz.translate_coords_vertically(from_cs, arrow_length)
    return stz.line_segment(from_cs, to_cs, s_arr)


### left tower
ier = rectangle_with_text("emb_color", ["Input", "Embedding"])
cl = circle_with_plus()
sl = circle_with_sine()
mha1_an = rectangle_with_add_norm("multi_head_attention_color",
                                  ["Multi-Head", "Attention"])
ff1_an = rectangle_with_add_norm("ff_color", ["Feed", "Forward"])

stz.place_above_and_align_to_the_center(cl, ier, 0.4)
stz.place_to_the_left_and_align_to_the_center(sl, cl, 0.3)
stz.place_above_and_align_to_the_center(mha1_an, cl, 1.33)
stz.place_above_and_align_to_the_center(ff1_an, mha1_an, 1.0)

### right tower
oer = rectangle_with_text("emb_color", ["Output", "Embedding"])
cr = circle_with_plus()
sr = circle_with_sine()
mha2_an = rectangle_with_add_norm("multi_head_attention_color",
                                  ["Multi-Head", "Attention"])
mmha_an = rectangle_with_add_norm("multi_head_attention_color",
                                  ["Masked", "Multi-Head", "Attention"])
ff2_an = rectangle_with_add_norm("ff_color", ["Feed", "Forward"])
linear = rectangle_with_text("linear_color", ["Linear"])
softmax = rectangle_with_text("softmax_color", ["Softmax"])
Ejemplo n.º 8
0
    sq_fn("$h_%d^{(2)}$" % (i + 1, ), h_grey_rfmt if i != 2 else h_green_rfmt,
          h_grey_lfmt if i != 2 else "") for i in range(4)
]
m1, m2 = [rct_fn("$\\text{mem}^{(%d)}$" % (i + 1, ), m_fmt) for i in range(2)]
x_out = [sq_fn("$x_3$", x3_fmt)]

row1 = [m1] + xs
row2 = [m2] + hs1

stz.distribute_horizontally_with_spacing([m1] + xs, horizontal_spacing)
stz.distribute_horizontally_with_spacing([m2] + hs1, horizontal_spacing)
stz.distribute_horizontally_with_spacing(hs2, horizontal_spacing)
stz.distribute_vertically_with_spacing([row1, row2, hs2], vertical_spacing)
stz.align_rights([row1, row2, hs2], 0)

stz.place_above_and_align_to_the_center(x_out, hs2[2], vertical_spacing)

legend = stz.latex(
    [0, 0],
    "Factorization order: $3 \\rightarrow 2 \\rightarrow 4 \\rightarrow 1$")
stz.place_below_and_align_to_the_center(legend, xs, legend_spacing)

connections = [
    connect(m1, hs1[2]),
    connect(m2, hs2[2]),
    connect(hs1[2], hs2[2]),
    connect(hs2[2], x_out)
]

e = [row1, row2, hs2, x_out, connections, legend]
stz.draw_to_tikz_standalone(e, "xlnet.tex", name2color)