def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
        gpu_id = variant.get("gpu_id", None)
        if gpu_id is not None:
            ptu.set_device(gpu_id)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Пример #2
0
def experiment(variant):
    from railrl.core import logger
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_dataset(
    )
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        kwargs = variant['beta_schedule_kwargs']
        kwargs['y_values'][2] = variant['beta']
        kwargs['x_values'][1] = variant['flat_x']
        kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    output_scale=1
    if variant['algo_kwargs']['is_auto_encoder']:
        m = AutoEncoder(representation_size,
                train_data.shape[1],
                output_scale=output_scale,
                **variant['vae_kwargs']
                )
    else:
        m = VAE(representation_size,
                        train_data.shape[1],
                        output_scale=output_scale,
                        **variant['vae_kwargs']
                        )
    t = VAETrainer(train_data, test_data, m, beta=beta,
                       beta_schedule=beta_schedule, **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs'])
    uniform_dataset = load_local_or_remote_file(
        variant['uniform_dataset_path']).item()
    uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size,
                       decoder_output_activation=nn.Sigmoid(),
                       **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(
            m, uniform_dataset,
            variant['algo_kwargs']['priority_function_kwargs'])
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(
                    dataset=uniform_dataset, epoch=epoch)
        t.update_train_weights()
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    # train_data, test_data, info = generate_vae_dataset(
    #     **variant['get_data_kwargs']
    # )
    num_divisions = 5
    images = np.zeros((num_divisions * 10000, 21168))
    for i in range(num_divisions):
        imgs = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_images100000_' +
            str(i + 1) + '.npy')
        images[i * 10000:(i + 1) * 10000] = imgs
        print(i)
    mid = int(num_divisions * 10000 * .9)
    train_data, test_data = images[:mid], images[mid:]
    info = dict()

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        kwargs = variant['beta_schedule_kwargs']
        kwargs['y_values'][2] = variant['beta']
        kwargs['x_values'][1] = variant['flat_x']
        kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Пример #6
0
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    data = joblib.load(variant['file'])
    obs = data['obs']
    size = int(data['size'])
    dataset = obs[:size, :]
    n = int(size * .9)
    train_data = dataset[:n, :]
    test_data = dataset[n:, :]
    logger.get_snapshot_dir()
    print('SIZE: ', size)
    uniform_dataset = generate_uniform_dataset_door(
        **variant['generate_uniform_dataset_kwargs']
    )
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data, test_data, m, beta=beta,
                       beta_schedule=beta_schedule, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(uniform_dataset)
        t.test_epoch(epoch, save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch)
        t.update_train_weights()
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    #this has both states and images so can't use generate vae dataset
    X = np.load(
        '/home/murtaza/vae_data/sawyer_torque_control_ou_imgs_zoomed_out10000.npy'
    )
    Y = np.load(
        '/home/murtaza/vae_data/sawyer_torque_control_ou_states_zoomed_out10000.npy'
    )
    Y = np.concatenate((Y[:, :7], Y[:, 14:]), axis=1)
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.1)
    info = dict()
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                state_sim_debug=True,
                state_size=Y.shape[1],
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer((X_train, Y_train), (X_test, Y_test),
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       state_sim_debug=True,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Пример #8
0
def experiment(variant):
    if variant["use_gpu"]:
        gpu_id = variant["gpu_id"]
        ptu.set_gpu_mode(True)
        ptu.set_device(gpu_id)

    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data = get_data(10000)
    m = ConvVAE(representation_size, input_channels=3)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta_schedule=PiecewiseLinearSchedule([0, 400, 800],
                                                             [0.5, 0.5, beta]))
    for epoch in range(1001):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size, **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
        t.update_train_weights()
Пример #10
0
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
Пример #12
0
def experiment(variant):
    lmbda = variant['lmbda']
    gamma = variant['gamma']
    mu = variant['mu']

    beta = PiecewiseLinearSchedule([0, 2500, 3500], [0, 0, variant['beta']])
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['generate_vae_dataset_kwargs'])
    m = ACAI(representation_size, input_channels=3)
    t = ACAITrainer(train_data,
                    test_data,
                    m,
                    beta_schedule=beta,
                    gamma=gamma,
                    mu=mu,
                    lmbda=lmbda)
    for epoch in range(6001):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        if epoch % variant['save_period'] == 0:
            t.dump_samples(epoch)
