示例#1
0
def train_vae_and_update_variant(variant):
    from railrl.core import logger
    grill_variant = variant['grill_variant']
    train_vae_variant = variant['train_vae_variant']
    if grill_variant.get('vae_path', None) is None:
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('vae_progress.csv',
                                  relative_to_snapshot_dir=True)
        vae, vae_train_data, vae_test_data = train_vae(train_vae_variant,
                                                       return_data=True)
        if grill_variant.get('save_vae_data', False):
            grill_variant['vae_train_data'] = vae_train_data
            grill_variant['vae_test_data'] = vae_test_data
        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        grill_variant['vae_path'] = vae  # just pass the VAE directly
    else:
        if grill_variant.get('save_vae_data', False):
            vae_train_data, vae_test_data, info = generate_vae_dataset(
                train_vae_variant['generate_vae_dataset_kwargs'])
            grill_variant['vae_train_data'] = vae_train_data
            grill_variant['vae_test_data'] = vae_test_data
示例#2
0
    def pretrain_policy_with_bc(self):
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'pretrain_policy.csv', relative_to_snapshot_dir=True
        )
        if self.do_pretrain_rollouts:
            total_ret = self.do_rollouts()
            print("INITIAL RETURN", total_ret/20)

        prev_time = time.time()
        for i in range(self.bc_num_pretrain_steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            self.policy_optimizer.zero_grad()
            train_policy_loss.backward()
            self.policy_optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret/20))

            if i % self.pretraining_logging_period==0:
                stats = {
                "pretrain_bc/batch": i,
                "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss),
                "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss),
                "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss),
                "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss),
                "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss),
                "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss),
                "pretrain_bc/epoch_time":time.time()-prev_time,
                }

                if self.do_pretrain_rollouts:
                    stats["pretrain_bc/avg_return"] = total_ret / 20

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
示例#3
0
    def pretrain_q_with_bc_data(self):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)
        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain_steps):
            self.eval_statistics = dict()
            self._need_to_update_eval_statistics = True

            train_data = self.replay_buffer.random_batch(128)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            if self.goal_conditioned:
                goals = train_data['resampled_goals']
                train_data['observations'] = torch.cat((obs, goals), dim=1)
                train_data['next_observations'] = torch.cat((next_obs, goals),
                                                            dim=1)
            self.train_from_torch(train_data)

            logger.record_dict(self.eval_statistics)
            logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        for i in range(self.q_num_pretrain_steps):
            self.eval_statistics = dict()
            self._need_to_update_eval_statistics = True

            train_data = self.replay_buffer.random_batch(128)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            if self.goal_conditioned:
                goals = train_data['resampled_goals']
                train_data['observations'] = torch.cat((obs, goals), dim=1)
                train_data['next_observations'] = torch.cat((next_obs, goals),
                                                            dim=1)
            self.train_from_torch(train_data)

            logger.record_dict(self.eval_statistics)
            logger.dump_tabular(with_prefix=True, with_timestamp=False)

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
示例#4
0
def grill_her_full_experiment(variant, mode='td3'):
    train_vae_variant = variant['train_vae_variant']
    grill_variant = variant['grill_variant']
    env_class = variant['env_class']
    env_kwargs = variant['env_kwargs']
    init_camera = variant['init_camera']
    train_vae_variant['generate_vae_dataset_kwargs']['env_class'] = env_class
    train_vae_variant['generate_vae_dataset_kwargs']['env_kwargs'] = env_kwargs
    train_vae_variant['generate_vae_dataset_kwargs']['init_camera'] = init_camera
    grill_variant['env_class'] = env_class
    grill_variant['env_kwargs'] = env_kwargs
    grill_variant['init_camera'] = init_camera
    if 'vae_paths' not in grill_variant:
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'vae_progress.csv', relative_to_snapshot_dir=True
        )
        vae = train_vae(train_vae_variant)
        rdim = train_vae_variant['representation_size']
        vae_file = logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
        logger.remove_tabular_output(
            'vae_progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        grill_variant['vae_paths'] = {
            str(rdim): vae_file,
        }
        grill_variant['rdim'] = str(rdim)
    if mode == 'td3':
        grill_her_td3_experiment(variant['grill_variant'])
    elif mode == 'twin-sac':
        grill_her_twin_sac_experiment(variant['grill_variant'])
    elif mode == 'sac':
        grill_her_sac_experiment(variant['grill_variant'])
