Exemplo n.º 1
0
def module(module_name,
           input_names,
           output_names,
           hyperp_names,
           p_width_scale=1.0):

    i_lst = [input(s) for s in input_names]
    o_lst = [output(s) for s in output_names]
    m = stz.rectangle([0, 0], [module_width, -module_height], module_s_fmt)
    l = stz.latex(stz.center_coords(m), "\\textbf{%s}" % module_name)

    stz.distribute_horizontally_with_spacing(i_lst, io_spacing)
    stz.translate_bbox_top_left_to_coords(
        i_lst, [module_inner_vertical_spacing, -module_inner_vertical_spacing])
    stz.distribute_horizontally_with_spacing(o_lst, io_spacing)
    stz.translate_bbox_bottom_left_to_coords(o_lst, [
        module_inner_vertical_spacing,
        -module_height + module_inner_vertical_spacing
    ])

    if len(hyperp_names) > 0:
        h_lst = [property(s, p_width_scale) for s in hyperp_names]
        stz.distribute_vertically_with_spacing(h_lst, p_spacing)
        stz.translate_bbox_top_right_to_coords(h_lst, [
            module_width - module_inner_vertical_spacing,
            -module_inner_vertical_spacing - delta_increment
        ])
        return [[m, l], i_lst, o_lst, h_lst]
    else:
        return [[m, l], i_lst, o_lst]
