Beispiel #1
0
def plate(e_lst, expr):
    top_left_cs, bottom_right_cs = stz.bbox(e_lst)
    r = stz.rectangle_from_additive_resizing(top_left_cs, bottom_right_cs,
                                             2.0 * horizontal_plate_spacing,
                                             2.0 * vertical_plate_spacing)
    l_cs = stz.translate_coords_antidiagonally(
        stz.bbox(r)[1], -plate_label_spacing)
    l = stz.latex(l_cs, expr)
    return [r, l]
Beispiel #2
0
def rectangle_with_add_norm(color_name, s_lst):
    r1 = rectangle_with_text("add_norm_color", ["Add \\& Norm"])
    r2 = rectangle_with_text(color_name, s_lst)

    top_left_cs, bottom_right_cs = stz.bbox(r2)
    cs = stz.translate_coords_vertically(
        stz.top_center_coords(top_left_cs, bottom_right_cs),
        to_add_norm_spacing)
    stz.translate_bbox_bottom_center_to_coords(r1, cs)
    c = connect_straight_vertical(r2, r1)
    return [r1, r2, c]
Beispiel #3
0
def trident_coords(e):
    top_left_cs, bottom_right_cs = stz.bbox(e)
    center_cs = stz.bottom_center_coords(top_left_cs, bottom_right_cs)
    left_cs = stz.translate_coords_horizontally(center_cs,
                                                -trident_alpha * width / 2.0)
    right_cs = stz.translate_coords_horizontally(center_cs,
                                                 trident_alpha * width / 2.0)

    cs_lst = [left_cs, center_cs, right_cs]
    translated_cs_lst = [
        stz.translate_coords_vertically(cs, -trident_spacing) for cs in cs_lst
    ]
    return cs_lst, translated_cs_lst
Beispiel #4
0
def get_ticks(e):
    top_left_cs, bottom_right_cs = stz.bbox(e)
    bottom_left_cs = stz.bottom_left_coords(top_left_cs, bottom_right_cs)
    top_ticks = []
    bottom_ticks = []
    for i in range(1, num_ticks):
        top_cs = stz.translate_coords_horizontally(top_left_cs,
                                                   i * tick_spacing)
        top_t = stz.vertical_line_segment(top_cs, -tick_length, t_fmt)
        top_ticks.append(top_t)

        bottom_cs = stz.translate_coords_horizontally(bottom_left_cs,
                                                      i * tick_spacing)
        bottom_t = stz.vertical_line_segment(bottom_cs, tick_length, t_fmt)
        bottom_ticks.append(bottom_t)
    return [bottom_ticks, top_ticks]
Beispiel #5
0
    connect_straight_with_arrow(cr, mmha_an),
    connect_straight_with_arrow(oer, cr),
    connect_straight_horizontal(sl, cl),
    connect_straight_horizontal(cr, sr)
]

connections.extend([
    connect_residual(ff1_an, 0.6, -residual_spacing),
    connect_residual(mha1_an, 0.6, -residual_spacing),
    connect_residual(mmha_an, 0.6, residual_spacing),
    connect_residual(mha2_an, 0.6, residual_spacing),
    connect_residual(ff2_an, 0.6, residual_spacing)
])

### bounding boxes
b = stz.bbox([ff1_an, mha1_an, connections[-4]])
bb1 = stz.rectangle_from_additive_resizing(b[0], b[1], bbox_spacing,
                                           bbox_spacing, s_bbox)
b = stz.bbox([ff2_an, mmha_an, connections[-2], connections[-3]])
bb2 = stz.rectangle_from_additive_resizing(b[0], b[1], bbox_spacing,
                                           bbox_spacing, s_bbox)

