def save_plots(target_state, best_state, cost_progress, *, modes, offset=-0.11, l=5, out_dir='sim_results', ID='state_learner', **kwargs): """Generate and save plots""" if modes == 1: # generate a wigner function plot of the target state fig1, ax1 = wigner_3D_plot(target_state, offset=offset, l=l) fig1.savefig(os.path.join(out_dir, ID+'_targetWigner.png')) # generate a wigner function plot of the learnt state fig2, ax2 = wigner_3D_plot(best_state, offset=offset, l=l) fig2.savefig(os.path.join(out_dir, ID+'_learntWigner.png')) # generate a wavefunction plot of the target state figW1, axW1 = wavefunction_plot(target_state, l=l) figW1.savefig(os.path.join(out_dir, ID+'_targetWavefunction.png')) # generate a wavefunction plot of the learnt state figW2, axW2 = wavefunction_plot(best_state, l=l) figW2.savefig(os.path.join(out_dir, ID+'_learntWavefunction.png')) elif modes == 2: # generate a 3D wavefunction plot of the target state figW1, axW1 = two_mode_wavefunction_plot(target_state, l=l) figW1.savefig(os.path.join(out_dir, ID+'_targetWavefunction.png')) # generate a 3D wavefunction plot of the learnt state figW2, axW2 = two_mode_wavefunction_plot(best_state, l=l) figW2.savefig(os.path.join(out_dir, ID+'_learntWavefunction.png')) # generate a cost function plot figC, axC = plot_cost(cost_progress) figC.savefig(os.path.join(out_dir, ID+'_cost.png'))
def save_plots(target_unitary, learnt_unitary, eq_state_learnt, eq_state_target, cost_progress, *, modes, offset=-0.11, l=5, out_dir='sim_results', ID='gate_synthesis', **kwargs): """Generate and save plots""" square = not kwargs.get('maps_outside', True) if modes == 1: # generate a wigner function plot of the target state fig1, ax1 = wigner_3D_plot(eq_state_target, offset=offset, l=l) fig1.savefig(os.path.join(out_dir, ID + '_targetWigner.png')) # generate a wigner function plot of the learnt state fig2, ax2 = wigner_3D_plot(eq_state_learnt, offset=offset, l=l) fig2.savefig(os.path.join(out_dir, ID + '_learntWigner.png')) # generate a matrix plot of the target and learnt unitaries figW1, axW1 = one_mode_unitary_plots(target_unitary, learnt_unitary, square=square) figW1.savefig(os.path.join(out_dir, ID + '_unitaryPlot.png')) elif modes == 2: # generate a 3D wavefunction plot of the target state figW1, axW1 = two_mode_wavefunction_plot(eq_state_target, l=l) figW1.savefig(os.path.join(out_dir, ID + '_targetWavefunction.png')) # generate a 3D wavefunction plot of the learnt state figW2, axW2 = two_mode_wavefunction_plot(eq_state_learnt, l=l) figW2.savefig(os.path.join(out_dir, ID + '_learntWavefunction.png')) # generate a matrix plot of the target and learnt unitaries figM1, axM1 = two_mode_unitary_plots(target_unitary, learnt_unitary, square=square) figM1.savefig(os.path.join(out_dir, ID + '_unitaryPlot.png')) # generate a cost function plot figC, axC = plot_cost(cost_progress) figC.savefig(os.path.join(out_dir, ID + '_cost.png'))