def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
        gpu_id = variant.get("gpu_id", None)
        if gpu_id is not None:
            ptu.set_device(gpu_id)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs']
    )
    uniform_dataset=generate_uniform_dataset_reacher(
       **variant['generate_uniform_dataset_kwargs']
    )
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = None
    m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data, test_data, m, beta=beta,
                       beta_schedule=beta_schedule, **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(m, uniform_dataset)
        t.test_epoch(epoch, save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch)
        if epoch % variant['train_weight_update_period'] == 0:
            t.update_train_weights()
예제 #3
0
def experiment(variant):
    from railrl.core import logger
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_dataset(
    )
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        kwargs = variant['beta_schedule_kwargs']
        kwargs['y_values'][2] = variant['beta']
        kwargs['x_values'][1] = variant['flat_x']
        kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    output_scale=1
    if variant['algo_kwargs']['is_auto_encoder']:
        m = AutoEncoder(representation_size,
                train_data.shape[1],
                output_scale=output_scale,
                **variant['vae_kwargs']
                )
    else:
        m = VAE(representation_size,
                        train_data.shape[1],
                        output_scale=output_scale,
                        **variant['vae_kwargs']
                        )
    t = VAETrainer(train_data, test_data, m, beta=beta,
                       beta_schedule=beta_schedule, **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
예제 #4
0
def train_vae_and_update_variant(variant):
    from railrl.core import logger
    grill_variant = variant['grill_variant']
    train_vae_variant = variant['train_vae_variant']
    if grill_variant.get('vae_path', None) is None:
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('vae_progress.csv',
                                  relative_to_snapshot_dir=True)
        vae, vae_train_data, vae_test_data = train_vae(train_vae_variant,
                                                       return_data=True)
        if grill_variant.get('save_vae_data', False):
            grill_variant['vae_train_data'] = vae_train_data
            grill_variant['vae_test_data'] = vae_test_data
        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        grill_variant['vae_path'] = vae  # just pass the VAE directly
    else:
        if grill_variant.get('save_vae_data', False):
            vae_train_data, vae_test_data, info = generate_vae_dataset(
                train_vae_variant['generate_vae_dataset_kwargs'])
            grill_variant['vae_train_data'] = vae_train_data
            grill_variant['vae_test_data'] = vae_test_data
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    beta = variant["beta"]
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])
    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]
    else:
        action_dim = 0
    model = get_vae(variant, action_dim)

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
예제 #6
0
    def _try_to_eval(self, epoch, eval_paths=None):
        logger.save_extra_data(self.get_extra_data_to_save(epoch))

        params = self.get_epoch_snapshot(epoch)
        logger.save_itr_params(epoch, params)

        if self._can_evaluate():
            self.evaluate(epoch, eval_paths=eval_paths)

            # params = self.get_epoch_snapshot(epoch)
            # logger.save_itr_params(epoch, params)
            table_keys = logger.get_table_key_set()
            if self._old_table_keys is not None:
                assert table_keys == self._old_table_keys, (
                    "Table keys cannot change from iteration to iteration."
                )
            self._old_table_keys = table_keys

            logger.record_tabular(
                "Number of train steps total",
                self._n_train_steps_total,
            )
            logger.record_tabular(
                "Number of env steps total",
                self._n_env_steps_total,
            )
            logger.record_tabular(
                "Number of rollouts total",
                self._n_rollouts_total,
            )

            if self.collection_mode != 'online-parallel':
                times_itrs = gt.get_times().stamps.itrs
                train_time = times_itrs['train'][-1]
                sample_time = times_itrs['sample'][-1]
                if 'eval' in times_itrs:
                    eval_time = times_itrs['eval'][-1] if epoch > 0 else -1
                else:
                    eval_time = -1
                epoch_time = train_time + sample_time + eval_time
                total_time = gt.get_times().total

                logger.record_tabular('Train Time (s)', train_time)
                logger.record_tabular('(Previous) Eval Time (s)', eval_time)
                logger.record_tabular('Sample Time (s)', sample_time)
                logger.record_tabular('Epoch Time (s)', epoch_time)
                logger.record_tabular('Total Train Time (s)', total_time)
            else:
                logger.record_tabular('Epoch Time (s)',
                                      time.time() - self._epoch_start_time)
            logger.record_tabular("Epoch", epoch)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
        else:
            logger.log("Skipping eval for now.")
