示例#1
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
示例#2
0
def experiment(variant):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data = get_data(10000)
    m = ConvVAE(representation_size)
    t = ConvVAETrainer(train_data, test_data, m, beta=beta, use_cuda=False)
    for epoch in range(10):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
示例#3
0
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE as conv_vae
    from rlkit.torch.vae.conv_vae import ConvVAE
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    from rlkit.torch.vae.conv_vae import imsize48_default_architecture

    beta = variant["beta"]
    representation_size = variant.get("representation_size", 4)

    train_dataset, test_dataset = generate_vae_data(variant)
    decoder_activation = identity
    # train_dataset = train_dataset.cuda()
    # test_dataset = test_dataset.cuda()
    architecture = variant.get('vae_architecture',
                               imsize48_default_architecture)
    image_size = variant.get('image_size', 48)
    input_channels = variant.get('input_channels', 1)
    vae_model = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        architecture=architecture,
                        imsize=image_size,
                        input_channels=input_channels,
                        decoder_distribution='gaussian_identity_variance')

    vae_model.cuda()
    vae_trainner = ConvVAETrainer(train_dataset,
                                  test_dataset,
                                  vae_model,
                                  beta=beta,
                                  beta_schedule=None)
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        vae_trainner.train_epoch(epoch)
        vae_trainner.test_epoch(epoch)
        if epoch % save_period == 0:
            vae_trainner.dump_samples(epoch)
        vae_trainner.update_train_weights()
    # logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    project_path = osp.abspath(os.curdir)
    save_dir = osp.join(project_path + str('/saved_model/'), 'vae_model.pkl')
    torch.save(vae_model.state_dict(), save_dir)
    # torch.save(vae_model.state_dict(), \
    # '/mnt/manh/project/visual_RL_imaged_goal/saved_model/vae_model.pkl')
    if return_data:
        return vae_model, train_dataset, test_dataset
    return vae_model
示例#4
0
def train_vae(variant, return_data=False):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
示例#5
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
def train_vae(variant, other_variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       other_variant,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus
    if return_data:
        return m, train_data, test_data
    return m
示例#7
0
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
示例#8
0
def train_vae(cfgs, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    train_data, test_data, info = generate_vae_dataset(cfgs)
    logger.save_extra_data(info)
    logger.get_snapshot_dir()

    # FIXME default gaussian
    if cfgs.VAE.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = cfgs.VAE.get('architecture', None)
    if not architecture and cfgs.ENV.get('img_size') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and cfgs.ENV.get('img_size') == 48:
        architecture = conv_vae.imsize48_default_architecture

    vae_model = ConvVAE(
        representation_size=cfgs.VAE.representation_size,
        architecture=architecture,
        decoder_output_activation=decoder_activation,
        input_channels=cfgs.VAE.input_channels,
        decoder_distribution=cfgs.VAE.decoder_distribution,
        imsize=cfgs.VAE.img_size,
    )
    vae_model.to(ptu.device)

    # FIXME the function of beta_schedule?
    if 'beta_schedule_kwargs' in cfgs.VAE_TRAINER:
        beta_schedule = PiecewiseLinearSchedule(
            **cfgs.VAE_TRAINER.beta_schedule_kwargs)
    else:
        beta_schedule = None

    t = ConvVAETrainer(train_data,
                       test_data,
                       vae_model,
                       lr=cfgs.VAE_TRAINER.lr,
                       beta=cfgs.VAE_TRAINER.beta,
                       beta_schedule=beta_schedule)

    save_period = cfgs.VAE_TRAINER.save_period
    for epoch in range(cfgs.VAE_TRAINER.num_epochs):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    if return_data:
        return vae_model, train_data, test_data
    return vae_model
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get("generate_vae_data_fctn",
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant["generate_vae_dataset_kwargs"])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if "beta_schedule_kwargs" in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant["beta_schedule_kwargs"])
    else:
        beta_schedule = None
    if variant.get("decoder_activation", None) == "sigmoid":
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant["vae_kwargs"].get("architecture", None)
    if not architecture and variant.get("imsize") == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get("imsize") == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant["vae_kwargs"]["architecture"] = architecture
    variant["vae_kwargs"]["imsize"] = variant.get("imsize")

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant["vae_kwargs"])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant["algo_kwargs"])
    save_period = variant["save_period"]
    dump_skew_debug_plots = variant.get("dump_skew_debug_plots", False)
    for epoch in range(variant["num_epochs"]):
        should_save_imgs = epoch % save_period == 0
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, "vae.pkl", mode="pickle")
    if return_data:
        return m, train_data, test_data
    return m