Exemplo n.º 2
0
def figure_b():
    x1 = node("$x_1$")
    x2 = node("$x_2$")
    x3 = node("$x_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    c1 = small_node_relative(x1, x2)
    c2 = small_node_relative(x2, x3)
    c3 = small_node([0, 0])
    stz.place_to_the_left_and_align_to_the_center(
        c3, x1, horizontal_node_spacing / 2.0 - filled_node_radius)
    small_nodes = [c1, c2, c3]
    small_nodes.extend([small_node([0, 0]) for _ in range(3)])
    stz.distribute_centers_horizontally_with_spacing(
        small_nodes[-3:], horizontal_node_spacing + 2.0 * node_radius)
    stz.place_below_and_align_to_the_center(small_nodes[-3:], nodes,
                                            vertical_node_spacing)

    x1_x2 = connect_nodes(x1, x2)
    x2_x3 = connect_nodes(x2, x3)
    x1_c3 = connect_small_node_to_node(c3, x1)
    connections = [x1_x2, x2_x3, x1_c3]
    connections.extend([
        connect_small_node_to_node(small_nodes[-3 + i], nodes[i])
        for i in range(3)
    ])

    # for qs
    q1 = node("$q_1$")
    q2 = node("$q_2$")
    q3 = node("$q_3$")
    stz.distribute_horizontally_with_spacing([q1, q2, q3],
                                             horizontal_node_spacing)
    nodes.extend([q1, q2, q3])
    stz.place_above_and_align_to_the_center([q1, q2, q3], [x1, x2, x3],
                                            vertical_node_spacing)

    qc1 = small_node_relative(q1, q2)
    qc2 = small_node_relative(q2, q3)
    qc3 = small_node([0, 0])
    stz.place_to_the_left_and_align_to_the_center(
        qc3, q1, horizontal_node_spacing / 2.0 - filled_node_radius)
    small_nodes.extend([qc1, qc2, qc3])

    q1_q2 = connect_nodes(q1, q2)
    q2_q3 = connect_nodes(q2, q3)
    q1_qc3 = connect_small_node_to_node(qc3, q1)
    connections.extend([q1_q2, q2_q3, q1_qc3])
    connections.extend(
        [connect_nodes(e1, e2) for e1, e2 in zip([x1, x2, x3], [q1, q2, q3])])

    labels = [
        label_small_node("$p(q_1)$", qc3, -90.0, vd, rectangle_width),
        label_small_node("$p(q_2| q_3)$", qc2, -90.0, vd, rectangle_width),
    ]

    return [nodes, small_nodes, connections, labels]
Exemplo n.º 3
0
def figure_e():
    x1 = node("$T_1$")
    x2 = node("$T_2$")
    x3 = node("$T_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    c1 = small_node_relative(x1, x2)
    c2 = small_node_relative(x2, x3)
    c3 = small_node([0, 0])
    stz.place_to_the_left_and_align_to_the_center(
        c3, x1, horizontal_node_spacing / 2.0 - filled_node_radius)
    small_nodes = [c1, c2, c3]

    connections = [
        connect_nodes(x1, x2),
        connect_nodes(x2, x3),
        connect_small_node_to_node(c3, x1),
    ]

    x4 = node("$l_4$")
    x5 = node("$l_5$")
    stz.distribute_horizontally_with_spacing([x5, x4], horizontal_node_spacing)
    stz.place_above_and_align_to_the_center([x5, x4], [x2],
                                            vertical_node_spacing)
    nodes.extend([x4, x5])

    connections.extend([
        connect_nodes(x1, x5),
        connect_nodes(x2, x5),
        connect_nodes(x2, x4),
        connect_nodes(x3, x4),
        connect_nodes(x1, x4),
    ])

    small_nodes.extend([
        small_node_relative(x1, x5, 0.6),
        small_node_relative(x2, x5, 0.6),
        small_node_relative(x2, x4, 0.6),
        small_node_relative(x3, x4, 0.6),
        small_node_relative(x1, x4, 0.6),
    ])

    labels = [
        label_small_node("$f_{br}(T_1, l_5)$", small_nodes[-5], 135.0, hd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]
Exemplo n.º 4
0
def figure_c():
    x1 = node("$x_1$")
    x2 = node("$x_2$")
    x3 = node("$x_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    c1 = small_node_relative(x1, x2)
    c2 = small_node_relative(x2, x3)
    small_nodes = [c1, c2]
    small_nodes.extend([small_node([0, 0]) for _ in range(3)])
    stz.distribute_centers_horizontally_with_spacing(
        small_nodes[-3:], horizontal_node_spacing + 2.0 * node_radius)
    stz.place_below_and_align_to_the_center(small_nodes[-3:], nodes,
                                            vertical_node_spacing)

    x1_x2 = connect_nodes(x1, x2)
    x2_x3 = connect_nodes(x2, x3)
    connections = [x1_x2, x2_x3]
    connections.extend([
        connect_small_node_to_node(small_nodes[-3 + i], nodes[i])
        for i in range(3)
    ])

    u1 = node("$u_1$")
    u2 = node("$u_2$")
    nodes.extend([u1, u2])
    for u, c in zip([u1, u2], [c1, c2]):
        cx = small_node([0, 0])
        stz.place_above_and_align_to_the_center(
            u, c, vertical_node_spacing - node_radius)
        stz.place_above_and_align_to_the_center(
            cx, u, vertical_node_spacing - node_radius)
        connections.extend([
            connect_small_node_to_node(c, u),
            connect_small_node_to_node(cx, u)
        ])
        small_nodes.append(cx)

    labels = [
        label_small_node("$J_x(x_1)$", small_nodes[2], 0.0, hd,
                         rectangle_width),
        label_small_node("$J_u(u_1)$", small_nodes[5], 0.0, hd,
                         rectangle_width),
        label_small_node("$p(x_3| x_2, u_2)$", small_nodes[1], -90.0, vd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]
Exemplo n.º 5
0
def search_space_transition():

    e0 = frame(0)
    e1 = frame(1)
    e2 = frame(2)
    e3 = frame(3)
    e = [e0, e1, e2, e3]

    def get_idx(e_frame, indices):
        e = e_frame
        for idx in indices:
            e = e[idx]
        return e

    def highlight(e_frame, indices, idx, color):
        e = get_idx(e_frame, indices)
        s_fmt = fmt.combine_tikz_strs([e["tikz_str"], fmt.fill_color(color)])
        e['tikz_str'] = s_fmt

    # highlight new modules
    highlight(e1, [0, 2, 0, 0], 0, "light_green_2")
    highlight(e1, [0, 3, 0, 0], 0, "light_green_2")
    highlight(e1, [0, 4, 0, 0], 0, "light_green_2")
    highlight(e2, [0, 1, 0, 0], 0, "light_green_2")

    # highlight new hyperparameters
    highlight(e1, [2, 2, 0], 0, "light_green_2")
    highlight(e1, [2, 3, 0], 0, "light_green_2")
    highlight(e1, [2, 4, 0], 0, "light_green_2")
    highlight(e2, [2, 4, 0], 0, "light_green_2")

    # highlight assigned hyperparameters
    highlight(e1, [2, 5, 0], 0, "light_red_2")
    highlight(e1, [2, 6, 0], 0, "light_red_2")
    highlight(e2, [2, 7, 0], 0, "light_red_2")
    highlight(e3, [2, 0, 0], 0, "light_red_2")
    highlight(e3, [2, 1, 0], 0, "light_red_2")
    highlight(e3, [2, 2, 0], 0, "light_red_2")
    highlight(e3, [2, 3, 0], 0, "light_red_2")
    highlight(e3, [2, 4, 0], 0, "light_red_2")

    # arrange the four frames
    stz.align_tops(e, 0.0)
    stz.distribute_horizontally_with_spacing([e0, e1], frame_spacing)
    stz.distribute_horizontally_with_spacing([e2, e3], frame_spacing)
    stz.distribute_vertically_with_spacing([[e2, e3], [e0, e1]], frame_spacing)

    stz.draw_to_tikz_standalone(e, "deep_architect.tex", name2color)
Exemplo n.º 6
0
def figure_f():
    x1 = node("$C_1$")
    x2 = node("$C_2$")
    x3 = node("$C_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    small_nodes = []
    connections = []

    x4 = node("$p_4$")
    x5 = node("$p_5$")
    stz.distribute_horizontally_with_spacing([x5, x4], horizontal_node_spacing)
    stz.place_above_and_align_to_the_center([x5, x4], [x2],
                                            vertical_node_spacing)
    nodes.extend([x4, x5])

    connections.extend([
        connect_nodes(x1, x5),
        connect_nodes(x2, x5),
        connect_nodes(x2, x4),
        connect_nodes(x3, x4),
        connect_nodes(x1, x4),
    ])

    small_nodes.extend([
        small_node_relative(x1, x5, 0.6),
        small_node_relative(x2, x5, 0.6),
        small_node_relative(x2, x4, 0.6),
        small_node_relative(x3, x4, 0.6),
        small_node_relative(x1, x4, 0.6),
    ])

    labels = [
        label_small_node("$f_{rp}(C_1, p_5)$", small_nodes[-5], 135.0, hd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]
Exemplo n.º 7
0
def figure_a():
    x1 = node("$x_1$")
    x2 = node("$x_2$")
    x3 = node("$x_3$")
    stz.distribute_horizontally_with_spacing([x1, x2, x3],
                                             horizontal_node_spacing)
    nodes = [x1, x2, x3]

    c1 = small_node_relative(x1, x2)
    c2 = small_node_relative(x2, x3)
    c3 = small_node([0, 0])
    stz.place_to_the_left_and_align_to_the_center(
        c3, x1, horizontal_node_spacing / 2.0 - filled_node_radius)
    small_nodes = [c1, c2, c3]
    small_nodes.extend([small_node([0, 0]) for _ in range(3)])
    stz.distribute_centers_horizontally_with_spacing(
        small_nodes[-3:], horizontal_node_spacing + 2.0 * node_radius)
    stz.place_below_and_align_to_the_center(small_nodes[-3:], nodes,
                                            vertical_node_spacing)

    x1_x2 = connect_nodes(x1, x2)
    x2_x3 = connect_nodes(x2, x3)
    x1_c3 = connect_small_node_to_node(c3, x1)
    connections = [x1_x2, x2_x3, x1_c3]
    connections.extend([
        connect_small_node_to_node(small_nodes[-3 + i], nodes[i])
        for i in range(3)
    ])

    labels = [
        label_small_node("$p(x_1)$", c3, 90.0, vd, rectangle_width),
        label_small_node("$l(x_1; z_1)$", small_nodes[3], 0.0, hd,
                         rectangle_width),
        label_small_node("$p(x_3 | x_2)$", c2, 90.0, vd, rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]
Exemplo n.º 8
0
        stz.bbox(r)[1], -plate_label_spacing)
    l = stz.latex(l_cs, expr)
    return [r, l]


alpha_c = stz.circle([0, 0], node_radius)
theta_c = stz.circle([0, 0], node_radius)
z_c = stz.circle([0, 0], node_radius)
w_c = stz.circle([0, 0], node_radius)
eta_c = stz.circle([0, 0], node_radius)
beta_c = stz.circle([0, 0], node_radius)

w_c["tikz_str"] = fmt.combine_tikz_strs(
    [w_c["tikz_str"], fmt.fill_color("gray")])

stz.distribute_horizontally_with_spacing([alpha_c, theta_c, z_c, w_c],
                                         node_spacing)
stz.distribute_horizontally_with_spacing([eta_c, beta_c], node_spacing)
stz.place_above_and_align_to_the_center([eta_c, beta_c],
                                        [alpha_c, theta_c, z_c, w_c],
                                        node_spacing)

alpha_l = label_below(alpha_c, "$\\alpha$")
theta_l = label_below(theta_c, "$\\theta$")
z_l = label_below(z_c, "$z$")
w_l = label_below(w_c, "$w$")
eta_l = label_left(eta_c, "$\\eta$")
beta_l = label_right(beta_c, "$\\beta$")

connections = [
    connect_horizontally(alpha_c, theta_c),
    connect_horizontally(theta_c, z_c),
Exemplo n.º 9
0
def frame(frame_idx):
    assert frame_idx >= 0 and frame_idx <= 3
    c1 = conv2d(1)
    o = optional(1)
    r1 = repeat(1)
    r2 = repeat(2)
    cc = concat(1)
    c2 = conv2d(2)
    c3 = conv2d(3)
    c4 = conv2d(4)
    d = dropout(1)

    stz.distribute_horizontally_with_spacing([r1, r2],
                                             horizontal_module_spacing)
    stz.distribute_horizontally_with_spacing([c2, [c3, c4]],
                                             horizontal_module_spacing)

    modules = []
    if frame_idx == 0:
        stz.distribute_vertically_with_spacing([cc, [r1, r2], o, c1],
                                               vertical_module_spacing)

        stz.align_centers_horizontally([cc, [r1, r2], o, c1], 0)
        modules.extend([c1, o, r1, r2, cc])

    else:
        stz.distribute_vertically_with_spacing([c4, c3],
                                               vertical_module_spacing)
        stz.distribute_horizontally_with_spacing([c2, [c3, c4]],
                                                 horizontal_module_spacing)
        stz.align_centers_vertically([[c3, c4], c2], 0)

        if frame_idx == 1:
            stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], o, c1],
                                                   vertical_module_spacing)
            stz.align_centers_horizontally([cc, [c2, c3, c4], o, c1], 0)
            modules.extend([c1, o, c2, c3, c4, cc])

        else:
            stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], d, c1],
                                                   vertical_module_spacing)
            stz.align_centers_horizontally([cc, [c2, c3, c4], d, c1], 0)
            modules.extend([c1, d, c2, c3, c4, cc])

    module_connections = []
    if frame_idx == 0:
        module_connections.extend([
            connect_modules(c1, o, 0, 0),
            connect_modules(o, r1, 0, 0),
            connect_modules(o, r2, 0, 0),
            connect_modules(r1, cc, 0, 0),
            connect_modules(r2, cc, 0, 1),
        ])

    else:
        if frame_idx == 1:
            module_connections.extend([
                connect_modules(c1, o, 0, 0),
                connect_modules(o, c2, 0, 0),
                connect_modules(o, c3, 0, 0),
            ])
        else:
            module_connections.extend([
                connect_modules(c1, d, 0, 0),
                connect_modules(d, c2, 0, 0),
                connect_modules(d, c3, 0, 0),
            ])

        module_connections.extend([
            connect_modules(c3, c4, 0, 0),
            connect_modules(c2, cc, 0, 0),
            connect_modules(c4, cc, 0, 1),
        ])

    # # hyperparameters
    if frame_idx <= 1:
        h_o = independent_hyperparameter("IH-2", "0, 1")
    else:
        h_o = independent_hyperparameter("IH-2", "0, 1", "1")

    if frame_idx <= 0:
        h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x")
        h_r2 = independent_hyperparameter("IH-3", "1, 2, 4")
    else:
        h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x", "2")
        h_r2 = independent_hyperparameter("IH-3", "1, 2, 4", "1")

    if frame_idx <= 2:
        h_c1 = independent_hyperparameter("IH-1", "64, 128")
        h_c2 = independent_hyperparameter("IH-4", "64, 128")
        h_c3 = independent_hyperparameter("IH-5", "64, 128")
        h_c4 = independent_hyperparameter("IH-6", "64, 128")
        h_d = independent_hyperparameter("IH-7", "0.25, 0.5")
    else:
        h_c1 = independent_hyperparameter("IH-1", "64, 128", "64")
        h_c2 = independent_hyperparameter("IH-4", "64, 128", "128")
        h_c3 = independent_hyperparameter("IH-5", "64, 128", "128")
        h_c4 = independent_hyperparameter("IH-6", "64, 128", "64")
        h_d = independent_hyperparameter("IH-7", "0.25, 0.5", "0.5")

    def place_hyperp_right_of(h, m):
        y_p = stz.center_coords(m[3])[1]
        stz.align_centers_vertically([h], y_p)
        stz.place_to_the_right(h, m, spacing_between_module_and_hyperp)

    hyperparameters = []
    place_hyperp_right_of(h_c1, c1)
    if frame_idx in [0, 1]:
        place_hyperp_right_of(h_o, o)
        hyperparameters.append(h_o)

    if frame_idx == 0:
        place_hyperp_right_of(h_r1, r2)
        stz.place_above_and_align_to_the_right(h_r2, h_r1, 0.8)
        hyperparameters.extend([h_r1, h_r2, h_c1])
    else:
        place_hyperp_right_of(h_c1, c1)
        place_hyperp_right_of(h_c3, c3)
        place_hyperp_right_of(h_c4, c4)
        stz.place_below(h_c2, h_c1, 3.0)
        hyperparameters.extend([h_c1, h_c2, h_c3, h_c4])

        if frame_idx in [2, 3]:
            place_hyperp_right_of(h_d, d)
            hyperparameters.extend([h_d])

    unreachable_hyperps = []
    if frame_idx == 1:
        stz.distribute_vertically_with_spacing([h_r1, h_r2], 0.2)
        unreachable_hyperps.extend([h_r1, h_r2])
    if frame_idx >= 2:
        stz.distribute_vertically_with_spacing([h_o, h_r1, h_r2], 0.2)
        unreachable_hyperps.extend([h_r1, h_r2, h_o])
    hyperparameters.extend(unreachable_hyperps)

    cs_fn = lambda e: stz.coords_from_bbox_with_fn(e, stz.left_center_coords)
    if frame_idx == 0:

        stz.translate_bbox_left_center_to_coords(h_r2, cs_fn([h_o, h_r1]))
    elif frame_idx == 1:
        stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_o, h_c3]))
    else:
        stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_d, h_c3]))

    hyperp_connections = [
        connect_hyperp_to_module(h_c1, c1, 0),
    ]
    if frame_idx in [0, 1]:
        hyperp_connections.extend([connect_hyperp_to_module(h_o, o, 0)])
    if frame_idx == 0:
        hyperp_connections.extend([
            connect_hyperp_to_module(h_r1, r2, 0),
            connect_hyperp_to_module(h_r2, r1, 0),
            connect_hyperp_to_hyperp(h_r2, h_r1)
        ])
    else:
        hyperp_connections.extend([
            connect_hyperp_to_module(h_c2, c2, 0),
            connect_hyperp_to_module(h_c3, c3, 0),
            connect_hyperp_to_module(h_c4, c4, 0),
        ])
        if frame_idx in [2, 3]:
            hyperp_connections.append(connect_hyperp_to_module(h_d, d, 0))

    f = stz.rectangle_from_width_and_height([0, 0], frame_height, frame_width,
                                            frame_s_fmt)
    e = [modules, module_connections, hyperparameters, hyperp_connections]
    stz.translate_bbox_center_to_coords(
        f, stz.translate_coords_horizontally(stz.center_coords(e), 0.8))
    if len(unreachable_hyperps) > 0:
        stz.translate_bbox_bottom_right_to_coords(unreachable_hyperps,
                                                  stz.bbox(e)[1])

    # frame id
    s = ["a", "b", "c", "d"][frame_idx]
    label = [stz.latex([0, 0], "\\Huge \\textbf %s" % s)]
    stz.translate_bbox_top_left_to_coords(
        label,
        stz.translate_coords_antidiagonally(
            stz.coords_from_bbox_with_fn(f, stz.top_left_coords), 0.6))

    return e + [f, label]