### trident connections
tri_cs, translated_tri_cs = trident_coords(mha1_an)
e1 = stz.open_path([
    translated_tri_cs[1],
    stz.bottom_left_coords(translated_tri_cs[1], tri_cs[0]), tri_cs[0]
], s_con)
e2 = stz.open_path([
    translated_tri_cs[1],
    stz.bottom_right_coords(translated_tri_cs[1], tri_cs[2]), tri_cs[2]
Beispiel #6
0
def frame(frame_idx):
    assert frame_idx >= 0 and frame_idx <= 3
    c1 = conv2d(1)
    o = optional(1)
    r1 = repeat(1)
    r2 = repeat(2)
    cc = concat(1)
    c2 = conv2d(2)
    c3 = conv2d(3)
    c4 = conv2d(4)
    d = dropout(1)

    stz.distribute_horizontally_with_spacing([r1, r2],
                                             horizontal_module_spacing)
    stz.distribute_horizontally_with_spacing([c2, [c3, c4]],
                                             horizontal_module_spacing)

    modules = []
    if frame_idx == 0:
        stz.distribute_vertically_with_spacing([cc, [r1, r2], o, c1],
                                               vertical_module_spacing)

        stz.align_centers_horizontally([cc, [r1, r2], o, c1], 0)
        modules.extend([c1, o, r1, r2, cc])

    else:
        stz.distribute_vertically_with_spacing([c4, c3],
                                               vertical_module_spacing)
        stz.distribute_horizontally_with_spacing([c2, [c3, c4]],
                                                 horizontal_module_spacing)
        stz.align_centers_vertically([[c3, c4], c2], 0)

        if frame_idx == 1:
            stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], o, c1],
                                                   vertical_module_spacing)
            stz.align_centers_horizontally([cc, [c2, c3, c4], o, c1], 0)
            modules.extend([c1, o, c2, c3, c4, cc])

        else:
            stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], d, c1],
                                                   vertical_module_spacing)
            stz.align_centers_horizontally([cc, [c2, c3, c4], d, c1], 0)
            modules.extend([c1, d, c2, c3, c4, cc])

    module_connections = []
    if frame_idx == 0:
        module_connections.extend([
            connect_modules(c1, o, 0, 0),
            connect_modules(o, r1, 0, 0),
            connect_modules(o, r2, 0, 0),
            connect_modules(r1, cc, 0, 0),
            connect_modules(r2, cc, 0, 1),
        ])

    else:
        if frame_idx == 1:
            module_connections.extend([
                connect_modules(c1, o, 0, 0),
                connect_modules(o, c2, 0, 0),
                connect_modules(o, c3, 0, 0),
            ])
        else:
            module_connections.extend([
                connect_modules(c1, d, 0, 0),
                connect_modules(d, c2, 0, 0),
                connect_modules(d, c3, 0, 0),
            ])

        module_connections.extend([
            connect_modules(c3, c4, 0, 0),
            connect_modules(c2, cc, 0, 0),
            connect_modules(c4, cc, 0, 1),
        ])

    # # hyperparameters
    if frame_idx <= 1:
        h_o = independent_hyperparameter("IH-2", "0, 1")
    else:
        h_o = independent_hyperparameter("IH-2", "0, 1", "1")

    if frame_idx <= 0:
        h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x")
        h_r2 = independent_hyperparameter("IH-3", "1, 2, 4")
    else:
        h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x", "2")
        h_r2 = independent_hyperparameter("IH-3", "1, 2, 4", "1")

    if frame_idx <= 2:
        h_c1 = independent_hyperparameter("IH-1", "64, 128")
        h_c2 = independent_hyperparameter("IH-4", "64, 128")
        h_c3 = independent_hyperparameter("IH-5", "64, 128")
        h_c4 = independent_hyperparameter("IH-6", "64, 128")
        h_d = independent_hyperparameter("IH-7", "0.25, 0.5")
    else:
        h_c1 = independent_hyperparameter("IH-1", "64, 128", "64")
        h_c2 = independent_hyperparameter("IH-4", "64, 128", "128")
        h_c3 = independent_hyperparameter("IH-5", "64, 128", "128")
        h_c4 = independent_hyperparameter("IH-6", "64, 128", "64")
        h_d = independent_hyperparameter("IH-7", "0.25, 0.5", "0.5")

    def place_hyperp_right_of(h, m):
        y_p = stz.center_coords(m[3])[1]
        stz.align_centers_vertically([h], y_p)
        stz.place_to_the_right(h, m, spacing_between_module_and_hyperp)

    hyperparameters = []
    place_hyperp_right_of(h_c1, c1)
    if frame_idx in [0, 1]:
        place_hyperp_right_of(h_o, o)
        hyperparameters.append(h_o)

    if frame_idx == 0:
        place_hyperp_right_of(h_r1, r2)
        stz.place_above_and_align_to_the_right(h_r2, h_r1, 0.8)
        hyperparameters.extend([h_r1, h_r2, h_c1])
    else:
        place_hyperp_right_of(h_c1, c1)
        place_hyperp_right_of(h_c3, c3)
        place_hyperp_right_of(h_c4, c4)
        stz.place_below(h_c2, h_c1, 3.0)
        hyperparameters.extend([h_c1, h_c2, h_c3, h_c4])

        if frame_idx in [2, 3]:
            place_hyperp_right_of(h_d, d)
            hyperparameters.extend([h_d])

    unreachable_hyperps = []
    if frame_idx == 1:
        stz.distribute_vertically_with_spacing([h_r1, h_r2], 0.2)
        unreachable_hyperps.extend([h_r1, h_r2])
    if frame_idx >= 2:
        stz.distribute_vertically_with_spacing([h_o, h_r1, h_r2], 0.2)
        unreachable_hyperps.extend([h_r1, h_r2, h_o])
    hyperparameters.extend(unreachable_hyperps)

    cs_fn = lambda e: stz.coords_from_bbox_with_fn(e, stz.left_center_coords)
    if frame_idx == 0:

        stz.translate_bbox_left_center_to_coords(h_r2, cs_fn([h_o, h_r1]))
    elif frame_idx == 1:
        stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_o, h_c3]))
    else:
        stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_d, h_c3]))

    hyperp_connections = [
        connect_hyperp_to_module(h_c1, c1, 0),
    ]
    if frame_idx in [0, 1]:
        hyperp_connections.extend([connect_hyperp_to_module(h_o, o, 0)])
    if frame_idx == 0:
        hyperp_connections.extend([
            connect_hyperp_to_module(h_r1, r2, 0),
            connect_hyperp_to_module(h_r2, r1, 0),
            connect_hyperp_to_hyperp(h_r2, h_r1)
        ])
    else:
        hyperp_connections.extend([
            connect_hyperp_to_module(h_c2, c2, 0),
            connect_hyperp_to_module(h_c3, c3, 0),
            connect_hyperp_to_module(h_c4, c4, 0),
        ])
        if frame_idx in [2, 3]:
            hyperp_connections.append(connect_hyperp_to_module(h_d, d, 0))

    f = stz.rectangle_from_width_and_height([0, 0], frame_height, frame_width,
                                            frame_s_fmt)
    e = [modules, module_connections, hyperparameters, hyperp_connections]
    stz.translate_bbox_center_to_coords(
        f, stz.translate_coords_horizontally(stz.center_coords(e), 0.8))
    if len(unreachable_hyperps) > 0:
        stz.translate_bbox_bottom_right_to_coords(unreachable_hyperps,
                                                  stz.bbox(e)[1])

    # frame id
    s = ["a", "b", "c", "d"][frame_idx]
    label = [stz.latex([0, 0], "\\Huge \\textbf %s" % s)]
    stz.translate_bbox_top_left_to_coords(
        label,
        stz.translate_coords_antidiagonally(
            stz.coords_from_bbox_with_fn(f, stz.top_left_coords), 0.6))

    return e + [f, label]
Beispiel #7
0
def dashed_bbox(e_lst):
    s_fmt = fmt.combine_tikz_strs([fmt.line_style("dashed"), s_lw])
    top_left_cs, bottom_right_cs = stz.bbox(e_lst)
    return stz.rectangle_from_additive_resizing(top_left_cs, bottom_right_cs,
                                                2.0 * bbox_spacing,
                                                2.0 * bbox_spacing, s_fmt)