Beispiel #1
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Beispiel #2
0
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE as conv_vae
    from rlkit.torch.vae.conv_vae import ConvVAE
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    from rlkit.torch.vae.conv_vae import imsize48_default_architecture

    beta = variant["beta"]
    representation_size = variant.get("representation_size", 4)

    train_dataset, test_dataset = generate_vae_data(variant)
    decoder_activation = identity
    # train_dataset = train_dataset.cuda()
    # test_dataset = test_dataset.cuda()
    architecture = variant.get('vae_architecture',
                               imsize48_default_architecture)
    image_size = variant.get('image_size', 48)
    input_channels = variant.get('input_channels', 1)
    vae_model = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        architecture=architecture,
                        imsize=image_size,
                        input_channels=input_channels,
                        decoder_distribution='gaussian_identity_variance')

    vae_model.cuda()
    vae_trainner = ConvVAETrainer(train_dataset,
                                  test_dataset,
                                  vae_model,
                                  beta=beta,
                                  beta_schedule=None)
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        vae_trainner.train_epoch(epoch)
        vae_trainner.test_epoch(epoch)
        if epoch % save_period == 0:
            vae_trainner.dump_samples(epoch)
        vae_trainner.update_train_weights()
    # logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    project_path = osp.abspath(os.curdir)
    save_dir = osp.join(project_path + str('/saved_model/'), 'vae_model.pkl')
    torch.save(vae_model.state_dict(), save_dir)
    # torch.save(vae_model.state_dict(), \
    # '/mnt/manh/project/visual_RL_imaged_goal/saved_model/vae_model.pkl')
    if return_data:
        return vae_model, train_dataset, test_dataset
    return vae_model
Beispiel #3
0
def train_vae(variant, return_data=False):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
Beispiel #4
0
def experiment(variant):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data = get_data(10000)
    m = ConvVAE(representation_size)
    t = ConvVAETrainer(train_data, test_data, m, beta=beta, use_cuda=False)
    for epoch in range(10):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
Beispiel #5
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
Beispiel #6
0
    def __init__(
        self,
        representation_size,
        architecture,
        action_dim,
        dynamics_type=None,
        encoder_class=CNN,
        decoder_class=DCNN,
        decoder_output_activation=identity,
        decoder_distribution='bernoulli',
        input_channels=1,
        num_labels=0,
        imsize=48,
        init_w=1e-3,
        min_variance=1e-3,
        hidden_init=nn.init.xavier_uniform_,
        add_labels_to_latents=False,
        reconstruction_channels=3,
        base_depth=32,
        weight_init_gain=1.0,
    ):
        super().__init__(representation_size, architecture, encoder_class,
                         decoder_class, decoder_output_activation,
                         decoder_distribution, num_labels, input_channels,
                         imsize, init_w, min_variance, add_labels_to_latents,
                         hidden_init, reconstruction_channels, base_depth,
                         weight_init_gain)

        self.CDVAE = CDVAE(representation_size, architecture, action_dim,
                           dynamics_type, encoder_class, decoder_class,
                           decoder_output_activation, decoder_distribution,
                           input_channels, num_labels, imsize, init_w,
                           min_variance, hidden_init, add_labels_to_latents,
                           reconstruction_channels, base_depth,
                           weight_init_gain)

        self.adversary = ConvVAE(representation_size[0], architecture,
                                 encoder_class, decoder_class,
                                 decoder_output_activation,
                                 decoder_distribution, input_channels, imsize,
                                 init_w, min_variance, hidden_init)
