def run(args): # set up args if args.cuda and torch.cuda.is_available(): device = torch.device('cuda') args.cuda = True else: device = torch.device('cpu') args.cuda = False if args.train_mode == 'thermo' or args.train_mode == 'thermo_wake': partition = util.get_partition(args.num_partitions, args.partition_type, args.log_beta_min, device) util.print_with_time('device = {}'.format(device)) util.print_with_time(str(args)) # save args save_dir = util.get_save_dir() args_path = util.get_args_path(save_dir) util.save_object(args, args_path) # data binarized_mnist_train, binarized_mnist_valid, binarized_mnist_test = \ data.load_binarized_mnist(where=args.where) data_loader = data.get_data_loader(binarized_mnist_train, args.batch_size, device) valid_data_loader = data.get_data_loader(binarized_mnist_valid, args.valid_batch_size, device) test_data_loader = data.get_data_loader(binarized_mnist_test, args.test_batch_size, device) train_obs_mean = torch.tensor(np.mean(binarized_mnist_train, axis=0), device=device, dtype=torch.float) # init models util.set_seed(args.seed) generative_model, inference_network = util.init_models( train_obs_mean, args.architecture, device) # optim optim_kwargs = {'lr': args.learning_rate} # train if args.train_mode == 'ws': train_callback = train.TrainWakeSleepCallback( save_dir, args.num_particles * args.batch_size, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_sleep(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, optim_kwargs, train_callback) elif args.train_mode == 'ww': train_callback = train.TrainWakeWakeCallback( save_dir, args.num_particles, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_wake(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, optim_kwargs, train_callback) elif args.train_mode == 'reinforce' or args.train_mode == 'vimco': train_callback = train.TrainIwaeCallback( save_dir, args.num_particles, args.train_mode, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, optim_kwargs, train_callback) elif args.train_mode == 'thermo': train_callback = train.TrainThermoCallback( save_dir, args.num_particles, partition, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_thermo(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, partition, optim_kwargs, train_callback) elif args.train_mode == 'thermo_wake': train_callback = train.TrainThermoWakeCallback( save_dir, args.num_particles, test_data_loader, args.eval_num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_thermo_wake(generative_model, inference_network, data_loader, args.num_iterations, args.num_particles, partition, optim_kwargs, train_callback) # eval validation train_callback.valid_log_p, train_callback.valid_kl = train.eval_gen_inf( generative_model, inference_network, valid_data_loader, args.eval_num_particles) # save models and stats util.save_checkpoint(save_dir, iteration=None, generative_model=generative_model, inference_network=inference_network) stats_path = util.get_stats_path(save_dir) util.save_object(train_callback, stats_path)
def plot_models(): saving_iterations = np.arange(100) * 1000 num_iterations_to_plot = 3 iterations_to_plot = saving_iterations[np.floor( np.linspace(0, 99, num=num_iterations_to_plot)).astype(int)] num_test_x = 3 num_particles_list = [2, 20] seed = seed_list[0] model_folder = util.get_most_recent_model_folder_args_match( seed=seed_list[0], train_mode=train_mode_list[0], num_particles=num_particles_list[0], init_near=init_near) args = util.load_object(util.get_args_path(model_folder)) _, _, true_generative_model = util.init_models(args) test_xs = np.linspace(0, 19, num=num_test_x) * 10 nrows = num_iterations_to_plot ncols = len(num_particles_list) * (num_test_x + 1) fig, axss = plt.subplots(nrows, ncols, sharex=True, sharey=True) width = 5.5 ax_width = width / ncols height = nrows * ax_width fig.set_size_inches(width, height) for iteration_idx, iteration in enumerate(iterations_to_plot): axss[iteration_idx, 0].set_ylabel('Iter. {}'.format(iteration)) for num_particles_idx, num_particles in enumerate(num_particles_list): ax = axss[iteration_idx, num_particles_idx * (num_test_x + 1)] ax.set_xticks([]) ax.set_xticklabels([]) ax.set_yticks([]) ax.set_yticklabels([]) ax.set_ylim(0, 8) ax.set_xlim(0, 20) if iteration_idx == 0: ax.set_title(r'$p_\theta(z)$') # true generative model i = 0 plot_hinton(ax, true_generative_model.get_latent_params().data.numpy(), 8 - i, 8 - i - 1, 0, 20, color='black') # learned generative models for train_mode_idx, train_mode in enumerate(train_mode_list): label = labels[train_mode_idx] color = colors[train_mode_idx] model_folder = util.get_most_recent_model_folder_args_match( seed=seed, train_mode=train_mode, num_particles=num_particles, init_near=init_near) if model_folder is not None: generative_model, _ = util.load_models(model_folder, iteration=iteration) if generative_model is not None: plot_hinton( ax, generative_model.get_latent_params().data.numpy(), 8 - train_mode_idx - 1, 8 - train_mode_idx - 2, 0, 20, label=label, color=color) # inference network for test_x_idx, test_x in enumerate(test_xs): ax = axss[iteration_idx, num_particles_idx * (num_test_x + 1) + test_x_idx + 1] ax.set_xticks([]) ax.set_xticklabels([]) ax.set_yticks([]) ax.set_yticklabels([]) ax.set_ylim(0, 8) ax.set_xlim(0, 20) test_x_tensor = torch.tensor(test_x, dtype=torch.float, device=args.device).unsqueeze(0) if iteration_idx == 0: ax.set_title(r'$q_\phi(z | x = {0:.0f})$'.format(test_x)) # true plot_hinton(ax, true_generative_model.get_posterior_probs( test_x_tensor)[0].data.numpy(), 8 - i, 8 - i - 1, 0, 20, color='black') # learned for train_mode_idx, train_mode in enumerate(train_mode_list): label = labels[train_mode_idx] color = colors[train_mode_idx] model_folder = \ util.get_most_recent_model_folder_args_match( seed=seed, train_mode=train_mode, num_particles=num_particles, init_near=init_near) if model_folder is not None: _, inference_network = util.load_models( model_folder, iteration=iteration) if inference_network is not None: plot_hinton(ax, inference_network.get_latent_params( test_x_tensor)[0].data.numpy(), 8 - train_mode_idx - 1, 8 - train_mode_idx - 2, 0, 20, label=label, color=color) for num_particles_idx, num_particles in enumerate(num_particles_list): ax = axss[0, num_particles_idx * (num_test_x + 1) + (num_test_x + 1) // 2] ax.text(0, 1.25, '$K = {}$'.format(num_particles), fontsize=SMALL_SIZE, verticalalignment='bottom', horizontalalignment='center', transform=ax.transAxes) handles = [mpatches.Rectangle((0, 0), 1, 1, color='black', label='True')] for color, label in zip(colors, labels): handles.append( mpatches.Rectangle((0, 0), 1, 1, color=color, label=label)) axss[-1, ncols // 2].legend(bbox_to_anchor=(0, -0.1), loc='upper center', ncol=len(handles), handles=handles) for ax in axss[-1]: ax.set_xlabel(r'$z$', labelpad=0.5) fig.tight_layout(pad=0) if not os.path.exists('./plots/'): os.makedirs('./plots/') filename = './plots/models.pdf' fig.savefig(filename, bbox_inches='tight') print('Saved to {}'.format(filename))
def plot_model_movie(): num_test_x = 5 num_particles_list = [2, 5, 10, 20] seed = seed_list[0] model_folder = util.get_most_recent_model_folder_args_match( seed=seed_list[0], train_mode=train_mode_list[0], num_particles=num_particles_list[0], init_near=init_near) args = util.load_object(util.get_args_path(model_folder)) _, _, true_generative_model = util.init_models(args) test_xs = np.linspace(0, 19, num=num_test_x) * 10 nrows = len(num_particles_list) ncols = num_test_x + 1 width = 5.5 ax_width = width / ncols height = nrows * ax_width fig, axss = plt.subplots(nrows, ncols, sharex=True, sharey=True, dpi=300) fig.set_size_inches(width, height) for num_particles_idx, num_particles in enumerate(num_particles_list): axss[num_particles_idx, 0].set_ylabel('$K = {}$'.format(num_particles), fontsize=SMALL_SIZE) handles = [mpatches.Rectangle((0, 0), 1, 1, color='black', label='True')] for color, label in zip(colors, labels): handles.append( mpatches.Rectangle((0, 0), 1, 1, color=color, label=label)) axss[-1, ncols // 2].legend(bbox_to_anchor=(0, -0.05), loc='upper center', ncol=len(handles), handles=handles) axss[0, 0].set_title(r'$p_\theta(z)$') for test_x_idx, test_x in enumerate(test_xs): axss[0, 1 + test_x_idx].set_title( r'$q_\phi(z | x = {0:.0f})$'.format(test_x)) for ax in axss[-1]: ax.set_xlabel(r'$z$', labelpad=0.5) for axs in axss: for ax in axs: ax.set_xticks([]) ax.set_xticklabels([]) ax.set_yticks([]) ax.set_yticklabels([]) ax.set_ylim(0, 8) ax.set_xlim(0, 20) # title = fig.suptitle('Iteration 0') t = axss[0, ncols // 2].text(0, 1.23, 'Iteration 0', horizontalalignment='center', verticalalignment='center', transform=axss[0, ncols // 2].transAxes, fontsize=MEDIUM_SIZE) fig.tight_layout(pad=0, rect=[0.01, 0.04, 0.99, 0.96]) def update(frame): result = [] iteration_idx = frame iteration = iteration_idx * 1000 t.set_text('Iteration {}'.format(iteration)) result.append(t) for axs in axss: for ax in axs: result.append( ax.add_artist( mpatches.Rectangle((0, 0), 20, 8, color='white'))) for num_particles_idx, num_particles in enumerate(num_particles_list): ax = axss[num_particles_idx, 0] # true generative model i = 0 plot_hinton(ax, true_generative_model.get_latent_params().data.numpy(), 8 - i, 8 - i - 1, 0, 20, color='black') # learned generative models for train_mode_idx, train_mode in enumerate(train_mode_list): label = labels[train_mode_idx] color = colors[train_mode_idx] model_folder = util.get_most_recent_model_folder_args_match( seed=seed, train_mode=train_mode, num_particles=num_particles, init_near=init_near) if model_folder is not None: generative_model, _ = util.load_models(model_folder, iteration=iteration) if generative_model is not None: plot_hinton( ax, generative_model.get_latent_params().data.numpy(), 8 - train_mode_idx - 1, 8 - train_mode_idx - 2, 0, 20, label=label, color=color) result += ax.artists # inference network for test_x_idx, test_x in enumerate(test_xs): ax = axss[num_particles_idx, test_x_idx + 1] test_x_tensor = torch.tensor(test_x, dtype=torch.float, device=args.device).unsqueeze(0) # true plot_hinton(ax, true_generative_model.get_posterior_probs( test_x_tensor)[0].data.numpy(), 8 - i, 8 - i - 1, 0, 20, color='black') # learned for train_mode_idx, train_mode in enumerate(train_mode_list): label = labels[train_mode_idx] color = colors[train_mode_idx] model_folder = \ util.get_most_recent_model_folder_args_match( seed=seed, train_mode=train_mode, num_particles=num_particles, init_near=init_near) if model_folder is not None: _, inference_network = util.load_models( model_folder, iteration=iteration) if inference_network is not None: plot_hinton(ax, inference_network.get_latent_params( test_x_tensor)[0].data.numpy(), 8 - train_mode_idx - 1, 8 - train_mode_idx - 2, 0, 20, label=label, color=color) result += ax.artists return result anim = FuncAnimation(fig, update, frames=np.arange(100), blit=True) if not os.path.exists('./plots/'): os.makedirs('./plots/') filename = './plots/model_movie.mp4' anim.save(filename, dpi=300) print('Saved to {}'.format(filename))
def run(args): # set up args args.device = None if args.cuda and torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu') args.num_mixtures = 20 if args.init_near: args.init_mixture_logits = np.ones(args.num_mixtures) else: args.init_mixture_logits = np.array( list(reversed(2 * np.arange(args.num_mixtures)))) args.softmax_multiplier = 0.5 if args.train_mode == 'concrete': args.relaxed_one_hot = True args.temperature = 3 else: args.relaxed_one_hot = False args.temperature = None temp = np.arange(args.num_mixtures) + 5 true_p_mixture_probs = temp / np.sum(temp) args.true_mixture_logits = \ np.log(true_p_mixture_probs) / args.softmax_multiplier util.print_with_time(str(args)) # save args model_folder = util.get_model_folder() args_filename = util.get_args_path(model_folder) util.save_object(args, args_filename) # init models util.set_seed(args.seed) generative_model, inference_network, true_generative_model = \ util.init_models(args) if args.train_mode == 'relax': control_variate = models.ControlVariate(args.num_mixtures) # init dataloader obss_data_loader = torch.utils.data.DataLoader( true_generative_model.sample_obs(args.num_obss), batch_size=args.batch_size, shuffle=True) # train if args.train_mode == 'mws': train_callback = train.TrainMWSCallback(model_folder, true_generative_model, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_mws(generative_model, inference_network, obss_data_loader, args.num_iterations, args.mws_memory_size, train_callback) if args.train_mode == 'ws': train_callback = train.TrainWakeSleepCallback( model_folder, true_generative_model, args.batch_size * args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_sleep(generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'ww': train_callback = train.TrainWakeWakeCallback(model_folder, true_generative_model, args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_wake(generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'dww': train_callback = train.TrainDefensiveWakeWakeCallback( model_folder, true_generative_model, args.num_particles, 0.2, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_defensive_wake_wake(0.2, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'reinforce' or args.train_mode == 'vimco': train_callback = train.TrainIwaeCallback( model_folder, true_generative_model, args.num_particles, args.train_mode, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'concrete': train_callback = train.TrainConcreteCallback( model_folder, true_generative_model, args.num_particles, args.num_iterations, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'relax': train_callback = train.TrainRelaxCallback(model_folder, true_generative_model, args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_relax(generative_model, inference_network, control_variate, obss_data_loader, args.num_iterations, args.num_particles, train_callback) # save models and stats util.save_models(generative_model, inference_network, model_folder) if args.train_mode == 'relax': util.save_control_variate(control_variate, model_folder) stats_filename = util.get_stats_path(model_folder) util.save_object(train_callback, stats_filename)