def train_vae(variant, other_variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       other_variant,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus
    if return_data:
        return m, train_data, test_data
    return m
예제 #2
0
파일: common.py 프로젝트: anair13/rlkit
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule, ConstantSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE
    # from rlkit.torch.vae.conv_vae import (
    #     ConvVAE,
    #     ConvDynamicsVAE,
    #     SpatialAutoEncoder,
    #     AutoEncoder,
    # )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    import gym
    beta = variant["beta"]
    representation_size = variant.get(
        "representation_size",
        variant.get("latent_sizes", variant.get("embedding_dim", None)))
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    variant['algo_kwargs']['num_epochs'] = variant['num_epochs']
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    variant['generate_vae_dataset_kwargs']['batch_size'] = variant[
        'algo_kwargs']['batch_size']
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if 'context_schedule' in variant:
        schedule = variant['context_schedule']
        if type(schedule) is dict:
            context_schedule = PiecewiseLinearSchedule(**schedule)
        else:
            context_schedule = ConstantSchedule(schedule)
        variant['algo_kwargs']['context_schedule'] = context_schedule
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        model = AutoEncoder(representation_size,
                            decoder_output_activation=decoder_activation,
                            **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        model = SpatialAutoEncoder(
            representation_size,
            decoder_output_activation=decoder_activation,
            **variant['vae_kwargs'])
    elif variant.get('only_kwargs', False):
        vae_class = variant.get('vae_class', ConvVAE)
        model = vae_class(**variant['vae_kwargs'])
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        if use_linear_dynamics:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              action_dim=action_dim,
                              **variant['vae_kwargs'])
        else:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              **variant['vae_kwargs'])

    model.to(ptu.device)

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'model', mode='pickle')

    if return_data:
        return model, train_dataset, test_dataset

    return model
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get("generate_vae_data_fctn",
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant["generate_vae_dataset_kwargs"])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if "beta_schedule_kwargs" in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant["beta_schedule_kwargs"])
    else:
        beta_schedule = None
    if variant.get("decoder_activation", None) == "sigmoid":
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant["vae_kwargs"].get("architecture", None)
    if not architecture and variant.get("imsize") == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get("imsize") == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant["vae_kwargs"]["architecture"] = architecture
    variant["vae_kwargs"]["imsize"] = variant.get("imsize")

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant["vae_kwargs"])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant["algo_kwargs"])
    save_period = variant["save_period"]
    dump_skew_debug_plots = variant.get("dump_skew_debug_plots", False)
    for epoch in range(variant["num_epochs"]):
        should_save_imgs = epoch % save_period == 0
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, "vae.pkl", mode="pickle")
    if return_data:
        return m, train_data, test_data
    return m