Пример #13
0
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.conv_vae import (
        ConvVAE,
        SpatialAutoEncoder,
        AutoEncoder,
    )
    import railrl.torch.vae.conv_vae as conv_vae
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset_from_demos(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        m = AutoEncoder(representation_size,
                        decoder_output_activation=decoder_activation,
                        **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        raise NotImplementedError(
            'This is currently broken, please update SpatialAutoEncoder then remove this line'
        )
        m = SpatialAutoEncoder(representation_size,
                               int(representation_size / 2))
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        m = vae_class(representation_size,
                      decoder_output_activation=decoder_activation,
                      **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            save_scatterplot=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
            if dump_skew_debug_plots:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
class OnlineVaeRelabelingBuffer(ObsDictRelabelingBuffer):

    def __init__(
            self,
            vae,
            *args,
            decoded_obs_key='image_observation',
            decoded_achieved_goal_key='image_achieved_goal',
            decoded_desired_goal_key='image_desired_goal',
            exploration_rewards_type='None',
            exploration_rewards_scale=1.0,
            vae_priority_type='None',
            start_skew_epoch=0,
            power=1.0,
            internal_keys=None,
            exploration_schedule_kwargs=None,
            priority_function_kwargs=None,
            exploration_counter_kwargs=None,
            relabeling_goal_sampling_mode='vae_prior',
            decode_vae_goals=False,
            **kwargs
    ):
        if internal_keys is None:
            internal_keys = []

        for key in [
            decoded_obs_key,
            decoded_achieved_goal_key,
            decoded_desired_goal_key
        ]:
            if key not in internal_keys:
                internal_keys.append(key)
        super().__init__(internal_keys=internal_keys, *args, **kwargs)
        # assert isinstance(self.env, VAEWrappedEnv)
        self.vae = vae
        self.decoded_obs_key = decoded_obs_key
        self.decoded_desired_goal_key = decoded_desired_goal_key
        self.decoded_achieved_goal_key = decoded_achieved_goal_key
        self.exploration_rewards_type = exploration_rewards_type
        self.exploration_rewards_scale = exploration_rewards_scale
        self.start_skew_epoch = start_skew_epoch
        self.vae_priority_type = vae_priority_type
        self.power = power
        self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode
        self.decode_vae_goals = decode_vae_goals

        if exploration_schedule_kwargs is None:
            self.explr_reward_scale_schedule = \
                ConstantSchedule(self.exploration_rewards_scale)
        else:
            self.explr_reward_scale_schedule = \
                PiecewiseLinearSchedule(**exploration_schedule_kwargs)

        self._give_explr_reward_bonus = (
                exploration_rewards_type != 'None'
                and exploration_rewards_scale != 0.
        )
        self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32)
        self._prioritize_vae_samples = (
                vae_priority_type != 'None'
                and power != 0.
        )
        self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32)
        self._vae_sample_probs = None

        self.use_dynamics_model = (
                self.exploration_rewards_type == 'forward_model_error'
        )
        if self.use_dynamics_model:
            self.initialize_dynamics_model()

        type_to_function = {
            'reconstruction_error': self.reconstruction_mse,
            'bce': self.binary_cross_entropy,
            'latent_distance': self.latent_novelty,
            'latent_distance_true_prior': self.latent_novelty_true_prior,
            'forward_model_error': self.forward_model_error,
            'gaussian_inv_prob': self.gaussian_inv_prob,
            'bernoulli_inv_prob': self.bernoulli_inv_prob,
            'vae_prob': self.vae_prob,
            'hash_count': self.hash_count_reward,
            'None': self.no_reward,
        }

        self.exploration_reward_func = (
            type_to_function[self.exploration_rewards_type]
        )
        self.vae_prioritization_func = (
            type_to_function[self.vae_priority_type]
        )

        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.exploration_rewards_type == 'hash_count':
            if exploration_counter_kwargs is None:
                exploration_counter_kwargs = dict()
            self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs)
        self.epoch = 0

    def add_path(self, path):
        if self.decode_vae_goals:
            self.add_decoded_vae_goals_to_path(path)
        super().add_path(path)

    def add_decoded_vae_goals_to_path(self, path):
        # decoding the self-sampled vae images should be done in batch (here)
        # rather than in the env for efficiency
        desired_goals = flatten_dict(
            path['observations'],
            [self.desired_goal_key]
        )[self.desired_goal_key]
        desired_decoded_goals = self.env._decode(desired_goals)
        desired_decoded_goals = desired_decoded_goals.reshape(
            len(desired_decoded_goals),
            -1
        )
        for idx, next_obs in enumerate(path['observations']):
            path['observations'][idx][self.decoded_desired_goal_key] = \
                desired_decoded_goals[idx]
            path['next_observations'][idx][self.decoded_desired_goal_key] = \
                desired_decoded_goals[idx]

    def random_batch(self, batch_size):
        batch = super().random_batch(batch_size)
        exploration_rewards_scale = float(self.explr_reward_scale_schedule.get_value(self.epoch))
        if self._give_explr_reward_bonus:
            batch_idxs = batch['indices'].flatten()
            batch['exploration_rewards'] = self._exploration_rewards[batch_idxs]
            batch['rewards'] += exploration_rewards_scale * batch['exploration_rewards']
        return batch

    def get_diagnostics(self):
        if self._vae_sample_probs is None or self._vae_sample_priorities is None:
            stats = create_stats_ordered_dict(
                'VAE Sample Weights',
                np.zeros(self._size),
            )
            stats.update(create_stats_ordered_dict(
                'VAE Sample Probs',
                np.zeros(self._size),
            ))
        else:
            vae_sample_priorities = self._vae_sample_priorities[:self._size]
            vae_sample_probs = self._vae_sample_probs[:self._size]
            stats = create_stats_ordered_dict(
                'VAE Sample Weights',
                vae_sample_priorities,
            )
            stats.update(create_stats_ordered_dict(
                'VAE Sample Probs',
                vae_sample_probs,
            ))
        return stats

    def refresh_latents(self, epoch):
        self.epoch = epoch
        self.skew = (self.epoch > self.start_skew_epoch)
        batch_size = 512
        next_idx = min(batch_size, self._size)

        if self.exploration_rewards_type == 'hash_count':
            # you have to count everything then compute exploration rewards
            cur_idx = 0
            next_idx = min(batch_size, self._size)
            while cur_idx < self._size:
                idxs = np.arange(cur_idx, next_idx)
                normalized_imgs = self._next_obs[self.decoded_obs_key][idxs]
                self.update_hash_count(normalized_imgs)
                cur_idx = next_idx
                next_idx += batch_size
                next_idx = min(next_idx, self._size)

        cur_idx = 0
        obs_sum = np.zeros(self.vae.representation_size)
        obs_square_sum = np.zeros(self.vae.representation_size)
        while cur_idx < self._size:
            idxs = np.arange(cur_idx, next_idx)
            self._obs[self.observation_key][idxs] = \
                self.env._encode(self._obs[self.decoded_obs_key][idxs])
            self._next_obs[self.observation_key][idxs] = \
                self.env._encode(self._next_obs[self.decoded_obs_key][idxs])
            # WARNING: we only refresh the desired/achieved latents for
            # "next_obs". This means that obs[desired/achieve] will be invalid,
            # so make sure there's no code that references this.
            # TODO: enforce this with code and not a comment
            self._next_obs[self.desired_goal_key][idxs] = \
                self.env._encode(self._next_obs[self.decoded_desired_goal_key][idxs])
            self._next_obs[self.achieved_goal_key][idxs] = \
                self.env._encode(self._next_obs[self.decoded_achieved_goal_key][idxs])
            normalized_imgs = self._next_obs[self.decoded_obs_key][idxs]
            if self._give_explr_reward_bonus:
                rewards = self.exploration_reward_func(
                    normalized_imgs,
                    idxs,
                    **self.priority_function_kwargs
                )
                self._exploration_rewards[idxs] = rewards.reshape(-1, 1)
            if self._prioritize_vae_samples:
                if (
                        self.exploration_rewards_type == self.vae_priority_type
                        and self._give_explr_reward_bonus
                ):
                    self._vae_sample_priorities[idxs] = (
                        self._exploration_rewards[idxs]
                    )
                else:
                    self._vae_sample_priorities[idxs] = (
                        self.vae_prioritization_func(
                            normalized_imgs,
                            idxs,
                            **self.priority_function_kwargs
                        ).reshape(-1, 1)
                    )
            obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0)
            obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0)

            cur_idx = next_idx
            next_idx += batch_size
            next_idx = min(next_idx, self._size)
        self.vae.dist_mu = obs_sum/self._size
        self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2))

        if self._prioritize_vae_samples:
            """
            priority^power is calculated in the priority function
            for image_bernoulli_prob or image_gaussian_inv_prob and
            directly here if not.
            """
            if self.vae_priority_type == 'vae_prob':
                self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs(
                    self._vae_sample_priorities[:self._size]
                )
                self._vae_sample_probs = self._vae_sample_priorities[:self._size]
            else:
                self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power
            p_sum = np.sum(self._vae_sample_probs)
            assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum)
            self._vae_sample_probs /= np.sum(self._vae_sample_probs)
            self._vae_sample_probs = self._vae_sample_probs.flatten()

    def sample_weighted_indices(self, batch_size):
        if (
            self._prioritize_vae_samples and
            self._vae_sample_probs is not None and
            self.skew
        ):
            indices = np.random.choice(
                len(self._vae_sample_probs),
                batch_size,
                p=self._vae_sample_probs,
            )
            assert (
                np.max(self._vae_sample_probs) <= 1 and
                np.min(self._vae_sample_probs) >= 0
            )
        else:
            indices = self._sample_indices(batch_size)
        return indices

    def _sample_goals_from_env(self, batch_size):
        self.env.goal_sampling_mode = self._relabeling_goal_sampling_mode
        return self.env.sample_goals(batch_size)

    def sample_buffer_goals(self, batch_size):
        """
        Samples goals from weighted replay buffer for relabeling or exploration.
        Returns None if replay buffer is empty.

        Example of what might be returned:
        dict(
            image_desired_goals: image_achieved_goals[weighted_indices],
            latent_desired_goals: latent_desired_goals[weighted_indices],
        )
        """
        if self._size == 0:
            return None
        weighted_idxs = self.sample_weighted_indices(
            batch_size,
        )
        next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs]
        next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs]
        return {
            self.decoded_desired_goal_key:  next_image_obs,
            self.desired_goal_key:          next_latent_obs
        }

    def random_vae_training_data(self, batch_size, epoch):
        # epoch no longer needed. Using self.skew in sample_weighted_indices
        # instead.
        weighted_idxs = self.sample_weighted_indices(
            batch_size,
        )

        next_image_obs = self._next_obs[self.decoded_obs_key][weighted_idxs]
        observations = ptu.from_numpy(next_image_obs)
        return dict(
            observations=observations,
        )

    def reconstruction_mse(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        recon_next_vae_obs, _, _ = self.vae(torch_input)

        error = torch_input - recon_next_vae_obs
        mse = torch.sum(error ** 2, dim=1)
        return ptu.get_numpy(mse)

    def gaussian_inv_prob(self, next_vae_obs, indices):
        return np.exp(self.reconstruction_mse(next_vae_obs, indices))

    def binary_cross_entropy(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        recon_next_vae_obs, _, _ = self.vae(torch_input)

        error = - torch_input * torch.log(
            torch.clamp(
                recon_next_vae_obs,
                min=1e-30,  # corresponds to about -70
            )
        )
        bce = torch.sum(error, dim=1)
        return ptu.get_numpy(bce)

    def bernoulli_inv_prob(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        recon_next_vae_obs, _, _ = self.vae(torch_input)
        prob = (
                torch_input * recon_next_vae_obs
                + (1 - torch_input) * (1 - recon_next_vae_obs)
        ).prod(dim=1)
        return ptu.get_numpy(1 / prob)

    def vae_prob(self, next_vae_obs, indices, **kwargs):
        return compute_p_x_np_to_np(
            self.vae,
            next_vae_obs,
            power=self.power,
            **kwargs
        )

    def forward_model_error(self, next_vae_obs, indices):
        obs = self._obs[self.observation_key][indices]
        next_obs = self._next_obs[self.observation_key][indices]
        actions = self._actions[indices]

        state_action_pair = ptu.from_numpy(np.c_[obs, actions])
        prediction = self.dynamics_model(state_action_pair)
        mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs))
        return ptu.get_numpy(mse)

    def latent_novelty(self, next_vae_obs, indices):
        distances = ((self.env._encode(next_vae_obs) - self.vae.dist_mu) /
                     self.vae.dist_std) ** 2
        return distances.sum(axis=1)

    def latent_novelty_true_prior(self, next_vae_obs, indices):
        distances = self.env._encode(next_vae_obs) ** 2
        return distances.sum(axis=1)

    def _kl_np_to_np(self, next_vae_obs, indices):
        torch_input = ptu.from_numpy(next_vae_obs)
        mu, log_var = self.vae.encode(torch_input)
        return ptu.get_numpy(
            - torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
        )

    def update_hash_count(self, next_vae_obs):
        torch_input = ptu.from_numpy(next_vae_obs)
        mus, log_vars = self.vae.encode(torch_input)
        mus = ptu.get_numpy(mus)
        self.exploration_counter.increment_counts(mus)
        return None

    def hash_count_reward(self, next_vae_obs, indices):
        obs = self.env._encode(next_vae_obs)
        return self.exploration_counter.compute_count_based_reward(obs)

    def no_reward(self, next_vae_obs, indices):
        return np.zeros((len(next_vae_obs), 1))

    def initialize_dynamics_model(self):
        obs_dim = self._obs[self.observation_key].shape[1]
        self.dynamics_model = Mlp(
            hidden_sizes=[128, 128],
            output_size=obs_dim,
            input_size=obs_dim + self._action_dim,
        )
        self.dynamics_model.to(ptu.device)
        self.dynamics_optimizer = Adam(self.dynamics_model.parameters())
        self.dynamics_loss = MSELoss()

    def train_dynamics_model(self, batches=50, batch_size=100):
        if not self.use_dynamics_model:
            return
        for _ in range(batches):
            indices = self._sample_indices(batch_size)
            self.dynamics_optimizer.zero_grad()
            obs = self._obs[self.observation_key][indices]
            next_obs = self._next_obs[self.observation_key][indices]
            actions = self._actions[indices]
            if self.exploration_rewards_type == 'inverse_model_error':
                obs, next_obs = next_obs, obs

            state_action_pair = ptu.from_numpy(np.c_[obs, actions])
            prediction = self.dynamics_model(state_action_pair)
            mse = self.dynamics_loss(prediction, ptu.from_numpy(next_obs))

            mse.backward()
            self.dynamics_optimizer.step()

    def log_loss_under_uniform(self, model, data, batch_size, rl_logger, priority_function_kwargs):
        import torch.nn.functional as F
        log_probs_prior = []
        log_probs_biased = []
        log_probs_importance = []
        kles = []
        mses = []
        for i in range(0, data.shape[0], batch_size):
            img = data[i:min(data.shape[0], i + batch_size), :]
            torch_img = ptu.from_numpy(img)
            reconstructions, obs_distribution_params, latent_distribution_params = self.vae(torch_img)

            priority_function_kwargs['sampling_method'] = 'true_prior_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs)
            log_prob_prior = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'biased_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs)
            log_prob_biased = log_d.mean()

            priority_function_kwargs['sampling_method'] = 'importance_sampling'
            log_p, log_q, log_d = compute_log_p_log_q_log_d(model, img, **priority_function_kwargs)
            log_prob_importance = (log_p - log_q + log_d).mean()

            kle = model.kl_divergence(latent_distribution_params)
            mse = F.mse_loss(torch_img, reconstructions, reduction='elementwise_mean')
            mses.append(mse.item())
            kles.append(kle.item())
            log_probs_prior.append(log_prob_prior.item())
            log_probs_biased.append(log_prob_biased.item())
            log_probs_importance.append(log_prob_importance.item())

        rl_logger["Uniform Data Log Prob (Prior)"] = np.mean(log_probs_prior)
        rl_logger["Uniform Data Log Prob (Biased)"] = np.mean(log_probs_biased)
        rl_logger["Uniform Data Log Prob (Importance)"] = np.mean(log_probs_importance)
        rl_logger["Uniform Data KL"] = np.mean(kles)
        rl_logger["Uniform Data MSE"] = np.mean(mses)

    def _get_sorted_idx_and_train_weights(self):
        idx_and_weights = zip(range(len(self._vae_sample_probs)),
                              self._vae_sample_probs)
        return sorted(idx_and_weights, key=lambda x: x[1])
Пример #15
0
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule
    from railrl.torch.vae.conv_vae import (
        ConvVAE,
        SpatialAutoEncoder,
        AutoEncoder,
    )
    import railrl.torch.vae.conv_vae as conv_vae
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant.get("representation_size",
                                      variant.get("latent_sizes", None))
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    variant['generate_vae_dataset_kwargs']['batch_size'] = variant[
        'algo_kwargs']['batch_size']
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if 'context_schedule' in variant:
        schedule = variant['context_schedule']
        if type(schedule) is dict:
            context_schedule = PiecewiseLinearSchedule(**schedule)
        else:
            context_schedule = ConstantSchedule(schedule)
        variant['algo_kwargs']['context_schedule'] = context_schedule
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        model = AutoEncoder(representation_size,
                            decoder_output_activation=decoder_activation,
                            **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        model = SpatialAutoEncoder(
            representation_size,
            decoder_output_activation=decoder_activation,
            **variant['vae_kwargs'])
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        if use_linear_dynamics:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              action_dim=action_dim,
                              **variant['vae_kwargs'])
        else:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              **variant['vae_kwargs'])
    model.to(ptu.device)

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
    def __init__(
            self,
            vae,
            *args,
            decoded_obs_key='image_observation',
            decoded_achieved_goal_key='image_achieved_goal',
            decoded_desired_goal_key='image_desired_goal',
            exploration_rewards_type='None',
            exploration_rewards_scale=1.0,
            vae_priority_type='None',
            start_skew_epoch=0,
            power=1.0,
            internal_keys=None,
            exploration_schedule_kwargs=None,
            priority_function_kwargs=None,
            exploration_counter_kwargs=None,
            relabeling_goal_sampling_mode='vae_prior',
            decode_vae_goals=False,
            **kwargs
    ):
        if internal_keys is None:
            internal_keys = []

        for key in [
            decoded_obs_key,
            decoded_achieved_goal_key,
            decoded_desired_goal_key
        ]:
            if key not in internal_keys:
                internal_keys.append(key)
        super().__init__(internal_keys=internal_keys, *args, **kwargs)
        # assert isinstance(self.env, VAEWrappedEnv)
        self.vae = vae
        self.decoded_obs_key = decoded_obs_key
        self.decoded_desired_goal_key = decoded_desired_goal_key
        self.decoded_achieved_goal_key = decoded_achieved_goal_key
        self.exploration_rewards_type = exploration_rewards_type
        self.exploration_rewards_scale = exploration_rewards_scale
        self.start_skew_epoch = start_skew_epoch
        self.vae_priority_type = vae_priority_type
        self.power = power
        self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode
        self.decode_vae_goals = decode_vae_goals

        if exploration_schedule_kwargs is None:
            self.explr_reward_scale_schedule = \
                ConstantSchedule(self.exploration_rewards_scale)
        else:
            self.explr_reward_scale_schedule = \
                PiecewiseLinearSchedule(**exploration_schedule_kwargs)

        self._give_explr_reward_bonus = (
                exploration_rewards_type != 'None'
                and exploration_rewards_scale != 0.
        )
        self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32)
        self._prioritize_vae_samples = (
                vae_priority_type != 'None'
                and power != 0.
        )
        self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32)
        self._vae_sample_probs = None

        self.use_dynamics_model = (
                self.exploration_rewards_type == 'forward_model_error'
        )
        if self.use_dynamics_model:
            self.initialize_dynamics_model()

        type_to_function = {
            'reconstruction_error': self.reconstruction_mse,
            'bce': self.binary_cross_entropy,
            'latent_distance': self.latent_novelty,
            'latent_distance_true_prior': self.latent_novelty_true_prior,
            'forward_model_error': self.forward_model_error,
            'gaussian_inv_prob': self.gaussian_inv_prob,
            'bernoulli_inv_prob': self.bernoulli_inv_prob,
            'vae_prob': self.vae_prob,
            'hash_count': self.hash_count_reward,
            'None': self.no_reward,
        }

        self.exploration_reward_func = (
            type_to_function[self.exploration_rewards_type]
        )
        self.vae_prioritization_func = (
            type_to_function[self.vae_priority_type]
        )

        if priority_function_kwargs is None:
            self.priority_function_kwargs = dict()
        else:
            self.priority_function_kwargs = priority_function_kwargs

        if self.exploration_rewards_type == 'hash_count':
            if exploration_counter_kwargs is None:
                exploration_counter_kwargs = dict()
            self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs)
        self.epoch = 0
