@dataclass class Nodes: input_m0: str = '$M_0$' input_mK: str = '$M_K$' encoder: str = 'Encoder' lr: str = 'Joint Distribution' decoder: str = 'Decoder' out_m0: str = '$M_0$' out_mK: str = '$M_K$' points: str = r'\ldots' nodes = Nodes() pic = tikz.Picture('modalities/.style={rectangle, draw=green!60, fill=green!5, very thick, minimum size=10mm},' 'model/.style={rectangle, draw=red!60, fill=red!5, very thick, minimum size=20mm},' 'lr/.style={ellipse, draw=blue!60, fill=blue!5, very thick, minimum size=15mm}') pic.set_node(text=nodes.input_m0, options='modalities', name='input_m0') pic.set_node(text=nodes.points, options=' below of=input_m0', name='input_points') pic.set_node(text=nodes.input_mK, options='modalities, below of=input_points', name='input_mK') pic.set_node(text=nodes.encoder, options='model, right of=input_points, xshift=2cm', name='encoder') pic.set_node(text=nodes.lr, options='lr, right of=encoder, xshift=3cm', name='lr') pic.set_node(text=nodes.decoder, options='model, right of=lr, xshift=3cm', name='decoder') pic.set_node(text=nodes.points, options=' right of=decoder, xshift=2cm', name='out_points') pic.set_node(text=nodes.out_m0, options='modalities, above of=out_points', name='out_m0') pic.set_node(text=nodes.out_mK, options='modalities, below of=out_points', name='out_mK') pic.set_line('input_m0', 'encoder') pic.set_line('input_points', 'encoder') pic.set_line('input_mK', 'encoder')
q1_tilde: str = r'$\tilde{q}_{\phi_1}$' q2_tilde: str = r'$\tilde{q}_{\phi_2}$' q12_tilde: str = r'$\tilde{q}_{\phi_{12}}$' gfm: str = r'$f$-Mean' z: str = r'joint\\ posterior' points: str = r'\ldots' moe: str = r'\textbf{MoE}' nodes = Nodes() pic = tikz.Picture( 'gfm/.style={rectangle, draw=red!60, fill=red!5, very thick, minimum size=10mm},' 'MoE/.style={rectangle, draw=red!60, fill=red!5, very thick, , minimum height=20mm, minimum width=10mm},' # 'lr/.style={ellipse, draw=blue!60, fill=blue!5, very thick, minimum size=15mm},' 'm0/.style={regular polygon,regular polygon sides=4, draw=green!60, fill=green!5, very thick, minimum size=28mm},' 'm0_dis/.style={circle, draw=green!60, fill=green!5, very thick, minimum size=10mm},' 'm1/.style={regular polygon,regular polygon sides=4, draw=orange!60, fill=orange!5, very thick, minimum size=28mm},' 'm1_distr/.style={circle, draw=orange!60, fill=orange!5, very thick, minimum size=10mm},' 'lr/.style={circle, draw=gray!60, fill=gray!5, very thick, minimum size=5mm},' 'subset/.style={circle, draw=gray!60, fill=gray!5, very thick, minimum size=5mm},' ) pic.set_node(text=nodes.input_m0, name='input_m0') pic.set_node(text=nodes.q1, options='m0_dis, right of=input_m0, xshift=1.5cm', name='q1') pic.set_node(text=nodes.input_m1, options='below of=input_m0, yshift=-2cm', name='input_m1') pic.set_node(text=nodes.q2, options='m1_distr, right of=input_m1, xshift=1.5cm', name='q2') pic.set_node(text=nodes.gfm, options='gfm, right of=q1, xshift=1cm,yshift=-1.5cm, align=center', name='gfm12') pic.set_node(text=nodes.gfm, options='gfm, above of=gfm12,yshift=1.5cm, align=center', name='gfm1') pic.set_node(text=nodes.gfm, options='gfm, below of=gfm12,yshift=-1.5cm, align=center', name='gfm2')
def make_cond_gen_fig_polymnist(which: str, methods: List[str], dataset: str): nbr_examples = {"polymnist": 10, "mimic": 5} split = which.split('__') in_mods = split[0] out_mod = split[-1] input_samples_dir = Path( f'data/thesis/{dataset}/{methods[0]}/cond_gen_plots/input_samples') cond_samples_path = { method: Path(f'data/thesis/{dataset}/{method}/cond_gen_plots/{in_mods}') for method in methods } pic = tikz.Picture() plot_xshift_base = 3 yshift = 0 for in_mod in in_mods.split('_'): xshift = plot_xshift_base for class_idx in range(nbr_examples[dataset]): img_name = f"{in_mod}_{class_idx}" pic.set_node( text= f'\\includegraphics[width=2cm]{{{input_samples_dir / img_name}}}', options=f'xshift={xshift}cm, yshift=-{yshift}cm', name=f'inmod_{img_name}', ) xshift += 2.1 yshift += 2 pic.set_node(text=r'\Large{\textbf{Input}}', options=f'yshift=-{np.floor(yshift / (2 * 2))}cm') yshift += 1.5 for method in methods: if method == 'mogfm_amortized': pic.set_node( text= fr'\Large{{\textbf{{mogfm}}}}\\\Large{{\textbf{{amortized}}}}', options=f'yshift=-{yshift}cm, align=center') else: pic.set_node(text=fr'\Large{{\textbf{{{tex_escape(method)}}}}}', options=f'yshift=-{yshift}cm') xshift = plot_xshift_base for class_idx in range(nbr_examples[dataset]): img_name = f"{out_mod}_{class_idx}" pic.set_node( text= f'\\includegraphics[width=2cm]{{{cond_samples_path[method] / img_name}}}', options=f'xshift={xshift}cm, yshift=-{yshift}cm', name=f'outmod_{img_name}', ) xshift += 2.1 yshift += 2.1 return pic.make()
def make_cond_gen_fig_iw_comp(which: str, methods: List[str], dataset: str = 'polymnist'): nbr_examples = { "polymnist": 10, } split = which.split('__') in_mods = split[0] out_mod = split[-1] input_samples_dir = Path( f'data/thesis/iw_comp/{methods[0]}/cond_gen_plots/K3/input_samples') def get_cond_gen_samples_path(K: int, method: str): if K == 1: return Path( f'data/thesis/{dataset}/{method.replace("iw", "")}/cond_gen_plots/{in_mods}' ) else: return Path( f'data/thesis/iw_comp/{method}/cond_gen_plots/K{K}/{in_mods}') pic = tikz.Picture() plot_xshift_base = 3.5 yshift = 0 for in_mod in in_mods.split('_'): xshift = plot_xshift_base for class_idx in range(nbr_examples[dataset]): img_name = f"{in_mod}_{class_idx}" pic.set_node( text= f'\\includegraphics[width=2cm]{{{input_samples_dir / img_name}}}', options=f'xshift={xshift}cm, yshift=-{yshift}cm', name=f'inmod_{img_name}', ) xshift += 2.1 yshift += 2 pic.set_node(text=r'\Large{\textbf{Input}}', options=f'yshift=-{np.floor(yshift / (2 * 2))}cm') yshift += 1.5 for K in [1, 3, 5]: for method in methods: pic.set_node( text=fr'\Large{{\textbf{{{tex_escape(method)} (K={K})}}}}', options=f'yshift=-{yshift}cm') xshift = plot_xshift_base for class_idx in range(nbr_examples[dataset]): img_name = f"{out_mod}_{class_idx}" pic.set_node( text= f'\\includegraphics[width=2cm]{{{get_cond_gen_samples_path(K, method) / img_name}}}', options=f'xshift={xshift}cm, yshift=-{yshift}cm', name=f'outmod_{img_name}', ) xshift += 2.1 yshift += 2.1 yshift += 0.5 return pic.make()
from pathlib import Path from scripts_ import tikz data_path = Path('data/thesis/mimic') mofop_dir = data_path / 'mofop/cond_gen_plots/Lateral_PA' mopoe_dir = data_path / 'mopoe/cond_gen_plots/Lateral_PA' mopgfm_dir = data_path / 'mopgfm/cond_gen_plots/Lateral_PA' input_samples_dir = data_path / 'mopoe/cond_gen_plots/input_samples' pic = tikz.Picture() imgsize = 3 pic.set_node( text= f'\\includegraphics[width={imgsize}cm]{{{input_samples_dir / "PA_1.png"}}}', name='inmod_PA', ) pic.set_node(text=r"Input PA", options="above of=inmod_PA, yshift=0.8cm") pic.set_node( text= f'\\includegraphics[width={imgsize}cm]{{{input_samples_dir / "Lateral_1.png"}}}', options='right of=inmod_PA, xshift=2.5cm', name='inmod_lat', ) pic.set_node(text=r"Input Lateral", options="above of=inmod_lat, yshift=0.8cm") pic.set_node( text=f'\\includegraphics[width={imgsize}cm]{{{mopoe_dir / "PA_1.png"}}}',
def make_graph(with_red_circle: bool = False): cond_samples_path = Path('data/pgfm/cond_gen_examples') input_samples_dir = cond_samples_path / 'input_samples' @dataclass class Nodes: input_m1: str = f'\\includegraphics[width=2cm]{{{str(input_samples_dir / "m2.png")}}}' input_m0: str = f'\\includegraphics[width=2cm]{{{str(input_samples_dir / "m1.png")}}}' output__m1m2_m2: str = f'\\includegraphics[width=2cm]{{{str(cond_samples_path / "m1_m2" / "m2.png")}}}' output__m1m2_m1: str = f'\\includegraphics[width=2cm]{{{str(cond_samples_path / "m1_m2" / "m1.png")}}}' output__m2_m1: str = f'\\includegraphics[width=2cm]{{{str(cond_samples_path / "m2" / "m1.png")}}}' output__m1_m2: str = f'\\includegraphics[width=2cm]{{{str(cond_samples_path / "m1" / "m2.png")}}}' q1: str = r'$q_{\phi_1}$' q2: str = r'$q_{\phi_2}$' q1_tilde: str = r'$\tilde{q}_{\phi_1}$' q2_tilde: str = r'$\tilde{q}_{\phi_2}$' q12_tilde: str = r'$\tilde{q}_{\phi_{12}}$' gfm: str = r'$f$-Mean' points: str = r'\ldots' nodes = Nodes() pic = tikz.Picture( 'gfm/.style={rectangle, draw=red!60, fill=red!5, very thick, minimum size=10mm},' # 'lr/.style={ellipse, draw=blue!60, fill=blue!5, very thick, minimum size=15mm},' 'm0/.style={regular polygon,regular polygon sides=4, draw=green!60, fill=green!5, very thick, minimum size=28mm},' 'm0_dis/.style={circle, draw=green!60, fill=green!5, very thick, minimum size=10mm},' 'm1/.style={regular polygon,regular polygon sides=4, draw=orange!60, fill=orange!5, very thick, minimum size=28mm},' 'm1_distr/.style={circle, draw=orange!60, fill=orange!5, very thick, minimum size=10mm},' 'lr/.style={circle, draw=gray!60, fill=gray!5, very thick, minimum size=15mm},' 'subset/.style={circle, draw=gray!60, fill=gray!5, very thick, minimum size=5mm},' ) pic.set_node(text=nodes.input_m0, name='input_m0') pic.set_node(text=nodes.q1, options='m0_dis, right of=input_m0, xshift=1.5cm', name='q1') pic.set_node(text=nodes.input_m1, options='below of=input_m0, yshift=-2cm', name='input_m1') pic.set_node(text=nodes.q2, options='m1_distr, right of=input_m1, xshift=1.5cm', name='q2') pic.set_node( text=nodes.gfm, options='gfm, right of=q1, xshift=1cm,yshift=-1.5cm, align=center', name='gfm') pic.set_node(text=nodes.q12_tilde, options='lr, right of=gfm, xshift=1.5cm', name='q12_tilde') pic.set_node(text=nodes.output__m1m2_m2, options='right of=q12_tilde, xshift=2cm,yshift=-1.5cm', name='output__m1m2_m2') pic.set_node(text=nodes.output__m1_m2, options='right of=output__m1m2_m2, xshift=1.5cm', name='output__m1_m2') pic.set_node(text=nodes.output__m1m2_m1, options='above of=output__m1m2_m2, yshift=2cm', name='output__m1m2_m1') pic.set_node(text=nodes.output__m2_m1, options='right of=output__m1m2_m1, xshift=1.5cm', name='output__m2_m1') if with_red_circle: pic.set_node( options= 'right of=output__m1m2_m2, xshift=1.5cm, yshift=1.5cm, ellipse, draw=red!100, line width=2pt, minimum height=70mm, minimum width=30mm' ) pic.set_line('input_m0', 'q1', label=r'$enc_1$', label_pos='south') pic.set_line('input_m1', 'q2', label=r'$enc_2$', label_pos='south') pic.set_line('q1', 'gfm', label=r'\textcolor{green}{$\mu_1$}', edge_options='bend right=-10', label_pos='south') pic.set_line('q1', 'gfm', label=r'\textcolor{green}{$\sigma_1$}', edge_options='bend right=10', label_pos='north') pic.set_line('q2', 'gfm', label=r'\textcolor{orange}{$\mu_2$}', edge_options='bend right=10', label_pos='north') pic.set_line('q2', 'gfm', label=r'\textcolor{orange}{$\sigma_2$}', edge_options='bend right=-10', label_pos='south') pic.set_line('gfm', 'q12_tilde') pic.set_line('q12_tilde', 'output__m1m2_m2', label=r'$dec_2$', label_pos='north', edge_options='bend right=30') # pic.set_line('z', 'output__m1_m2', label=r'$dec_2$', label_pos='north, rotate=-45', edge_options='bend right=50') pic.set_line('q12_tilde', 'output__m1m2_m1', label=r'$dec_1$\ ', label_pos='south, rotate=10', edge_options='bend left=30') # pic.set_line('z', 'output__m2_m1', label=r'$dec_1$\ ', label_pos='south, rotate=45', edge_options='bend left=50') output = pic.make() print(output)