Esempio n. 1
0
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE as conv_vae
    from rlkit.torch.vae.conv_vae import ConvVAE
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    from rlkit.torch.vae.conv_vae import imsize48_default_architecture

    beta = variant["beta"]
    representation_size = variant.get("representation_size", 4)

    train_dataset, test_dataset = generate_vae_data(variant)
    decoder_activation = identity
    # train_dataset = train_dataset.cuda()
    # test_dataset = test_dataset.cuda()
    architecture = variant.get('vae_architecture',
                               imsize48_default_architecture)
    image_size = variant.get('image_size', 48)
    input_channels = variant.get('input_channels', 1)
    vae_model = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        architecture=architecture,
                        imsize=image_size,
                        input_channels=input_channels,
                        decoder_distribution='gaussian_identity_variance')

    vae_model.cuda()
    vae_trainner = ConvVAETrainer(train_dataset,
                                  test_dataset,
                                  vae_model,
                                  beta=beta,
                                  beta_schedule=None)
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        vae_trainner.train_epoch(epoch)
        vae_trainner.test_epoch(epoch)
        if epoch % save_period == 0:
            vae_trainner.dump_samples(epoch)
        vae_trainner.update_train_weights()
    # logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    project_path = osp.abspath(os.curdir)
    save_dir = osp.join(project_path + str('/saved_model/'), 'vae_model.pkl')
    torch.save(vae_model.state_dict(), save_dir)
    # torch.save(vae_model.state_dict(), \
    # '/mnt/manh/project/visual_RL_imaged_goal/saved_model/vae_model.pkl')
    if return_data:
        return vae_model, train_dataset, test_dataset
    return vae_model
Esempio n. 2
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Esempio n. 3
0
def experiment(variant):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data = get_data(10000)
    m = ConvVAE(representation_size)
    t = ConvVAETrainer(train_data, test_data, m, beta=beta, use_cuda=False)
    for epoch in range(10):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
