예제 #1
0
파일: run.py 프로젝트: yyht/rrws
def run(args):
    # set up args
    if args.cuda and torch.cuda.is_available():
        device = torch.device('cuda')
        args.cuda = True
    else:
        device = torch.device('cpu')
        args.cuda = False
    if args.train_mode == 'thermo' or args.train_mode == 'thermo_wake':
        partition = util.get_partition(args.num_partitions,
                                       args.partition_type, args.log_beta_min,
                                       device)
    util.print_with_time('device = {}'.format(device))
    util.print_with_time(str(args))

    # save args
    save_dir = util.get_save_dir()
    args_path = util.get_args_path(save_dir)
    util.save_object(args, args_path)

    # data
    binarized_mnist_train, binarized_mnist_valid, binarized_mnist_test = \
        data.load_binarized_mnist(where=args.where)
    data_loader = data.get_data_loader(binarized_mnist_train, args.batch_size,
                                       device)
    valid_data_loader = data.get_data_loader(binarized_mnist_valid,
                                             args.valid_batch_size, device)
    test_data_loader = data.get_data_loader(binarized_mnist_test,
                                            args.test_batch_size, device)
    train_obs_mean = torch.tensor(np.mean(binarized_mnist_train, axis=0),
                                  device=device,
                                  dtype=torch.float)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network = util.init_models(
        train_obs_mean, args.architecture, device)

    # optim
    optim_kwargs = {'lr': args.learning_rate}

    # train
    if args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            save_dir, args.num_particles * args.batch_size, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               data_loader, args.num_iterations,
                               args.num_particles, optim_kwargs,
                               train_callback)
    elif args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(
            save_dir, args.num_particles, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_wake(generative_model, inference_network, data_loader,
                              args.num_iterations, args.num_particles,
                              optim_kwargs, train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            save_dir, args.num_particles, args.train_mode, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         data_loader, args.num_iterations, args.num_particles,
                         optim_kwargs, train_callback)
    elif args.train_mode == 'thermo':
        train_callback = train.TrainThermoCallback(
            save_dir, args.num_particles, partition, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_thermo(generative_model, inference_network, data_loader,
                           args.num_iterations, args.num_particles, partition,
                           optim_kwargs, train_callback)
    elif args.train_mode == 'thermo_wake':
        train_callback = train.TrainThermoWakeCallback(
            save_dir, args.num_particles, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_thermo_wake(generative_model, inference_network,
                                data_loader, args.num_iterations,
                                args.num_particles, partition, optim_kwargs,
                                train_callback)

    # eval validation
    train_callback.valid_log_p, train_callback.valid_kl = train.eval_gen_inf(
        generative_model, inference_network, valid_data_loader,
        args.eval_num_particles)

    # save models and stats
    util.save_checkpoint(save_dir,
                         iteration=None,
                         generative_model=generative_model,
                         inference_network=inference_network)
    stats_path = util.get_stats_path(save_dir)
    util.save_object(train_callback, stats_path)
예제 #2
0
파일: plot.py 프로젝트: insperatum/rrws
def plot_models():
    saving_iterations = np.arange(100) * 1000
    num_iterations_to_plot = 3
    iterations_to_plot = saving_iterations[np.floor(
        np.linspace(0, 99, num=num_iterations_to_plot)).astype(int)]
    num_test_x = 3
    num_particles_list = [2, 20]
    seed = seed_list[0]
    model_folder = util.get_most_recent_model_folder_args_match(
        seed=seed_list[0],
        train_mode=train_mode_list[0],
        num_particles=num_particles_list[0],
        init_near=init_near)
    args = util.load_object(util.get_args_path(model_folder))
    _, _, true_generative_model = util.init_models(args)
    test_xs = np.linspace(0, 19, num=num_test_x) * 10

    nrows = num_iterations_to_plot
    ncols = len(num_particles_list) * (num_test_x + 1)
    fig, axss = plt.subplots(nrows, ncols, sharex=True, sharey=True)
    width = 5.5
    ax_width = width / ncols
    height = nrows * ax_width
    fig.set_size_inches(width, height)
    for iteration_idx, iteration in enumerate(iterations_to_plot):
        axss[iteration_idx, 0].set_ylabel('Iter. {}'.format(iteration))
        for num_particles_idx, num_particles in enumerate(num_particles_list):
            ax = axss[iteration_idx, num_particles_idx * (num_test_x + 1)]
            ax.set_xticks([])
            ax.set_xticklabels([])
            ax.set_yticks([])
            ax.set_yticklabels([])
            ax.set_ylim(0, 8)
            ax.set_xlim(0, 20)
            if iteration_idx == 0:
                ax.set_title(r'$p_\theta(z)$')

            # true generative model
            i = 0
            plot_hinton(ax,
                        true_generative_model.get_latent_params().data.numpy(),
                        8 - i,
                        8 - i - 1,
                        0,
                        20,
                        color='black')

            # learned generative models
            for train_mode_idx, train_mode in enumerate(train_mode_list):
                label = labels[train_mode_idx]
                color = colors[train_mode_idx]
                model_folder = util.get_most_recent_model_folder_args_match(
                    seed=seed,
                    train_mode=train_mode,
                    num_particles=num_particles,
                    init_near=init_near)
                if model_folder is not None:
                    generative_model, _ = util.load_models(model_folder,
                                                           iteration=iteration)
                    if generative_model is not None:
                        plot_hinton(
                            ax,
                            generative_model.get_latent_params().data.numpy(),
                            8 - train_mode_idx - 1,
                            8 - train_mode_idx - 2,
                            0,
                            20,
                            label=label,
                            color=color)

            # inference network
            for test_x_idx, test_x in enumerate(test_xs):
                ax = axss[iteration_idx, num_particles_idx * (num_test_x + 1) +
                          test_x_idx + 1]
                ax.set_xticks([])
                ax.set_xticklabels([])
                ax.set_yticks([])
                ax.set_yticklabels([])
                ax.set_ylim(0, 8)
                ax.set_xlim(0, 20)
                test_x_tensor = torch.tensor(test_x,
                                             dtype=torch.float,
                                             device=args.device).unsqueeze(0)
                if iteration_idx == 0:
                    ax.set_title(r'$q_\phi(z | x = {0:.0f})$'.format(test_x))

                # true
                plot_hinton(ax,
                            true_generative_model.get_posterior_probs(
                                test_x_tensor)[0].data.numpy(),
                            8 - i,
                            8 - i - 1,
                            0,
                            20,
                            color='black')

                # learned
                for train_mode_idx, train_mode in enumerate(train_mode_list):
                    label = labels[train_mode_idx]
                    color = colors[train_mode_idx]
                    model_folder = \
                        util.get_most_recent_model_folder_args_match(
                            seed=seed, train_mode=train_mode,
                            num_particles=num_particles, init_near=init_near)
                    if model_folder is not None:
                        _, inference_network = util.load_models(
                            model_folder, iteration=iteration)
                        if inference_network is not None:
                            plot_hinton(ax,
                                        inference_network.get_latent_params(
                                            test_x_tensor)[0].data.numpy(),
                                        8 - train_mode_idx - 1,
                                        8 - train_mode_idx - 2,
                                        0,
                                        20,
                                        label=label,
                                        color=color)

    for num_particles_idx, num_particles in enumerate(num_particles_list):
        ax = axss[0,
                  num_particles_idx * (num_test_x + 1) + (num_test_x + 1) // 2]
        ax.text(0,
                1.25,
                '$K = {}$'.format(num_particles),
                fontsize=SMALL_SIZE,
                verticalalignment='bottom',
                horizontalalignment='center',
                transform=ax.transAxes)

    handles = [mpatches.Rectangle((0, 0), 1, 1, color='black', label='True')]
    for color, label in zip(colors, labels):
        handles.append(
            mpatches.Rectangle((0, 0), 1, 1, color=color, label=label))
    axss[-1, ncols // 2].legend(bbox_to_anchor=(0, -0.1),
                                loc='upper center',
                                ncol=len(handles),
                                handles=handles)

    for ax in axss[-1]:
        ax.set_xlabel(r'$z$', labelpad=0.5)

    fig.tight_layout(pad=0)
    if not os.path.exists('./plots/'):
        os.makedirs('./plots/')
    filename = './plots/models.pdf'
    fig.savefig(filename, bbox_inches='tight')
    print('Saved to {}'.format(filename))
예제 #3
0
파일: plot.py 프로젝트: insperatum/rrws
def plot_model_movie():
    num_test_x = 5
    num_particles_list = [2, 5, 10, 20]
    seed = seed_list[0]
    model_folder = util.get_most_recent_model_folder_args_match(
        seed=seed_list[0],
        train_mode=train_mode_list[0],
        num_particles=num_particles_list[0],
        init_near=init_near)
    args = util.load_object(util.get_args_path(model_folder))
    _, _, true_generative_model = util.init_models(args)
    test_xs = np.linspace(0, 19, num=num_test_x) * 10

    nrows = len(num_particles_list)
    ncols = num_test_x + 1
    width = 5.5
    ax_width = width / ncols
    height = nrows * ax_width
    fig, axss = plt.subplots(nrows, ncols, sharex=True, sharey=True, dpi=300)
    fig.set_size_inches(width, height)

    for num_particles_idx, num_particles in enumerate(num_particles_list):
        axss[num_particles_idx, 0].set_ylabel('$K = {}$'.format(num_particles),
                                              fontsize=SMALL_SIZE)

    handles = [mpatches.Rectangle((0, 0), 1, 1, color='black', label='True')]
    for color, label in zip(colors, labels):
        handles.append(
            mpatches.Rectangle((0, 0), 1, 1, color=color, label=label))
    axss[-1, ncols // 2].legend(bbox_to_anchor=(0, -0.05),
                                loc='upper center',
                                ncol=len(handles),
                                handles=handles)

    axss[0, 0].set_title(r'$p_\theta(z)$')
    for test_x_idx, test_x in enumerate(test_xs):
        axss[0, 1 + test_x_idx].set_title(
            r'$q_\phi(z | x = {0:.0f})$'.format(test_x))
    for ax in axss[-1]:
        ax.set_xlabel(r'$z$', labelpad=0.5)

    for axs in axss:
        for ax in axs:
            ax.set_xticks([])
            ax.set_xticklabels([])
            ax.set_yticks([])
            ax.set_yticklabels([])
            ax.set_ylim(0, 8)
            ax.set_xlim(0, 20)
    # title = fig.suptitle('Iteration 0')
    t = axss[0, ncols // 2].text(0,
                                 1.23,
                                 'Iteration 0',
                                 horizontalalignment='center',
                                 verticalalignment='center',
                                 transform=axss[0, ncols // 2].transAxes,
                                 fontsize=MEDIUM_SIZE)

    fig.tight_layout(pad=0, rect=[0.01, 0.04, 0.99, 0.96])

    def update(frame):
        result = []
        iteration_idx = frame
        iteration = iteration_idx * 1000
        t.set_text('Iteration {}'.format(iteration))
        result.append(t)

        for axs in axss:
            for ax in axs:
                result.append(
                    ax.add_artist(
                        mpatches.Rectangle((0, 0), 20, 8, color='white')))
        for num_particles_idx, num_particles in enumerate(num_particles_list):
            ax = axss[num_particles_idx, 0]

            # true generative model
            i = 0
            plot_hinton(ax,
                        true_generative_model.get_latent_params().data.numpy(),
                        8 - i,
                        8 - i - 1,
                        0,
                        20,
                        color='black')

            # learned generative models
            for train_mode_idx, train_mode in enumerate(train_mode_list):
                label = labels[train_mode_idx]
                color = colors[train_mode_idx]
                model_folder = util.get_most_recent_model_folder_args_match(
                    seed=seed,
                    train_mode=train_mode,
                    num_particles=num_particles,
                    init_near=init_near)
                if model_folder is not None:
                    generative_model, _ = util.load_models(model_folder,
                                                           iteration=iteration)
                    if generative_model is not None:
                        plot_hinton(
                            ax,
                            generative_model.get_latent_params().data.numpy(),
                            8 - train_mode_idx - 1,
                            8 - train_mode_idx - 2,
                            0,
                            20,
                            label=label,
                            color=color)

            result += ax.artists

            # inference network
            for test_x_idx, test_x in enumerate(test_xs):
                ax = axss[num_particles_idx, test_x_idx + 1]
                test_x_tensor = torch.tensor(test_x,
                                             dtype=torch.float,
                                             device=args.device).unsqueeze(0)

                # true
                plot_hinton(ax,
                            true_generative_model.get_posterior_probs(
                                test_x_tensor)[0].data.numpy(),
                            8 - i,
                            8 - i - 1,
                            0,
                            20,
                            color='black')

                # learned
                for train_mode_idx, train_mode in enumerate(train_mode_list):
                    label = labels[train_mode_idx]
                    color = colors[train_mode_idx]
                    model_folder = \
                        util.get_most_recent_model_folder_args_match(
                            seed=seed, train_mode=train_mode,
                            num_particles=num_particles, init_near=init_near)
                    if model_folder is not None:
                        _, inference_network = util.load_models(
                            model_folder, iteration=iteration)
                        if inference_network is not None:
                            plot_hinton(ax,
                                        inference_network.get_latent_params(
                                            test_x_tensor)[0].data.numpy(),
                                        8 - train_mode_idx - 1,
                                        8 - train_mode_idx - 2,
                                        0,
                                        20,
                                        label=label,
                                        color=color)
                result += ax.artists
        return result

    anim = FuncAnimation(fig, update, frames=np.arange(100), blit=True)
    if not os.path.exists('./plots/'):
        os.makedirs('./plots/')
    filename = './plots/model_movie.mp4'
    anim.save(filename, dpi=300)
    print('Saved to {}'.format(filename))
예제 #4
0
def run(args):
    # set up args
    args.device = None
    if args.cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')
    args.num_mixtures = 20
    if args.init_near:
        args.init_mixture_logits = np.ones(args.num_mixtures)
    else:
        args.init_mixture_logits = np.array(
            list(reversed(2 * np.arange(args.num_mixtures))))
    args.softmax_multiplier = 0.5
    if args.train_mode == 'concrete':
        args.relaxed_one_hot = True
        args.temperature = 3
    else:
        args.relaxed_one_hot = False
        args.temperature = None
    temp = np.arange(args.num_mixtures) + 5
    true_p_mixture_probs = temp / np.sum(temp)
    args.true_mixture_logits = \
        np.log(true_p_mixture_probs) / args.softmax_multiplier
    util.print_with_time(str(args))

    # save args
    model_folder = util.get_model_folder()
    args_filename = util.get_args_path(model_folder)
    util.save_object(args, args_filename)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network, true_generative_model = \
        util.init_models(args)
    if args.train_mode == 'relax':
        control_variate = models.ControlVariate(args.num_mixtures)

    # init dataloader
    obss_data_loader = torch.utils.data.DataLoader(
        true_generative_model.sample_obs(args.num_obss),
        batch_size=args.batch_size,
        shuffle=True)

    # train
    if args.train_mode == 'mws':
        train_callback = train.TrainMWSCallback(model_folder,
                                                true_generative_model,
                                                args.logging_interval,
                                                args.checkpoint_interval,
                                                args.eval_interval)
        train.train_mws(generative_model, inference_network, obss_data_loader,
                        args.num_iterations, args.mws_memory_size,
                        train_callback)
    if args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            model_folder, true_generative_model,
            args.batch_size * args.num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               obss_data_loader, args.num_iterations,
                               args.num_particles, train_callback)
    elif args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(model_folder,
                                                     true_generative_model,
                                                     args.num_particles,
                                                     args.logging_interval,
                                                     args.checkpoint_interval,
                                                     args.eval_interval)
        train.train_wake_wake(generative_model, inference_network,
                              obss_data_loader, args.num_iterations,
                              args.num_particles, train_callback)
    elif args.train_mode == 'dww':
        train_callback = train.TrainDefensiveWakeWakeCallback(
            model_folder, true_generative_model, args.num_particles, 0.2,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_defensive_wake_wake(0.2, generative_model,
                                        inference_network, obss_data_loader,
                                        args.num_iterations,
                                        args.num_particles, train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            model_folder, true_generative_model, args.num_particles,
            args.train_mode, args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'concrete':
        train_callback = train.TrainConcreteCallback(
            model_folder, true_generative_model, args.num_particles,
            args.num_iterations, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'relax':
        train_callback = train.TrainRelaxCallback(model_folder,
                                                  true_generative_model,
                                                  args.num_particles,
                                                  args.logging_interval,
                                                  args.checkpoint_interval,
                                                  args.eval_interval)
        train.train_relax(generative_model, inference_network, control_variate,
                          obss_data_loader, args.num_iterations,
                          args.num_particles, train_callback)

    # save models and stats
    util.save_models(generative_model, inference_network, model_folder)
    if args.train_mode == 'relax':
        util.save_control_variate(control_variate, model_folder)
    stats_filename = util.get_stats_path(model_folder)
    util.save_object(train_callback, stats_filename)