def get_n_train_vae(latent_dim,
                    env,
                    vae_train_epochs,
                    num_image_examples,
                    vae_kwargs,
                    vae_trainer_kwargs,
                    vae_architecture,
                    vae_save_period=10,
                    vae_test_p=.9,
                    decoder_activation='sigmoid',
                    vae_class='VAE',
                    **kwargs):
    env.goal_sampling_mode = 'test'
    image_examples = unnormalize_image(
        env.sample_goals(num_image_examples)['desired_goal'])
    n = int(num_image_examples * vae_test_p)
    train_dataset = ImageObservationDataset(image_examples[:n, :])
    test_dataset = ImageObservationDataset(image_examples[n:, :])

    if decoder_activation == 'sigmoid':
        decoder_activation = torch.nn.Sigmoid()

    vae_class = vae_class.lower()
    if vae_class == 'VAE'.lower():
        vae_class = ConvVAE
    elif vae_class == 'SpatialVAE'.lower():
        vae_class = SpatialAutoEncoder
    else:
        raise RuntimeError("Invalid VAE Class: {}".format(vae_class))

    vae = vae_class(latent_dim,
                    architecture=vae_architecture,
                    decoder_output_activation=decoder_activation,
                    **vae_kwargs)

    trainer = ConvVAETrainer(vae, **vae_trainer_kwargs)

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)
    for epoch in range(vae_train_epochs):
        should_save_imgs = (epoch % vae_save_period == 0)
        trainer.train_epoch(epoch, train_dataset)
        trainer.test_epoch(epoch, test_dataset)

        if should_save_imgs:
            trainer.dump_reconstructions(epoch)
            trainer.dump_samples(epoch)
        stats = trainer.get_diagnostics()
        for k, v in stats.items():
            logger.record_tabular(k, v)

        logger.dump_tabular()
        trainer.end_epoch(epoch)

        if epoch % 50 == 0:
            logger.save_itr_params(epoch, vae)
    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output('vae_progress.csv',
                                 relative_to_snapshot_dir=True)
    logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    return vae
示例#6
0
    def pretrain_q_with_bc_data(self):
        logger.remove_tabular_output(
            'progress.csv', relative_to_snapshot_dir=True
        )
        logger.add_tabular_output(
            'pretrain_q.csv', relative_to_snapshot_dir=True
        )

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs # torch.cat((obs, goals), dim=1)
            train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data)
            if i%self.pretraining_logging_period == 0:
                logger.record_dict(self.eval_statistics)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics=True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs # torch.cat((obs, goals), dim=1)
            train_data['next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data)
            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret/20))

            if i%self.pretraining_logging_period==0:
                if self.do_pretrain_rollouts:
                    self.eval_statistics["pretrain_bc/avg_return"] = total_ret / 20
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time()-prev_time
                logger.record_dict(self.eval_statistics)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)