Esempio n. 4
0
def train_vae(variant, return_data=False):
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
Esempio n. 5
0
def experiment(variant):
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
Esempio n. 6
0
def tdm_td3_experiment_online_vae(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.state_distance.tdm_networks import TdmQf, TdmPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
        'vectorized'] = vectorized

    norm_order = env.norm_order
    # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
    #     'norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs']
    td3_kwargs = algo_kwargs['td3_kwargs']
    td3_kwargs['training_env'] = env
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algo_kwargs["replay_buffer"] = replay_buffer

    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)
    render = variant["render"]
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    algorithm = OnlineVaeTdmTd3(
        online_vae_kwargs=dict(vae=vae,
                               vae_trainer=t,
                               **variant['algo_kwargs']['online_vae_kwargs']),
        tdm_td3_kwargs=dict(env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs']['tdm_td3_kwargs']),
    )

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
Esempio n. 7
0
def td3_experiment_online_vae_exploring(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.td3.td3 import TD3
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    exploring_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploring_exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=exploring_policy,
    )

    vae = env.vae
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    if variant.get('use_replay_buffer_goals', False):
        env.replay_buffer = replay_buffer
        env.use_replay_buffer_goals = True

    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)

    control_algorithm = TD3(env=env,
                            training_env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs'])
    exploring_algorithm = TD3(env=env,
                              training_env=env,
                              qf1=exploring_qf1,
                              qf2=exploring_qf2,
                              policy=exploring_policy,
                              exploration_policy=exploring_exploration_policy,
                              **variant['algo_kwargs'])

    assert 'vae_training_schedule' not in variant,\
        "Just put it in joint_algo_kwargs"
    algorithm = OnlineVaeHerJointAlgo(vae=vae,
                                      vae_trainer=t,
                                      env=env,
                                      training_env=env,
                                      policy=policy,
                                      exploration_policy=exploration_policy,
                                      replay_buffer=replay_buffer,
                                      algo1=control_algorithm,
                                      algo2=exploring_algorithm,
                                      algo1_prefix="Control_",
                                      algo2_prefix="VAE_Exploration_",
                                      observation_key=observation_key,
                                      desired_goal_key=desired_goal_key,
                                      **variant['joint_algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
def train_vae(variant, other_variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       other_variant,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus
    if return_data:
        return m, train_data, test_data
    return m
def skewfit_experiment(variant, other_variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    vae_trainer = ConvVAETrainer(variant['vae_train_data'],
                                 variant['vae_test_data'], env.vae,
                                 other_variant,
                                 **variant['online_vae_trainer_kwargs'])
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant['evaluation_goal_sampling_mode'],
        env,
        MakeDeterministic(policy),
        max_path_length,
        other_variant=other_variant,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant['exploration_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        other_variant=other_variant,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Esempio n. 10
0
def train_vae(beta,
              representation_size,
              imsize,
              num_epochs,
              save_period,
              generate_vae_dataset_fctn=None,
              beta_schedule_kwargs=None,
              decoder_activation=None,
              vae_kwargs=None,
              generate_vae_dataset_kwargs=None,
              algo_kwargs=None,
              use_spatial_auto_encoder=False,
              vae_class=None,
              dump_skew_debug_plots=False):
    from rlkit.misc.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE,
        SpatialAutoEncoder,
        AutoEncoder,
    )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.torch.vae.vae_experiment import VAEExperiment
    from rlkit.pythonplusplus import identity
    from rlkit.torch.grill.launcher import generate_vae_dataset
    import torch
    if vae_kwargs is None:
        vae_kwargs = {}
    if generate_vae_dataset_kwargs is None:
        generate_vae_dataset_kwargs = {}
    if algo_kwargs is None:
        algo_kwargs = {}
    if generate_vae_dataset_fctn is None:
        generate_vae_dataset_fctn = generate_vae_dataset
    if vae_class is None:
        vae_class = ConvVAE
    if beta_schedule_kwargs is not None:
        beta_schedule = PiecewiseLinearSchedule(**beta_schedule_kwargs)
    else:
        beta_schedule = None
    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = vae_kwargs.get('architecture', None)
    if not architecture and imsize == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and imsize == 48:
        architecture = conv_vae.imsize48_default_architecture
    vae_kwargs['architecture'] = architecture
    vae_kwargs['imsize'] = imsize

    if algo_kwargs.get('is_auto_encoder', False):
        m = AutoEncoder(representation_size,
                        decoder_output_activation=decoder_activation,
                        **vae_kwargs)
    elif use_spatial_auto_encoder:
        m = SpatialAutoEncoder(representation_size,
                               decoder_output_activation=decoder_activation,
                               **vae_kwargs)
    else:
        m = vae_class(representation_size,
                      decoder_output_activation=decoder_activation,
                      **vae_kwargs)
    train_data, test_data, info = generate_vae_dataset_fctn(
        generate_vae_dataset_kwargs)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **algo_kwargs)
    vae_exp = VAEExperiment(t, num_epochs, save_period, dump_skew_debug_plots)
    return vae_exp, train_data, test_data
Esempio n. 11
0
def skewfit_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer \
        import OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    import rlkit.torch.vae.vae_schedules as vae_schedules

    #### getting parameter for training VAE and RIG
    env = get_envs(variant)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    replay_buffer_kwargs = variant.get(
        'replay_buffer_kwargs',
        dict(
            start_skew_epoch=10,
            max_size=int(100000),
            fraction_goals_rollout_goals=0.2,
            fraction_goals_env_goals=0.5,
            exploration_rewards_type='None',
            vae_priority_type='vae_prob',
            priority_function_kwargs=dict(
                sampling_method='importance_sampling',
                decoder_distribution='gaussian_identity_variance',
                num_latents_to_sample=10,
            ),
            power=0,
            relabeling_goal_sampling_mode='vae_prior',
        ))
    online_vae_trainer_kwargs = variant.get('online_vae_trainer_kwargs',
                                            dict(beta=20, lr=1e-3))
    max_path_length = variant.get('max_path_length', 50)
    algo_kwargs = variant.get(
        'algo_kwargs',
        dict(
            batch_size=1024,
            num_epochs=1000,
            num_eval_steps_per_epoch=500,
            num_expl_steps_per_train_loop=500,
            num_trains_per_train_loop=1000,
            min_num_steps_before_training=10000,
            vae_training_schedule=vae_schedules.custom_schedule_2,
            oracle_data=False,
            vae_save_period=50,
            parallel_vae_train=False,
        ))
    twin_sac_trainer_kwargs = variant.get(
        'twin_sac_trainer_kwargs',
        dict(
            discount=0.99,
            reward_scale=1,
            soft_target_tau=1e-3,
            target_update_period=1,  # 1
            use_automatic_entropy_tuning=True,
        ))
    ############################################################################

    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                hidden_sizes=hidden_sizes)

    vae = variant['vae_model']
    # create a replay buffer for training an online VAE
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs)
    # create an online vae_trainer to train vae on the fly
    vae_trainer = ConvVAETrainer(variant['vae_train_data'],
                                 variant['vae_test_data'], vae,
                                 **online_vae_trainer_kwargs)
    # create a SACTrainer to learn a soft Q-function and appropriate policy
    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **twin_sac_trainer_kwargs)
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant.get('evaluation_goal_sampling_mode', 'reset_of_env'),
        env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant.get('exploration_goal_sampling_mode', 'vae_prior'),
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        max_path_length=max_path_length,
        **algo_kwargs)

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Esempio n. 12
0
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
Esempio n. 13
0
def train_vae(cfgs, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    train_data, test_data, info = generate_vae_dataset(cfgs)
    logger.save_extra_data(info)
    logger.get_snapshot_dir()

    # FIXME default gaussian
    if cfgs.VAE.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = cfgs.VAE.get('architecture', None)
    if not architecture and cfgs.ENV.get('img_size') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and cfgs.ENV.get('img_size') == 48:
        architecture = conv_vae.imsize48_default_architecture

    vae_model = ConvVAE(
        representation_size=cfgs.VAE.representation_size,
        architecture=architecture,
        decoder_output_activation=decoder_activation,
        input_channels=cfgs.VAE.input_channels,
        decoder_distribution=cfgs.VAE.decoder_distribution,
        imsize=cfgs.VAE.img_size,
    )
    vae_model.to(ptu.device)

    # FIXME the function of beta_schedule?
    if 'beta_schedule_kwargs' in cfgs.VAE_TRAINER:
        beta_schedule = PiecewiseLinearSchedule(
            **cfgs.VAE_TRAINER.beta_schedule_kwargs)
    else:
        beta_schedule = None

    t = ConvVAETrainer(train_data,
                       test_data,
                       vae_model,
                       lr=cfgs.VAE_TRAINER.lr,
                       beta=cfgs.VAE_TRAINER.beta,
                       beta_schedule=beta_schedule)

    save_period = cfgs.VAE_TRAINER.save_period
    for epoch in range(cfgs.VAE_TRAINER.num_epochs):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle')
    if return_data:
        return vae_model, train_data, test_data
    return vae_model
Esempio n. 14
0
def skewfit_experiment(cfgs):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(cfgs)
    env = get_envs(cfgs)

    # TODO
    uniform_dataset_fn = cfgs.GENERATE_VAE_DATASET.get(
        'uniform_dataset_generator', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **cfgs.GENERATE_VAE_DATASET.generate_uniform_dataset_kwargs)
    else:
        uniform_dataset = None

    observation_key = cfgs.SKEW_FIT.get('observation_key',
                                        'latent_observation')
    desired_goal_key = cfgs.SKEW_FIT.get('desired_goal_key',
                                         'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = cfgs.Q_FUNCTION.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=cfgs.POLICY.get('hidden_sizes', [400, 300]),
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        priority_function_kwargs=cfgs.PRIORITY_FUNCTION,
        **cfgs.REPLAY_BUFFER)
    vae_trainer = ConvVAETrainer(
        cfgs.VAE_TRAINER.train_data,
        cfgs.VAE_TRAINER.test_data,
        env.vae,
        beta=cfgs.VAE_TRAINER.beta,
        lr=cfgs.VAE_TRAINER.lr,
    )

    # assert 'vae_training_schedule' not in cfgs, "Just put it in algo_kwargs"
    max_path_length = cfgs.SKEW_FIT.max_path_length
    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **cfgs.TWIN_SAC_TRAINER)
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        cfgs.SKEW_FIT.evaluation_goal_sampling_mode,
        env,
        MakeDeterministic(policy),
        decode_goals=True,  # TODO check this
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        cfgs.SKEW_FIT.exploration_goal_sampling_mode,
        env,
        policy,
        decode_goals=True,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,  # TODO used in test vae
        max_path_length=max_path_length,
        parallel_vae_train=cfgs.VAE_TRAINER.parallel_train,
        **cfgs.ALGORITHM)

    if cfgs.SKEW_FIT.custom_goal_sampler == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