예제 #4
0
def train_vae(cfgs, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    train_data, test_data, info = generate_vae_dataset(cfgs)
    logger.save_extra_data(info)
    logger.get_snapshot_dir()

    # FIXME default gaussian
    if cfgs.VAE.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = cfgs.VAE.get('architecture', None)
    if not architecture and cfgs.ENV.get('img_size') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and cfgs.ENV.get('img_size') == 48:
        architecture = conv_vae.imsize48_default_architecture

    vae_model = ConvVAE(
        representation_size=cfgs.VAE.representation_size,
        architecture=architecture,
        decoder_output_activation=decoder_activation,
        input_channels=cfgs.VAE.input_channels,
        decoder_distribution=cfgs.VAE.decoder_distribution,
        imsize=cfgs.VAE.img_size,
    )
    vae_model.to(ptu.device)

    # FIXME the function of beta_schedule?
    if 'beta_schedule_kwargs' in cfgs.VAE_TRAINER:
        beta_schedule = PiecewiseLinearSchedule(
            **cfgs.VAE_TRAINER.beta_schedule_kwargs)
    else:
        beta_schedule = None

    t = ConvVAETrainer(train_data,
                       test_data,
                       vae_model,
                       lr=cfgs.VAE_TRAINER.lr,
                       beta=cfgs.VAE_TRAINER.beta,
                       beta_schedule=beta_schedule)

    save_period = cfgs.VAE_TRAINER.save_period
    for epoch in range(cfgs.VAE_TRAINER.num_epochs):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    if return_data:
        return vae_model, train_data, test_data
    return vae_model
def train_vae(
        variant,
        return_data=False,
        skewfit_variant=None):  # acutally train both the vae and the lstm
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    import ROLL.LSTM_model as LSTM_model
    from ROLL.LSTM_model import ConvLSTM2
    from ROLL.LSTM_trainer import ConvLSTMTrainer
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    seg_pretrain = variant['seg_pretrain']
    ori_pretrain = variant['ori_pretrain']
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn')
    assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!"

    train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn(
        variant['generate_lstm_dataset_kwargs'],
        segmented=True,
        segmentation_method=skewfit_variant['segmentation_method'])

    train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    architecture = variant['lstm_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = None  # TODO LSTM: wrap a 84 lstm architecutre
    elif not architecture and variant.get('imsize') == 48:
        architecture = LSTM_model.imsize48_default_architecture
    variant['lstm_kwargs']['architecture'] = architecture
    variant['lstm_kwargs']['imsize'] = variant.get('imsize')

    train_datas = [
        train_data_lstm,
        train_data_ori,
    ]
    test_datas = [
        test_data_lstm,
        test_data_ori,
    ]
    names = [
        'lstm_seg_pretrain',
        'vae_ori_pretrain',
    ]
    vaes = []
    env_id = variant['generate_lstm_dataset_kwargs'].get('env_id')
    assert env_id is not None
    lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False)

    for idx in range(2):
        train_data, test_data, name = train_datas[idx], test_datas[idx], names[
            idx]

        logger.add_tabular_output('{}_progress.csv'.format(name),
                                  relative_to_snapshot_dir=True)

        if idx == 1:  # train the original vae
            representation_size = variant.get(
                "vae_representation_size", variant.get('representation_size'))
            beta = variant.get('vae_beta', variant.get('beta'))
            m = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        **variant['vae_kwargs'])
            t = ConvVAETrainer(train_data,
                               test_data,
                               m,
                               beta=beta,
                               beta_schedule=beta_schedule,
                               **variant['algo_kwargs'])
        else:  # train the segmentation lstm
            lstm_version = variant.get('lstm_version', 2)
            if lstm_version == 2:
                lstm_class = ConvLSTM2

            representation_size = variant.get(
                "lstm_representation_size", variant.get('representation_size'))
            beta = variant.get('lstm_beta', variant.get('beta'))
            m = lstm_class(representation_size,
                           decoder_output_activation=decoder_activation,
                           **variant['lstm_kwargs'])
            t = ConvLSTMTrainer(train_data,
                                test_data,
                                m,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])

        m.to(ptu.device)

        vaes.append(m)

        print("test data len: ", len(test_data))
        print("train data len: ", len(train_data))

        save_period = variant['save_period']

        pjhome = os.environ['PJHOME']
        if env_id == 'SawyerPushHurdle-v0' and osp.exists(
                osp.join(
                    pjhome,
                    'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(
                        'SawyerPushHurdle-v0', 'seg-color', '500'))):
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500))
            puck_pos_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color',
                                                       500))
            all_data = np.load(data_file_path)
            puck_pos = np.load(puck_pos_path)
            all_data = normalize_image(all_data.copy())
            obj_states = puck_pos
        else:
            all_data = np.concatenate([train_data_lstm, test_data_lstm],
                                      axis=0)
            all_data = normalize_image(all_data.copy())
            obj_states = info_lstm['obj_state']

        obj = 'door' if 'Door' in env_id else 'puck'

        num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[
            'num_vae_epochs']

        if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain):
            for epoch in range(num_epochs):
                should_save_imgs = (epoch % save_period == 0)
                if idx == 0:  # only LSTM trainer has 'only_train_vae' argument
                    t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only)
                    t.test_epoch(epoch,
                                 save_reconstruction=should_save_imgs,
                                 save_prefix='r_' + name,
                                 only_train_vae=lstm_pretrain_vae_only)
                else:
                    t.train_epoch(epoch)
                    t.test_epoch(
                        epoch,
                        save_reconstruction=should_save_imgs,
                        save_prefix='r_' + name,
                    )

                if should_save_imgs:
                    t.dump_samples(epoch, save_prefix='s_' + name)

                    if idx == 0:
                        compare_latent_distance(
                            m,
                            all_data,
                            obj_states,
                            obj_name=obj,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='lstm_latent_distance_{}.png'.format(
                                epoch))
                        test_lstm_traj(
                            env_id,
                            m,
                            save_path=logger.get_snapshot_dir(),
                            save_name='lstm_test_traj_{}.png'.format(epoch))
                        test_masked_traj_lstm(
                            env_id,
                            m,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='masked_test_{}.png'.format(epoch))

                t.update_train_weights()

            logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle')

        logger.remove_tabular_output('{}_progress.csv'.format(name),
                                     relative_to_snapshot_dir=True)

        if idx == 0 and variant.get("only_train_lstm", False):
            exit()

    if return_data:
        return vaes, train_datas, test_datas
    return m