示例#7
0
    def pretrain_policy_with_bc(self):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_policy.csv',
                                  relative_to_snapshot_dir=True)
        for i in range(self.bc_num_pretrain_steps):
            train_batch = self.get_batch_from_buffer(self.demo_train_buffer)
            train_o = train_batch["observations"]
            train_u = train_batch["actions"]
            if self.goal_conditioned:
                train_g = train_batch["resampled_goals"]
                train_o = torch.cat((train_o, train_g), dim=1)

            train_pred_u = self.policy(train_o)
            train_error = (train_pred_u - train_u)**2
            train_bc_loss = train_error.mean()

            policy_loss = self.bc_weight * train_bc_loss.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            test_batch = self.get_batch_from_buffer(self.demo_test_buffer)
            test_o = test_batch["observations"]
            test_u = test_batch["actions"]

            if self.goal_conditioned:
                test_g = test_batch["resampled_goals"]
                test_o = torch.cat((test_o, test_g), dim=1)

            test_pred_u = self.policy(test_o)

            test_error = (test_pred_u - test_u)**2
            test_bc_loss = test_error.mean()

            train_loss_mean = np.mean(ptu.get_numpy(train_bc_loss))
            test_loss_mean = np.mean(ptu.get_numpy(test_bc_loss))

            stats = {
                "Train BC Loss": train_loss_mean,
                "Test BC Loss": test_loss_mean,
                "policy_loss": ptu.get_numpy(policy_loss),
                "batch": i,
            }
            logger.record_dict(stats)
            logger.dump_tabular(with_prefix=True, with_timestamp=False)

            if i % 1000 == 0:
                logger.save_itr_params(
                    i, {
                        "evaluation/policy": self.policy,
                        "evaluation/env": self.env.wrapped_env,
                    })

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
示例#8
0
def train_vae(variant):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.conv_vae import ConvVAE
    from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import local_path_from_s3_or_local_path

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)

    env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None)
    if env_id is not None:
        import gym
        env = gym.make(env_id)
    else:
        env_class = variant['generate_vae_dataset_kwargs']['env_class']
        env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs']
        env = env_class(**env_kwargs)

    representation_size = variant["representation_size"]
    beta = variant["beta"]
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    # obtain training and testing data
    dataset_path = variant['generate_vae_dataset_kwargs'].get(
        'dataset_path', None)
    test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9)
    filename = local_path_from_s3_or_local_path(dataset_path)
    dataset = np.load(filename, allow_pickle=True).item()
    N = dataset['obs'].shape[0]
    n = int(N * test_p)
    train_data = {}
    test_data = {}
    for k in dataset.keys():
        train_data[k] = dataset[k][:n, :]
        test_data[k] = dataset[k][n:, :]

    # setup vae
    variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1]
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae import VAE
        input_size = train_data['obs'].shape[1]
        variant['vae_kwargs']['input_size'] = input_size
        m = VAE(representation_size, **variant['vae_kwargs'])
    elif variant.get('vae_type', None) == "VAE2":
        from railrl.torch.vae.conv_vae2 import ConvVAE2
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE2(representation_size, **variant['vae_kwargs'])
    else:
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE(representation_size, **variant['vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()

    # setup vae trainer
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae_trainer import VAETrainer
        t = VAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    else:
        t = ConvVAETrainer(train_data,
                           test_data,
                           m,
                           beta=beta,
                           beta_schedule=beta_schedule,
                           **variant['algo_kwargs'])

    # visualization
    vis_variant = variant.get('vis_kwargs', {})
    save_video = vis_variant.get('save_video', False)
    if isinstance(env, ImageEnv):
        image_env = env
    else:
        image_env = ImageEnv(
            env,
            variant['generate_vae_dataset_kwargs'].get('imsize'),
            init_camera=variant['generate_vae_dataset_kwargs'].get(
                'init_camera'),
            transpose=True,
            normalize=True,
        )
    render = variant.get('render', False)
    reward_params = variant.get("reward_params", dict())
    vae_env = VAEWrappedEnv(image_env,
                            m,
                            imsize=image_env.imsize,
                            decode_goals=render,
                            render_goals=render,
                            render_rollouts=render,
                            reward_params=reward_params,
                            **variant.get('vae_wrapped_env_kwargs', {}))
    vae_env.reset()
    vae_env.add_mode("video_env", 'video_env')
    vae_env.add_mode("video_vae", 'video_vae')
    if save_video:
        import railrl.samplers.rollout_functions as rf
        from railrl.policies.simple import RandomPolicy
        random_policy = RandomPolicy(vae_env.action_space)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=100,
            observation_key='latent_observation',
            desired_goal_key='latent_desired_goal',
            vis_list=vis_variant.get('vis_list', []),
            dont_terminate=True,
        )

        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['imsize'] = vae_env.imsize
        dump_video_kwargs['vis_list'] = [
            'image_observation',
            'reconstr_image_observation',
            'image_latent_histogram_2d',
            'image_latent_histogram_mu_2d',
            'image_plt',
            'image_rew',
            'image_rew_euclidean',
            'image_rew_mahalanobis',
            'image_rew_logp',
            'image_rew_kl',
            'image_rew_kl_rev',
        ]

    def visualization_post_processing(save_vis, save_video, epoch):
        vis_list = vis_variant.get('vis_list', [])

        if save_vis:
            if vae_env.vae_input_key_prefix == 'state':
                vae_env.dump_reconstructions(epoch,
                                             n_recon=vis_variant.get(
                                                 'n_recon', 16))
            vae_env.dump_samples(epoch,
                                 n_samples=vis_variant.get('n_samples', 64))
            if 'latent_representation' in vis_list:
                vae_env.dump_latent_plots(epoch)
            if any(elem in vis_list for elem in [
                    'latent_histogram', 'latent_histogram_mu',
                    'latent_histogram_2d', 'latent_histogram_mu_2d'
            ]):
                vae_env.compute_latent_histogram()
            if not save_video and ('latent_histogram' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=True,
                                              use_true_prior=True)
            if not save_video and ('latent_histogram_mu' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=False,
                                              use_true_prior=True)

        if save_video and save_vis:
            from railrl.envs.vae_wrappers import temporary_mode
            from railrl.misc.video_gen import dump_video
            from railrl.core import logger

            vae_env.compute_goal_encodings()

            logdir = logger.get_snapshot_dir()
            filename = osp.join(logdir,
                                'video_{epoch}.mp4'.format(epoch=epoch))
            variant['dump_video_kwargs']['epoch'] = epoch
            temporary_mode(vae_env,
                           mode='video_env',
                           func=dump_video,
                           args=(vae_env, random_policy, filename,
                                 rollout_function),
                           kwargs=variant['dump_video_kwargs'])
            if not vis_variant.get('save_video_env_only', True):
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(vae_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(vae_env, random_policy, filename,
                                     rollout_function),
                               kwargs=variant['dump_video_kwargs'])

    # train vae
    for epoch in range(variant['num_epochs']):
        #for epoch in range(2000):
        save_vis = (epoch % vis_variant['save_period'] == 0
                    or epoch == variant['num_epochs'] - 1)
        save_vae = (epoch % variant['snapshot_gap'] == 0
                    or epoch == variant['num_epochs'] - 1)

        t.train_epoch(epoch)
        '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1:
           t.test_epoch(
                epoch,
                save_reconstruction=save_vis,
                save_interpolation=save_vis,
                save_vae=save_vae,
            )
        if epoch % 200 == 0 or epoch == variant['num_epochs']-1:
            visualization_post_processing(save_video, save_video, epoch)'''

        t.test_epoch(
            epoch,
            save_reconstruction=save_vis,
            save_interpolation=save_vis,
            save_vae=save_vae,
        )
        if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1:
            visualization_post_processing(save_vis, save_video, epoch)

    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output(
        'vae_progress.csv',
        relative_to_snapshot_dir=True,
    )
    logger.add_tabular_output(
        'progress.csv',
        relative_to_snapshot_dir=True,
    )
    print("finished --------------------!!!!!!!!!!!!!!!")

    return m
示例#9
0
def train_reprojection_network_and_update_variant(variant):
    from railrl.core import logger
    from railrl.misc.asset_loader import load_local_or_remote_file
    import railrl.torch.pytorch_util as ptu

    rl_variant = variant.get("rl_variant", {})
    vae_wrapped_env_kwargs = rl_variant.get('vae_wrapped_env_kwargs', {})
    if vae_wrapped_env_kwargs.get("use_reprojection_network", False):
        train_reprojection_network_variant = variant.get(
            "train_reprojection_network_variant", {})

        if train_reprojection_network_variant.get("use_cached_network", False):
            vae_path = rl_variant.get("vae_path", None)
            reprojection_network = load_local_or_remote_file(
                osp.join(vae_path, 'reproj_network.pkl'))
            from railrl.core import logger
            logger.save_extra_data(reprojection_network,
                                   'reproj_network.pkl',
                                   mode='pickle')

            if ptu.gpu_enabled():
                reprojection_network.cuda()

            vae_wrapped_env_kwargs[
                'reprojection_network'] = reprojection_network
        else:
            logger.remove_tabular_output('progress.csv',
                                         relative_to_snapshot_dir=True)
            logger.add_tabular_output('reproj_progress.csv',
                                      relative_to_snapshot_dir=True)

            vae_path = rl_variant.get("vae_path", None)
            ckpt = rl_variant.get("ckpt", None)

            if type(ckpt) is str:
                vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl'))
                from railrl.core import logger

                logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
            elif type(vae_path) is str:
                vae = load_local_or_remote_file(
                    osp.join(vae_path, 'vae_params.pkl'))
                from railrl.core import logger

                logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
            else:
                vae = vae_path

            if type(vae) is str:
                vae = load_local_or_remote_file(vae)
            else:
                vae = vae

            if ptu.gpu_enabled():
                vae.cuda()

            train_reprojection_network_variant['vae'] = vae
            reprojection_network = train_reprojection_network(
                train_reprojection_network_variant)
            vae_wrapped_env_kwargs[
                'reprojection_network'] = reprojection_network
示例#10
0
def run_experiment(argv):
    default_log_dir = config.LOCAL_LOG_DIR
    now = datetime.datetime.now(dateutil.tz.tzlocal())

    # avoid name clashes when running distributed jobs
    rand_id = str(uuid.uuid4())[:5]
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')

    default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--n_parallel',
        type=int,
        default=1,
        help=
        'Number of parallel workers to perform rollouts. 0 => don\'t start any workers'
    )
    parser.add_argument('--exp_name',
                        type=str,
                        default=default_exp_name,
                        help='Name of the experiment.')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='Path to save the log and iteration snapshot.')
    parser.add_argument('--snapshot_mode',
                        type=str,
                        default='all',
                        help='Mode to save the snapshot. Can be either "all" '
                        '(all iterations will be saved), "last" (only '
                        'the last iteration will be saved), "gap" (every'
                        '`snapshot_gap` iterations are saved), or "none" '
                        '(do not save snapshots)')
    parser.add_argument('--snapshot_gap',
                        type=int,
                        default=1,
                        help='Gap between snapshot iterations.')
    parser.add_argument('--tabular_log_file',
                        type=str,
                        default='progress.csv',
                        help='Name of the tabular log file (in csv).')
    parser.add_argument('--text_log_file',
                        type=str,
                        default='debug.log',
                        help='Name of the text log file (in pure text).')
    parser.add_argument('--params_log_file',
                        type=str,
                        default='params.json',
                        help='Name of the parameter log file (in json).')
    parser.add_argument('--variant_log_file',
                        type=str,
                        default='variant.json',
                        help='Name of the variant log file (in json).')
    parser.add_argument(
        '--resume_from',
        type=str,
        default=None,
        help='Name of the pickle file to resume experiment from.')
    parser.add_argument('--plot',
                        type=ast.literal_eval,
                        default=False,
                        help='Whether to plot the iteration results')
    parser.add_argument(
        '--log_tabular_only',
        type=ast.literal_eval,
        default=False,
        help=
        'Whether to only print the tabular log information (in a horizontal format)'
    )
    parser.add_argument('--seed', type=int, help='Random seed for numpy')
    parser.add_argument('--args_data',
                        type=str,
                        help='Pickled data for stub objects')
    parser.add_argument('--variant_data',
                        type=str,
                        help='Pickled data for variant configuration')
    parser.add_argument('--use_cloudpickle',
                        type=ast.literal_eval,
                        default=False)
    parser.add_argument('--code_diff',
                        type=str,
                        help='A string of the code diff to save.')
    parser.add_argument('--commit_hash',
                        type=str,
                        help='A string of the commit hash')
    parser.add_argument('--script_name',
                        type=str,
                        help='Name of the launched script')

    args = parser.parse_args(argv[1:])

    if args.seed is not None:
        set_seed(args.seed)

    if args.n_parallel > 0:
        from rllab.sampler import parallel_sampler
        parallel_sampler.initialize(n_parallel=args.n_parallel)
        if args.seed is not None:
            parallel_sampler.set_seed(args.seed)

    if args.plot:
        from rllab.plotter import plotter
        plotter.init_worker()

    if args.log_dir is None:
        log_dir = osp.join(default_log_dir, args.exp_name)
    else:
        log_dir = args.log_dir
    tabular_log_file = osp.join(log_dir, args.tabular_log_file)
    text_log_file = osp.join(log_dir, args.text_log_file)
    params_log_file = osp.join(log_dir, args.params_log_file)

    if args.variant_data is not None:
        variant_data = pickle.loads(base64.b64decode(args.variant_data))
        variant_log_file = osp.join(log_dir, args.variant_log_file)
        logger.log_variant(variant_log_file, variant_data)
    else:
        variant_data = None

    if not args.use_cloudpickle:
        raise NotImplementedError("Not supporting non-cloud-pickle")

    logger.add_text_output(text_log_file)
    logger.add_tabular_output(tabular_log_file)
    prev_snapshot_dir = logger.get_snapshot_dir()
    prev_mode = logger.get_snapshot_mode()
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode(args.snapshot_mode)
    logger.set_snapshot_gap(args.snapshot_gap)
    logger.set_log_tabular_only(args.log_tabular_only)
    logger.push_prefix("[%s] " % args.exp_name)
    """
    Save information for code reproducibility.
    """
    if args.code_diff is not None:
        code_diff_str = cloudpickle.loads(base64.b64decode(args.code_diff))
        with open(osp.join(log_dir, "code.diff"), "w") as f:
            f.write(code_diff_str)
    if args.commit_hash is not None:
        with open(osp.join(log_dir, "commit_hash.txt"), "w") as f:
            f.write(args.commit_hash)
    if args.script_name is not None:
        with open(osp.join(log_dir, "script_name.txt"), "w") as f:
            f.write(args.script_name)

    if args.resume_from is not None:
        data = joblib.load(args.resume_from)
        assert 'algo' in data
        algo = data['algo']
        algo.train()
    else:
        # read from stdin
        if args.use_cloudpickle:
            method_call = cloudpickle.loads(base64.b64decode(args.args_data))
            method_call(variant_data)
        else:
            data = pickle.loads(base64.b64decode(args.args_data))
            maybe_iter = concretize(data)
            if is_iterable(maybe_iter):
                for _ in maybe_iter:
                    pass

    logger.set_snapshot_mode(prev_mode)
    logger.set_snapshot_dir(prev_snapshot_dir)
    logger.remove_tabular_output(tabular_log_file)
    logger.remove_text_output(text_log_file)
    logger.pop_prefix()