Exemplo n.º 10
0
    cs = [0, 0]
    rs = [
        stz.rectangle_from_width_and_height(cs, rectangle_height, x, r_fmt)
        for x in data
    ]

    stz.distribute_vertically_with_spacing(rs, rectangle_spacing)
    frame_height = (len(data) * rectangle_height +
                    (len(data) - 1) * rectangle_spacing + 2.0 * margin)
    frame = stz.rectangle_from_width_and_height(cs, frame_height, frame_width)
    stz.align_centers_vertically([frame, rs], 0.0)

    labels = []
    for i, r in enumerate(rs):
        l_cs = stz.coords_from_bbox_with_fn(r, stz.left_center_coords)
        l_cs = stz.translate_coords_horizontally(l_cs, -label_spacing)
        l_s = label_strs[i]
        lab = stz.latex(l_cs, l_s, fmt.anchor('right_center'))
        labels.append(lab)

    ticks = get_ticks(frame)
    return [rs, labels, frame, ticks]


e_large = bar_plot(large_data, large_label_strs)
e_small = bar_plot(small_data, small_label_strs)
stz.distribute_horizontally_with_spacing([e_small, e_large], frame_spacing)
stz.align_bottoms([e_large, e_small], 0)

stz.draw_to_tikz_standalone([e_small, e_large], "boxes.tex")
Exemplo n.º 11
0
    small_nodes.extend([
        small_node_relative(x1, x5, 0.6),
        small_node_relative(x2, x5, 0.6),
        small_node_relative(x2, x4, 0.6),
        small_node_relative(x3, x4, 0.6),
        small_node_relative(x1, x4, 0.6),
    ])

    labels = [
        label_small_node("$f_{rp}(C_1, p_5)$", small_nodes[-5], 135.0, hd,
                         1.4 * rectangle_width)
    ]

    return [nodes, small_nodes, connections, labels]