def train_vae(variant, other_variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       other_variant,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus
    if return_data:
        return m, train_data, test_data
    return m
Beispiel #8
0
def train_vae(cfgs, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    train_data, test_data, info = generate_vae_dataset(cfgs)
    logger.save_extra_data(info)
    logger.get_snapshot_dir()

    # FIXME default gaussian
    if cfgs.VAE.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = cfgs.VAE.get('architecture', None)
    if not architecture and cfgs.ENV.get('img_size') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and cfgs.ENV.get('img_size') == 48:
        architecture = conv_vae.imsize48_default_architecture

    vae_model = ConvVAE(
        representation_size=cfgs.VAE.representation_size,
        architecture=architecture,
        decoder_output_activation=decoder_activation,
        input_channels=cfgs.VAE.input_channels,
        decoder_distribution=cfgs.VAE.decoder_distribution,
        imsize=cfgs.VAE.img_size,
    )
    vae_model.to(ptu.device)

    # FIXME the function of beta_schedule?
    if 'beta_schedule_kwargs' in cfgs.VAE_TRAINER:
        beta_schedule = PiecewiseLinearSchedule(
            **cfgs.VAE_TRAINER.beta_schedule_kwargs)
    else:
        beta_schedule = None

    t = ConvVAETrainer(train_data,
                       test_data,
                       vae_model,
                       lr=cfgs.VAE_TRAINER.lr,
                       beta=cfgs.VAE_TRAINER.beta,
                       beta_schedule=beta_schedule)

    save_period = cfgs.VAE_TRAINER.save_period
    for epoch in range(cfgs.VAE_TRAINER.num_epochs):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    if return_data:
        return vae_model, train_data, test_data
    return vae_model
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get("generate_vae_data_fctn",
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant["generate_vae_dataset_kwargs"])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if "beta_schedule_kwargs" in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant["beta_schedule_kwargs"])
    else:
        beta_schedule = None
    if variant.get("decoder_activation", None) == "sigmoid":
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant["vae_kwargs"].get("architecture", None)
    if not architecture and variant.get("imsize") == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get("imsize") == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant["vae_kwargs"]["architecture"] = architecture
    variant["vae_kwargs"]["imsize"] = variant.get("imsize")

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant["vae_kwargs"])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant["algo_kwargs"])
    save_period = variant["save_period"]
    dump_skew_debug_plots = variant.get("dump_skew_debug_plots", False)
    for epoch in range(variant["num_epochs"]):
        should_save_imgs = epoch % save_period == 0
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, "vae.pkl", mode="pickle")
    if return_data:
        return m, train_data, test_data
    return m
def train_vae(
        variant,
        return_data=False,
        skewfit_variant=None):  # acutally train both the vae and the lstm
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    import ROLL.LSTM_model as LSTM_model
    from ROLL.LSTM_model import ConvLSTM2
    from ROLL.LSTM_trainer import ConvLSTMTrainer
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    seg_pretrain = variant['seg_pretrain']
    ori_pretrain = variant['ori_pretrain']
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn')
    assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!"

    train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn(
        variant['generate_lstm_dataset_kwargs'],
        segmented=True,
        segmentation_method=skewfit_variant['segmentation_method'])

    train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    architecture = variant['lstm_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = None  # TODO LSTM: wrap a 84 lstm architecutre
    elif not architecture and variant.get('imsize') == 48:
        architecture = LSTM_model.imsize48_default_architecture
    variant['lstm_kwargs']['architecture'] = architecture
    variant['lstm_kwargs']['imsize'] = variant.get('imsize')

    train_datas = [
        train_data_lstm,
        train_data_ori,
    ]
    test_datas = [
        test_data_lstm,
        test_data_ori,
    ]
    names = [
        'lstm_seg_pretrain',
        'vae_ori_pretrain',
    ]
    vaes = []
    env_id = variant['generate_lstm_dataset_kwargs'].get('env_id')
    assert env_id is not None
    lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False)

    for idx in range(2):
        train_data, test_data, name = train_datas[idx], test_datas[idx], names[
            idx]

        logger.add_tabular_output('{}_progress.csv'.format(name),
                                  relative_to_snapshot_dir=True)

        if idx == 1:  # train the original vae
            representation_size = variant.get(
                "vae_representation_size", variant.get('representation_size'))
            beta = variant.get('vae_beta', variant.get('beta'))
            m = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        **variant['vae_kwargs'])
            t = ConvVAETrainer(train_data,
                               test_data,
                               m,
                               beta=beta,
                               beta_schedule=beta_schedule,
                               **variant['algo_kwargs'])
        else:  # train the segmentation lstm
            lstm_version = variant.get('lstm_version', 2)
            if lstm_version == 2:
                lstm_class = ConvLSTM2

            representation_size = variant.get(
                "lstm_representation_size", variant.get('representation_size'))
            beta = variant.get('lstm_beta', variant.get('beta'))
            m = lstm_class(representation_size,
                           decoder_output_activation=decoder_activation,
                           **variant['lstm_kwargs'])
            t = ConvLSTMTrainer(train_data,
                                test_data,
                                m,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])

        m.to(ptu.device)

        vaes.append(m)

        print("test data len: ", len(test_data))
        print("train data len: ", len(train_data))

        save_period = variant['save_period']

        pjhome = os.environ['PJHOME']
        if env_id == 'SawyerPushHurdle-v0' and osp.exists(
                osp.join(
                    pjhome,
                    'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(
                        'SawyerPushHurdle-v0', 'seg-color', '500'))):
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500))
            puck_pos_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color',
                                                       500))
            all_data = np.load(data_file_path)
            puck_pos = np.load(puck_pos_path)
            all_data = normalize_image(all_data.copy())
            obj_states = puck_pos
        else:
            all_data = np.concatenate([train_data_lstm, test_data_lstm],
                                      axis=0)
            all_data = normalize_image(all_data.copy())
            obj_states = info_lstm['obj_state']

        obj = 'door' if 'Door' in env_id else 'puck'

        num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[
            'num_vae_epochs']

        if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain):
            for epoch in range(num_epochs):
                should_save_imgs = (epoch % save_period == 0)
                if idx == 0:  # only LSTM trainer has 'only_train_vae' argument
                    t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only)
                    t.test_epoch(epoch,
                                 save_reconstruction=should_save_imgs,
                                 save_prefix='r_' + name,
                                 only_train_vae=lstm_pretrain_vae_only)
                else:
                    t.train_epoch(epoch)
                    t.test_epoch(
                        epoch,
                        save_reconstruction=should_save_imgs,
                        save_prefix='r_' + name,
                    )

                if should_save_imgs:
                    t.dump_samples(epoch, save_prefix='s_' + name)

                    if idx == 0:
                        compare_latent_distance(
                            m,
                            all_data,
                            obj_states,
                            obj_name=obj,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='lstm_latent_distance_{}.png'.format(
                                epoch))
                        test_lstm_traj(
                            env_id,
                            m,
                            save_path=logger.get_snapshot_dir(),
                            save_name='lstm_test_traj_{}.png'.format(epoch))
                        test_masked_traj_lstm(
                            env_id,
                            m,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='masked_test_{}.png'.format(epoch))

                t.update_train_weights()

            logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle')

        logger.remove_tabular_output('{}_progress.csv'.format(name),
                                     relative_to_snapshot_dir=True)

        if idx == 0 and variant.get("only_train_lstm", False):
            exit()

    if return_data:
        return vaes, train_datas, test_datas
    return m