예제 #7
0
 def _try_to_offline_eval(self, epoch):
     start_time = time.time()
     logger.save_extra_data(self.get_extra_data_to_save(epoch))
     self.offline_evaluate(epoch)
     params = self.get_epoch_snapshot(epoch)
     logger.save_itr_params(epoch, params)
     table_keys = logger.get_table_key_set()
     if self._old_table_keys is not None:
         assert table_keys == self._old_table_keys, (
             "Table keys cannot change from iteration to iteration.")
     self._old_table_keys = table_keys
     logger.dump_tabular(with_prefix=False, with_timestamp=False)
     logger.log("Eval Time: {0}".format(time.time() - start_time))
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    # train_data, test_data, info = generate_vae_dataset(
    #     **variant['get_data_kwargs']
    # )
    num_divisions = 5
    images = np.zeros((num_divisions * 10000, 21168))
    for i in range(num_divisions):
        imgs = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_images100000_' +
            str(i + 1) + '.npy')
        images[i * 10000:(i + 1) * 10000] = imgs
        print(i)
    mid = int(num_divisions * 10000 * .9)
    train_data, test_data = images[:mid], images[mid:]
    info = dict()

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        kwargs = variant['beta_schedule_kwargs']
        kwargs['y_values'][2] = variant['beta']
        kwargs['x_values'][1] = variant['flat_x']
        kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs'])
    uniform_dataset = load_local_or_remote_file(
        variant['uniform_dataset_path']).item()
    uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size,
                       decoder_output_activation=nn.Sigmoid(),
                       **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.log_loss_under_uniform(
            m, uniform_dataset,
            variant['algo_kwargs']['priority_function_kwargs'])
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
                t.dump_uniform_imgs_and_reconstructions(
                    dataset=uniform_dataset, epoch=epoch)
        t.update_train_weights()