def train_vae(variant, return_data=False):
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import ConvVAE
    import rlkit.torch.vae.conv_vae as conv_vae
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.core import logger
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch

    beta = variant["beta"]
    representation_size = variant["representation_size"]
    generate_vae_dataset_fctn = variant.get("generate_vae_data_fctn",
                                            generate_vae_dataset)
    train_data, test_data, info = generate_vae_dataset_fctn(
        variant["generate_vae_dataset_kwargs"])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if "beta_schedule_kwargs" in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant["beta_schedule_kwargs"])
    else:
        beta_schedule = None
    if variant.get("decoder_activation", None) == "sigmoid":
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant["vae_kwargs"].get("architecture", None)
    if not architecture and variant.get("imsize") == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get("imsize") == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant["vae_kwargs"]["architecture"] = architecture
    variant["vae_kwargs"]["imsize"] = variant.get("imsize")

    m = ConvVAE(representation_size,
                decoder_output_activation=decoder_activation,
                **variant["vae_kwargs"])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant["algo_kwargs"])
    save_period = variant["save_period"]
    dump_skew_debug_plots = variant.get("dump_skew_debug_plots", False)
    for epoch in range(variant["num_epochs"]):
        should_save_imgs = epoch % save_period == 0
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, "vae.pkl", mode="pickle")
    if return_data:
        return m, train_data, test_data
    return m