Пример #17
0
def train_vae(variant):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.conv_vae import ConvVAE
    from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import local_path_from_s3_or_local_path

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)

    env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None)
    if env_id is not None:
        import gym
        env = gym.make(env_id)
    else:
        env_class = variant['generate_vae_dataset_kwargs']['env_class']
        env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs']
        env = env_class(**env_kwargs)

    representation_size = variant["representation_size"]
    beta = variant["beta"]
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    # obtain training and testing data
    dataset_path = variant['generate_vae_dataset_kwargs'].get(
        'dataset_path', None)
    test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9)
    filename = local_path_from_s3_or_local_path(dataset_path)
    dataset = np.load(filename, allow_pickle=True).item()
    N = dataset['obs'].shape[0]
    n = int(N * test_p)
    train_data = {}
    test_data = {}
    for k in dataset.keys():
        train_data[k] = dataset[k][:n, :]
        test_data[k] = dataset[k][n:, :]

    # setup vae
    variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1]
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae import VAE
        input_size = train_data['obs'].shape[1]
        variant['vae_kwargs']['input_size'] = input_size
        m = VAE(representation_size, **variant['vae_kwargs'])
    elif variant.get('vae_type', None) == "VAE2":
        from railrl.torch.vae.conv_vae2 import ConvVAE2
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE2(representation_size, **variant['vae_kwargs'])
    else:
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE(representation_size, **variant['vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()

    # setup vae trainer
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae_trainer import VAETrainer
        t = VAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    else:
        t = ConvVAETrainer(train_data,
                           test_data,
                           m,
                           beta=beta,
                           beta_schedule=beta_schedule,
                           **variant['algo_kwargs'])

    # visualization
    vis_variant = variant.get('vis_kwargs', {})
    save_video = vis_variant.get('save_video', False)
    if isinstance(env, ImageEnv):
        image_env = env
    else:
        image_env = ImageEnv(
            env,
            variant['generate_vae_dataset_kwargs'].get('imsize'),
            init_camera=variant['generate_vae_dataset_kwargs'].get(
                'init_camera'),
            transpose=True,
            normalize=True,
        )
    render = variant.get('render', False)
    reward_params = variant.get("reward_params", dict())
    vae_env = VAEWrappedEnv(image_env,
                            m,
                            imsize=image_env.imsize,
                            decode_goals=render,
                            render_goals=render,
                            render_rollouts=render,
                            reward_params=reward_params,
                            **variant.get('vae_wrapped_env_kwargs', {}))
    vae_env.reset()
    vae_env.add_mode("video_env", 'video_env')
    vae_env.add_mode("video_vae", 'video_vae')
    if save_video:
        import railrl.samplers.rollout_functions as rf
        from railrl.policies.simple import RandomPolicy
        random_policy = RandomPolicy(vae_env.action_space)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=100,
            observation_key='latent_observation',
            desired_goal_key='latent_desired_goal',
            vis_list=vis_variant.get('vis_list', []),
            dont_terminate=True,
        )

        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['imsize'] = vae_env.imsize
        dump_video_kwargs['vis_list'] = [
            'image_observation',
            'reconstr_image_observation',
            'image_latent_histogram_2d',
            'image_latent_histogram_mu_2d',
            'image_plt',
            'image_rew',
            'image_rew_euclidean',
            'image_rew_mahalanobis',
            'image_rew_logp',
            'image_rew_kl',
            'image_rew_kl_rev',
        ]

    def visualization_post_processing(save_vis, save_video, epoch):
        vis_list = vis_variant.get('vis_list', [])

        if save_vis:
            if vae_env.vae_input_key_prefix == 'state':
                vae_env.dump_reconstructions(epoch,
                                             n_recon=vis_variant.get(
                                                 'n_recon', 16))
            vae_env.dump_samples(epoch,
                                 n_samples=vis_variant.get('n_samples', 64))
            if 'latent_representation' in vis_list:
                vae_env.dump_latent_plots(epoch)
            if any(elem in vis_list for elem in [
                    'latent_histogram', 'latent_histogram_mu',
                    'latent_histogram_2d', 'latent_histogram_mu_2d'
            ]):
                vae_env.compute_latent_histogram()
            if not save_video and ('latent_histogram' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=True,
                                              use_true_prior=True)
            if not save_video and ('latent_histogram_mu' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=False,
                                              use_true_prior=True)

        if save_video and save_vis:
            from railrl.envs.vae_wrappers import temporary_mode
            from railrl.misc.video_gen import dump_video
            from railrl.core import logger

            vae_env.compute_goal_encodings()

            logdir = logger.get_snapshot_dir()
            filename = osp.join(logdir,
                                'video_{epoch}.mp4'.format(epoch=epoch))
            variant['dump_video_kwargs']['epoch'] = epoch
            temporary_mode(vae_env,
                           mode='video_env',
                           func=dump_video,
                           args=(vae_env, random_policy, filename,
                                 rollout_function),
                           kwargs=variant['dump_video_kwargs'])
            if not vis_variant.get('save_video_env_only', True):
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(vae_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(vae_env, random_policy, filename,
                                     rollout_function),
                               kwargs=variant['dump_video_kwargs'])

    # train vae
    for epoch in range(variant['num_epochs']):
        #for epoch in range(2000):
        save_vis = (epoch % vis_variant['save_period'] == 0
                    or epoch == variant['num_epochs'] - 1)
        save_vae = (epoch % variant['snapshot_gap'] == 0
                    or epoch == variant['num_epochs'] - 1)

        t.train_epoch(epoch)
        '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1:
           t.test_epoch(
                epoch,
                save_reconstruction=save_vis,
                save_interpolation=save_vis,
                save_vae=save_vae,
            )
        if epoch % 200 == 0 or epoch == variant['num_epochs']-1:
            visualization_post_processing(save_video, save_video, epoch)'''

        t.test_epoch(
            epoch,
            save_reconstruction=save_vis,
            save_interpolation=save_vis,
            save_vae=save_vae,
        )
        if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1:
            visualization_post_processing(save_vis, save_video, epoch)

    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output(
        'vae_progress.csv',
        relative_to_snapshot_dir=True,
    )
    logger.add_tabular_output(
        'progress.csv',
        relative_to_snapshot_dir=True,
    )
    print("finished --------------------!!!!!!!!!!!!!!!")

    return m