def module(module_name, input_names, output_names, hyperp_names, p_width_scale=1.0): i_lst = [input(s) for s in input_names] o_lst = [output(s) for s in output_names] m = stz.rectangle([0, 0], [module_width, -module_height], module_s_fmt) l = stz.latex(stz.center_coords(m), "\\textbf{%s}" % module_name) stz.distribute_horizontally_with_spacing(i_lst, io_spacing) stz.translate_bbox_top_left_to_coords( i_lst, [module_inner_vertical_spacing, -module_inner_vertical_spacing]) stz.distribute_horizontally_with_spacing(o_lst, io_spacing) stz.translate_bbox_bottom_left_to_coords(o_lst, [ module_inner_vertical_spacing, -module_height + module_inner_vertical_spacing ]) if len(hyperp_names) > 0: h_lst = [property(s, p_width_scale) for s in hyperp_names] stz.distribute_vertically_with_spacing(h_lst, p_spacing) stz.translate_bbox_top_right_to_coords(h_lst, [ module_width - module_inner_vertical_spacing, -module_inner_vertical_spacing - delta_increment ]) return [[m, l], i_lst, o_lst, h_lst] else: return [[m, l], i_lst, o_lst]
def search_space_transition(): e0 = frame(0) e1 = frame(1) e2 = frame(2) e3 = frame(3) e = [e0, e1, e2, e3] def get_idx(e_frame, indices): e = e_frame for idx in indices: e = e[idx] return e def highlight(e_frame, indices, idx, color): e = get_idx(e_frame, indices) s_fmt = fmt.combine_tikz_strs([e["tikz_str"], fmt.fill_color(color)]) e['tikz_str'] = s_fmt # highlight new modules highlight(e1, [0, 2, 0, 0], 0, "light_green_2") highlight(e1, [0, 3, 0, 0], 0, "light_green_2") highlight(e1, [0, 4, 0, 0], 0, "light_green_2") highlight(e2, [0, 1, 0, 0], 0, "light_green_2") # highlight new hyperparameters highlight(e1, [2, 2, 0], 0, "light_green_2") highlight(e1, [2, 3, 0], 0, "light_green_2") highlight(e1, [2, 4, 0], 0, "light_green_2") highlight(e2, [2, 4, 0], 0, "light_green_2") # highlight assigned hyperparameters highlight(e1, [2, 5, 0], 0, "light_red_2") highlight(e1, [2, 6, 0], 0, "light_red_2") highlight(e2, [2, 7, 0], 0, "light_red_2") highlight(e3, [2, 0, 0], 0, "light_red_2") highlight(e3, [2, 1, 0], 0, "light_red_2") highlight(e3, [2, 2, 0], 0, "light_red_2") highlight(e3, [2, 3, 0], 0, "light_red_2") highlight(e3, [2, 4, 0], 0, "light_red_2") # arrange the four frames stz.align_tops(e, 0.0) stz.distribute_horizontally_with_spacing([e0, e1], frame_spacing) stz.distribute_horizontally_with_spacing([e2, e3], frame_spacing) stz.distribute_vertically_with_spacing([[e2, e3], [e0, e1]], frame_spacing) stz.draw_to_tikz_standalone(e, "deep_architect.tex", name2color)
def bar_plot(data, label_strs): cs = [0, 0] rs = [ stz.rectangle_from_width_and_height(cs, rectangle_height, x, r_fmt) for x in data ] stz.distribute_vertically_with_spacing(rs, rectangle_spacing) frame_height = (len(data) * rectangle_height + (len(data) - 1) * rectangle_spacing + 2.0 * margin) frame = stz.rectangle_from_width_and_height(cs, frame_height, frame_width) stz.align_centers_vertically([frame, rs], 0.0) labels = [] for i, r in enumerate(rs): l_cs = stz.coords_from_bbox_with_fn(r, stz.left_center_coords) l_cs = stz.translate_coords_horizontally(l_cs, -label_spacing) l_s = label_strs[i] lab = stz.latex(l_cs, l_s, fmt.anchor('right_center')) labels.append(lab) ticks = get_ticks(frame) return [rs, labels, frame, ticks]
tick_label_spacing = 0.25 tick_length = 0.25 segment_spacing = 0.6 length_multiplier = 0.8 def segment(length, label_str, left_tick_label_str, right_tick_label_str): seg = stz.line_segment([0, 0], [length, 0]) left_tick = stz.centered_vertical_line_segment([0, 0], tick_length) right_tick = stz.centered_vertical_line_segment([length, 0], tick_length) left_tick_label = stz.latex([0, 0], left_tick_label_str) right_tick_label = stz.latex([0, 0], right_tick_label_str) stz.place_above_and_align_to_the_center(left_tick_label, left_tick, tick_label_spacing) stz.place_above_and_align_to_the_center(right_tick_label, right_tick, tick_label_spacing) seg_label = stz.latex([-label_spacing, 0], label_str) return [ seg, left_tick, right_tick, left_tick_label, right_tick_label, seg_label ] segs = [ segment(length_multiplier * 5, "A", "0", "5"), segment(length_multiplier * 1, "B", "2", "3"), segment(length_multiplier * 3, "C", "1", "4") ] stz.distribute_vertically_with_spacing(segs[::-1], segment_spacing) stz.draw_to_tikz_standalone(segs, "segments.tex")
def frame(frame_idx): assert frame_idx >= 0 and frame_idx <= 3 c1 = conv2d(1) o = optional(1) r1 = repeat(1) r2 = repeat(2) cc = concat(1) c2 = conv2d(2) c3 = conv2d(3) c4 = conv2d(4) d = dropout(1) stz.distribute_horizontally_with_spacing([r1, r2], horizontal_module_spacing) stz.distribute_horizontally_with_spacing([c2, [c3, c4]], horizontal_module_spacing) modules = [] if frame_idx == 0: stz.distribute_vertically_with_spacing([cc, [r1, r2], o, c1], vertical_module_spacing) stz.align_centers_horizontally([cc, [r1, r2], o, c1], 0) modules.extend([c1, o, r1, r2, cc]) else: stz.distribute_vertically_with_spacing([c4, c3], vertical_module_spacing) stz.distribute_horizontally_with_spacing([c2, [c3, c4]], horizontal_module_spacing) stz.align_centers_vertically([[c3, c4], c2], 0) if frame_idx == 1: stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], o, c1], vertical_module_spacing) stz.align_centers_horizontally([cc, [c2, c3, c4], o, c1], 0) modules.extend([c1, o, c2, c3, c4, cc]) else: stz.distribute_vertically_with_spacing([cc, [c2, c3, c4], d, c1], vertical_module_spacing) stz.align_centers_horizontally([cc, [c2, c3, c4], d, c1], 0) modules.extend([c1, d, c2, c3, c4, cc]) module_connections = [] if frame_idx == 0: module_connections.extend([ connect_modules(c1, o, 0, 0), connect_modules(o, r1, 0, 0), connect_modules(o, r2, 0, 0), connect_modules(r1, cc, 0, 0), connect_modules(r2, cc, 0, 1), ]) else: if frame_idx == 1: module_connections.extend([ connect_modules(c1, o, 0, 0), connect_modules(o, c2, 0, 0), connect_modules(o, c3, 0, 0), ]) else: module_connections.extend([ connect_modules(c1, d, 0, 0), connect_modules(d, c2, 0, 0), connect_modules(d, c3, 0, 0), ]) module_connections.extend([ connect_modules(c3, c4, 0, 0), connect_modules(c2, cc, 0, 0), connect_modules(c4, cc, 0, 1), ]) # # hyperparameters if frame_idx <= 1: h_o = independent_hyperparameter("IH-2", "0, 1") else: h_o = independent_hyperparameter("IH-2", "0, 1", "1") if frame_idx <= 0: h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x") h_r2 = independent_hyperparameter("IH-3", "1, 2, 4") else: h_r1 = dependent_hyperparameter("DH-1", ["x"], "2*x", "2") h_r2 = independent_hyperparameter("IH-3", "1, 2, 4", "1") if frame_idx <= 2: h_c1 = independent_hyperparameter("IH-1", "64, 128") h_c2 = independent_hyperparameter("IH-4", "64, 128") h_c3 = independent_hyperparameter("IH-5", "64, 128") h_c4 = independent_hyperparameter("IH-6", "64, 128") h_d = independent_hyperparameter("IH-7", "0.25, 0.5") else: h_c1 = independent_hyperparameter("IH-1", "64, 128", "64") h_c2 = independent_hyperparameter("IH-4", "64, 128", "128") h_c3 = independent_hyperparameter("IH-5", "64, 128", "128") h_c4 = independent_hyperparameter("IH-6", "64, 128", "64") h_d = independent_hyperparameter("IH-7", "0.25, 0.5", "0.5") def place_hyperp_right_of(h, m): y_p = stz.center_coords(m[3])[1] stz.align_centers_vertically([h], y_p) stz.place_to_the_right(h, m, spacing_between_module_and_hyperp) hyperparameters = [] place_hyperp_right_of(h_c1, c1) if frame_idx in [0, 1]: place_hyperp_right_of(h_o, o) hyperparameters.append(h_o) if frame_idx == 0: place_hyperp_right_of(h_r1, r2) stz.place_above_and_align_to_the_right(h_r2, h_r1, 0.8) hyperparameters.extend([h_r1, h_r2, h_c1]) else: place_hyperp_right_of(h_c1, c1) place_hyperp_right_of(h_c3, c3) place_hyperp_right_of(h_c4, c4) stz.place_below(h_c2, h_c1, 3.0) hyperparameters.extend([h_c1, h_c2, h_c3, h_c4]) if frame_idx in [2, 3]: place_hyperp_right_of(h_d, d) hyperparameters.extend([h_d]) unreachable_hyperps = [] if frame_idx == 1: stz.distribute_vertically_with_spacing([h_r1, h_r2], 0.2) unreachable_hyperps.extend([h_r1, h_r2]) if frame_idx >= 2: stz.distribute_vertically_with_spacing([h_o, h_r1, h_r2], 0.2) unreachable_hyperps.extend([h_r1, h_r2, h_o]) hyperparameters.extend(unreachable_hyperps) cs_fn = lambda e: stz.coords_from_bbox_with_fn(e, stz.left_center_coords) if frame_idx == 0: stz.translate_bbox_left_center_to_coords(h_r2, cs_fn([h_o, h_r1])) elif frame_idx == 1: stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_o, h_c3])) else: stz.translate_bbox_left_center_to_coords(h_c2, cs_fn([h_d, h_c3])) hyperp_connections = [ connect_hyperp_to_module(h_c1, c1, 0), ] if frame_idx in [0, 1]: hyperp_connections.extend([connect_hyperp_to_module(h_o, o, 0)]) if frame_idx == 0: hyperp_connections.extend([ connect_hyperp_to_module(h_r1, r2, 0), connect_hyperp_to_module(h_r2, r1, 0), connect_hyperp_to_hyperp(h_r2, h_r1) ]) else: hyperp_connections.extend([ connect_hyperp_to_module(h_c2, c2, 0), connect_hyperp_to_module(h_c3, c3, 0), connect_hyperp_to_module(h_c4, c4, 0), ]) if frame_idx in [2, 3]: hyperp_connections.append(connect_hyperp_to_module(h_d, d, 0)) f = stz.rectangle_from_width_and_height([0, 0], frame_height, frame_width, frame_s_fmt) e = [modules, module_connections, hyperparameters, hyperp_connections] stz.translate_bbox_center_to_coords( f, stz.translate_coords_horizontally(stz.center_coords(e), 0.8)) if len(unreachable_hyperps) > 0: stz.translate_bbox_bottom_right_to_coords(unreachable_hyperps, stz.bbox(e)[1]) # frame id s = ["a", "b", "c", "d"][frame_idx] label = [stz.latex([0, 0], "\\Huge \\textbf %s" % s)] stz.translate_bbox_top_left_to_coords( label, stz.translate_coords_antidiagonally( stz.coords_from_bbox_with_fn(f, stz.top_left_coords), 0.6)) return e + [f, label]
def label_left(e, expr): cs = stz.coords_from_bbox_with_fn(e, stz.left_center_coords) cs = stz.translate_coords(cs, -label_spacing, 0.1) return stz.latex(cs, "\\scriptsize{%s}" % expr) nodes = [] for i in [8, 3, 10, 1, 6, 14, 4, 7, 13]: if i == 4: s = "\\textbf{%s}" % str(i) else: s = str(i) nodes.append(fn(s)) stz.distribute_vertically_with_spacing( [nodes[0:1], nodes[1:3], nodes[3:6], nodes[6:9]][::-1], vertical_node_spacing) place(nodes[1], [-1]) place(nodes[2], [1]) place(nodes[3], [-1, -1]) place(nodes[4], [-1, 1]) place(nodes[5], [1, 1]) place(nodes[6], [-1, 1, -1]) place(nodes[7], [-1, 1, 1]) place(nodes[8], [1, 1, -1]) connections = [ connect(nodes[0], nodes[1], "blue"), connect(nodes[0], nodes[2]), connect(nodes[1], nodes[3]),
"E", "F", ] def box_with_text(s): s_fmt = fmt.rounded_corners(box_roundness) return [ stz.rectangle([0, 0], [box_width, -box_height], s_fmt), stz.latex([box_width / 2.0, -box_height / 2.0], s) ] def arrow(): s_fmt = fmt.fill_color_with_no_line("lightgray") a = stz.arrow(shaft_width, shaft_height, head_width, head_height, -90.0, s_fmt) return a boxes = [box_with_text(s) for s in lst] stz.distribute_vertically_with_spacing(boxes[::-1], box_spacing) arrows = [arrow() for _ in range(len(lst) - 1)] for i in range(len(boxes) - 1): to_cs = stz.center_coords(boxes[i:i + 2]) from_cs = stz.center_coords(arrows[i]) stz.translate_to_coords(arrows[i], from_cs, to_cs) stz.draw_to_tikz_standalone([boxes, arrows], "flowchart.tex")
h_grey_lfmt if i != 2 else "") for i in range(4) ] hs2 = [ sq_fn("$h_%d^{(2)}$" % (i + 1, ), h_grey_rfmt if i != 2 else h_green_rfmt, h_grey_lfmt if i != 2 else "") for i in range(4) ] m1, m2 = [rct_fn("$\\text{mem}^{(%d)}$" % (i + 1, ), m_fmt) for i in range(2)] x_out = [sq_fn("$x_3$", x3_fmt)] row1 = [m1] + xs row2 = [m2] + hs1 stz.distribute_horizontally_with_spacing([m1] + xs, horizontal_spacing) stz.distribute_horizontally_with_spacing([m2] + hs1, horizontal_spacing) stz.distribute_horizontally_with_spacing(hs2, horizontal_spacing) stz.distribute_vertically_with_spacing([row1, row2, hs2], vertical_spacing) stz.align_rights([row1, row2, hs2], 0) stz.place_above_and_align_to_the_center(x_out, hs2[2], vertical_spacing) legend = stz.latex( [0, 0], "Factorization order: $3 \\rightarrow 2 \\rightarrow 4 \\rightarrow 1$") stz.place_below_and_align_to_the_center(legend, xs, legend_spacing) connections = [ connect(m1, hs1[2]), connect(m2, hs2[2]), connect(hs1[2], hs2[2]), connect(hs2[2], x_out) ]
small_nodes.extend([ small_node_relative(x1, x5, 0.6), small_node_relative(x2, x5, 0.6), small_node_relative(x2, x4, 0.6), small_node_relative(x3, x4, 0.6), small_node_relative(x1, x4, 0.6), ]) labels = [ label_small_node("$f_{rp}(C_1, p_5)$", small_nodes[-5], 135.0, hd, 1.4 * rectangle_width) ] return [nodes, small_nodes, connections, labels] e = [ figure_a(), figure_b(), figure_c(), figure_d(), figure_e(), figure_f(), ] stz.distribute_vertically_with_spacing(e[:3][::-1], 2.0) stz.distribute_vertically_with_spacing(e[3:][::-1], 2.0) stz.distribute_horizontally_with_spacing([e[:3], e[3:]], 2.0) stz.draw_to_tikz_standalone(e, "factor_graphs.tex")