def load_errors(): len_history = 100 p_error = np.full((len(seed_list), len(train_mode_list), len(num_particles_list), len_history), np.nan, dtype=np.float) q_error = np.full((len(seed_list), len(train_mode_list), len(num_particles_list), len_history), np.nan, dtype=np.float) grad_std = np.full((len(seed_list), len(train_mode_list), len(num_particles_list), len_history), np.nan, dtype=np.float) for seed_idx, seed in enumerate(seed_list): for train_mode_idx, train_mode in enumerate(train_mode_list): for num_particles_idx, num_particles in enumerate( num_particles_list): print('{} {} {}'.format(seed, train_mode, num_particles)) model_folder = util.get_most_recent_model_folder_args_match( seed=seed, train_mode=train_mode, num_particles=num_particles, init_near=init_near) if model_folder is not None: stats = util.load_object(util.get_stats_path(model_folder)) p_error[seed_idx, train_mode_idx, num_particles_idx, :len(stats.p_error_history )] = stats.p_error_history q_error[seed_idx, train_mode_idx, num_particles_idx, :len(stats.q_error_history )] = stats.q_error_history grad_std[seed_idx, train_mode_idx, num_particles_idx, :len( stats.grad_std_history)] = stats.grad_std_history return p_error, q_error, grad_std
def __call__(self, iteration, theta_loss, phi_loss, generative_model, inference_network, memory, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}'.format( iteration, theta_loss, phi_loss)) self.theta_loss_history.append(theta_loss) self.phi_loss_history.append(phi_loss) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_path(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.model_folder, iteration, memory) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_history.append( util.get_q_error(self.true_generative_model, inference_network, self.test_obss)) # TODO # self.memory_error_history.append(util.get_memory_error( # self.true_generative_model, memory, generative_model, # self.test_obss)) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = ' '{:.3f}'.format(iteration, self.p_error_history[-1], self.q_error_history[-1]))
def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo, generative_model, inference_network, optimizer_theta, optimizer_phi): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = ' '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss, elbo)) self.wake_theta_loss_history.append(wake_theta_loss) self.wake_phi_loss_history.append(wake_phi_loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_path = util.get_stats_path(self.save_dir) util.save_object(self, stats_path) util.save_checkpoint(self.save_dir, iteration, generative_model=generative_model, inference_network=inference_network) if iteration % self.eval_interval == 0: log_p, kl = eval_gen_inf(generative_model, inference_network, self.test_data_loader, self.eval_num_particles) self.log_p_history.append(log_p) self.kl_history.append(kl) stats = util.OnlineMeanStd() for _ in range(10): generative_model.zero_grad() wake_theta_loss, elbo = losses.get_wake_theta_loss( generative_model, inference_network, self.test_obs, self.num_particles) wake_theta_loss.backward() theta_grads = [ p.grad.clone() for p in generative_model.parameters() ] inference_network.zero_grad() wake_phi_loss = losses.get_wake_phi_loss( generative_model, inference_network, self.test_obs, self.num_particles) wake_phi_loss.backward() phi_grads = [p.grad for p in inference_network.parameters()] stats.update(theta_grads + phi_grads) self.grad_std_history.append(stats.avg_of_means_stds()[1].item()) util.print_with_time( 'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format( iteration, self.log_p_history[-1], self.kl_history[-1]))
def __call__(self, iteration, loss, elbo, generative_model, inference_network, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_path = util.get_stats_path(self.save_dir) util.save_object(self, stats_path) util.save_checkpoint(self.save_dir, iteration, generative_model=generative_model, inference_network=inference_network) if iteration % self.eval_interval == 0: log_p, kl = eval_gen_inf(generative_model, inference_network, self.test_data_loader, self.eval_num_particles) _, renyi = eval_gen_inf_alpha(generative_model, inference_network, self.test_data_loader, self.eval_num_particles, self.alpha) self.log_p_history.append(log_p) self.kl_history.append(kl) self.renyi_history.append(renyi) stats = util.OnlineMeanStd() for _ in range(10): generative_model.zero_grad() inference_network.zero_grad() loss, elbo = losses.get_thermo_alpha_loss( generative_model, inference_network, self.test_obs, self.partition, self.num_particles, self.alpha, self.integration) loss.backward() stats.update([p.grad for p in generative_model.parameters()] + [p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1].item()) util.print_with_time( 'Iteration {} log_p = {:.3f}, kl = {:.3f}, renyi = {:.3f}'. format(iteration, self.log_p_history[-1], self.kl_history[-1], self.renyi_history[-1]))
def __call__(self, iteration, loss, elbo, generative_model, inference_network, optimizer): if iteration % self.logging_interval == 0: util.print_with_time( 'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format( iteration, loss, elbo)) self.loss_history.append(loss) self.elbo_history.append(elbo) if iteration % self.checkpoint_interval == 0: stats_filename = util.get_stats_path(self.model_folder) util.save_object(self, stats_filename) util.save_models(generative_model, inference_network, self.model_folder, iteration) if iteration % self.eval_interval == 0: self.p_error_history.append( util.get_p_error(self.true_generative_model, generative_model)) self.q_error_history.append( util.get_q_error(self.true_generative_model, inference_network, self.test_obss)) stats = util.OnlineMeanStd() for _ in range(10): inference_network.zero_grad() if self.train_mode == 'vimco': loss, elbo = losses.get_vimco_loss(generative_model, inference_network, self.test_obss, self.num_particles) elif self.train_mode == 'reinforce': loss, elbo = losses.get_reinforce_loss( generative_model, inference_network, self.test_obss, self.num_particles) loss.backward() stats.update([p.grad for p in inference_network.parameters()]) self.grad_std_history.append(stats.avg_of_means_stds()[1]) util.print_with_time( 'Iteration {} p_error = {:.3f}, q_error_to_true = ' '{:.3f}'.format(iteration, self.p_error_history[-1], self.q_error_history[-1]))
def run(args): # set up args if args.cuda and torch.cuda.is_available(): device = torch.device('cuda') args.cuda = True else: device = torch.device('cpu') args.cuda = False if args.train_mode == 'thermo' or args.train_mode == 'thermo_wake': partition = util.get_partition(args.num_partitions, args.partition_type, args.log_beta_min, device) util.print_with_time('device = {}'.format(device)) util.print_with_time(str(args)) # save args save_dir = util.get_save_dir() args_path = util.get_args_path(save_dir) util.save_object(args, args_path) # data binarized_mnist_train, binarized_mnist_valid, binarized_mnist_test = \ data.load_binarized_mnist(where=args.where) data_loader = data.get_data_loader(binarized_mnist_train, args.batch_size, device) valid_data_loader = data.get_data_loader(binarized_mnist_valid, args.valid_batch_size, device) test_data_loader = data.get_data_loader(binarized_mnist_test, args.test_batch_size, device) train_obs_mean = torch.tensor(np.mean(binarized_mnist_train, axis=0), device=device, dtype=torch.float) # init models util.set_seed(args.seed) generative_model, inference_network = util.init_models( train_obs_mean, args.architecture, device) # optim optim_kwargs = {'lr': args.learning_rate} # train if args.train_mode == 'ws': train_callback = train.TrainWakeSleepCallback( save_dir, args.num_particles * args.batch_size, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_sleep(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, optim_kwargs, train_callback) elif args.train_mode == 'ww': train_callback = train.TrainWakeWakeCallback( save_dir, args.num_particles, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_wake(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, optim_kwargs, train_callback) elif args.train_mode == 'reinforce' or args.train_mode == 'vimco': train_callback = train.TrainIwaeCallback( save_dir, args.num_particles, args.train_mode, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, optim_kwargs, train_callback) elif args.train_mode == 'thermo': train_callback = train.TrainThermoCallback( save_dir, args.num_particles, partition, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_thermo(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, partition, optim_kwargs, train_callback) elif args.train_mode == 'thermo_wake': train_callback = train.TrainThermoWakeCallback( save_dir, args.num_particles, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_thermo_wake(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, partition, optim_kwargs, train_callback) # eval validation train_callback.valid_log_p, train_callback.valid_kl = train.eval_gen_inf( generative_model, inference_network, valid_data_loader, args.eval_num_particles) # save models and stats util.save_checkpoint(save_dir, iteration=None, generative_model=generative_model, inference_network=inference_network) stats_path = util.get_stats_path(save_dir) util.save_object(train_callback, stats_path)
def run(args): # set up args args.device = None if args.cuda and torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu') args.num_mixtures = 20 if args.init_near: args.init_mixture_logits = np.ones(args.num_mixtures) else: args.init_mixture_logits = np.array( list(reversed(2 * np.arange(args.num_mixtures)))) args.softmax_multiplier = 0.5 if args.train_mode == 'concrete': args.relaxed_one_hot = True args.temperature = 3 else: args.relaxed_one_hot = False args.temperature = None temp = np.arange(args.num_mixtures) + 5 true_p_mixture_probs = temp / np.sum(temp) args.true_mixture_logits = \ np.log(true_p_mixture_probs) / args.softmax_multiplier util.print_with_time(str(args)) # save args model_folder = util.get_model_folder() args_filename = util.get_args_path(model_folder) util.save_object(args, args_filename) # init models util.set_seed(args.seed) generative_model, inference_network, true_generative_model = \ util.init_models(args) if args.train_mode == 'relax': control_variate = models.ControlVariate(args.num_mixtures) # init dataloader obss_data_loader = torch.utils.data.DataLoader( true_generative_model.sample_obs(args.num_obss), batch_size=args.batch_size, shuffle=True) # train if args.train_mode == 'mws': train_callback = train.TrainMWSCallback(model_folder, true_generative_model, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_mws(generative_model, inference_network, obss_data_loader, args.num_iterations, args.mws_memory_size, train_callback) if args.train_mode == 'ws': train_callback = train.TrainWakeSleepCallback( model_folder, true_generative_model, args.batch_size * args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_sleep(generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'ww': train_callback = train.TrainWakeWakeCallback(model_folder, true_generative_model, args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_wake(generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'dww': train_callback = train.TrainDefensiveWakeWakeCallback( model_folder, true_generative_model, args.num_particles, 0.2, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_defensive_wake_wake(0.2, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'reinforce' or args.train_mode == 'vimco': train_callback = train.TrainIwaeCallback( model_folder, true_generative_model, args.num_particles, args.train_mode, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'concrete': train_callback = train.TrainConcreteCallback( model_folder, true_generative_model, args.num_particles, args.num_iterations, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'relax': train_callback = train.TrainRelaxCallback(model_folder, true_generative_model, args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_relax(generative_model, inference_network, control_variate, obss_data_loader, args.num_iterations, args.num_particles, train_callback) # save models and stats util.save_models(generative_model, inference_network, model_folder) if args.train_mode == 'relax': util.save_control_variate(control_variate, model_folder) stats_filename = util.get_stats_path(model_folder) util.save_object(train_callback, stats_filename)
def main(args): if args.mode == 'efficiency': num_runs = 10 num_particles_list = [2, 5, 10, 50, 100, 500, 1000, 5000] num_partitions_list = [2, 5, 10, 50, 100, 500, 1000] path = './save/efficiency.pkl' (memory_thermo, time_thermo, memory_vimco, time_vimco, memory_reinforce, time_reinforce) = util.load_object(path) fig, axs = plt.subplots(1, 2, dpi=200, figsize=(6, 4)) # colors = ['C0', 'C1', 'C2', 'C4', 'C5', 'C6', 'C7', 'C8'] norm = matplotlib.colors.Normalize(vmin=0, vmax=len(num_particles_list)) cmap = matplotlib.cm.ScalarMappable(norm=norm, cmap=matplotlib.cm.Blues) cmap.set_array([]) colors = [cmap.to_rgba(i + 1) for i in range(len(num_particles_list))] for i, num_partitions in enumerate(num_partitions_list): axs[0].plot(num_particles_list, np.mean(time_thermo[:, i], axis=-1), label='thermo K={}'.format(num_partitions), color=colors[i], marker='x', linestyle='none') axs[0].plot(num_particles_list, np.mean(time_vimco, axis=-1), color='black', label='vimco', marker='o', linestyle='none', fillstyle='none') axs[0].plot(num_particles_list, np.mean(time_reinforce, axis=-1), color='black', label='reinforce', marker='v', linestyle='none', fillstyle='none') axs[0].set_xscale('log') axs[0].set_yscale('log') axs[0].set_xlabel('number of particles') axs[0].set_ylabel('time (seconds)') axs[0].grid(True) axs[0].grid(True, which='minor', linewidth=0.2) # axs[0].legend(bbox_to_anchor=(1.13, -0.19), loc='upper center', ncol=3) sns.despine(ax=axs[0]) # colors = ['C0', 'C1', 'C2', 'C4', 'C5', 'C6', 'C7', 'C8'] for i, num_partitions in enumerate(num_partitions_list): axs[1].plot(num_particles_list, np.mean(memory_thermo[:, i] / 1e6, axis=-1), label='thermo K={}'.format(num_partitions), color=colors[i], marker='x', linestyle='none') axs[1].plot(num_particles_list, np.mean(memory_vimco / 1e6, axis=-1), color='black', label='vimco', marker='o', linestyle='none', fillstyle='none') axs[1].plot(num_particles_list, np.mean(memory_reinforce / 1e6, axis=-1), color='black', label='reinforce', marker='v', linestyle='none', fillstyle='none') axs[1].set_xscale('log') axs[1].set_yscale('log') axs[1].set_xlabel('number of particles') axs[1].set_ylabel('memory (MB)') axs[-1].legend(fontsize=6, ncol=2) axs[1].grid(True) axs[1].grid(True, which='minor', linewidth=0.2) sns.despine(ax=axs[1]) fig.tight_layout() if not os.path.exists('./plots/'): os.makedirs('./plots/') filename = './plots/efficiency.pdf' fig.savefig(filename, bbox_inches='tight') print('saved to {}'.format(filename)) elif args.mode == 'insights': markersize = 3 learning_rate = 3e-4 architecture = 'linear_3' seed = 8 train_mode = 'thermo' num_particles_list = [2, 5, 10, 50] num_partitions_list = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50] log_beta_mins_1 = [-10, -1, -0.045757490560675115] log_beta_mins_2 = [ -5, -2, -1.6989700043360187, -1.5228787452803376, -1.3979400086720375, -1.3010299956639813, -1.2218487496163564, -1.1549019599857433, -1.0969100130080565, -1.0457574905606752, -1, -0.6989700043360187, -0.5228787452803375, -0.3979400086720376, -0.3010299956639812, -0.2218487496163564, -0.15490195998574313, -0.09691001300805639, -0.045757490560675115 ] num_iterations = 400 log_p_thermo_partition_sweep = np.full( (len(num_particles_list), len(log_beta_mins_1), len(num_partitions_list), num_iterations), np.nan) log_p_thermo_beta_sweep = np.full( (len(num_particles_list), len(log_beta_mins_2), num_iterations), np.nan) for num_particles_idx, num_particles in enumerate(num_particles_list): for log_beta_min_idx, log_beta_min in enumerate(log_beta_mins_1): for num_partitions_idx, num_partitions in enumerate( num_partitions_list): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_thermo_partition_sweep[ num_particles_idx, log_beta_min_idx, num_partitions_idx] = stats.log_p_history[: num_iterations] print('thermo {} ({} partitions) beta_min = 1e{} after' ' {} it: {}'.format(num_particles, num_partitions, log_beta_min, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') num_partitions = 2 for log_beta_min_idx, log_beta_min in enumerate(log_beta_mins_2): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_thermo_beta_sweep[ num_particles_idx, log_beta_min_idx] = stats.log_p_history[: num_iterations] print('thermo {} ({} partitions) beta_min = 1e{} after {}' ' it: {}'.format(num_particles, num_partitions, log_beta_min, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') fig, axs = plt.subplots(2, 2, dpi=200, figsize=(12, 7), sharey=True) for log_beta_min_idx, ax in zip(range(len(log_beta_mins_1)), [axs[0, 0], axs[0, 1], axs[1, 0]]): colors = ['C1', 'C2', 'C4', 'C5'] # ax = axs[log_beta_min_idx] for num_particles_idx, num_particles in enumerate( num_particles_list): ax.plot(num_partitions_list, log_p_thermo_partition_sweep[num_particles_idx, log_beta_min_idx, :, -1], color=colors[num_particles_idx], label=num_particles, marker='o', markersize=markersize, linestyle='solid', linewidth=0.7) ax.set_title(r'$\beta_1 = {:.0e}$'.format( 10**log_beta_mins_1[log_beta_min_idx])) # ax.set_xticks(np.arange(len(num_partitions_list))) # ax.set_xticklabels(num_partitions_list) ax.set_xlabel('number of partitions') ax.set_xticks(np.arange(0, max(num_partitions_list) + 1, 10)) ax = axs[1, 1] for num_particles_idx, num_particles in enumerate(num_particles_list): ax.plot(10**np.array(log_beta_mins_2), log_p_thermo_beta_sweep[num_particles_idx, :, -1], color=colors[num_particles_idx], label=num_particles, marker='o', markersize=markersize, linestyle='solid', linewidth=0.7) ax.set_xticks(np.arange(0, 1.1, 0.2)) ax.set_title('2 partitions') ax.set_xlabel(r'$\beta_1$') print(np.max(log_p_thermo_beta_sweep[..., -1], axis=-1)) print(np.argmax(log_p_thermo_beta_sweep[..., -1], axis=-1)) print([ log_beta_mins_2[i] for i in np.argmax(log_p_thermo_beta_sweep[..., -1], axis=-1) ]) print([ 10**log_beta_mins_2[i] for i in np.argmax(log_p_thermo_beta_sweep[..., -1], axis=-1) ]) # print(log_beta_mins_2[np.argmax(log_p_thermo_beta_sweep[..., -1], axis=-1)]) for axx in axs: for ax in axx: ax.grid(True, axis='y') for ax in axs[:, 0]: ax.set_ylim(top=-88) ax.set_ylabel(r'$\log p(x)$') axs[1, 1].legend(title='number of particles', ncol=2, loc='lower right') for axx in axs: for ax in axx: sns.despine(ax=ax, trim=True) # ax.('thermo') fig.tight_layout() if not os.path.exists('./plots/'): os.makedirs('./plots/') filename = './plots/insights.pdf' fig.savefig(filename, bbox_inches='tight') print('saved to {}'.format(filename)) elif args.mode == 'baselines': learning_rate = 3e-4 architecture = 'linear_3' seed = 8 non_thermo_train_modes = ['ww', 'vimco'] num_particles_list = [2, 5, 10, 50] num_partitions_list = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50] # log_beta_mins_1 = [-10, -1, -0.045757490560675115] log_beta_mins_2 = [ -5, -2, -1.6989700043360187, -1.5228787452803376, -1.3979400086720375, -1.3010299956639813, -1.2218487496163564, -1.1549019599857433, -1.0969100130080565, -1.0457574905606752, -1, -0.6989700043360187, -0.5228787452803375, -0.3979400086720376, -0.3010299956639812, -0.2218487496163564, -0.15490195998574313, -0.09691001300805639, -0.045757490560675115 ] num_iterations = 400 log_p_thermo_beta_sweep = np.full( (len(num_particles_list), len(log_beta_mins_2), num_iterations), np.nan) log_p_non_thermo = np.full((len(non_thermo_train_modes), len(num_particles_list), num_iterations), np.nan) train_mode = 'thermo' for num_particles_idx, num_particles in enumerate(num_particles_list): num_partitions = 2 for log_beta_min_idx, log_beta_min in enumerate(log_beta_mins_2): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_thermo_beta_sweep[ num_particles_idx, log_beta_min_idx] = stats.log_p_history print('thermo {} ({} partitions) beta_min = 1e{} after {}' ' it: {}'.format(num_particles, num_partitions, log_beta_min, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') seed = 7 log_beta_min = -10 learning_rate = 3e-4 num_partitions = 1 for train_mode_idx, train_mode in enumerate(non_thermo_train_modes): for num_particles_idx, num_particles in enumerate( num_particles_list): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_non_thermo[train_mode_idx, num_particles_idx, :len( stats.log_p_history)] = stats.log_p_history print('{} {} after {} it: {}'.format( train_mode, num_particles, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') fig, ax = plt.subplots(1, 1, dpi=200, figsize=(6, 4)) colors = ['C1', 'C2', 'C4', 'C5'] linestyles = ['dashed', 'dotted'] for train_mode_idx, train_mode in enumerate(non_thermo_train_modes): for num_particles_idx, num_particles in enumerate( num_particles_list): if train_mode == 'ww': label = 'rws' else: label = train_mode ax.plot(log_p_non_thermo[train_mode_idx, num_particles_idx], linestyle=linestyles[train_mode_idx], color=colors[num_particles_idx], label='{} {} ({:.2f})'.format( label, num_particles, log_p_non_thermo[train_mode_idx, num_particles_idx, -1])) # best_num_particles_idx = 3 # best_beta_idxs = [4, 5, 11] # best_beta_idxs = [0, 4, 7, 11] best_beta_idxs = [18, 5, 11, 12] for num_particles_idx, num_particles in enumerate(num_particles_list): best_beta_idx = best_beta_idxs[num_particles_idx] color = colors[num_particles_idx] ax.plot( log_p_thermo_beta_sweep[num_particles_idx, best_beta_idx], linestyle='solid', color=color, label='thermo S={}, K={}, $\\beta_1$={:.0e} ({:.2f})'.format( num_particles_list[num_particles_idx], 2, 10**(log_beta_mins_2[best_beta_idx]), log_p_thermo_beta_sweep[num_particles_idx, best_beta_idx, -1])) ax.set_ylim(-110) ax.grid(True, axis='y', linewidth=0.2) ax.legend(fontsize=6, ncol=3, frameon=False) ax.set_ylabel(r'$\log p(x)$') ax.set_xlabel('iteration') ax.xaxis.set_label_coords(0.5, -0.025) ax.set_xticks([0, num_iterations]) ax.set_xticklabels([0, '4e6']) sns.despine(ax=ax, trim=True) fig.tight_layout() if not os.path.exists('./plots/'): os.makedirs('./plots/') filename = './plots/baselines.pdf' fig.savefig(filename, bbox_inches='tight') print('saved to {}'.format(filename)) elif args.mode == 'grad_std': learning_rate = 3e-4 architecture = 'linear_3' seed = 8 non_thermo_train_modes = ['ww', 'vimco'] num_particles_list = [2, 5, 10, 50] num_partitions_list = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50] log_beta_mins_1 = [-10, -1, -0.045757490560675115] log_beta_mins_2 = [ -5, -2, -1.6989700043360187, -1.5228787452803376, -1.3979400086720375, -1.3010299956639813, -1.2218487496163564, -1.1549019599857433, -1.0969100130080565, -1.0457574905606752, -1, -0.6989700043360187, -0.5228787452803375, -0.3979400086720376, -0.3010299956639812, -0.2218487496163564, -0.15490195998574313, -0.09691001300805639, -0.045757490560675115 ] num_iterations = 400 log_p_thermo_beta_sweep = np.full( (len(num_particles_list), len(log_beta_mins_2), num_iterations), np.nan) log_p_non_thermo = np.full((len(non_thermo_train_modes), len(num_particles_list), num_iterations), np.nan) train_mode = 'thermo' for num_particles_idx, num_particles in enumerate(num_particles_list): num_partitions = 2 for log_beta_min_idx, log_beta_min in enumerate(log_beta_mins_2): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_thermo_beta_sweep[ num_particles_idx, log_beta_min_idx] = stats.grad_std_history print('thermo {} ({} partitions) beta_min = 1e{} after {}' ' it: {}'.format(num_particles, num_partitions, log_beta_min, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') seed = 7 log_beta_min = -10 learning_rate = 3e-4 num_partitions = 1 for train_mode_idx, train_mode in enumerate(non_thermo_train_modes): for num_particles_idx, num_particles in enumerate( num_particles_list): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_non_thermo[train_mode_idx, num_particles_idx, :len( stats.log_p_history)] = stats.grad_std_history print('{} {} after {} it: {}'.format( train_mode, num_particles, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') fig, ax = plt.subplots(1, 1, dpi=200, figsize=(6, 4)) colors = ['C1', 'C2', 'C4', 'C5'] linestyles = ['dashed', 'dotted'] for train_mode_idx, train_mode in enumerate(non_thermo_train_modes): for num_particles_idx, num_particles in enumerate( num_particles_list): if train_mode == 'ww': label = 'rws' else: label = train_mode ax.plot(log_p_non_thermo[train_mode_idx, num_particles_idx], linestyle=linestyles[train_mode_idx], color=colors[num_particles_idx], label='{} {} ({:.2f})'.format( label, num_particles, log_p_non_thermo[train_mode_idx, num_particles_idx, -1])) # best_num_particles_idx = 3 # best_beta_idxs = [4, 5, 11] # best_beta_idxs = [0, 4, 5, 11] best_beta_idxs = [18, 5, 11, 12] for num_particles_idx, num_particles in enumerate(num_particles_list): best_beta_idx = best_beta_idxs[num_particles_idx] color = colors[num_particles_idx] ax.plot( log_p_thermo_beta_sweep[num_particles_idx, best_beta_idx], linestyle='solid', color=color, label='thermo S={}, K={}, $\\beta_1$={:.0e} ({:.2f})'.format( num_particles_list[num_particles_idx], 2, 10**(log_beta_mins_2[best_beta_idx]), log_p_thermo_beta_sweep[num_particles_idx, best_beta_idx, -1])) ax.set_ylim(0, 20) ax.grid(True, axis='y', linewidth=0.2) ax.legend(fontsize=6, ncol=3, frameon=False) ax.set_ylabel(r'grad std') ax.set_xlabel('iteration') ax.xaxis.set_label_coords(0.5, -0.025) ax.set_xticks([0, num_iterations]) ax.set_xticklabels([0, '4e6']) sns.despine(ax=ax, trim=True) fig.tight_layout() if not os.path.exists('./plots/'): os.makedirs('./plots/') filename = './plots/grad_std.pdf' fig.savefig(filename, bbox_inches='tight') print('saved to {}'.format(filename)) elif args.mode == 'baselines_kl': learning_rate = 3e-4 architecture = 'linear_3' seed = 8 non_thermo_train_modes = ['ww', 'vimco'] num_particles_list = [2, 5, 10, 50] num_partitions_list = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50] # log_beta_mins_1 = [-10, -1, -0.045757490560675115] log_beta_mins_2 = [ -5, -2, -1.6989700043360187, -1.5228787452803376, -1.3979400086720375, -1.3010299956639813, -1.2218487496163564, -1.1549019599857433, -1.0969100130080565, -1.0457574905606752, -1, -0.6989700043360187, -0.5228787452803375, -0.3979400086720376, -0.3010299956639812, -0.2218487496163564, -0.15490195998574313, -0.09691001300805639, -0.045757490560675115 ] num_iterations = 400 log_p_thermo_beta_sweep = np.full( (len(num_particles_list), len(log_beta_mins_2), num_iterations), np.nan) log_p_non_thermo = np.full((len(non_thermo_train_modes), len(num_particles_list), num_iterations), np.nan) train_mode = 'thermo' for num_particles_idx, num_particles in enumerate(num_particles_list): num_partitions = 2 for log_beta_min_idx, log_beta_min in enumerate(log_beta_mins_2): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_thermo_beta_sweep[ num_particles_idx, log_beta_min_idx] = stats.kl_history print('thermo {} ({} partitions) beta_min = 1e{} after {}' ' it: {}'.format(num_particles, num_partitions, log_beta_min, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') seed = 7 log_beta_min = -10 learning_rate = 3e-4 num_partitions = 1 for train_mode_idx, train_mode in enumerate(non_thermo_train_modes): for num_particles_idx, num_particles in enumerate( num_particles_list): dir_ = util.get_most_recent_dir_args_match( train_mode=train_mode, architecture=architecture, learning_rate=learning_rate, num_particles=num_particles, num_partitions=num_partitions, log_beta_min=log_beta_min, seed=seed) if dir_ is not None: stats = util.load_object(util.get_stats_path(dir_)) log_p_non_thermo[train_mode_idx, num_particles_idx, :len( stats.kl_history)] = stats.kl_history print('{} {} after {} it: {}'.format( train_mode, num_particles, len(stats.log_p_history), stats.log_p_history[-1])) else: print('missing') fig, ax = plt.subplots(1, 1, dpi=200, figsize=(6, 4)) colors = ['C1', 'C2', 'C4', 'C5'] linestyles = ['dashed', 'dotted'] for train_mode_idx, train_mode in enumerate(non_thermo_train_modes): for num_particles_idx, num_particles in enumerate( num_particles_list): if train_mode == 'ww': label = 'rws' else: label = train_mode ax.plot(log_p_non_thermo[train_mode_idx, num_particles_idx], linestyle=linestyles[train_mode_idx], color=colors[num_particles_idx], label='{} {} ({:.2f})'.format( label, num_particles, log_p_non_thermo[train_mode_idx, num_particles_idx, -1])) # best_num_particles_idx = 3 # best_beta_idxs = [4, 5, 11] # best_beta_idxs = [0, 4, 5, 11] best_beta_idxs = [18, 5, 11, 12] for num_particles_idx, num_particles in enumerate(num_particles_list): best_beta_idx = best_beta_idxs[num_particles_idx] color = colors[num_particles_idx] ax.plot( log_p_thermo_beta_sweep[num_particles_idx, best_beta_idx], linestyle='solid', color=color, label='thermo S={}, K={}, $\\beta_1$={:.0e} ({:.2f})'.format( num_particles_list[num_particles_idx], 2, 10**(log_beta_mins_2[best_beta_idx]), log_p_thermo_beta_sweep[num_particles_idx, best_beta_idx, -1])) ax.set_ylim(5, 20) ax.grid(True, axis='y', linewidth=0.2) ax.legend(fontsize=6, ncol=3, frameon=False) ax.set_ylabel(r'KL(q || p)') ax.set_xlabel('iteration') ax.xaxis.set_label_coords(0.5, -0.025) ax.set_xticks([0, num_iterations]) ax.set_xticklabels([0, '4e6']) sns.despine(ax=ax, trim=True) fig.tight_layout() if not os.path.exists('./plots/'): os.makedirs('./plots/') filename = './plots/baselines_kl.pdf' fig.savefig(filename, bbox_inches='tight') print('saved to {}'.format(filename))