def experiment(variant):
    num_feat_points=variant['feat_points']
    from railrl.core import logger
    beta = variant["beta"]
    print('collecting data')
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    print('finish collecting data')
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    m = SpatialAutoEncoder(2 * num_feat_points, num_feat_points, input_channels=3)
#    m = ConvVAE(2*num_feat_points, input_channels=3)
    t = ConvVAETrainer(train_data, test_data, m,  lr=variant['lr'], beta=beta)
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    info = dict()
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    num_divisions = variant['num_divisions']
    images = np.zeros((num_divisions * 10000, 21168))
    states = np.zeros((num_divisions * 10000, 7))
    for i in range(num_divisions):
        imgs = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_images100000_' +
            str(i + 1) + '.npy')
        state = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_states100000_' +
            str(i + 1) + '.npy')[:, :7] % (2 * np.pi)
        images[i * 10000:(i + 1) * 10000] = imgs
        states[i * 10000:(i + 1) * 10000] = state
        print(i)
    if variant['normalize']:
        std = np.std(states, axis=0)
        mu = np.mean(states, axis=0)
        states = np.divide((states - mu), std)
        print(mu, std)
    net = FlattenMlp(input_size=32,
                     hidden_sizes=variant['hidden_sizes'],
                     output_size=states.shape[1])
    vae = variant['vae']
    vae.cuda()
    tensor = ptu.np_to_var(images)
    images, log_var = vae.encode(tensor)
    images = ptu.get_numpy(images)
    mid = int(num_divisions * 10000 * .9)
    train_images, test_images = images[:mid], images[mid:]
    train_labels, test_labels = states[:mid], states[mid:]

    algo = SupervisedAlgorithm(train_images,
                               test_images,
                               train_labels,
                               test_labels,
                               net,
                               batch_size=variant['batch_size'],
                               lr=variant['lr'],
                               weight_decay=variant['weight_decay'])
    for epoch in range(variant['num_epochs']):
        algo.train_epoch(epoch)
        algo.test_epoch(epoch)
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    #this has both states and images so can't use generate vae dataset
    X = np.load(
        '/home/murtaza/vae_data/sawyer_torque_control_ou_imgs_zoomed_out10000.npy'
    )
    Y = np.load(
        '/home/murtaza/vae_data/sawyer_torque_control_ou_states_zoomed_out10000.npy'
    )
    Y = np.concatenate((Y[:, :7], Y[:, 14:]), axis=1)
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.1)
    info = dict()
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                state_sim_debug=True,
                state_size=Y.shape[1],
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer((X_train, Y_train), (X_test, Y_test),
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       state_sim_debug=True,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
예제 #13
0
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    ptu.set_gpu_mode(True)
    info = dict()
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    net = CNN(**variant['cnn_kwargs'])
    net.cuda()
    num_divisions = variant['num_divisions']
    images = np.zeros((num_divisions * 10000, 21168))
    states = np.zeros((num_divisions * 10000, 7))
    for i in range(num_divisions):
        imgs = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_images100000_' +
            str(i + 1) + '.npy')
        state = np.load(
            '/home/murtaza/vae_data/sawyer_torque_control_states100000_' +
            str(i + 1) + '.npy')[:, :7] % (2 * np.pi)
        images[i * 10000:(i + 1) * 10000] = imgs
        states[i * 10000:(i + 1) * 10000] = state
        print(i)
    if variant['normalize']:
        std = np.std(states, axis=0)
        mu = np.mean(states, axis=0)
        states = np.divide((states - mu), std)
        print(mu, std)
    mid = int(num_divisions * 10000 * .9)
    train_images, test_images = images[:mid], images[mid:]
    train_labels, test_labels = states[:mid], states[mid:]

    algo = SupervisedAlgorithm(train_images,
                               test_images,
                               train_labels,
                               test_labels,
                               net,
                               batch_size=variant['batch_size'],
                               lr=variant['lr'],
                               weight_decay=variant['weight_decay'])
    for epoch in range(variant['num_epochs']):
        algo.train_epoch(epoch)
        algo.test_epoch(epoch)