def train_vae(
        variant,
        return_data=False,
        skewfit_variant=None):  # acutally train both the vae and the lstm
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    import ROLL.LSTM_model as LSTM_model
    from ROLL.LSTM_model import ConvLSTM2
    from ROLL.LSTM_trainer import ConvLSTMTrainer
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    seg_pretrain = variant['seg_pretrain']
    ori_pretrain = variant['ori_pretrain']
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn')
    assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!"

    train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn(
        variant['generate_lstm_dataset_kwargs'],
        segmented=True,
        segmentation_method=skewfit_variant['segmentation_method'])

    train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    architecture = variant['lstm_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = None  # TODO LSTM: wrap a 84 lstm architecutre
    elif not architecture and variant.get('imsize') == 48:
        architecture = LSTM_model.imsize48_default_architecture
    variant['lstm_kwargs']['architecture'] = architecture
    variant['lstm_kwargs']['imsize'] = variant.get('imsize')

    train_datas = [
        train_data_lstm,
        train_data_ori,
    ]
    test_datas = [
        test_data_lstm,
        test_data_ori,
    ]
    names = [
        'lstm_seg_pretrain',
        'vae_ori_pretrain',
    ]
    vaes = []
    env_id = variant['generate_lstm_dataset_kwargs'].get('env_id')
    assert env_id is not None
    lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False)

    for idx in range(2):
        train_data, test_data, name = train_datas[idx], test_datas[idx], names[
            idx]

        logger.add_tabular_output('{}_progress.csv'.format(name),
                                  relative_to_snapshot_dir=True)

        if idx == 1:  # train the original vae
            representation_size = variant.get(
                "vae_representation_size", variant.get('representation_size'))
            beta = variant.get('vae_beta', variant.get('beta'))
            m = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        **variant['vae_kwargs'])
            t = ConvVAETrainer(train_data,
                               test_data,
                               m,
                               beta=beta,
                               beta_schedule=beta_schedule,
                               **variant['algo_kwargs'])
        else:  # train the segmentation lstm
            lstm_version = variant.get('lstm_version', 2)
            if lstm_version == 2:
                lstm_class = ConvLSTM2

            representation_size = variant.get(
                "lstm_representation_size", variant.get('representation_size'))
            beta = variant.get('lstm_beta', variant.get('beta'))
            m = lstm_class(representation_size,
                           decoder_output_activation=decoder_activation,
                           **variant['lstm_kwargs'])
            t = ConvLSTMTrainer(train_data,
                                test_data,
                                m,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])

        m.to(ptu.device)

        vaes.append(m)

        print("test data len: ", len(test_data))
        print("train data len: ", len(train_data))

        save_period = variant['save_period']

        pjhome = os.environ['PJHOME']
        if env_id == 'SawyerPushHurdle-v0' and osp.exists(
                osp.join(
                    pjhome,
                    'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(
                        'SawyerPushHurdle-v0', 'seg-color', '500'))):
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500))
            puck_pos_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color',
                                                       500))
            all_data = np.load(data_file_path)
            puck_pos = np.load(puck_pos_path)
            all_data = normalize_image(all_data.copy())
            obj_states = puck_pos
        else:
            all_data = np.concatenate([train_data_lstm, test_data_lstm],
                                      axis=0)
            all_data = normalize_image(all_data.copy())
            obj_states = info_lstm['obj_state']

        obj = 'door' if 'Door' in env_id else 'puck'

        num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[
            'num_vae_epochs']

        if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain):
            for epoch in range(num_epochs):
                should_save_imgs = (epoch % save_period == 0)
                if idx == 0:  # only LSTM trainer has 'only_train_vae' argument
                    t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only)
                    t.test_epoch(epoch,
                                 save_reconstruction=should_save_imgs,
                                 save_prefix='r_' + name,
                                 only_train_vae=lstm_pretrain_vae_only)
                else:
                    t.train_epoch(epoch)
                    t.test_epoch(
                        epoch,
                        save_reconstruction=should_save_imgs,
                        save_prefix='r_' + name,
                    )

                if should_save_imgs:
                    t.dump_samples(epoch, save_prefix='s_' + name)

                    if idx == 0:
                        compare_latent_distance(
                            m,
                            all_data,
                            obj_states,
                            obj_name=obj,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='lstm_latent_distance_{}.png'.format(
                                epoch))
                        test_lstm_traj(
                            env_id,
                            m,
                            save_path=logger.get_snapshot_dir(),
                            save_name='lstm_test_traj_{}.png'.format(epoch))
                        test_masked_traj_lstm(
                            env_id,
                            m,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='masked_test_{}.png'.format(epoch))

                t.update_train_weights()

            logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle')

        logger.remove_tabular_output('{}_progress.csv'.format(name),
                                     relative_to_snapshot_dir=True)

        if idx == 0 and variant.get("only_train_lstm", False):
            exit()

    if return_data:
        return vaes, train_datas, test_datas
    return m