def train_vae(
        variant,
        return_data=False,
        skewfit_variant=None):  # acutally train both the vae and the lstm
    from rlkit.util.ml_util import PiecewiseLinearSchedule
    from rlkit.torch.vae.conv_vae import (
        ConvVAE, )
    import rlkit.torch.vae.conv_vae as conv_vae
    import ROLL.LSTM_model as LSTM_model
    from ROLL.LSTM_model import ConvLSTM2
    from ROLL.LSTM_trainer import ConvLSTMTrainer
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    import rlkit.torch.pytorch_util as ptu
    from rlkit.pythonplusplus import identity
    import torch
    seg_pretrain = variant['seg_pretrain']
    ori_pretrain = variant['ori_pretrain']
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn')
    assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!"

    train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn(
        variant['generate_lstm_dataset_kwargs'],
        segmented=True,
        segmentation_method=skewfit_variant['segmentation_method'])

    train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity

    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    architecture = variant['lstm_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = None  # TODO LSTM: wrap a 84 lstm architecutre
    elif not architecture and variant.get('imsize') == 48:
        architecture = LSTM_model.imsize48_default_architecture
    variant['lstm_kwargs']['architecture'] = architecture
    variant['lstm_kwargs']['imsize'] = variant.get('imsize')

    train_datas = [
        train_data_lstm,
        train_data_ori,
    ]
    test_datas = [
        test_data_lstm,
        test_data_ori,
    ]
    names = [
        'lstm_seg_pretrain',
        'vae_ori_pretrain',
    ]
    vaes = []
    env_id = variant['generate_lstm_dataset_kwargs'].get('env_id')
    assert env_id is not None
    lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False)

    for idx in range(2):
        train_data, test_data, name = train_datas[idx], test_datas[idx], names[
            idx]

        logger.add_tabular_output('{}_progress.csv'.format(name),
                                  relative_to_snapshot_dir=True)

        if idx == 1:  # train the original vae
            representation_size = variant.get(
                "vae_representation_size", variant.get('representation_size'))
            beta = variant.get('vae_beta', variant.get('beta'))
            m = ConvVAE(representation_size,
                        decoder_output_activation=decoder_activation,
                        **variant['vae_kwargs'])
            t = ConvVAETrainer(train_data,
                               test_data,
                               m,
                               beta=beta,
                               beta_schedule=beta_schedule,
                               **variant['algo_kwargs'])
        else:  # train the segmentation lstm
            lstm_version = variant.get('lstm_version', 2)
            if lstm_version == 2:
                lstm_class = ConvLSTM2

            representation_size = variant.get(
                "lstm_representation_size", variant.get('representation_size'))
            beta = variant.get('lstm_beta', variant.get('beta'))
            m = lstm_class(representation_size,
                           decoder_output_activation=decoder_activation,
                           **variant['lstm_kwargs'])
            t = ConvLSTMTrainer(train_data,
                                test_data,
                                m,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])

        m.to(ptu.device)

        vaes.append(m)

        print("test data len: ", len(test_data))
        print("train data len: ", len(train_data))

        save_period = variant['save_period']

        pjhome = os.environ['PJHOME']
        if env_id == 'SawyerPushHurdle-v0' and osp.exists(
                osp.join(
                    pjhome,
                    'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(
                        'SawyerPushHurdle-v0', 'seg-color', '500'))):
            data_file_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500))
            puck_pos_path = osp.join(
                pjhome, 'data/local/pre-train-lstm',
                '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color',
                                                       500))
            all_data = np.load(data_file_path)
            puck_pos = np.load(puck_pos_path)
            all_data = normalize_image(all_data.copy())
            obj_states = puck_pos
        else:
            all_data = np.concatenate([train_data_lstm, test_data_lstm],
                                      axis=0)
            all_data = normalize_image(all_data.copy())
            obj_states = info_lstm['obj_state']

        obj = 'door' if 'Door' in env_id else 'puck'

        num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[
            'num_vae_epochs']

        if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain):
            for epoch in range(num_epochs):
                should_save_imgs = (epoch % save_period == 0)
                if idx == 0:  # only LSTM trainer has 'only_train_vae' argument
                    t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only)
                    t.test_epoch(epoch,
                                 save_reconstruction=should_save_imgs,
                                 save_prefix='r_' + name,
                                 only_train_vae=lstm_pretrain_vae_only)
                else:
                    t.train_epoch(epoch)
                    t.test_epoch(
                        epoch,
                        save_reconstruction=should_save_imgs,
                        save_prefix='r_' + name,
                    )

                if should_save_imgs:
                    t.dump_samples(epoch, save_prefix='s_' + name)

                    if idx == 0:
                        compare_latent_distance(
                            m,
                            all_data,
                            obj_states,
                            obj_name=obj,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='lstm_latent_distance_{}.png'.format(
                                epoch))
                        test_lstm_traj(
                            env_id,
                            m,
                            save_path=logger.get_snapshot_dir(),
                            save_name='lstm_test_traj_{}.png'.format(epoch))
                        test_masked_traj_lstm(
                            env_id,
                            m,
                            save_dir=logger.get_snapshot_dir(),
                            save_name='masked_test_{}.png'.format(epoch))

                t.update_train_weights()

            logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle')

        logger.remove_tabular_output('{}_progress.csv'.format(name),
                                     relative_to_snapshot_dir=True)

        if idx == 0 and variant.get("only_train_lstm", False):
            exit()

    if return_data:
        return vaes, train_datas, test_datas
    return m