예제 #14
0
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = variant['generate_vae_dataset_fn'](
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = variant['vae'](representation_size, **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
            if variant['dump_skew_debug_plots']:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
        t.update_train_weights()
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset(
        **variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        # kwargs = variant['beta_schedule_kwargs']
        # kwargs['y_values'][2] = variant['beta']
        # kwargs['x_values'][1] = variant['flat_x']
        # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x']
        variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta']
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    m = ConvVAE(representation_size,
                input_channels=3,
                **variant['conv_vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch,
                     save_reconstruction=should_save_imgs,
                     save_scatterplot=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
예제 #16
0
def experiment(variant):
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = get_data(**variant['get_data_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs'])
    m = ConvVAE(representation_size, input_channels=3)
    if ptu.gpu_enabled():
        m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    for epoch in range(variant['num_epochs']):
        t.train_epoch(epoch)
        t.test_epoch(epoch)
        t.dump_samples(epoch)
예제 #17
0
def train_reprojection_network(variant):
    from railrl.torch.vae.reprojection_network import (
        ReprojectionNetwork,
        ReprojectionNetworkTrainer,
    )
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu

    logger.get_snapshot_dir()

    vae = variant['vae']

    generate_reprojection_network_dataset_kwargs = variant.get(
        "generate_reprojection_network_dataset_kwargs", {})
    generate_reprojection_network_dataset_kwargs['vae'] = vae
    train_data, test_data = generate_reprojection_network_dataset(
        generate_reprojection_network_dataset_kwargs)

    reprojection_network_kwargs = variant.get("reprojection_network_kwargs",
                                              {})
    m = ReprojectionNetwork(vae, **reprojection_network_kwargs)
    if ptu.gpu_enabled():
        m.cuda()

    algo_kwargs = variant.get("algo_kwargs", {})
    t = ReprojectionNetworkTrainer(train_data, test_data, m, **algo_kwargs)

    num_epochs = variant.get('num_epochs', 5000)
    for epoch in range(num_epochs):
        should_save_network = (epoch % 250 == 0 or epoch == num_epochs - 1)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_network=should_save_network,
        )
    logger.save_extra_data(m, 'reproj_network.pkl', mode='pickle')
    return m
예제 #18
0
def grill_her_full_experiment(variant, mode='td3'):
    train_vae_variant = variant['train_vae_variant']
    grill_variant = variant['grill_variant']
    env_class = variant['env_class']
    env_kwargs = variant['env_kwargs']
    init_camera = variant['init_camera']
    train_vae_variant['generate_vae_dataset_kwargs']['env_class'] = env_class
    train_vae_variant['generate_vae_dataset_kwargs']['env_kwargs'] = env_kwargs
    train_vae_variant['generate_vae_dataset_kwargs']['init_camera'] = init_camera
    grill_variant['env_class'] = env_class
    grill_variant['env_kwargs'] = env_kwargs
    grill_variant['init_camera'] = init_camera
    if 'vae_paths' not in grill_variant:
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'vae_progress.csv', relative_to_snapshot_dir=True
        )
        vae = train_vae(train_vae_variant)
        rdim = train_vae_variant['representation_size']
        vae_file = logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        grill_variant['vae_paths'] = {
            str(rdim): vae_file,
        }
        grill_variant['rdim'] = str(rdim)
    if mode == 'td3':
        grill_her_td3_experiment(variant['grill_variant'])
    elif mode == 'twin-sac':
        grill_her_twin_sac_experiment(variant['grill_variant'])
    elif mode == 'sac':
        grill_her_sac_experiment(variant['grill_variant'])
예제 #19
0
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule
    from railrl.torch.vae.conv_vae import (
        ConvVAE,
        SpatialAutoEncoder,
        AutoEncoder,
    )
    import railrl.torch.vae.conv_vae as conv_vae
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant.get("representation_size",
                                      variant.get("latent_sizes", None))
    use_linear_dynamics = variant.get('use_linear_dynamics', False)
    generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn',
                                            generate_vae_dataset)
    variant['generate_vae_dataset_kwargs'][
        'use_linear_dynamics'] = use_linear_dynamics
    variant['generate_vae_dataset_kwargs']['batch_size'] = variant[
        'algo_kwargs']['batch_size']
    train_dataset, test_dataset, info = generate_vae_dataset_fctn(
        variant['generate_vae_dataset_kwargs'])

    if use_linear_dynamics:
        action_dim = train_dataset.data['actions'].shape[2]

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if 'context_schedule' in variant:
        schedule = variant['context_schedule']
        if type(schedule) is dict:
            context_schedule = PiecewiseLinearSchedule(**schedule)
        else:
            context_schedule = ConstantSchedule(schedule)
        variant['algo_kwargs']['context_schedule'] = context_schedule
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        model = AutoEncoder(representation_size,
                            decoder_output_activation=decoder_activation,
                            **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        model = SpatialAutoEncoder(
            representation_size,
            decoder_output_activation=decoder_activation,
            **variant['vae_kwargs'])
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        if use_linear_dynamics:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              action_dim=action_dim,
                              **variant['vae_kwargs'])
        else:
            model = vae_class(representation_size,
                              decoder_output_activation=decoder_activation,
                              **variant['vae_kwargs'])
    model.to(ptu.device)

    vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer)
    trainer = vae_trainer_class(model,
                                beta=beta,
                                beta_schedule=beta_schedule,
                                **variant['algo_kwargs'])
    save_period = variant['save_period']

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
            if dump_skew_debug_plots:
                trainer.dump_best_reconstruction(epoch)
                trainer.dump_worst_reconstruction(epoch)
                trainer.dump_sampling_histogram(epoch)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
예제 #20
0
def train_vae(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.conv_vae import (
        ConvVAE,
        SpatialAutoEncoder,
        AutoEncoder,
    )
    import railrl.torch.vae.conv_vae as conv_vae
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    beta = variant["beta"]
    representation_size = variant["representation_size"]
    train_data, test_data, info = generate_vae_dataset_from_demos(
        variant['generate_vae_dataset_kwargs'])
    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['vae_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['vae_kwargs']['architecture'] = architecture
    variant['vae_kwargs']['imsize'] = variant.get('imsize')

    if variant['algo_kwargs'].get('is_auto_encoder', False):
        m = AutoEncoder(representation_size,
                        decoder_output_activation=decoder_activation,
                        **variant['vae_kwargs'])
    elif variant.get('use_spatial_auto_encoder', False):
        raise NotImplementedError(
            'This is currently broken, please update SpatialAutoEncoder then remove this line'
        )
        m = SpatialAutoEncoder(representation_size,
                               int(representation_size / 2))
    else:
        vae_class = variant.get('vae_class', ConvVAE)
        m = vae_class(representation_size,
                      decoder_output_activation=decoder_activation,
                      **variant['vae_kwargs'])
    m.to(ptu.device)
    t = ConvVAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    save_period = variant['save_period']
    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(
            epoch,
            save_reconstruction=should_save_imgs,
            save_scatterplot=should_save_imgs,
            # save_vae=False,
        )
        if should_save_imgs:
            t.dump_samples(epoch)
            if dump_skew_debug_plots:
                t.dump_best_reconstruction(epoch)
                t.dump_worst_reconstruction(epoch)
                t.dump_sampling_histogram(epoch)
        t.update_train_weights()
    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    if return_data:
        return m, train_data, test_data
    return m
예제 #21
0
def train_vae(variant):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.conv_vae import ConvVAE
    from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import local_path_from_s3_or_local_path

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)

    env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None)
    if env_id is not None:
        import gym
        env = gym.make(env_id)
    else:
        env_class = variant['generate_vae_dataset_kwargs']['env_class']
        env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs']
        env = env_class(**env_kwargs)

    representation_size = variant["representation_size"]
    beta = variant["beta"]
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    # obtain training and testing data
    dataset_path = variant['generate_vae_dataset_kwargs'].get(
        'dataset_path', None)
    test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9)
    filename = local_path_from_s3_or_local_path(dataset_path)
    dataset = np.load(filename, allow_pickle=True).item()
    N = dataset['obs'].shape[0]
    n = int(N * test_p)
    train_data = {}
    test_data = {}
    for k in dataset.keys():
        train_data[k] = dataset[k][:n, :]
        test_data[k] = dataset[k][n:, :]

    # setup vae
    variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1]
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae import VAE
        input_size = train_data['obs'].shape[1]
        variant['vae_kwargs']['input_size'] = input_size
        m = VAE(representation_size, **variant['vae_kwargs'])
    elif variant.get('vae_type', None) == "VAE2":
        from railrl.torch.vae.conv_vae2 import ConvVAE2
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE2(representation_size, **variant['vae_kwargs'])
    else:
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE(representation_size, **variant['vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()

    # setup vae trainer
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae_trainer import VAETrainer
        t = VAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    else:
        t = ConvVAETrainer(train_data,
                           test_data,
                           m,
                           beta=beta,
                           beta_schedule=beta_schedule,
                           **variant['algo_kwargs'])

    # visualization
    vis_variant = variant.get('vis_kwargs', {})
    save_video = vis_variant.get('save_video', False)
    if isinstance(env, ImageEnv):
        image_env = env
    else:
        image_env = ImageEnv(
            env,
            variant['generate_vae_dataset_kwargs'].get('imsize'),
            init_camera=variant['generate_vae_dataset_kwargs'].get(
                'init_camera'),
            transpose=True,
            normalize=True,
        )
    render = variant.get('render', False)
    reward_params = variant.get("reward_params", dict())
    vae_env = VAEWrappedEnv(image_env,
                            m,
                            imsize=image_env.imsize,
                            decode_goals=render,
                            render_goals=render,
                            render_rollouts=render,
                            reward_params=reward_params,
                            **variant.get('vae_wrapped_env_kwargs', {}))
    vae_env.reset()
    vae_env.add_mode("video_env", 'video_env')
    vae_env.add_mode("video_vae", 'video_vae')
    if save_video:
        import railrl.samplers.rollout_functions as rf
        from railrl.policies.simple import RandomPolicy
        random_policy = RandomPolicy(vae_env.action_space)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=100,
            observation_key='latent_observation',
            desired_goal_key='latent_desired_goal',
            vis_list=vis_variant.get('vis_list', []),
            dont_terminate=True,
        )

        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['imsize'] = vae_env.imsize
        dump_video_kwargs['vis_list'] = [
            'image_observation',
            'reconstr_image_observation',
            'image_latent_histogram_2d',
            'image_latent_histogram_mu_2d',
            'image_plt',
            'image_rew',
            'image_rew_euclidean',
            'image_rew_mahalanobis',
            'image_rew_logp',
            'image_rew_kl',
            'image_rew_kl_rev',
        ]

    def visualization_post_processing(save_vis, save_video, epoch):
        vis_list = vis_variant.get('vis_list', [])

        if save_vis:
            if vae_env.vae_input_key_prefix == 'state':
                vae_env.dump_reconstructions(epoch,
                                             n_recon=vis_variant.get(
                                                 'n_recon', 16))
            vae_env.dump_samples(epoch,
                                 n_samples=vis_variant.get('n_samples', 64))
            if 'latent_representation' in vis_list:
                vae_env.dump_latent_plots(epoch)
            if any(elem in vis_list for elem in [
                    'latent_histogram', 'latent_histogram_mu',
                    'latent_histogram_2d', 'latent_histogram_mu_2d'
            ]):
                vae_env.compute_latent_histogram()
            if not save_video and ('latent_histogram' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=True,
                                              use_true_prior=True)
            if not save_video and ('latent_histogram_mu' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=False,
                                              use_true_prior=True)

        if save_video and save_vis:
            from railrl.envs.vae_wrappers import temporary_mode
            from railrl.misc.video_gen import dump_video
            from railrl.core import logger

            vae_env.compute_goal_encodings()

            logdir = logger.get_snapshot_dir()
            filename = osp.join(logdir,
                                'video_{epoch}.mp4'.format(epoch=epoch))
            variant['dump_video_kwargs']['epoch'] = epoch
            temporary_mode(vae_env,
                           mode='video_env',
                           func=dump_video,
                           args=(vae_env, random_policy, filename,
                                 rollout_function),
                           kwargs=variant['dump_video_kwargs'])
            if not vis_variant.get('save_video_env_only', True):
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(vae_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(vae_env, random_policy, filename,
                                     rollout_function),
                               kwargs=variant['dump_video_kwargs'])

    # train vae
    for epoch in range(variant['num_epochs']):
        #for epoch in range(2000):
        save_vis = (epoch % vis_variant['save_period'] == 0
                    or epoch == variant['num_epochs'] - 1)
        save_vae = (epoch % variant['snapshot_gap'] == 0
                    or epoch == variant['num_epochs'] - 1)

        t.train_epoch(epoch)
        '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1:
           t.test_epoch(
                epoch,
                save_reconstruction=save_vis,
                save_interpolation=save_vis,
                save_vae=save_vae,
            )
        if epoch % 200 == 0 or epoch == variant['num_epochs']-1:
            visualization_post_processing(save_video, save_video, epoch)'''

        t.test_epoch(
            epoch,
            save_reconstruction=save_vis,
            save_interpolation=save_vis,
            save_vae=save_vae,
        )
        if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1:
            visualization_post_processing(save_vis, save_video, epoch)

    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output(
        'vae_progress.csv',
        relative_to_snapshot_dir=True,
    )
    logger.add_tabular_output(
        'progress.csv',
        relative_to_snapshot_dir=True,
    )
    print("finished --------------------!!!!!!!!!!!!!!!")

    return m
예제 #22
0
def train_rfeatures_model(variant, return_data=False):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    # from railrl.torch.vae.conv_vae import (
    #     ConvVAE, ConvResnetVAE
    # )
    import railrl.torch.vae.conv_vae as conv_vae
    # from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_model import TimestepPredictionModel
    from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_trainer import TimePredictionTrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from railrl.pythonplusplus import identity
    import torch
    output_classes = variant["output_classes"]
    representation_size = variant["representation_size"]
    batch_size = variant["batch_size"]
    variant['dataset_kwargs']["output_classes"] = output_classes
    train_dataset, test_dataset, info = get_data(variant['dataset_kwargs'])

    num_train_workers = variant.get("num_train_workers",
                                    0)  # 0 uses main process (good for pdb)
    train_dataset_loader = InfiniteBatchLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_train_workers,
    )
    test_dataset_loader = InfiniteBatchLoader(
        test_dataset,
        batch_size=batch_size,
    )

    logger.save_extra_data(info)
    logger.get_snapshot_dir()
    if variant.get('decoder_activation', None) == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()
    else:
        decoder_activation = identity
    architecture = variant['model_kwargs'].get('architecture', None)
    if not architecture and variant.get('imsize') == 84:
        architecture = conv_vae.imsize84_default_architecture
    elif not architecture and variant.get('imsize') == 48:
        architecture = conv_vae.imsize48_default_architecture
    variant['model_kwargs']['architecture'] = architecture

    model_class = variant.get('model_class', TimestepPredictionModel)
    model = model_class(
        representation_size,
        decoder_output_activation=decoder_activation,
        output_classes=output_classes,
        **variant['model_kwargs'],
    )
    # model = torch.nn.DataParallel(model)
    model.to(ptu.device)

    variant['trainer_kwargs']['batch_size'] = batch_size
    trainer_class = variant.get('trainer_class', TimePredictionTrainer)
    trainer = trainer_class(
        model,
        **variant['trainer_kwargs'],
    )
    save_period = variant['save_period']

    trainer.dump_trajectory_rewards(
        "initial", dict(train=train_dataset.dataset,
                        test=test_dataset.dataset))

    dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False)
    for epoch in range(variant['num_epochs']):
        should_save_imgs = (epoch % save_period == 0)
        trainer.train_epoch(epoch, train_dataset_loader, batches=10)
        trainer.test_epoch(epoch, test_dataset_loader, batches=1)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)

        trainer.dump_trajectory_rewards(
            epoch, dict(train=train_dataset.dataset,
                        test=test_dataset.dataset), should_save_imgs)

        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)
        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, model)
    logger.save_extra_data(model, 'vae.pkl', mode='pickle')
    if return_data:
        return model, train_dataset, test_dataset
    return model
예제 #23
0
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import load_local_or_remote_file

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reproj_vae_path = variant.get("reproj_vae_path", None)
    ckpt = variant.get("ckpt", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)

    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only', False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    if not do_state_exp and type(ckpt) is str:
        vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl'))
        if vae is not None:
            from railrl.core import logger
            logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    else:
        vae = None

    if vae is None and type(vae_path) is str:
        vae = load_local_or_remote_file(osp.join(vae_path, 'vae_params.pkl'))
        from railrl.core import logger

        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    elif vae is None:
        vae = vae_path

    if type(vae) is str:
        vae = load_local_or_remote_file(vae)
    else:
        vae = vae

    if type(reproj_vae_path) is str:
        reproj_vae = load_local_or_remote_file(osp.join(reproj_vae_path, 'vae_params.pkl'))
    else:
        reproj_vae = None

    if 'env_id' in variant:
        import gym
        # trigger registration
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        vae_env = VAEWrappedEnv(
            image_env,
            vae,
            imsize=image_env.imsize,
            decode_goals=render,
            render_goals=render,
            render_rollouts=render,
            reward_params=reward_params,
            reproj_vae=reproj_vae,
            **variant.get('vae_wrapped_env_kwargs', {})
        )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=image_env,
                    **variant['goal_generation_kwargs']
                )
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path
                ).item()
                presampled_goals = {
                    'state_desired_goal': presampled_goals['next_obs_state'],
                    'image_desired_goal': presampled_goals['next_obs'],
                }

            image_env.set_presampled_goals(presampled_goals)
            vae_env.set_presampled_goals(presampled_goals)
            print("Presampling all goals")
        else:
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs']
                )
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    if not do_state_exp:
        training_mode = variant.get("training_mode", "train")
        testing_mode = variant.get("testing_mode", "test")
        env.add_mode('eval', testing_mode)
        env.add_mode('train', training_mode)
        env.add_mode('relabeling', training_mode)
        # relabeling_env.disable_render()
        env.add_mode("video_vae", 'video_vae')
        env.add_mode("video_env", 'video_env')
    return env
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
예제 #25
0
def train_reprojection_network_and_update_variant(variant):
    from railrl.core import logger
    from railrl.misc.asset_loader import load_local_or_remote_file
    import railrl.torch.pytorch_util as ptu

    rl_variant = variant.get("rl_variant", {})
    vae_wrapped_env_kwargs = rl_variant.get('vae_wrapped_env_kwargs', {})
    if vae_wrapped_env_kwargs.get("use_reprojection_network", False):
        train_reprojection_network_variant = variant.get(
            "train_reprojection_network_variant", {})

        if train_reprojection_network_variant.get("use_cached_network", False):
            vae_path = rl_variant.get("vae_path", None)
            reprojection_network = load_local_or_remote_file(
                osp.join(vae_path, 'reproj_network.pkl'))
            from railrl.core import logger
            logger.save_extra_data(reprojection_network,
                                   'reproj_network.pkl',
                                   mode='pickle')

            if ptu.gpu_enabled():
                reprojection_network.cuda()

            vae_wrapped_env_kwargs[
                'reprojection_network'] = reprojection_network
        else:
            logger.remove_tabular_output('progress.csv',
                                         relative_to_snapshot_dir=True)
            logger.add_tabular_output('reproj_progress.csv',
                                      relative_to_snapshot_dir=True)

            vae_path = rl_variant.get("vae_path", None)
            ckpt = rl_variant.get("ckpt", None)

            if type(ckpt) is str:
                vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl'))
                from railrl.core import logger

                logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
            elif type(vae_path) is str:
                vae = load_local_or_remote_file(
                    osp.join(vae_path, 'vae_params.pkl'))
                from railrl.core import logger

                logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
            else:
                vae = vae_path

            if type(vae) is str:
                vae = load_local_or_remote_file(vae)
            else:
                vae = vae

            if ptu.gpu_enabled():
                vae.cuda()

            train_reprojection_network_variant['vae'] = vae
            reprojection_network = train_reprojection_network(
                train_reprojection_network_variant)
            vae_wrapped_env_kwargs[
                'reprojection_network'] = reprojection_network