Beispiel #11
0
def _pointmass_fixed_goal_experiment(
        vae_latent_size,
        replay_buffer_size,
        cnn_kwargs,
        vae_kwargs,
        policy_kwargs,
        qf_kwargs,
        e2e_trainer_kwargs,
        sac_trainer_kwargs,
        algorithm_kwargs,
        eval_path_collector_kwargs=None,
        expl_path_collector_kwargs=None,
        **kwargs
):
    if expl_path_collector_kwargs is None:
        expl_path_collector_kwargs = {}
    if eval_path_collector_kwargs is None:
        eval_path_collector_kwargs = {}
    from multiworld.core.image_env import ImageEnv
    from multiworld.envs.pygame.point2d import Point2DEnv
    from multiworld.core.flat_goal_env import FlatGoalEnv
    env = Point2DEnv(
        images_are_rgb=True,
        render_onscreen=False,
        show_goal=False,
        ball_radius=2,
        render_size=48,
        fixed_goal=(0, 0),
    )
    env = ImageEnv(env, imsize=env.render_size, transpose=True, normalize=True)
    env = FlatGoalEnv(env)#, append_goal_to_obs=True)
    input_width, input_height = env.image_shape

    action_dim = int(np.prod(env.action_space.shape))
    vae = ConvVAE(
        representation_size=vae_latent_size,
        input_channels=3,
        imsize=input_width,
        decoder_output_activation=nn.Sigmoid(),
        # decoder_distribution='gaussian_identity_variance',
        **vae_kwargs)
    encoder = Vae2Encoder(vae)

    def make_cnn():
        return networks.CNN(
            input_width=input_width,
            input_height=input_height,
            input_channels=3,
            output_conv_channels=True,
            output_size=None,
            **cnn_kwargs
        )

    def make_qf():
        return networks.MlpQfWithObsProcessor(
            obs_processor=nn.Sequential(
                encoder,
                networks.Flatten(),
            ),
            output_size=1,
            input_size=action_dim+vae_latent_size,
            **qf_kwargs
        )
    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()
    action_dim = int(np.prod(env.action_space.shape))
    policy_cnn = make_cnn()
    policy = TanhGaussianPolicyAdapter(
        nn.Sequential(policy_cnn, networks.Flatten()),
        policy_cnn.conv_output_flat_size,
        action_dim,
        **policy_kwargs
    )
    eval_env = expl_env = env

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
        **eval_path_collector_kwargs
    )
    replay_buffer = EnvReplayBuffer(
        replay_buffer_size,
        expl_env,
    )
    vae_trainer = VAETrainer(vae)
    sac_trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs
    )
    trainer = End2EndSACTrainer(
        sac_trainer=sac_trainer,
        vae_trainer=vae_trainer,
        **e2e_trainer_kwargs,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
        **expl_path_collector_kwargs
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **algorithm_kwargs
    )
    algorithm.to(ptu.device)
    algorithm.train()