e = [
    figure_a(),
    figure_b(),
    figure_c(),
    figure_d(),
    figure_e(),
    figure_f(),
]
stz.distribute_vertically_with_spacing(e[:3][::-1], 2.0)
stz.distribute_vertically_with_spacing(e[3:][::-1], 2.0)
stz.distribute_horizontally_with_spacing([e[:3], e[3:]], 2.0)

stz.draw_to_tikz_standalone(e, "factor_graphs.tex")
Exemplo n.º 12
0
xs = [sq_fn("$x_%d$" % (i + 1, ), xs_fmt[i]) for i in range(4)]
hs1 = [
    sq_fn("$h_%d^{(1)}$" % (i + 1, ), h_grey_rfmt if i != 2 else h_green_rfmt,
          h_grey_lfmt if i != 2 else "") for i in range(4)
]
hs2 = [
    sq_fn("$h_%d^{(2)}$" % (i + 1, ), h_grey_rfmt if i != 2 else h_green_rfmt,
          h_grey_lfmt if i != 2 else "") for i in range(4)
]
m1, m2 = [rct_fn("$\\text{mem}^{(%d)}$" % (i + 1, ), m_fmt) for i in range(2)]
x_out = [sq_fn("$x_3$", x3_fmt)]

row1 = [m1] + xs
row2 = [m2] + hs1

stz.distribute_horizontally_with_spacing([m1] + xs, horizontal_spacing)
stz.distribute_horizontally_with_spacing([m2] + hs1, horizontal_spacing)
stz.distribute_horizontally_with_spacing(hs2, horizontal_spacing)
stz.distribute_vertically_with_spacing([row1, row2, hs2], vertical_spacing)
stz.align_rights([row1, row2, hs2], 0)

stz.place_above_and_align_to_the_center(x_out, hs2[2], vertical_spacing)

legend = stz.latex(
    [0, 0],
    "Factorization order: $3 \\rightarrow 2 \\rightarrow 4 \\rightarrow 1$")
stz.place_below_and_align_to_the_center(legend, xs, legend_spacing)

connections = [
    connect(m1, hs1[2]),
    connect(m2, hs2[2]),