def plate(e_lst, expr): top_left_cs, bottom_right_cs = stz.bbox(e_lst) r = stz.rectangle_from_additive_resizing(top_left_cs, bottom_right_cs, 2.0 * horizontal_plate_spacing, 2.0 * vertical_plate_spacing) l_cs = stz.translate_coords_antidiagonally( stz.bbox(r)[1], -plate_label_spacing) l = stz.latex(l_cs, expr) return [r, l]
def rectangle_with_add_norm(color_name, s_lst): r1 = rectangle_with_text("add_norm_color", ["Add \\& Norm"]) r2 = rectangle_with_text(color_name, s_lst) top_left_cs, bottom_right_cs = stz.bbox(r2) cs = stz.translate_coords_vertically( stz.top_center_coords(top_left_cs, bottom_right_cs), to_add_norm_spacing) stz.translate_bbox_bottom_center_to_coords(r1, cs) c = connect_straight_vertical(r2, r1) return [r1, r2, c]
def trident_coords(e): top_left_cs, bottom_right_cs = stz.bbox(e) center_cs = stz.bottom_center_coords(top_left_cs, bottom_right_cs) left_cs = stz.translate_coords_horizontally(center_cs, -trident_alpha * width / 2.0) right_cs = stz.translate_coords_horizontally(center_cs, trident_alpha * width / 2.0) cs_lst = [left_cs, center_cs, right_cs] translated_cs_lst = [ stz.translate_coords_vertically(cs, -trident_spacing) for cs in cs_lst ] return cs_lst, translated_cs_lst
def get_ticks(e): top_left_cs, bottom_right_cs = stz.bbox(e) bottom_left_cs = stz.bottom_left_coords(top_left_cs, bottom_right_cs) top_ticks = [] bottom_ticks = [] for i in range(1, num_ticks): top_cs = stz.translate_coords_horizontally(top_left_cs, i * tick_spacing) top_t = stz.vertical_line_segment(top_cs, -tick_length, t_fmt) top_ticks.append(top_t) bottom_cs = stz.translate_coords_horizontally(bottom_left_cs, i * tick_spacing) bottom_t = stz.vertical_line_segment(bottom_cs, tick_length, t_fmt) bottom_ticks.append(bottom_t) return [bottom_ticks, top_ticks]
connect_straight_with_arrow(cr, mmha_an), connect_straight_with_arrow(oer, cr), connect_straight_horizontal(sl, cl), connect_straight_horizontal(cr, sr) ] connections.extend([ connect_residual(ff1_an, 0.6, -residual_spacing), connect_residual(mha1_an, 0.6, -residual_spacing), connect_residual(mmha_an, 0.6, residual_spacing), connect_residual(mha2_an, 0.6, residual_spacing), connect_residual(ff2_an, 0.6, residual_spacing) ]) ### bounding boxes b = stz.bbox([ff1_an, mha1_an, connections[-4]]) bb1 = stz.rectangle_from_additive_resizing(b[0], b[1], bbox_spacing, bbox_spacing, s_bbox) b = stz.bbox([ff2_an, mmha_an, connections[-2], connections[-3]]) bb2 = stz.rectangle_from_additive_resizing(b[0], b[1], bbox_spacing, bbox_spacing, s_bbox) ### trident connections tri_cs, translated_tri_cs = trident_coords(mha1_an) e1 = stz.open_path([ translated_tri_cs[1], stz.bottom_left_coords(translated_tri_cs[1], tri_cs[0]), tri_cs[0] ], s_con) e2 = stz.open_path([ translated_tri_cs[1], stz.bottom_right_coords(translated_tri_cs[1], tri_cs[2]), tri_cs[2]
def frame(frame_idx): assert frame_idx >= 0 and frame_idx <= 3 c1 = conv2d(1) o = optional(1) r1 = repeat(1) r2 = repeat(2) cc = concat(1) c2 = conv2d(2) c3 = conv2d(3) c4 = conv2d(4) d = dropout(1) stz.distribute_horizontally_with_spacing([r1, r2], horizontal_module_spacing) stz.distribute_horizontally_with_spacing([c2, [c3, c4]], horizontal_module_spacing) modules = [] if frame_idx == 0: stz.distribute_vertically_with_spacing([cc, [r1, r2], o, c1], vertical_module_spacing) stz.align_centers_horizontally([cc, [r1, r2], o, c1], 0) modules.extend([c1, o, r1, r2, cc]) else: stz.distribute_vertically_with_spacing([c4, c3], vertical_module_spacing) stz.distribute_horizontally_with_spacing([c2, [c3, c4]], horizontal_module_spacing) stz.align_centers_vertically([[c3, c4], c2], 0) if frame_idx == 1: stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], o, c1], vertical_module_spacing) stz.align_centers_horizontally([cc, [c2, c3, c4], o, c1], 0) modules.extend([c1, o, c2, c3, c4, cc]) else: stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], d, c1], vertical_module_spacing) stz.align_centers_horizontally([cc, [c2, c3, c4], d, c1], 0) modules.extend([c1, d, c2, c3, c4, cc]) module_connections = [] if frame_idx == 0: module_connections.extend([ connect_modules(c1, o, 0, 0), connect_modules(o, r1, 0, 0), connect_modules(o, r2, 0, 0), connect_modules(r1, cc, 0, 0), connect_modules(r2, cc, 0, 1), ]) else: if frame_idx == 1: module_connections.extend([ connect_modules(c1, o, 0, 0), connect_modules(o, c2, 0, 0), connect_modules(o, c3, 0, 0), ]) else: module_connections.extend([ connect_modules(c1, d, 0, 0), connect_modules(d, c2, 0, 0), connect_modules(d, c3, 0, 0), ]) module_connections.extend([ connect_modules(c3, c4, 0, 0), connect_modules(c2, cc, 0, 0), connect_modules(c4, cc, 0, 1), ]) # # hyperparameters if frame_idx <= 1: h_o = independent_hyperparameter("IH-2", "0, 1") else: h_o = independent_hyperparameter("IH-2", "0, 1", "1") if frame_idx <= 0: h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x") h_r2 = independent_hyperparameter("IH-3", "1, 2, 4") else: h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x", "2") h_r2 = independent_hyperparameter("IH-3", "1, 2, 4", "1") if frame_idx <= 2: h_c1 = independent_hyperparameter("IH-1", "64, 128") h_c2 = independent_hyperparameter("IH-4", "64, 128") h_c3 = independent_hyperparameter("IH-5", "64, 128") h_c4 = independent_hyperparameter("IH-6", "64, 128") h_d = independent_hyperparameter("IH-7", "0.25, 0.5") else: h_c1 = independent_hyperparameter("IH-1", "64, 128", "64") h_c2 = independent_hyperparameter("IH-4", "64, 128", "128") h_c3 = independent_hyperparameter("IH-5", "64, 128", "128") h_c4 = independent_hyperparameter("IH-6", "64, 128", "64") h_d = independent_hyperparameter("IH-7", "0.25, 0.5", "0.5") def place_hyperp_right_of(h, m): y_p = stz.center_coords(m[3])[1] stz.align_centers_vertically([h], y_p) stz.place_to_the_right(h, m, spacing_between_module_and_hyperp) hyperparameters = [] place_hyperp_right_of(h_c1, c1) if frame_idx in [0, 1]: place_hyperp_right_of(h_o, o) hyperparameters.append(h_o) if frame_idx == 0: place_hyperp_right_of(h_r1, r2) stz.place_above_and_align_to_the_right(h_r2, h_r1, 0.8) hyperparameters.extend([h_r1, h_r2, h_c1]) else: place_hyperp_right_of(h_c1, c1) place_hyperp_right_of(h_c3, c3) place_hyperp_right_of(h_c4, c4) stz.place_below(h_c2, h_c1, 3.0) hyperparameters.extend([h_c1, h_c2, h_c3, h_c4]) if frame_idx in [2, 3]: place_hyperp_right_of(h_d, d) hyperparameters.extend([h_d]) unreachable_hyperps = [] if frame_idx == 1: stz.distribute_vertically_with_spacing([h_r1, h_r2], 0.2) unreachable_hyperps.extend([h_r1, h_r2]) if frame_idx >= 2: stz.distribute_vertically_with_spacing([h_o, h_r1, h_r2], 0.2) unreachable_hyperps.extend([h_r1, h_r2, h_o]) hyperparameters.extend(unreachable_hyperps) cs_fn = lambda e: stz.coords_from_bbox_with_fn(e, stz.left_center_coords) if frame_idx == 0: stz.translate_bbox_left_center_to_coords(h_r2, cs_fn([h_o, h_r1])) elif frame_idx == 1: stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_o, h_c3])) else: stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_d, h_c3])) hyperp_connections = [ connect_hyperp_to_module(h_c1, c1, 0), ] if frame_idx in [0, 1]: hyperp_connections.extend([connect_hyperp_to_module(h_o, o, 0)]) if frame_idx == 0: hyperp_connections.extend([ connect_hyperp_to_module(h_r1, r2, 0), connect_hyperp_to_module(h_r2, r1, 0), connect_hyperp_to_hyperp(h_r2, h_r1) ]) else: hyperp_connections.extend([ connect_hyperp_to_module(h_c2, c2, 0), connect_hyperp_to_module(h_c3, c3, 0), connect_hyperp_to_module(h_c4, c4, 0), ]) if frame_idx in [2, 3]: hyperp_connections.append(connect_hyperp_to_module(h_d, d, 0)) f = stz.rectangle_from_width_and_height([0, 0], frame_height, frame_width, frame_s_fmt) e = [modules, module_connections, hyperparameters, hyperp_connections] stz.translate_bbox_center_to_coords( f, stz.translate_coords_horizontally(stz.center_coords(e), 0.8)) if len(unreachable_hyperps) > 0: stz.translate_bbox_bottom_right_to_coords(unreachable_hyperps, stz.bbox(e)[1]) # frame id s = ["a", "b", "c", "d"][frame_idx] label = [stz.latex([0, 0], "\\Huge \\textbf %s" % s)] stz.translate_bbox_top_left_to_coords( label, stz.translate_coords_antidiagonally( stz.coords_from_bbox_with_fn(f, stz.top_left_coords), 0.6)) return e + [f, label]
def dashed_bbox(e_lst): s_fmt = fmt.combine_tikz_strs([fmt.line_style("dashed"), s_lw]) top_left_cs, bottom_right_cs = stz.bbox(e_lst) return stz.rectangle_from_additive_resizing(top_left_cs, bottom_right_cs, 2.0 * bbox_spacing, 2.0 * bbox_spacing, s_fmt)